API documentation
malpolon.models
malpolon.models.model_builder
This module provides classes to build your PyTorch models.
Classes listed in this module allow to select a model from your provider (timm, torchvision…), retrieve it with or without pre-trained weights, and modify it by adding or removing layers.
- Author: Titouan Lorieul <titouan.lorieul@gmail.com>
Theo Larcher <theo.larcher@inria.fr>
- malpolon.models.model_builder.change_first_convolutional_layer_modifier(model: nn.Module, num_input_channels: int, new_conv_layer_init_func: Optional[Callable[[nn.Conv2d, nn.Conv2d], None]] = None) nn.Module
Remove the first registered convolutional layer of a model and replaces it by a new convolutional layer with the provided number of input channels.
- Parameters:
model (torch.nn.Module) – Model to adapt.
num_input_channels (integer) – Number of input channels, used to update the first convolutional layer.
new_conv_layer_init_func (callable) – Function defining how to initialize the new convolutional layer.
- Returns:
model – Newly created last dense classification layer.
- Return type:
torch.nn.Module
- malpolon.models.model_builder.change_last_layer_modifier(model: Module, num_outputs: int, flatten: bool = False) Module
Remove the last registered linear layer of a model and replaces it by a new dense layer with the provided number of outputs.
- Parameters:
model (torch.nn.Module) – Model to adapt.
num_outputs (integer) – Number of outputs of the new output layer.
flatten (boolean) – If True, adds a nn.Flatten layer to squeeze the last dimension. Can be useful when num_outputs=1.
- Returns:
model – Reference to model object given in input.
- Return type:
torch.nn.Module
- malpolon.models.model_builder.change_last_layer_to_identity_modifier(model: Module) Module
Remove the last linear layer of a model and replaces it by an nn.Identity layer.
- Parameters:
model (torch.nn.Module) – Model to adapt.
- Returns:
num_features – Size of the feature space.
- Return type:
int
- malpolon.models.model_builder.malpolon_model_provider(model_name: str, *model_args: Any, **model_kwargs: Any) nn.Module
Return a model from Malpolon’s models list.
This method uses Malpolon’s internal model listing to retrieve a model.
- Parameters:
model_name (str) – name of the model to retrieve from torchvision’s library
- Returns:
model object
- Return type:
nn.Module
- malpolon.models.model_builder.timm_model_provider(model_name: str, *model_args: Any, **model_kwargs: Any) nn.Module
Return a model from timm’s library.
This method uses timm’s API to retrieve a model from its library.
- Parameters:
model_name (str) – name of the model to retrieve from timm’s library
- Returns:
model object
- Return type:
nn.Module
- Raises:
ValueError – if the model name is not listed in TIMM’s library
- malpolon.models.model_builder.torchvision_model_provider(model_name: str, *model_args: Any, **model_kwargs: Any) nn.Module
Return a model from torchvision’s library.
This method uses tochvision’s API to retrieve a model from its library.
- Parameters:
model_name (str) – name of the model to retrieve from torchvision’s library
- Returns:
model object
- Return type:
nn.Module
malpolon.models.standard_prediction_systems
This module provides classes wrapping pytorchlightning training modules.
- Author: Titouan Lorieul <titouan.lorieul@gmail.com>
Theo Larcher <theo.larcher@inria.fr>
- class malpolon.models.standard_prediction_systems.ClassificationSystem(model: Union[torch.nn.Module, Mapping], lr: float = 0.01, weight_decay: float = 0, momentum: float = 0.9, nesterov: bool = True, metrics: Optional[dict[str, Callable]] = None, task: str = 'classification_binary', loss_kwargs: Optional[dict] = {}, hparams_preprocess: bool = True, checkpoint_path: Optional[str] = None)
Bases:
GenericPredictionSystem
Classification task class.
- __init__(model: Union[torch.nn.Module, Mapping], lr: float = 0.01, weight_decay: float = 0, momentum: float = 0.9, nesterov: bool = True, metrics: Optional[dict[str, Callable]] = None, task: str = 'classification_binary', loss_kwargs: Optional[dict] = {}, hparams_preprocess: bool = True, checkpoint_path: Optional[str] = None)
Class constructor.
- Parameters:
model (dict) – model to use
lr (float) – learning rate
weight_decay (float) – weight decay
momentum (float) – value of momentum
nesterov (bool) – if True, uses Nesterov’s momentum
metrics (dict) – dictionnary containing the metrics to compute. Keys must match metrics’ names and have a subkey with each metric’s functional methods as value. This subkey is either created from the malpolon.models.utils.FMETRICS_CALLABLES constant or supplied, by the user directly.
task (str, optional) – Machine learning task (used to format labels accordingly), by default ‘classification_multiclass’. The value determines the loss to be selected. if ‘multilabel’ or ‘binary’ is in the task, the BCEWithLogitsLoss is selected, otherwise the CrossEntropyLoss is used.
hparams_preprocess (bool, optional) – if True performs preprocessing operations on the hyperparameters, by default True
- class malpolon.models.standard_prediction_systems.GenericPredictionSystem(model: Union[torch.nn.Module, Mapping], loss: torch.nn.modules.loss._Loss, optimizer: torch.optim.Optimizer, metrics: Optional[dict[str, Callable]] = None, save_hyperparameters: Optional[bool] = True)
Bases:
LightningModule
Generic prediction system providing standard methods.
- Parameters:
model (torch.nn.Module) – Model to use.
loss (torch.nn.modules.loss._Loss) – Loss used to fit the model.
optimizer (torch.optim.Optimizer) – Optimization algorithm used to train the model.
metrics (dict) – Dictionary containing the metrics to monitor during the training and to compute at test time.
save_hyperparameters (bool) – Save arguments to hparams attribute.
- configure_optimizers() Optimizer
- download_weights(url: str, out_path: str, filename: str, md5: Optional[str] = None)
Download pretrained weights from a remote repository.
Downloads weights and ajusts self.checkpoint_path accordingly. This method is intended to be used to perform transfer learning or resume a model training later on and/or on a different machine. Downloaded content can either be a single file or a pre-zipped directory containing all training filee, in which case the value of checkpoint_path is updated to point inside that unzipped folder.
- Parameters:
url (str) – url to the path or directory to download
out_path (str) – local root path where to to extract the downloaded content
filename (str) – name of the file (in case of a single file download) or the directory (in case of a zip download) on local disk
md5 (Optional[str], optional) – checksum value to verify the integrity of the downloaded content, by default None
- forward(x: Any) Any
- predict(datamodule, trainer)
Predict a model’s output on the test dataset.
This method performs inference on the test dataset using only pytorchlightning tools.
- Parameters:
datamodule (pl.LightningDataModule) – pytorchlightning datamodule handling all train/test/val datasets.
trainer (pl.Trainer) – pytorchlightning trainer in charge of running the model on train and inference mode.
- Returns:
Predicted tensor values.
- Return type:
array
- predict_point(checkpoint_path: str, data: Union[Tensor, tuple[Any, Any]], state_dict_replace_key: Optional[list[str, str]] = None, ckpt_transform: Callable = None)
Predict a model’s output on 1 data point.
Performs as predict() but for a single data point and using native pytorch tools.
- Parameters:
checkpoint_path (str) – path to the model’s checkpoint to load.
data (Union[Tensor, tuple[Any, Any]]) – data point to perform inference on.
state_dict_replace_key (Optional[list[str, str]], optional) – list of values used to call the static method state_dict_replace_key(). Defaults to None.
ckpt_transform (Callable, optional) – callable function applied to the loaded checkpoint object. Use this to modify the structure of the loaded model’s checkpoint on the fly. Defaults to None.
remove_model_prefix (bool, optional) – if True, removes the “model.” prefix from the keys of the loaded checkpoint. Defaults
- Returns:
Predicted tensor value.
- Return type:
array
- predict_step(batch, batch_idx, dataloader_idx=0)
- remove_state_dict_prefix(state_dict: dict, prefix: str = 'model.')
Remove a prefix from the keys of a state_dict.
This method is intended to remove the “.model” prefix from the keys of a state_dict which is added by PyTorchLightning when saving a LightningModule’s checkpoint. This is due to the fact that a LightningModule contains a model attribute which is referenced in the LightningModule state_dict as “model.<model_state_dict_key>”. And the LightningModule state_dict is saved as a whole when calling the save_checkpoint method (enabling the saving of more hyperparameters). This is useful when loading a state_dict directly on a model object instead of a LightningModule.
- Parameters:
state_dict (dict) – Model state_dict
prefix (str) – Prefix to remove from the state_dict keys.
- Returns:
State_dict with new keys.
- Return type:
dict
- static state_dict_replace_key(state_dict: dict, replace: Optional[list[str]] = ['.', ''])
Replace keys in a state_dict dictionnary.
A state_dict usually is an OrderedDict where the keys are the model’s module names. This method allows to change these names by replacing a given string, or token, by another.
- Parameters:
state_dict (dict) – Model state_dict
replace (Optional[list[str]], optional) – Tokens to replace in the state_dict module names. The first element is the token to look for while the second is the replacement value. By default [‘.’, ‘’].
- Returns:
State_dict with new keys.
- Return type:
dict
Examples
I have loaded a Resnet18 model through a checkpoint after a training session. But the names of the model modules have been altered with a prefix “model.”:
>>> sd = model.state_dict() >>> print(list(sd)[:2]) (model.conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (model.bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
To remove this prefix:
>>> sd = GenericPredictionSystem.state_dict_replace_key(sd, ['model.', '']) >>> print(sd[:2]) (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
- test_step(batch: tuple[Any, Any], batch_idx: int) Union[Tensor, dict[str, Any]]
- training_step(batch: tuple[Any, Any], batch_idx: int) Union[Tensor, dict[str, Any]]
- validation_step(batch: tuple[Any, Any], batch_idx: int) Union[Tensor, dict[str, Any]]
malpolon.models.utils
This file compiles useful functions related to models.
- Author: Theo Larcher <theo.larcher@inria.fr>
Titouan Lorieul <titouan.lorieul@gmail.com>
- class malpolon.models.utils.CrashHandler(trainer)
Bases:
object
Saves the model in case of unexpected crash or user interruption.
- save_checkpoint()
Save the latest checkpoint.
- signal_handler(sig, frame)
Attempt to save the latest checkpoint in case of crash.
- malpolon.models.utils.check_loss(loss: _Loss) _Loss
Ensure input loss is a pytorch loss.
- Args:
loss (nn.modules.loss._Loss): input loss.
- Raises:
ValueError: if input loss isn’t a pytorch loss object.
- Returns:
nn.modules.loss._Loss: the pytorch input loss itself.
- malpolon.models.utils.check_metric(metrics: OmegaConf) OmegaConf
Ensure user’s model metrics are valid.
Users can either choose from a list of predefined metrics or define their own custom metrics. This function binds the user’s metrics with their corresponding callable function from torchmetrics, by reading the values in metrics which is a dict-like structure returned by hydra when reading the config file. If the user chose predefined metrics, the function will automatically bind the corresponding callable function from torchmetrics. If the user chose custom metrics, the function checks that they also provided the callable function to compute the metric.
- Parameters:
metrics (OmegaConf) – user’s input metrics, read from the config file via hydra, in a dict-like structure
- Returns:
user’s metrics with their corresponding callable function
- Return type:
OmegaConf
- malpolon.models.utils.check_model(model: Union[Module, Mapping]) Module
Ensure input model is a pytorch model.
- Args:
model (Union[nn.Module, Mapping]): input model.
- Raises:
ValueError: if input model isn’t a pytorch model object.
- Returns:
nn.Module: the pytorch input model itself.
- malpolon.models.utils.check_optimizer(optimizer: Optimizer) Optimizer
Ensure input optimizer is a pytorch optimizer.
- Args:
optimizer (optim.Optimizer): input optimizer.
- Raises:
ValueError: if input optimizer isn’t a pytorch optimizer object.
- Returns:
optim.Optimizer: the pytorch input optimizer itself.
malpolon.models.custom_models.glc2024_multimodal_ensemble_model_modality_specific
malpolon.models.custom_models.glc2024_multimodal_ensemble_model
This module provides a Multimodal Ensemble model for GeoLifeCLEF2024 data.
- Author: Lukas Picek <lukas.picek@inria.fr>
Theo Larcher <theo.larcher@inria.fr>
License: GPLv3 Python version: 3.10.6
- class malpolon.models.custom_models.glc2024_multimodal_ensemble_model.MultimodalEnsemble(num_classes: int = 11255, pretrained: bool = False, **kwargs)
Bases:
Module
Multimodal ensemble model processing Sentinel-2A, Landsat & Bioclimatic data.
Inherits torch nn.Module.
- __init__(num_classes: int = 11255, pretrained: bool = False, **kwargs)
Class constructor.
- Parameters:
num_classes (int, optional) – numbre of classes, by default 11255
pretrained (bool, optional) – if True, downloads the model’s weights from our remote storage platform, by default False
- forward(x, y, z)
malpolon.models.custom_models.glc2024_pre_extracted_prediction_system
This module provides a Multimodal Ensemble model for GeoLifeCLEF2024 data.
- Author: Lukas Picek <lukas.picek@inria.fr>
Theo Larcher <theo.larcher@inria.fr>
License: GPLv3 Python version: 3.10.6
- class malpolon.models.custom_models.glc2024_pre_extracted_prediction_system.ClassificationSystemGLC24(model: Union[Module, Mapping], lr: float = 0.01, weight_decay: float = 0, momentum: float = 0.9, nesterov: bool = True, metrics: Optional[dict[str, Callable]] = None, task: str = 'classification_multilabel', loss_kwargs: Optional[dict] = {}, hparams_preprocess: bool = True, weights_dir: str = 'outputs/glc24_cnn_multimodal_ensemble/', checkpoint_path: Optional[str] = None, num_classes: Optional[int] = None)
Bases:
ClassificationSystem
Classification task class for GLC24_pre-extracted.
Inherits ClassificationSystem.
- __init__(model: Union[Module, Mapping], lr: float = 0.01, weight_decay: float = 0, momentum: float = 0.9, nesterov: bool = True, metrics: Optional[dict[str, Callable]] = None, task: str = 'classification_multilabel', loss_kwargs: Optional[dict] = {}, hparams_preprocess: bool = True, weights_dir: str = 'outputs/glc24_cnn_multimodal_ensemble/', checkpoint_path: Optional[str] = None, num_classes: Optional[int] = None)
Class constructor.
- Parameters:
model (Union[torch.nn.Module, Mapping]) – model to use, either a torch model object, or a mapping (dictionary from config file) used to load and build the model
lr (float, optional) – learning rate, by default 1e-2
weight_decay (float, optional) – weight decay, by default 0
momentum (float) – value of momentum
nesterov (bool) – if True, uses Nesterov’s momentum
metrics (dict) – dictionnary containing the metrics to compute. Keys must match metrics’ names and have a subkey with each metric’s functional methods as value. This subkey is either created from the malpolon.models.utils.FMETRICS_CALLABLES constant or supplied, by the user directly.
task (str, optional) – Machine learning task (used to format labels accordingly), by default ‘classification_multiclass’. The value determines the loss to be selected. if ‘multilabel’ or ‘binary’ is in the task, the BCEWithLogitsLoss is selected, otherwise the CrossEntropyLoss is used.
loss_kwargs (Optional[dict], optional) – loss parameters, by default {}
hparams_preprocess (bool, optional) – if True performs preprocessing operations on the hyperparameters, by default True
weights_dir (str, optional) – directory where to download the model weights, by default ‘outputs/glc24_cnn_multimodal_ensemble/’
checkpoint_path (Optional[str], optional) – path to the model checkpoint to load either to resume a previous training, perform transfer learning or run in prediction mode (inference), by default None
num_classes (int, optional) – number of classes for the classification task, by default None
- configure_optimizers()
Override default optimizer and scheduler.
By default, SGD is selected and the scheduler is handled by PyTorch Lightning’s default one.
- Returns:
dictionary containing keys for optimizer and scheduler, passed on to PyTorch Lightning
- Return type:
(dict)
- forward(x, y, z)
- predict_step(batch, batch_idx, dataloader_idx=0)
malpolon.models.custom_models.multi_modal
This module provides classes for advanced model building.
- Author: Titouan Lorieul <titouan.lorieul@gmail.com>
Theo Larcher <theo.larcher@inria.fr>
- class malpolon.models.custom_models.multi_modal.HomogeneousMultiModalModel(modality_names: list, modalities_model: dict, aggregator_model: Union[nn.Module, Mapping])
Bases:
MultiModalModel
Straightforward multi-modal model.
- __init__(modality_names: list, modalities_model: dict, aggregator_model: Union[nn.Module, Mapping])
Class constructor.
- Parameters:
modality_names (list) – list of modalities names
modalities_model (dict) – dictionary of modality names and their respective models to pass on to the model builder
aggregator_model (Union[nn.Module, Mapping]) – Model strategy to aggregate the features from each modality. Can either be a PyTorch module directly (in this case, the module will be directly called), or a mapping in the same fashion as for buiding the modality models, in which case the model builder will be called again.
- class malpolon.models.custom_models.multi_modal.MultiModalModel(modality_models: Union[nn.Module, Mapping], aggregator_model: Union[nn.Module, Mapping])
Bases:
Module
Base multi-modal model.
This class builds an aggregation of multiple models from the passed on config file values, one for each modality, splits the training routine per modality and then aggregates the features from each modality after each forward pass.
- __init__(modality_models: Union[nn.Module, Mapping], aggregator_model: Union[nn.Module, Mapping])
Class constructor.
- Parameters:
modality_models (Union[nn.Module, Mapping]) – dictionary of modality names and their respective models to pass on to the model builder
aggregator_model (Union[nn.Module, Mapping]) – Model strategy to aggregate the features from each modality. Can either be a PyTorch module directly (in this case, the module will be directly called), or a mapping in the same fashion as for buiding the modality models, in which case the model builder will be called again.
- forward(x: list[Any]) Any
- class malpolon.models.custom_models.multi_modal.ParallelMultiModalModelStrategy(accelerator=None, parallel_devices=None, checkpoint_io=None, precision_plugin=None)
Bases:
SingleDeviceStrategy
Model parallelism strategy for multi-modal models.
WARNING: STILL UNDER DEVELOPMENT.
- batch_to_device(batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) Any
TODO: Docstring.
- model_to_device() None
TODO: Docstring.
- strategy_name = 'parallel_multi_modal_model'
malpolon.data
malpolon.data.data_module
This module provides a base class for data modules.
- Author: Theo Larcher <theo.larcher@inria.fr>
Titouan Lorieul <titouan.lorieul@gmail.com>
- class malpolon.data.data_module.BaseDataModule(train_batch_size: int = 32, inference_batch_size: int = 256, num_workers: int = 8)
Bases:
LightningDataModule
,ABC
Base class for data modules.
This class inherits pytorchlightining’s LightningDataModule class and provides a base class for data modules by re-defining steps methods as well as adding new data manipulation methods.
- export_predict_csv(predictions: Union[Tensor, np.ndarray], probas: Union[Tensor, np.ndarray] = None, single_point_query: dict = None, out_name: str = 'predictions', out_dir: str = './', return_csv: bool = False, top_k: int = None, **kwargs: Any) Any
Export predictions to csv file.
This method is used to export predictions to a csv file. It can be used with a single point query or with the whole test dataset. This method is adapted for a classification task with an observations file and multi-modal data. Keys in the csv file match the ones used to inistantiate the RasterTorchGeoDataset class, that is to say : observation_id, lon, lat, target_species_id. The crs key is also mandatory in the case of singl-point query.
- Parameters:
predictions (Union[Tensor, np.ndarray]) – model’s predictions.
probas (Union[Tensor, np.ndarray], optional) – predictions’ raw logits or logits passed through an activation function, by default None
single_point_query (dict, optional) – query dictionnary of the single-point prediction. ‘target_species_id’ key is mandatory expects a list of numpy arrays of species ids. ‘predictions’ and ‘probas’ keys expect numpy arrays of predictions and probabilities. By default None (whole test dataset predictions)
out_name (str, optional) – output CSV file name, by default “predictions”
out_dir (str, optional) – output directory name, by default “./”
return_csv (bool, optional) – if true, the method returns the CSV as a pandas DataFrame, by default False
top_k (int, optional) – number of top predictions to return, by default None (max number of predictions)
- Returns:
CSV content as a pandas DataFrame if return_csv is True
- Return type:
pandas.DataFrame
- export_predict_csv_basic(predictions: Union[Tensor, np.ndarray], targets: Union[np.ndarray, list], probas: Union[Tensor, np.ndarray] = None, ids: Union[np.ndarray, list] = None, out_name: str = 'predictions', out_dir: str = './', return_csv: bool = False, top_k: int = None, **kwargs: Any)
Export predictions to csv file.
Exports predictions, probabilities and ids to a csv file.
- Parameters:
predictions (Union[Tensor, np.ndarray]) – model’s predictions.
targets (Union[np.ndarray, list], optional) – target species ids, by default None
probas (Union[Tensor, np.ndarray], optional) – predictions’ raw logits or logits passed through an activation function, by default None
ids (Union[np.ndarray, list], optional) – ids of the observations, by default None
out_name (str, optional) – output CSV file name, by default “predictions”
out_dir (str, optional) – output directory name, by default “./”
return_csv (bool, optional) – if true, the method returns the CSV as a pandas DataFrame, by default False
top_k (int, optional) – number of top predictions to return, by default None (max number of predictions)
- Returns:
CSV content as a pandas DataFrame if return_csv is True
- Return type:
pandas.DataFrame
- abstract get_dataset(split: str, transform: Callable, **kwargs: Any) Dataset
Return the dataset corresponding to the split.
- Parameters:
split (str) – Type of dataset. Values must be on of [“train”, “val”, “test”]
transform (Callable) – data transforms to apply when loading the dataset
- Returns:
dataset corresponding to the split
- Return type:
Dataset
- get_test_dataset() Dataset
Call self.get_dataset to return the test dataset.
- Returns:
test dataset
- Return type:
Dataset
- get_train_dataset() Dataset
Call self.get_dataset to return the train dataset.
- Returns:
train dataset
- Return type:
Dataset
- get_val_dataset() Dataset
Call self.get_dataset to return the validation dataset.
- Returns:
validation dataset
- Return type:
Dataset
- predict_dataloader() DataLoader
Return predict dataloader instantiated with class attributes.
- Returns:
predict dataloader
- Return type:
DataLoader
- predict_logits_to_class(predictions: Tensor, classes: Union[np.ndarray, Tensor], activation_fn: torch.nn.modules.activation = Softmax(dim=1)) Tensor
Convert the model’s predictions to class labels.
This method applies an activation function to the model’s predictions and returns the corresponding class labels.
- Parameters:
predictions (Tensor) – model’s predictions (raw logits), by default Softmax(dim=1)
classes (Union[np.ndarray, Tensor]) – classes labels
activation_fn (torch.nn.modules.activation, optional) – activation function to apply to the model’s predictions, by default torch.nn.Softmax(dim=1)
- Returns:
class labels and corresponding probabilities
- Return type:
tuple[np.ndarray, np.ndarray]
- prepare_data() None
Prepare data.
Called once on CPU. Class states defined here are lost afterwards. This method is intended for data downloading, tokenization, permanent transformation…
- setup(stage: Optional[str] = None) None
Register the correct datasets to the class attributes.
Depending on the trainer’s stage, this method will retrieve the train, val or test dataset and register it as a class attribute. The “predict” stage calls for the test dataset.
- Parameters:
stage (Optional[str], optional) – trainer’s stage, by default None (train)
- test_dataloader() DataLoader
Return test dataloader instantiated with class attributes.
- Returns:
test dataloader
- Return type:
DataLoader
- abstract property test_transform: Callable
Return test data transforms.
- Returns:
test data transforms
- Return type:
Callable
- train_dataloader() DataLoader
Return train dataloader instantiated with class attributes.
- Returns:
train dataloader
- Return type:
DataLoader
- abstract property train_transform: Callable
Return train data transforms.
- Returns:
train data transforms
- Return type:
Callable
- val_dataloader() DataLoader
Return validation dataloader instantiated with class attributes.
- Returns:
Validation dataloader
- Return type:
DataLoader
malpolon.data.environmental_raster
Custom classes to handle environmental rasters without torchgeo.
Author: Titouan Lorieul <titouan.lorieul@gmail.com>
- class malpolon.data.environmental_raster.PatchExtractor(root_path: Union[str, Path], size: int = 256)
Bases:
object
Handles the loading and extraction of an environmental tensor from multiple rasters given GPS coordinates.
- Parameters:
root_path (string or pathlib.Path) – Path to the folder containing all the rasters.
size (integer) – Size in pixels (size x size) of the patches to extract around each location.
- __getitem__(coordinates: Coordinates) npt.NDArray[np.float32]
Extract the patches around the given GPS coordinates for all the previously loaded rasters.
- Parameters:
coordinates (tuple containing two floats) – GPS coordinates (latitude, longitude)
- Returns:
patch – Extracted patches around the given coordinates.
- Return type:
3d array of floats, [n_rasters, size, size], or 1d array of floats, [n_rasters,], if size == 1
- __len__() int
Return the number of variables/rasters loaded.
- Returns:
n_rasters – Number of loaded rasters
- Return type:
integer
- add_all_bioclimatic_rasters(**kwargs: Any) None
Add all bioclimatic variables (rasters) available.
- Parameters:
kwargs (dict) – Updates the default arguments passed to Raster (nan, out_of_bounds, etc.)
- add_all_pedologic_rasters(**kwargs: Any) None
Add all pedologic variables (rasters) available.
- Parameters:
kwargs (dict) – Updates the default arguments passed to Raster (nan, out_of_bounds, etc.)
- add_all_rasters(**kwargs: Any) None
Add all variables (rasters) available.
- Parameters:
kwargs (dict) – Updates the default arguments passed to Raster (nan, out_of_bounds, etc.)
- append(raster_name: str, **kwargs: Any) None
Load and append a single raster to the rasters already loaded.
Can be useful to load only a subset of rasters or to pass configurations specific to each raster.
- Parameters:
raster_name (string) – Name of the raster to load, should be a subfolder of root_path.
kwargs (dict) – Updates the default arguments passed to Raster (nan, out_of_bounds, etc.)
- clean() None
Remove all rasters from the extractor.
- plot(coordinates: Coordinates, return_fig: bool = False, n_cols: int = 5, fig: Optional[plt.Figure] = None, resolution: float = 1.0) Optional[plt.Figure]
Plot an environmental tensor (only works if size > 1).
- Parameters:
coordinates (tuple containing two floats) – GPS coordinates (latitude, longitude)
return_fig (boolean) – If True, returns the created plt.Figure object
n_cols (integer) – Number of columns to use
fig (plt.Figure or None) – If not None, use the given plt.Figure object instead of creating a new one
resolution (float) – Resolution of the created figure
- Returns:
fig – If return_fig is True, the used plt.Figure object
- Return type:
plt.Figure
- class malpolon.data.environmental_raster.Raster(path: Union[str, Path], country: str, size: int = 256, nan: Optional[float] = nan, out_of_bounds: str = 'error')
Bases:
object
Loads a GeoTIFF file and extract patches for a single environmental raster.
- Parameters:
path (string / pathlib.Path) – Path to the folder containing all the rasters.
country (string, either "FR" or "USA") – Which country to load raster from.
size (integer) – Size in pixels (size x size) of the patch to extract around each location.
nan (float or None) – Value to use to replace missing data in original rasters, if None, leaves default values.
out_of_bounds (string, either "error", "warn" or "ignore") – If “error”, raises an exception if the location requested is out of bounds of the rasters. Set to “warn” to only produces a warning and to “ignore” to silently ignore it and return a patch filled with missing data.
- __getitem__(coordinates: Coordinates) Patch
Extract the patch around the given GPS coordinates.
- Parameters:
coordinates (tuple containing two floats) – GPS coordinates (latitude, longitude)
- Returns:
patch – Extracted patch around the given coordinates.
- Return type:
2d array of floats, [size, size], or single float if size == 1
- __len__() int
Return the number of bands in the raster.
Should always be equal to 1.
- Returns:
n_bands – Number of bands in the raster
- Return type:
integer
malpolon.data.get_jpeg_patches_stats
Script / module used to compute the mean and std on JPEG files.
When dealing with a large amount of files it should be run only once, and the statistics should be stored in a separate .csv for later use.
Author: Theo Larcher <theo.larcher@inria.fr>
- malpolon.data.get_jpeg_patches_stats.standardize(root_path: str = 'sample_data/SatelliteImages/', ext: str = ['jpeg', 'jpg'], output: str = 'root_path')
Perform standardization over images.
Returns and stores the mean and standard deviation of an image dataset organized inside a root directory for computation purposes like deep learning.
- Args:
- root_path (str): root dir. containing the images.
Defaults to ‘./sample_data/SatelliteImages/’.
- ext (str, optional): the images extensions to consider.
Defaults to ‘jpeg’.
- output (str, optional): output path where to save the csv containing
the mean and std of the dataset. If None: doesn’t output anything. Defaults to root_path.
- Returns:
(tuple): tuple of mean and std fo the jpeg images.
- malpolon.data.get_jpeg_patches_stats.standardize_by_parts(fps_fp: str, output: str = 'glc23_stats.csv', max_imgs_per_computation: int = 100000)
Perform standardization over images part by part.
With too many images, memory can overflow. This function addresses this problem by performing the computation in parts. Downside: the computed standard deviation is an mean approximation of the true value.
- Args:
- fps_fp (str): file path to a text file containing the paths to
the images.
- output (str, optional): output path where to save the csv containing
the mean and std of the dataset. If None: doesn’t output anything. Defaults to root_path.
- max_imgs_per_computation (int, optional): maximum number of images to hold in memory.
Defaults to 100000.
- Returns:
(tuple): tuple of mean and std fo the jpeg images.
malpolon.data.utils
This file compiles useful functions related to data and file handling.
Author: Theo Larcher <theo.larcher@inria.fr>
- malpolon.data.utils.get_files_path_recursively(path, *args, suffix='') list
Retrieve specific files path recursively from a directory.
Retrieve the path of all files with one of the given extension names, in the given directory and all its subdirectories, recursively. The extension names should be given as a list of strings. The search for extension names is case sensitive.
- Parameters:
path (str) – root directory from which to search for files recursively
*args (list) – list of file extensions to be considered.
- Returns:
subdirectories.
- Return type:
list list of paths of every file in the directory and all its
- malpolon.data.utils.is_bbox_contained(bbox1: Union[Iterable, BoundingBox], bbox2: Union[Iterable, BoundingBox], method: str = 'shapely') bool
Determine if a 2D bbox in included inside of another.
Returns a boolean answering the question “Is bbox1 contained inside bbox2 ?”. With methods ‘shapely’ and ‘manual’, bounding boxes must follow the format: [xmin, ymin, xmax, ymax]. With method ‘torchgeo’, bounding boxes must be of type: torchgeo.datasets.utils.BoundingBox.
- Parameters:
bbox1 (Union[Iterable, BoundingBox]) – Bounding box n°1.
bbox2 (Union[Iterable, BoundingBox]) – Bounding box n°2.
method (str) – Method to use for comparison. Can take any value in [‘shapely’, ‘manual’, ‘torchgeo’], by default ‘shapely’.
- Returns:
True if bbox1 ⊂ bbox2, False otherwise.
- Return type:
boolean
- malpolon.data.utils.is_point_in_bbox(point: Iterable, bbox: Iterable, method: str = 'shapely') bool
Determine if a 2D point in included inside of a 2D bounding box.
Returns a boolean answering the question “Is point contained inside bbox ?”. Point must follow the format: [x, y] Bounding box must follow the format: [xmin, ymin, xmax, ymax]
- Parameters:
point (Iterable) – Point in the format [x, y].
bbox (Iterable) – Bounding box in the format [xmin, xmax, ymin, ymax].
method (str) – Method to use for comparison. Can take any value in [‘shapely’, ‘manual’], by default ‘shapely’.
- Returns:
True if point ⊂ bbox, False otherwise.
- Return type:
boolean
- malpolon.data.utils.split_obs_per_species_frequency(input_path: str, output_name: str, val_ratio: float = 0.05)
Split an obs csv in val/train.
Performs a split with equal proportions of classes in train and val (if possible depending on the number of occurrences per species). If too few species are in the obs file, they are not included in the val split.
The val proportion is defined by the val_ratio argument.
Input csv is expected to have at least the following columns: [‘speciesId’]
- malpolon.data.utils.split_obs_spatially(input_path: str, spacing: float = 0.16666666666666666, plot: bool = False, val_size: float = 0.15)
Perform a spatial train/val split on the input csv file.
- Parameters:
input_path (str) – obs CSV input file’s path
spacing (float, optional) – size of the spatial split in degrees (or whatever unit the coordinates are in), by default 10/60
plot (bool, optional) – if true, plots the train/val split on a 2D map, by default False
val_size (float, optional) – size of the validation split, by default 0.15
- malpolon.data.utils.to_one_hot_encoding(labels_predict: int | list, labels_target: list) list
Return a one-hot encoding of class-index predicted labels.
Converts a single label value or a vector of labels into a vector of one-hot encoded labels. The labels order follow that of input labels_target.
- Parameters:
labels_predict (int | list) – Labels to convert to one-hot encoding.
labels_target (list) – All existing labels, in the right order.
- Returns:
One-hot encoded labels.
- Return type:
list
malpolon.data.datasets
malpolon.data.datasets.torchgeo_datasets
This module provides raster related classes based on torchgeo.
Author: Theo Larcher <theo.larcher@inria.fr>
- class malpolon.data.datasets.torchgeo_datasets.RasterBioclim(root: str = 'data', labels_name: str = None, split: str = None, crs: Any = None, res: float = None, bands: Sequence[str] = None, transform: Callable[..., Any] = None, transform_target: Callable[..., Any] = None, patch_size: int | float | tuple = 256, query_units: str = 'pixel', query_crs: int | str | CRS = 'self', obs_data_columns: Dict = {'index': 'surveyId', 'species_id': 'speciesId', 'split': 'subset', 'x': 'lon', 'y': 'lat'}, task: str = 'multiclass', binary_positive_classes: list = [], cache: TYPE_CHECKING = True, **kwargs)
Bases:
RasterTorchGeoDataset
Raster dataset adapted for CHELSA Bioclimatic data.
Inherits RasterTorchGeoDataset.
- all_bands: list[str] = ['bio_1', 'bio_2', 'bio_3', 'bio_4']
Names of all available bands in the dataset
- date_format = '%Y%m%dT%H%M%S'
Date format string used to parse date from filename.
Not used if
filename_regex
does not contain adate
group.
- filename_glob = 'bio_*.tif'
Glob expression used to search for files.
This expression should be specific enough that it will not pick up files from other datasets. It should not include a file extension, as the dataset may be in a different file format than what it was originally downloaded as.
- filename_regex = '(?P<band>bio_[\\d])'
Regular expression used to extract date from filename.
The expression should use named groups. The expression may contain any number of groups. The following groups are specifically searched for by the base class:
date
: used to calculatemint
andmaxt
forindex
insertion
When
separate_files
is True, the following additional groups are searched for to find other files:band
: replaced with requested band name
- is_image = True
True if dataset contains imagery, False if dataset contains mask
- plot(sample: Patches)
Plot all layers of a given patch.
A patch is selected based on a key matching the associated provider’s __get__() method.
- Args:
item (dict): provider’s get index.
- plot_bands = 'all_bands'
- separate_files = True
True if data is stored in a separate file for each band, else False.
- class malpolon.data.datasets.torchgeo_datasets.RasterTorchGeoDataset(root: str = 'data', labels_name: Optional[str] = None, split: Optional[str] = None, crs: Optional[Any] = None, res: Optional[float] = None, bands: Optional[Sequence[str]] = None, transform: Optional[Callable] = None, transform_target: Optional[Callable] = None, patch_size: Union[int, float, tuple] = 256, query_units: str = 'pixel', query_crs: Union[int, str, CRS] = 'self', obs_data_columns: dict = {'index': 'surveyId', 'species_id': 'speciesId', 'split': 'subset', 'x': 'lon', 'y': 'lat'}, task: str = 'multiclass', binary_positive_classes: list = [], cache: bool = True)
Bases:
RasterDataset
Generic torchgeo based raster datasets.
Datasets based on this class return patches from raster files and can be queried by either a torchgeo BoundingBox object, a tuple of coordinates in the dataset’s CRS or a dictionary specifying coordinates and the wanted CRS. Additionally one can specify the desired size and units of the wanted patch even if they don’t match the dataset’s.
RasterTorchGeoDataset inherits torchgeo’s RasterDataset class.
- __getitem__(query: Union[int, dict, tuple, list, set, BoundingBox]) Dict[str, Any]
Query an item from the dataset.
Supports querying the dataset with coordinates in the dataset’s CRS or in another CRS. The dataset is always queried with a torchgeo BoundingBox because it is itself a torchgeo dataset, but the query in this getter method can be passed as a tuple, list, set, dict or BoundingBox.
Use case 1: query is a [list, tuple, set] of 2 elements : lon, lat. Here the CRS and Units system are by default those of the dataset’s.
Use case 2: query is a torchgeo BoundingBox. Here the CRS and Units system are by default those of the dataset’s.
Use case 3: query is a dict containing the following necessary keys: {‘lon’, ‘lat’}, and optional keys: {‘crs’, ‘units’, ‘size’} which values default to those of the dataset’s.
In Use case 3, if the ‘crs’ key is registered and it is different from the dataset’s CRS, the coordinates of the point are projected into the dataset’s CRS and the value of the key is overwritten by said CRS.
Use cases 1 and 3 give the possibility to easily query the dataset using only a point and a bounding box (bbox) size, using the desired input CRS.
The unit of measurement of the bbox can be set to [‘m’, ‘meters’, ‘metres’] even if the dataset’s unit is different as the points will be projected in the nearest meter-based CRS (see self.point_to_bbox()). Note that depending on your dataset’s CRS, querying a meter-based bbox may result in rectangular patches because of deformations.
- Parameters:
query (Union[dict, tuple, BoundingBox]) –
item query containing geographical coordinates. It can be of different types for different use. One can query a patch by providing a BoundingBox using torchgeo.datasets.BoundingBox constructor; or by given a center and a size.
— BoundingBox strategy —
Must follow : BoundingBox(minx, maxx, miny, maxy, mint, maxt)
— Point strategy —
If tuple, must follow : (lon, lat) and the CRS of the coordinates will be assumed to be the dataset’s. If dict, must follow : {‘lon’: lon, ‘lat’: lat, <’crs’: crs>} and the coordinates CRS can be specified. If not, it will be assumed that it is equal to the dataset’s. In both cases, a BoundingBox is generated to pursue the query.
- Returns:
dataset patch.
- Return type:
Dict[str, Any]
- __init__(root: str = 'data', labels_name: Optional[str] = None, split: Optional[str] = None, crs: Optional[Any] = None, res: Optional[float] = None, bands: Optional[Sequence[str]] = None, transform: Optional[Callable] = None, transform_target: Optional[Callable] = None, patch_size: Union[int, float, tuple] = 256, query_units: str = 'pixel', query_crs: Union[int, str, CRS] = 'self', obs_data_columns: dict = {'index': 'surveyId', 'species_id': 'speciesId', 'split': 'subset', 'x': 'lon', 'y': 'lat'}, task: str = 'multiclass', binary_positive_classes: list = [], cache: bool = True) None
Class constructor.
- Parameters:
root (str, optional) – path to the directory containing the data and labels, by default “data”
labels_name (str, optional) – labels file name, by default None
split (str, optional) – dataset subset desired for labels selection, by default None
crs (Any | None, optional) – coordinate reference system (CRS) to warp to (defaults to the CRS of the first file found), by default None
res (float | None, optional) – resolution of the dataset in units of CRS (defaults to the resolution of the first file found), by default None
bands (Sequence[str] | None, optional) – bands to return (defaults to all bands), by default None
transform (Callable | None, optional) – a callable function that takes an input sample and returns a transformed version, by default None
transform_target (Callable | None, optional) – a callable function that takes an input target and returns a transformed version, by default None
patch_size (int, optional) – size of the 2D extracted patches. Patches can either be square (int/float value) or rectangular (tuple of int/float). Defaults to a square of size 256, by default 256
query_units (str, optional) – unit system of the dataset’s queries, by default ‘pixel’
query_crs (Union[int, str, CRS], optional) – CRS of the dataset’s queries, by default ‘self’ (same as dataset’s CRS)
obs_data_columns (dict, optional) –
this dictionary allows users to match the dataset attributes with custom column names of their obs data file, by default:
{'x': 'lon', 'y': 'lat', 'index': 'surveyId', 'species_id': 'speciesId', 'split': 'subset'}
Here’s a description of the keys:
’x’, ‘y’: coordinates of the obs points (by default ‘lon’, ‘lat’ as per the WGS84 system)
’index’: obs ID over which to iterate during the training loop
’species_id’: species ID (label) associated with the obs points
’split’: dataset split column name
task (str, optional) – machine learning task (used to format labels accordingly), by default ‘multiclass’
binary_positive_classes (list, optional) – labels’ classes to consider valid in the case of binary classification with multi-class labels (defaults to all 0), by default []
cache (bool, optional) – if True, cache file handle to speed up repeated sampling, by default True
- __len__() int
- coords_transform(lon: Union[int, float], lat: Union[int, float], input_crs: Union[str, int, CRS] = '4326', output_crs: Union[str, int, CRS] = 'self') tuple[float, float]
Transform coordinates from one CRS to another.
- Parameters:
lon (Union[int, float]) – longitude
lat (Union[int, float]) – latitude
input_crs (Union[str, int, CRS], optional) – Input CRS, by default “4326”
output_crs (Union[str, int, CRS], optional) – Output CRS, by default “self”
- Returns:
Transformed coordinates.
- Return type:
tuple
- get_label(df: DataFrame, query_lon: float, query_lat: float, obs_id: Optional[int] = None) Union[ndarray, int]
Return the label(s) matching the query coordinates.
This method takes into account the fact that several labels can match a single coordinate set. For that reason, the labels are chosen according to the value of the ‘obs_id’ parameter (matching the observation_id column of the labels DataFrame). If no value is given to ‘obs_id’, all matching labels are returned.
- Parameters:
df (pd.DataFrame) – dataset DataFrame composed of columns: [‘lon’, ‘lat’, ‘observation_id’]
query_lon (float) – longitude value of the query point.
query_lat (float) – latitude value of the query point.
obs_id (int, optional) – observation ID tied to the query point, by default None
- Returns:
target label(s).
- Return type:
Union[np.ndarray, int]
- point_to_bbox(lon: Union[int, float], lat: Union[int, float], size: Optional[Union[tuple, int]] = None, units: str = 'crs', crs: Union[int, str] = 'self') BoundingBox
Convert a geographical point to a torchgeo BoundingBox.
This method converts a 2D point into a 2D torchgeo bounding box (bbox). If ‘size’ is in the CRS’ unit system, the bbox is computed directly from the point’s coordinates. If ‘size’ is in pixels, ‘size’ is multiplied by the resolution of the dataset. If ‘size’ is in meters and the dataset’s unit system isn’t, the point is projected into the nearest meter-based CRS (from a list defined as constant at the begining of this file), the bbox vertices’ min and max are computed in thise reference system, then they are projected back into the input CRS ‘crs’. If the dataset’s CRS doesn’t match en EPSG code but is instead built from a WKT, the nearest meter-based CRS will always be EPSG:3035.
By default, ‘size’ is set to the dataset’s ‘patch_size’ value via None.
- Parameters:
lon (Union[int, float]) – longitude
lat (Union[int, float]) – latitude
size (Union[tuple, int], optional) – Patch size, by default None. If passed as an int, the patch will be square. If passed as a tuple (width, height), can be rectangular. By default None.
units (str, optional) – The coordinates’ unit system, must have a value in [‘pixel’, ‘crs’]. The size of the bbox will adapt to the unit. If ‘pixel’ is selected, the bbox size will be multiplied by the dataset resolution. Selecting ‘crs’ will not modify the bbox size. In that case the returned bbox will be of size: (size[0], size[1]) <metric_of_the_dataset (usually meters)>. Defaults to ‘crs’.
crs (Union[int, str]) – CRS of the point’s lon/lat coordinates, by default ‘self’.
- Returns:
Corresponding torchgeo BoundingBox.
- Return type:
BoundingBox
malpolon.data.datasets.torchgeo_sentinel2
This module provides Sentinel-2 related classes based on torchgeo.
Sentinel-2 data is queried from Microsoft Planetary Computer (MPC).
Author: Theo Larcher <theo.larcher@inria.fr>
- class malpolon.data.datasets.torchgeo_sentinel2.RasterSentinel2(root: str = 'data', labels_name: Optional[str] = None, split: Optional[str] = None, crs: Optional[Any] = None, res: Optional[float] = None, bands: Optional[Sequence[str]] = None, transform: Optional[Callable] = None, transform_target: Optional[Callable] = None, patch_size: Union[int, float, tuple] = 256, query_units: str = 'pixel', query_crs: Union[int, str, CRS] = 'self', obs_data_columns: dict = {'index': 'surveyId', 'species_id': 'speciesId', 'split': 'subset', 'x': 'lon', 'y': 'lat'}, task: str = 'multiclass', binary_positive_classes: list = [], cache: bool = True)
Bases:
RasterTorchGeoDataset
Raster dataset adapted for Sentinel-2 data.
Inherits RasterTorchGeoDataset.
- all_bands: list[str] = ['B02', 'B03', 'B04', 'B08']
Names of all available bands in the dataset
- date_format = '%Y%m%dT%H%M%S'
Date format string used to parse date from filename.
Not used if
filename_regex
does not contain adate
group.
- filename_glob = 'T*_B0*_10m.tif'
Glob expression used to search for files.
This expression should be specific enough that it will not pick up files from other datasets. It should not include a file extension, as the dataset may be in a different file format than what it was originally downloaded as.
- filename_regex = 'T31TEJ_20190801T104031_(?P<band>B0[\\d])'
Regular expression used to extract date from filename.
The expression should use named groups. The expression may contain any number of groups. The following groups are specifically searched for by the base class:
date
: used to calculatemint
andmaxt
forindex
insertion
When
separate_files
is True, the following additional groups are searched for to find other files:band
: replaced with requested band name
- is_image = True
True if dataset contains imagery, False if dataset contains mask
- plot(sample: Patches) Figure
Plot a 3-bands dataset patch (sample).
Plots a dataset sample by selecting the 3 bands indicated in the plot_bands variable (in the same order). By default, the method plots the RGB bands.
- Parameters:
sample (Patches) – dataset’s patch to plot
- Returns:
matplotlib figure containing the plot
- Return type:
Figure
- plot_bands = ['B04', 'B03', 'B02']
- separate_files = True
True if data is stored in a separate file for each band, else False.
- class malpolon.data.datasets.torchgeo_sentinel2.RasterSentinel2GLC23(root: str = 'data', labels_name: Optional[str] = None, split: Optional[str] = None, crs: Optional[Any] = None, res: Optional[float] = None, bands: Optional[Sequence[str]] = None, transform: Optional[Callable] = None, transform_target: Optional[Callable] = None, patch_size: Union[int, float, tuple] = 256, query_units: str = 'pixel', query_crs: Union[int, str, CRS] = 'self', obs_data_columns: dict = {'index': 'surveyId', 'species_id': 'speciesId', 'split': 'subset', 'x': 'lon', 'y': 'lat'}, task: str = 'multiclass', binary_positive_classes: list = [], cache: bool = True)
Bases:
RasterSentinel2
Adaptation of RasterSentinel2 for new GLC23 observations.
- all_bands: list[str] = ['red', 'green', 'blue', 'nir']
Names of all available bands in the dataset
- filename_glob = '*.tif'
Glob expression used to search for files.
This expression should be specific enough that it will not pick up files from other datasets. It should not include a file extension, as the dataset may be in a different file format than what it was originally downloaded as.
- filename_regex = '(?P<band>red|green|blue|nir)_2021'
Regular expression used to extract date from filename.
The expression should use named groups. The expression may contain any number of groups. The following groups are specifically searched for by the base class:
date
: used to calculatemint
andmaxt
forindex
insertion
When
separate_files
is True, the following additional groups are searched for to find other files:band
: replaced with requested band name
- plot_bands = ['red', 'green', 'blue']
- class malpolon.data.datasets.torchgeo_sentinel2.Sentinel2GeoSampler(dataset: GeoDataset, size: Union[Tuple[float, float], float], length: Optional[int] = None, roi: Optional[BoundingBox] = None, units: str = 'pixel', crs: str = 'crs')
Bases:
GeoSampler
Custom sampler for RasterSentinel2.
This custom sampler is used by RasterSentinel2 to query the dataset with the fully constructed dictionary. The sampler is passed to and used by PyTorch dataloaders in the training/inference workflow.
Inherits GeoSampler.
- NOTE: this sampler is compatible with any class inheriting
RasterTorchGeoDataset’s __getitem__ method so the name of this sampler may become irrelevant when more dataset-specific classes inheriting RasterTorchGeoDataset are created.
- __len__() int
Return the number of samples in a single epoch.
- Returns:
length of the epoch
- class malpolon.data.datasets.torchgeo_sentinel2.Sentinel2TorchGeoDataModule(dataset_path: str, labels_name: str = 'labels.csv', train_batch_size: int = 32, inference_batch_size: int = 16, num_workers: int = 8, size: int = 200, units: str = 'pixel', crs: int = 4326, binary_positive_classes: list = [], task: str = 'classification_multiclass', dataset_kwargs: dict = {}, download_data_sample: bool = False, **kwargs)
Bases:
BaseDataModule
Data module for Sentinel-2A dataset.
- __init__(dataset_path: str, labels_name: str = 'labels.csv', train_batch_size: int = 32, inference_batch_size: int = 16, num_workers: int = 8, size: int = 200, units: str = 'pixel', crs: int = 4326, binary_positive_classes: list = [], task: str = 'classification_multiclass', dataset_kwargs: dict = {}, download_data_sample: bool = False, **kwargs)
Class constructor.
- Parameters:
dataset_path (str) – path to the directory containing the data
labels_name (str, optional) – labels file name, by default ‘labels.csv’
train_batch_size (int, optional) – train batch size, by default 32
inference_batch_size (int, optional) – inference batch size, by default 256
num_workers (int, optional) – how many subprocesses to use for data loading.
0
means that the data will be loaded in the main process, by default 8size (int, optional) – size of the 2D extracted patches. Patches can either be square (int/float value) or rectangular (tuple of int/float). Defaults to a square of size 200, by default 200
units (Units, optional) – The queries’ unit system, must have a value in [‘pixel’, ‘crs’, ‘m’, ‘meter’, ‘metre]. This sets the unit you want your query to be performed in, even if it doesn’t match the dataset’s unit system, by default Units.CRS
crs (int, optional) – The queries’ coordinate reference system (CRS). This argument sets the CRS of the dataset’s queries. The value should be equal to the CRS of your observations. It takes any EPSG integer code, by default 4326
binary_positive_classes (list, optional) – labels’ classes to consider valid in the case of binary classification with multi-class labels (defaults to all 0), by default []
task (str, optional) – machine learning task (used to format labels accordingly), by default ‘classification_multiclass’
dataset_kwargs (dict, optional) – additional keyword arguments for the dataset, by default {}
download_data_sample (bool, optional) – whether to download a sample of Sentinel-2 data, by default False
- download_data_sample()
Download 4 Sentinel-2A tiles from MPC.
This method is useful to quickly download a sample of Sentinel-2A tiles via Microsoft Planetary Computer (MPC). The referenced of the tile downloaded are specified by the tile_id and timestamp variables. Tiles are not downloaded if they already have been and are found locally.
- get_dataset(split, transform, **kwargs)
- predict_dataloader() DataLoader
- test_dataloader() DataLoader
- property test_transform
- train_dataloader() DataLoader
- property train_transform
- val_dataloader() DataLoader
malpolon.data.datasets.torchgeo_concat
This module provides Sentinel-2 related classes based on torchgeo.
Sentinel-2 data is queried from Microsoft Planetary Computer (MPC).
NOTE: “unused” imports are necessary because they are evaluated in the eval() function. These classes are passed by the user in the config file along with their arguments.
Author: Theo Larcher <theo.larcher@inria.fr>
- class malpolon.data.datasets.torchgeo_concat.ConcatPatchRasterDataset(datasets: list[dict[torch.utils.data.dataset.Dataset]], split: str, transform: Callable, task: str)
Bases:
Dataset
Concatenation dataset.
This class concatenates multiple datasets into a single one with a single sampler. It is useful when you want to train a model on multiple datasets at the same time _(e.g.: to train both rasters and pre-extracted jpeg patches)_. In the case of RasterTorchgeDataset, the __getitem__ method calls a private method _default_sample_to_getitem to convert the iterating index to the correct index for the dataset. This is necessary because the RasterTorchGeoDataset class uses a custom dict-based sampler but the other classes don’t.
The minimum required class arguments _(i.e. observation_ids, targets, coordinates)_ are taken from the first dataset in the list.
Target labels are taken from the first dataset in the list.
All datasets must return tensors of the same shape.
- __getitem__(idx: int) Tuple[Patches, Targets]
Query an item from the dataset.
- Parameters:
idx (int) – item index (standard int).
- Returns:
concatenated data and corresponding label(s).
- Return type:
Tuple[Patches, Targets]
- __init__(datasets: list[dict[torch.utils.data.dataset.Dataset]], split: str, transform: Callable, task: str) None
Class constructor.
- Parameters:
datasets (list[dict[Dataset]]) – list of dictionaries with keys ‘callable’ and ‘kwargs’ on which to call the datasets.
transform (Callable) – data transform callable function.
task (str) – deep learning task.
- __len__() int
- class malpolon.data.datasets.torchgeo_concat.ConcatTorchGeoDataModule(dataset_kwargs: list[dict[torch.utils.data.dataset.Dataset, Any]], dataset_path: str = 'dataset/', labels_name: str = 'labels.csv', train_batch_size: int = 32, inference_batch_size: int = 256, num_workers: int = 8, binary_positive_classes: list = [], task: str = 'classification_multiclass', **kwargs)
Bases:
BaseDataModule
Data module to handle concatenation dataset.
Inherits BaseDataModule
- __init__(dataset_kwargs: list[dict[torch.utils.data.dataset.Dataset, Any]], dataset_path: str = 'dataset/', labels_name: str = 'labels.csv', train_batch_size: int = 32, inference_batch_size: int = 256, num_workers: int = 8, binary_positive_classes: list = [], task: str = 'classification_multiclass', **kwargs)
Class constructor.
- Parameters:
concat_datasets (list[dict[Dataset, Any]]) – list of dictionaries with keys ‘callable’ and ‘kwargs’ on which to call the datasets.
dataset_path (str) – path to the directory containing the data
labels_name (str, optional) – labels file name, by default ‘labels.csv’
train_batch_size (int, optional) – train batch size, by default 32
inference_batch_size (int, optional) – inference batch size, by default 256
num_workers (int, optional) – how many subprocesses to use for data loading.
0
means that the data will be loaded in the main process, by default 8binary_positive_classes (list, optional) – labels’ classes to consider valid in the case of binary classification with multi-class labels (defaults to all 0), by default []
task (str, optional) – machine learning task (used to format labels accordingly), by default ‘classification_multiclass’
- get_dataset(split, transform=None, **kwargs) Dataset
- predict_dataloader() DataLoader
- test_dataloader() DataLoader
- property test_transform
- train_dataloader() DataLoader
- property train_transform
- val_dataloader() DataLoader
malpolon.data.datasets.geolifeclef2022
GeoLifeCLEF2022: datasets
- class malpolon.data.datasets.geolifeclef2022.GeoLifeCLEF2022Dataset(root: Union[str, Path], subset: str, *, region: str = 'both', patch_data: str = 'all', use_rasters: bool = True, patch_extractor: Optional[PatchExtractor] = None, use_localisation: bool = False, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, **kwargs)
Bases:
Dataset
Pytorch dataset handler for GeoLifeCLEF 2022 dataset.
- Parameters:
root (string or pathlib.Path) – Root directory of dataset.
subset (string, either "train", "val", "train+val" or "test") – Use the given subset (“train+val” is the complete training data).
region (string, either "both", "fr" or "us") – Load the observations of both France and US or only a single region.
patch_data (string or list of string) – Specifies what type of patch data to load, possible values: ‘all’, ‘rgb’, ‘near_ir’, ‘landcover’ or ‘altitude’.
use_rasters (boolean (optional)) – If True, extracts patches from environmental rasters.
patch_extractor (PatchExtractor object (optional)) – Patch extractor to use if rasters are used.
use_localisation (boolean) – If True, returns also the localisation as a tuple (latitude, longitude).
transform (callable (optional)) – A function/transform that takes a list of arrays and returns a transformed version.
target_transform (callable (optional)) – A function/transform that takes in the target and transforms it.
- __getitem__(index: int) Union[dict[str, Patches], tuple[dict[str, Patches], Targets]]
Return a dataset item.
- Args:
index (int): dataset id.
- Returns:
- Union[dict[str, Patches], tuple[dict[str, Patches], Targets]]:
data and labels corresponding to the dataset id.
- __len__() int
Return the number of observations in the dataset.
- class malpolon.data.datasets.geolifeclef2022.MiniGeoLifeCLEF2022Dataset(root: Union[str, Path], subset: str, *, patch_data: str = 'all', use_rasters: bool = True, patch_extractor: Optional[PatchExtractor] = None, use_localisation: bool = False, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, **kwargs)
Bases:
GeoLifeCLEF2022Dataset
Pytorch dataset handler for a subset of GeoLifeCLEF 2022 dataset.
It consists in a restriction to France and to the 100 most present plant species.
- Parameters:
root (string or pathlib.Path) – Root directory of dataset.
subset (string, either "train", "val", "train+val" or "test") – Use the given subset (“train+val” is the complete training data).
patch_data (string or list of string) – Specifies what type of patch data to load, possible values: ‘all’, ‘rgb’, ‘near_ir’, ‘landcover’ or ‘altitude’.
use_rasters (boolean (optional)) – If True, extracts patches from environmental rasters.
patch_extractor (PatchExtractor object (optional)) – Patch extractor to use if rasters are used.
use_localisation (boolean) – If True, returns also the localisation as a tuple (latitude, longitude).
transform (callable (optional)) – A function/transform that takes a list of arrays and returns a transformed version.
target_transform (callable (optional)) – A function/transform that takes in the target and transforms it.
- __getitem__(index: int) Union[dict[str, Patches], tuple[dict[str, Patches], Targets]]
Return a dataset item.
- Args:
index (int): dataset id.
- Returns:
- Union[dict[str, Patches], tuple[dict[str, Patches], Targets]]:
data and labels corresponding to the dataset id.
- __len__() int
Return the number of observations in the dataset.
Pytorch dataset handler for a subset of GeoLifeCLEF 2022 dataset.
It consists in a restriction to France and to the 100 most present plant species.
- param root:
Root directory of dataset.
- type root:
string or pathlib.Path
- param subset:
Use the given subset (“train+val” is the complete training data).
- type subset:
string, either “train”, “val”, “train+val” or “test”
- param patch_data:
Specifies what type of patch data to load, possible values: ‘all’, ‘rgb’, ‘near_ir’, ‘landcover’ or ‘altitude’.
- type patch_data:
string or list of string
- param use_rasters:
If True, extracts patches from environmental rasters.
- type use_rasters:
boolean (optional)
- param patch_extractor:
Patch extractor to use if rasters are used.
- type patch_extractor:
PatchExtractor object (optional)
- param use_localisation:
If True, returns also the localisation as a tuple (latitude, longitude).
- type use_localisation:
boolean
- param transform:
A function/transform that takes a list of arrays and returns a transformed version.
- type transform:
callable (optional)
- param target_transform:
A function/transform that takes in the target and transforms it.
- type target_transform:
callable (optional)
- param download:
If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
- type download:
boolean (optional)
- malpolon.data.datasets.geolifeclef2022.MicroGeoLifeCLEF2022Dataset.__getitem__(self, index)
Return a MicroGeolifeClef dataset item.
- Args:
index (int): dataset id.
- Returns:
(tuple): data and labels.
- malpolon.data.datasets.geolifeclef2022.MicroGeoLifeCLEF2022Dataset.__len__(self)
Return the number of observations in the dataset.
malpolon.data.datasets.geolifeclef2023
GeoLifeCLEF2023: datasets
- class malpolon.data.datasets.geolifeclef2023.PatchesDataset(occurrences: str, providers: Iterable, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, id_name: str = 'glcID', label_name: str = 'speciesId', item_columns: Iterable = ['lat', 'lon', 'patchID'])
Bases:
Dataset
Patches dataset class.
This class provides a PyTorch-friendly dataset to handle patch data and labels. The data can be .jpeg or .tif files of variable depth (i.e. multi-spectral data). Each __getitem__ call returns a patch extracted from the dataset.
- Args:
Dataset (Dataset): PyTorch Dataset class.
- __getitem__(index)
Return a dataset element.
Returns an element from a dataset id (0 to n) with its label.
- Args:
index (int): dataset id.
- Returns:
(tuple): tuple of data patch (tensor) and label (int).
- __len__()
Return the size of the dataset.
- Returns:
(int): number of occurrences.
- class malpolon.data.datasets.geolifeclef2023.PatchesDatasetMultiLabel(occurrences: str, providers: Iterable, n_classes: Union[int, str] = 'max', id_getitem: str = 'patchId', **kwargs)
Bases:
PatchesDataset
Multilabel patches dataset.
Like PatchesDataset but provides one-hot encoded labels.
- Args:
PatchesDataset (PatchesDataset): pytorch friendly dataset class.
- __getitem__(index)
Return a dataset element.
Returns an element from a dataset id (0 to n) with the labels in one-hot encoding.
- Args:
index (int): dataset id.
- Returns:
(tuple): tuple of data patch (tensor) and labels (list).
- __len__()
Return the size of the dataset.
- Returns:
(int): number of occurrences.
- class malpolon.data.datasets.geolifeclef2023.TimeSeriesDataset(occurrences: str, providers: Iterable, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, id_name: str = 'glcID', label_name: str = 'speciesId', item_columns: Iterable = ['timeSerieID'])
Bases:
Dataset
Time series dataset.
Like PatchesDataset but adapted to time series data which is formatted as .csv files where each row is an occurrence and each column is a timestamp. Timeseries present one additional dimension compared to visual patches. Timeseries do not all have the same size and have a parametrable no_data_value argument.
- Args:
Dataset (Dataset): pytorch Dataset class.
- __getitem__(index)
Return a time series.
Returns a tuple of time series data and its label. The data in in the following shape : [1, n_bands, max_length_ts]
- Args:
index (int): dataset id.
- Returns:
(tuple): tuple of time series data (tensor) and labels (list).
- __len__()
Return the size of the dataset.
- Returns:
(int): number of occurrences.
GeoLifeCLEF2023: providers
- class malpolon.data.datasets.geolifeclef2023.PatchProvider(size: int = 128, normalize: bool = False)
Bases:
object
Parent class for all GLC23 patch data providers.
This class implements common implemented & abstract methods used by all patch providers. Particularly, the plot method is designed to accommodate all cases of the GLC23 patch datasets.
- abstract __getitem__(item)
Return a patch (parent class).
- __len__()
Return the size of the provider.
- Returns:
(int): number of layers (depth).
- class malpolon.data.datasets.geolifeclef2023.MetaPatchProvider(providers: Iterable[Callable], transform: Optional[Callable] = None)
Bases:
PatchProvider
Parent class for patch providers.
This class interfaces patch providers with patch datasets.
- Args:
(PatchProvider): inherits PatchProvider.
- __getitem__(item)
Return a patch.
This getter is used by a patch dataset class and calls each provider’s getter method to return concatenated patches.
- Args:
item (dict): provider index.
- Returns:
(array): concatenaned patch from all providers.
- __len__()
Return the size of the provider.
- Returns:
(int): number of layers (depth).
- class malpolon.data.datasets.geolifeclef2023.RasterPatchProvider(raster_path: str, size: int = 128, spatial_noise: int = 0, normalize: bool = True, fill_zero_if_error: bool = False, nan_value: Union[int, float] = 0)
Bases:
PatchProvider
Patch provider for .tif raster files.
This class handles rasters stored as .tif files and returns patches from them based on a dict key containing (lon, lat) coordinates.
- Args:
(PatchProvider): inherits PatchProvider.
- __getitem__(item)
Return a patch from a .tif raster.
This getter returns a patch of size self.size from a .tif raster loaded in self.data using GPS coordinates projected in EPSG:4326.
- Args:
- item (dict): dictionary containing at least latitude and longitude
keys ({‘lat’: lat, ‘lon’:lon})
- Returns:
(array): the environmental vector (size>1 or size=1).
- __len__()
Return the size of the provider.
- Returns:
(int): number of layers (depth).
- class malpolon.data.datasets.geolifeclef2023.MultipleRasterPatchProvider(rasters_folder: str, select: Optional[Iterable[str]] = None, **kwargs)
Bases:
PatchProvider
Patch provider for multiple PatchProvider.
This provider is useful when having to load several patch modalities through RasterPatchProvider by selecting to desired data in the ‘select’ argument of the constructor.
- Args:
PatchProvider (_type_): _description_
- __getitem__(item)
Return multiple patches.
Returns multiple patches from multiple raster providers in a 3-dimensional numpy array.
- Args:
item (dict): providers index.
- Returns:
(array): array of patches.
- __len__()
Return the size of the provider.
- Returns:
(int): number of layers (depth).
- class malpolon.data.datasets.geolifeclef2023.JpegPatchProvider(root_path: str, select: Optional[Iterable[str]] = None, normalize: bool = False, patch_transform: Optional[Callable] = None, size: int = 128, dataset_stats: str = 'jpeg_patches_stats.csv', id_getitem: str = 'patchID')
Bases:
PatchProvider
JPEG patches provider for GLC23.
Provides tensors of multi-modal patches from JPEG patch files of rasters of the GLC23 challenge.
- Attributes:
(PatchProvider): inherits PatchProvider.
- __getitem__(item)
Return a tensor composed of every channels of a jpeg patch.
Looks for every spectral bands listed in self.channels and returns a 3-dimensionnal patch concatenated in 1 tensor. The index used to query the right patch is a dictionnary with at least one key/value pair : {‘patchID’, <patchID_value>}.
- Args:
- item (dict): dictionnary containing the patchID necessary to
identify the jpeg patch to return.
- Raises:
KeyError: the ‘patchID’ key is missing from item Exception: item is not a dictionnary as expected
- Returns:
(tensor): multi-channel patch tensor.
- __len__()
Return the size of the provider.
- Returns:
(int): number of layers (depth).
- class malpolon.data.datasets.geolifeclef2023.TimeSeriesProvider(root_path: str, eos_replace_value: Union[int, float] = -1, transform: Optional[Iterable[Callable]] = None)
Bases:
object
Provide time series data.
This provider is the parent class of time series providers. It handles time series data stored as .csv files where each file has values for a single spectral band (red, green, infra-red etc…).
- abstract __getitem__(item)
Return a time series (parent class).
- __len__()
Return the size of the provider.
- Returns:
(int): number of layers (depth).
- class malpolon.data.datasets.geolifeclef2023.MetaTimeSeriesProvider(providers: Iterable[Callable], transform: Optional[Iterable[Callable]] = None)
Bases:
TimeSeriesProvider
Time Series provider called by TimeSeriesDataset to handle TS providers.
This TS provider handles all TS providers passed in TimeSeriesDataset to provide multi-modal time series objects.
- Args:
(TimeSeriesProvider) : inherits TimeSeriesProvider.
- __getitem__(item)
Return the time series from all TS providers.
This getter is called by a TimeSeriesDataset and returns a 3-dimensional array containing time series from all providers. The time series is select from the TS index item which is a dictionary containing at least 1 key/value pair : {‘timeSerieID’: <timeSerieID_value>}.
- Args:
item (dict): time series index.
- Returns:
(array): array containing the time series from all TS providers.
- __len__()
Return the size of the provider.
- Returns:
(int): number of layers (depth).
- class malpolon.data.datasets.geolifeclef2023.CSVTimeSeriesProvider(ts_data_path: str, normalize: bool = False, ts_id: str = 'timeSerieID', features_col: list = [], eos_replace_value: Union[int, float] = -1, transform: Optional[Iterable[Callable]] = None)
Bases:
TimeSeriesProvider
Implement TimeSeriesProvider for .csv time series.
Only loads time series from a single .csv file. If the time series of an occurrence is smaller than the longest one in the .csv file, the remaining columns are filled with the ‘eos’ string.
- Args:
(TimeSeriesProvider) : inherits TimeSeriesProvider.
- __getitem__(item)
Return a time series.
This getter returns a time series occurrence in tensor fashionned array, based on the item index which is a dictionnary containing at least 1 key/value pair : {‘timeSeriesID’: <timeSeriesID_value>}.
- Args:
item (dict): time series index.
- Returns:
(array): time series occurrence.
- __len__()
Return the size of the time series.
- Returns:
(int): length of biggest sequence.
- class malpolon.data.datasets.geolifeclef2023.MultipleCSVTimeSeriesProvider(root_path: str, select: list = [], normalize: bool = False, ts_id: str = 'timeSerieID', features_col: list = [], eos_replace_value: Union[int, float] = -1, transform: Optional[Iterable[Callable]] = None)
Bases:
TimeSeriesProvider
Like CSVTimeSeriesProvider but with several .csv files.
- Args:
(TimeSeriesProvider) : inherits TimeSeriesProvider
- __getitem__(item)
Return multiple time series.
Like CSVTimeSeriesProvider but concatenantes the time series in a tensor fashioned 3-dimensional array to provide a multi-modal time series array. This array is of the shape : (1, n_modalities, max_sequence_length)
- Args:
item (_type_): _description_
- Returns:
_type_: _description_
- __len__()
Return the size of the provider.
- Returns:
(int): number of layers (depth).
malpolon.data.datasets.geolifeclef2024
GeoLifeCLEF2024: datasets
- class malpolon.data.datasets.geolifeclef2024.PatchesDataset(occurrences: str, providers: Iterable, transform: Optional[Callable] = None, transform_target: Optional[Callable] = None, id_name: str = 'surveyId', labels_name: str = 'speciesId', item_columns: Iterable = ['lat', 'lon', 'surveyId'], split: Optional[str] = None, **kwargs)
Bases:
Dataset
Patches dataset class.
This class provides a PyTorch-friendly dataset to handle patch data and labels. The data can be .jpeg or .tif files of variable depth (i.e. multi-spectral data). Each __getitem__ call returns a patch extracted from the dataset.
- Args:
Dataset (Dataset): PyTorch Dataset class.
- __getitem__(index)
Return a dataset element.
Returns an element from a dataset id (0 to n) with its label.
- Args:
index (int): dataset id.
- Returns:
(tuple): tuple of data patch (tensor) and label (int).
- __len__()
Return the size of the dataset.
- Returns:
(int): number of occurrences.
- class malpolon.data.datasets.geolifeclef2024.PatchesDatasetMultiLabel(occurrences: str, providers: Iterable, n_classes: Union[int, str] = 'max', id_getitem: str = 'surveyId', **kwargs)
Bases:
PatchesDataset
Multilabel patches dataset.
Like PatchesDataset but provides one-hot encoded labels.
- Args:
PatchesDataset (PatchesDataset): pytorch friendly dataset class.
- __getitem__(index)
Return a dataset element.
Returns an element from a dataset id (0 to n) with the labels in one-hot encoding.
- Args:
index (int): dataset id.
- Returns:
(tuple): tuple of data patch (tensor) and labels (list).
- __len__()
Return the size of the dataset.
- Returns:
(int): number of occurrences.
- class malpolon.data.datasets.geolifeclef2024.TimeSeriesDataset(occurrences: str, providers: Iterable, transform: Optional[Callable] = None, transform_target: Optional[Callable] = None, id_name: str = 'surveyId', labels_name: str = 'speciesId', item_columns: Iterable = ['timeSerieID'])
Bases:
Dataset
Time series dataset.
Like PatchesDataset but adapted to time series data which is formatted as .csv files where each row is an occurrence and each column is a timestamp. Timeseries present one additional dimension compared to visual patches. Timeseries do not all have the same size and have a parametrable no_data_value argument.
- Args:
Dataset (Dataset): pytorch Dataset class.
- __getitem__(index)
Return a time series.
Returns a tuple of time series data and its label. The data in in the following shape : [1, n_bands, max_length_ts]
- Args:
index (int): dataset id.
- Returns:
(tuple): tuple of time series data (tensor) and labels (list).
- __len__()
Return the size of the dataset.
- Returns:
(int): number of occurrences.
GeoLifeCLEF2024: providers
- class malpolon.data.datasets.geolifeclef2024.PatchProvider(size: int = 128, normalize: bool = False)
Bases:
object
Parent class for all GLC23 patch data providers.
This class implements common implemented & abstract methods used by all patch providers. Particularly, the plot method is designed to accommodate all cases of the GLC23 patch datasets.
- abstract __getitem__(item)
Return a patch (parent class).
- __len__()
Return the size of the provider.
- Returns:
(int): number of layers (depth).
- class malpolon.data.datasets.geolifeclef2024.MetaPatchProvider(providers: Iterable[Callable], transform: Optional[Callable] = None)
Bases:
PatchProvider
Parent class for patch providers.
This class interfaces patch providers with patch datasets.
- Args:
(PatchProvider): inherits PatchProvider.
- __getitem__(item)
Return a patch.
This getter is used by a patch dataset class and calls each provider’s getter method to return concatenated patches.
- Args:
item (dict): provider index.
- Returns:
(array): concatenaned patch from all providers.
- __len__()
Return the size of the provider.
- Returns:
(int): number of layers (depth).
- class malpolon.data.datasets.geolifeclef2024.RasterPatchProvider(raster_path: str, size: int = 128, spatial_noise: int = 0, normalize: bool = True, fill_zero_if_error: bool = False, nan_value: Union[int, float] = 0)
Bases:
PatchProvider
Patch provider for .tif raster files.
This class handles rasters stored as .tif files and returns patches from them based on a dict key containing (lon, lat) coordinates.
- Args:
(PatchProvider): inherits PatchProvider.
- __getitem__(item)
Return a patch from a .tif raster.
This getter returns a patch of size self.size from a .tif raster loaded in self.data using GPS coordinates projected in EPSG:4326.
- Args:
- item (dict): dictionary containing at least latitude and longitude
keys ({‘lat’: lat, ‘lon’:lon})
- Returns:
(array): the environmental vector (size>1 or size=1).
- __len__()
Return the size of the provider.
- Returns:
(int): number of layers (depth).
- class malpolon.data.datasets.geolifeclef2024.MultipleRasterPatchProvider(rasters_folder: str, select: Optional[Iterable[str]] = None, **kwargs)
Bases:
PatchProvider
Patch provider for multiple PatchProvider.
This provider is useful when having to load several patch modalities through RasterPatchProvider by selecting to desired data in the ‘select’ argument of the constructor.
- Args:
PatchProvider (_type_): _description_
- __getitem__(item)
Return multiple patches.
Returns multiple patches from multiple raster providers in a 3-dimensional numpy array.
- Args:
item (dict): providers index.
- Returns:
(array): array of patches.
- __len__()
Return the size of the provider.
- Returns:
(int): number of layers (depth).
- class malpolon.data.datasets.geolifeclef2024.JpegPatchProvider(root_path: str, select: Optional[Iterable[str]] = None, normalize: bool = False, transform: Optional[Callable] = None, size: int = 128, dataset_stats: str = 'jpeg_patches_stats.csv', id_getitem: str = 'surveyId')
Bases:
PatchProvider
JPEG patches provider for GLC23.
Provides tensors of multi-modal patches from JPEG patch files of rasters of the GLC23 challenge.
Image patches are expected to be named by a patch ID and arranged in folders and sub-folders in the following way: root_path/YZ/WX/patch_id.jpeg with patch_id being the value ABCDWXYZ.
- Attributes:
(PatchProvider): inherits PatchProvider.
- __getitem__(item)
Return a tensor composed of every channels of a jpeg patch.
Looks for every spectral bands listed in self.channels and returns a 3-dimensionnal patch concatenated in 1 tensor. The index used to query the right patch is a dictionnary with at least one key/value pair : {‘surveyId’, <surveyId_value>}.
- Args:
- item (dict): dictionnary containing the surveyId necessary to
identify the jpeg patch to return.
- Raises:
KeyError: the ‘surveyId’ key is missing from item Exception: item is not a dictionnary as expected
- Returns:
(tensor): multi-channel patch tensor.
- __len__()
Return the size of the provider.
- Returns:
(int): number of layers (depth).
- class malpolon.data.datasets.geolifeclef2024.TimeSeriesProvider(root_path: str, eos_replace_value: Union[int, float] = -1, transform: Optional[Iterable[Callable]] = None)
Bases:
object
Provide time series data.
This provider is the parent class of time series providers. It handles time series data stored as .csv files where each file has values for a single spectral band (red, green, infra-red etc…).
- abstract __getitem__(item)
Return a time series (parent class).
- __len__()
Return the size of the provider.
- Returns:
(int): number of layers (depth).
- class malpolon.data.datasets.geolifeclef2024.MetaTimeSeriesProvider(providers: Iterable[Callable], transform: Optional[Iterable[Callable]] = None)
Bases:
TimeSeriesProvider
Time Series provider called by TimeSeriesDataset to handle TS providers.
This TS provider handles all TS providers passed in TimeSeriesDataset to provide multi-modal time series objects.
- Args:
(TimeSeriesProvider) : inherits TimeSeriesProvider.
- __getitem__(item)
Return the time series from all TS providers.
This getter is called by a TimeSeriesDataset and returns a 3-dimensional array containing time series from all providers. The time series is select from the TS index item which is a dictionary containing at least 1 key/value pair : {‘timeSerieID’: <timeSerieID_value>}.
- Args:
item (dict): time series index.
- Returns:
(array): array containing the time series from all TS providers.
- __len__()
Return the size of the provider.
- Returns:
(int): number of layers (depth).
- class malpolon.data.datasets.geolifeclef2024.CSVTimeSeriesProvider(ts_data_path: str, normalize: bool = False, ts_id: str = 'timeSerieID', features_col: list = [], eos_replace_value: Union[int, float] = -1, transform: Optional[Iterable[Callable]] = None)
Bases:
TimeSeriesProvider
Implement TimeSeriesProvider for .csv time series.
Only loads time series from a single .csv file. If the time series of an occurrence is smaller than the longest one in the .csv file, the remaining columns are filled with the ‘eos’ string.
- Args:
(TimeSeriesProvider) : inherits TimeSeriesProvider.
- __getitem__(item)
Return a time series.
This getter returns a time series occurrence in tensor fashionned array, based on the item index which is a dictionnary containing at least 1 key/value pair : {‘timeSeriesID’: <timeSeriesID_value>}.
- Args:
item (dict): time series index.
- Returns:
(array): time series occurrence.
- __len__()
Return the size of the time series.
- Returns:
(int): length of biggest sequence.
- class malpolon.data.datasets.geolifeclef2024.MultipleCSVTimeSeriesProvider(root_path: str, select: list = [], normalize: bool = False, ts_id: str = 'timeSerieID', features_col: list = [], eos_replace_value: Union[int, float] = -1, transform: Optional[Iterable[Callable]] = None)
Bases:
TimeSeriesProvider
Like CSVTimeSeriesProvider but with several .csv files.
- Args:
(TimeSeriesProvider) : inherits TimeSeriesProvider
- __getitem__(item)
Return multiple time series.
Like CSVTimeSeriesProvider but concatenantes the time series in a tensor fashioned 3-dimensional array to provide a multi-modal time series array. This array is of the shape : (1, n_modalities, max_sequence_length)
- Args:
item (_type_): _description_
- Returns:
_type_: _description_
- __len__()
Return the size of the provider.
- Returns:
(int): number of layers (depth).
malpolon.data.datasets.geolifeclef2024_pre_extracted
- malpolon.data.datasets.geolifeclef2024_pre_extracted.construct_patch_path(data_path, survey_id)
Construct the patch file path.
File path is reconstructed based on plot_id as ‘./CD/AB/XXXXABCD.jpeg’.
- Parameters:
data_path (str) – root path
survey_id (int) – observation id
- Returns:
patch path
- Return type:
(str)
- malpolon.data.datasets.geolifeclef2024_pre_extracted.load_landsat(path, transform=None)
Load Landsat pre-extracted time series data.
Loads pre-extracted time series data from Landsat satellite time series, stored as torch tensors.
- Parameters:
path (str) – path to data cube
transform (callable, optional) – data transform, by default None
- Returns:
numpy array of loaded transformed data
- Return type:
(array)
- malpolon.data.datasets.geolifeclef2024_pre_extracted.load_bioclim(path, transform=None)
Load Bioclim pre-extracted time series data.
Loads pre-extracted time series data from bioclim environmental time series, stored as torch tensors.
- Parameters:
path (str) – path to data cube
transform (callable, optional) – data transform, by default None
- Returns:
numpy array of loaded transformed data
- Return type:
(array)
- malpolon.data.datasets.geolifeclef2024_pre_extracted.load_sentinel(path, survey_id, transform=None)
Load Sentinel-2A pre-extracted patch data.
Loads pre-extracted data from Sentinel-2A satellite image patches, stored as image patches.
- Parameters:
path (str) – path to data cube
survey_id (str) – observation id which identifies the patch to load
transform (callable, optional) – data transform, by default None
- Returns:
numpy array of loaded transformed data
- Return type:
(array)
GeoLifeCLEF2024 pre-extracted: datamodules
- class malpolon.data.datasets.geolifeclef2024_pre_extracted.GLC24Datamodule(data_paths: dict, metadata_paths: dict, num_classes: int, train_batch_size: int = 64, inference_batch_size: int = 16, num_workers: int = 16, sampler: Optional[Callable] = None, dataset_kwargs: dict = {}, download_data: bool = False, task: str = 'classification_multilabel', **kwargs)
Bases:
BaseDataModule
Data module for GeoLifeCLEF 2024 dataset.
- __init__(data_paths: dict, metadata_paths: dict, num_classes: int, train_batch_size: int = 64, inference_batch_size: int = 16, num_workers: int = 16, sampler: Optional[Callable] = None, dataset_kwargs: dict = {}, download_data: bool = False, task: str = 'classification_multilabel', **kwargs)
Class constructor.
- Parameters:
data_paths (dict) – a 2-level dictionary containing data paths. 1st level keys: “train” and “test”, each containing another dictionary with keys: “landsat_data_dir”, “bioclim_data_dir”, “sentinel_data_dir” and values: the corresponding data paths as strings.
metadata_paths (dict) – a dictionary containing the paths to the observations (or “metadata”) as values for keys “train”, “test”, “val”
num_classes (int) – number of classes to train on.
train_batch_size (int, optional) – training batch size, by default 64
inference_batch_size (int, optional) – inference batch size, by default 16
num_workers (int, optional) – number of PyTorch workers, by default 16
sampler (Callable, optional) – dataloader sampler to use, by default None (standard iteration)
dataset_kwargs (dict, optional) – additional keyword arguments to pass to the dataset, by default {}
download_data (bool, optional) – if true, will offer to download the pre-extracted data from Seafile, by default False
task (str, optional) – Task to perform. Can take values in [‘classification_multiclass’, ‘classification_multilabel’], by default ‘classification_multilabel’
- download()
Download the GeolifeClef2024 dataset.
- get_dataset(split: str, transform: Callable, **kwargs)
Dataset getter.
- Parameters:
split (str) – dataset split to get, can take values in [‘train’, ‘val’, ‘test’]
transform (Callable) – transformfunctions to apply to the data
- Returns:
dataset class to return
- Return type:
Union[TrainDataset, TestDataset]
- predict_dataloader() DataLoader
- test_dataloader() DataLoader
- property test_transform
Return the test transform functions for each data modality.
The normalization values are computed from the test dataset (pre-extracted values) for each modality.
- Returns:
dictionary of transform functions for each data modality.
- Return type:
(dict)
- train_dataloader() DataLoader
- property train_transform
Return the training transform functions for each data modality.
The normalization values are computed from the training dataset (pre-extracted values) for each modality.
- Returns:
dictionary of transform functions for each data modality.
- Return type:
(dict)
- val_dataloader() DataLoader
- class malpolon.data.datasets.geolifeclef2024_pre_extracted.GLC24DatamoduleHabitats(**kwargs)
Bases:
GLC24Datamodule
GLC24 pre-extracted datamodule for habitat classification.
- Parameters:
GLC24Datamodule. (Inherits) –
- download()
- get_dataset(split, transform, **kwargs)
GeoLifeCLEF2024 pre-extracted: datasets
- class malpolon.data.datasets.geolifeclef2024_pre_extracted.TrainDataset(metadata: DataFrame, num_classes: int = 11255, bioclim_data_dir: Optional[str] = None, landsat_data_dir: Optional[str] = None, sentinel_data_dir: Optional[str] = None, transform: Optional[Callable] = None, task: str = 'classification_multilabel', **kwargs)
Bases:
Dataset
Train dataset with training transform functions.
Inherits Dataset.
- Returns:
tuple of data samples (landsat, bioclim, sentinel), label tensor (speciesId) and surveyId
- Return type:
(tuple)
- __getitem__(idx)
- __len__()
- class malpolon.data.datasets.geolifeclef2024_pre_extracted.TestDataset(metadata: DataFrame, num_classes: int = 11255, bioclim_data_dir: Optional[str] = None, landsat_data_dir: Optional[str] = None, sentinel_data_dir: Optional[str] = None, transform: Optional[Callable] = None, task: str = 'classification_multilabel')
Bases:
TrainDataset
Test dataset with test transform functions.
Inherits TrainDataset.
- Parameters:
TrainDataset (Dataset) – inherits TrainDataset attributes and __len__() method
- __getitem__(idx)
- __len__()
- class malpolon.data.datasets.geolifeclef2024_pre_extracted.TrainDatasetHabitat(metadata, classes, bioclim_data_dir=None, landsat_data_dir=None, sentinel_data_dir=None, transform=None, task='classification_multilabel')
Bases:
TrainDataset
GLC24 pre-extracted train dataset for habitat classification.
- Parameters:
TrainDataset. (Inherits) –
- __getitem__(idx)
- __len__()
- class malpolon.data.datasets.geolifeclef2024_pre_extracted.TestDatasetHabitat(metadata, classes, bioclim_data_dir=None, landsat_data_dir=None, sentinel_data_dir=None, transform=None, task='classification_multilabel')
Bases:
TestDataset
GLC24 pre-extracted test dataset for habitat classification.
- Parameters:
TestDataset. (Inherits) –
- __getitem__(idx)
- __len__()
malpolon.plot
malpolon.plot.history
Utilities used for plotting purposes.
Author: Titouan Lorieul <titouan.lorieul@gmail.com>
- malpolon.plot.history.escape_tex(s: str) str
Escape special characters for LaTeX rendering.
- malpolon.plot.history.plot_history(df_metrics: pd.DataFrame, *, fig: Optional[plt.Figure] = None, axes: Optional[list[plt.Axis]] = None) tuple[plt.Figure, list[plt.Axis]]
Plot model training history.
- Parameters:
df_metrics (pd.DataFrame containing metrics monitored during training) –
fig (plt.Figure to use for plotting) –
axes (list of plt.Axis to use for plotting) –
- Returns:
fig (plt.Figure used for plotting)
axes (list of plt.Axis used for plotting)
- malpolon.plot.history.plot_metric(df_metrics: pd.DataFrame, metric: str, ax: plt.Axis) plt.Axis
Plot specific metric monitored during model training history.
- Parameters:
df_metrics (pd.DataFrame containing metrics monitored during training) –
metrics (name of the metric to plot) –
ax (plt.Axis to use for plotting) –
- Returns:
ax
- Return type:
plt.Axis used for plotting
malpolon.plot.map
Utilities for plotting maps.
Author: Titouan Lorieul <titouan.lorieul@gmail.com>
- malpolon.plot.map.plot_map(*, region: Optional[str] = None, extent: Optional[npt.ArrayLike] = None, ax: Optional[plt.Axes] = None) plt.Axes
Plot a map on which to show the observations.
- Parameters:
region (string, either "fr" or "us") – Region to show, France or US.
extent (array-like of form [longitude min, longitude max, latitude min, latitude max]) – Explicit extent of the area to show, e.g., for zooming.
ax (plt.Axes) – Provide an Axes to use instead of creating one.
- Returns:
Returns the used Axes.
- Return type:
plt.Axes
- malpolon.plot.map.plot_observation_dataset(*, df: DataFrame, obs_data_columns: dict = {'index': 'surveyId', 'species_id': 'speciesId', 'split': 'subset', 'x': 'lon', 'y': 'lat'}, show_map: bool = False) Axes
Plot observations on a map from an observation dataset.
This method expects a pandas DataFrame with columns containing coordinates, species ids, and dataset split information (‘train’, ‘test’ or ‘val’). Users can specify the names of the columns containing these informations if they do not match the default names.
- Parameters:
df (pd.DataFrame) – observation dataset
obs_data_columns (_type_, optional) – dictionary matching custom dataframe keys with necessary keys, by default {‘x’: ‘lon’, ‘y’: ‘lat’, ‘index’: ‘surveyId’, ‘species_id’: ‘speciesId’, ‘split’: ‘subset’}
show_map (bool, optional) – if True, displays the map, by default False
- Returns:
map’s ax object
- Return type:
plt.Axes
- malpolon.plot.map.plot_observation_map(*, longitudes: npt.ArrayLike, latitudes: npt.ArrayLike, ax: Optional[plt.Axes] = None, **kwargs) plt.Axes
Plot observations on a map.
- Parameters:
longitude (array-like) – Longitudes of the observations.
latitude (array-like) – Latitudes of the observations.
ax (plt.Axes) – Provide an Axes to use instead of creating one.
kwargs – Additional arguments to pass to plt.scatter.
- Returns:
Returns the used Axes.
- Return type:
plt.Axes
malpolon.logging
- class malpolon.logging.Summary
Bases:
Callback
Log model summary at the beginning of training.
FIXME handle multi validation data loaders, combined datasets
- on_train_start(trainer: pl.Trainer, pl_module: GenericPredictionSystem) None
- malpolon.logging.str_object(obj: Any) str
Format an object to string.
Formats an object to printing by returning a string containing the class name and attributes (both name and values)
- Parameters:
obj (object to print.) –
- Returns:
str
- Return type:
string containing class name and attributes.
malpolon.check_install
This module checks the installation of PyTorch and GPU libraries.
Author: Titouan Lorieul <titouan.lorieul@gmail.com>
- malpolon.check_install.print_cuda_info()
Print information about the CUDA/PyTorch installation.