API documentation

malpolon.models

malpolon.models.model_builder

This module provides classes to build your PyTorch models.

Classes listed in this module allow to select a model from your provider (timm, torchvision…), retrieve it with or without pre-trained weights, and modify it by adding or removing layers.

Author: Titouan Lorieul <titouan.lorieul@gmail.com>

Theo Larcher <theo.larcher@inria.fr>

malpolon.models.model_builder.change_first_convolutional_layer_modifier(model: nn.Module, num_input_channels: int, new_conv_layer_init_func: Optional[Callable[[nn.Conv2d, nn.Conv2d], None]] = None) nn.Module

Remove the first registered convolutional layer of a model and replaces it by a new convolutional layer with the provided number of input channels.

Parameters:
  • model (torch.nn.Module) – Model to adapt.

  • num_input_channels (integer) – Number of input channels, used to update the first convolutional layer.

  • new_conv_layer_init_func (callable) – Function defining how to initialize the new convolutional layer.

Returns:

model – Newly created last dense classification layer.

Return type:

torch.nn.Module

malpolon.models.model_builder.change_last_layer_modifier(model: Module, num_outputs: int, flatten: bool = False) Module

Remove the last registered linear layer of a model and replaces it by a new dense layer with the provided number of outputs.

Parameters:
  • model (torch.nn.Module) – Model to adapt.

  • num_outputs (integer) – Number of outputs of the new output layer.

  • flatten (boolean) – If True, adds a nn.Flatten layer to squeeze the last dimension. Can be useful when num_outputs=1.

Returns:

model – Reference to model object given in input.

Return type:

torch.nn.Module

malpolon.models.model_builder.change_last_layer_to_identity_modifier(model: Module) Module

Remove the last linear layer of a model and replaces it by an nn.Identity layer.

Parameters:

model (torch.nn.Module) – Model to adapt.

Returns:

num_features – Size of the feature space.

Return type:

int

malpolon.models.model_builder.timm_model_provider(model_name: str, *model_args: Any, **model_kwargs: Any) nn.Module

Return a model from timm’s library.

This method uses timm’s API to retrieve a model from its library.

Parameters:

model_name (str) – name of the model to retrieve from timm’s library

Returns:

model object

Return type:

nn.Module

Raises:

ValueError – if the model name is not listed in TIMM’s library

malpolon.models.model_builder.torchvision_model_provider(model_name: str, *model_args: Any, **model_kwargs: Any) nn.Module

Return a model from torchvision’s library.

This method uses tochvision’s API to retrieve a model from its library.

Parameters:

model_name (str) – name of the model to retrieve from torchvision’s library

Returns:

model object

Return type:

nn.Module

malpolon.models.standard_prediction_systems

This module provides classes wrapping pytorchlightning training modules.

Author: Titouan Lorieul <titouan.lorieul@gmail.com>

Theo Larcher <theo.larcher@inria.fr>

class malpolon.models.standard_prediction_systems.ClassificationSystem(model: Union[torch.nn.Module, Mapping], lr: float = 0.01, weight_decay: float = 0, momentum: float = 0.9, nesterov: bool = True, metrics: Optional[dict[str, Callable]] = None, task: str = 'classification_binary', hparams_preprocess: bool = True)

Bases: GenericPredictionSystem

Classification task class.

__init__(model: Union[torch.nn.Module, Mapping], lr: float = 0.01, weight_decay: float = 0, momentum: float = 0.9, nesterov: bool = True, metrics: Optional[dict[str, Callable]] = None, task: str = 'classification_binary', hparams_preprocess: bool = True)

Class constructor.

Parameters:
  • model (dict) – model to use

  • lr (float) – learning rate

  • weight_decay (float) – weight decay

  • momentum (float) – value of momentum

  • nesterov (bool) – if True, uses Nesterov’s momentum

  • metrics (dict) – dictionnary containing the metrics to compute. Keys must match metrics’ names and have a subkey with each metric’s functional methods as value. This subkey is either created from the malpolon.models.utils.FMETRICS_CALLABLES constant or supplied, by the user directly.

  • task (str, optional) – Machine learning task (used to format labels accordingly), by default ‘classification_multiclass’. The value determines the loss to be selected. if ‘multilabel’ or ‘binary’ is in the task, the BCEWithLogitsLoss is selected, otherwise the CrossEntropyLoss is used.

  • hparams_preprocess (bool, optional) – if True performs preprocessing operations on the hyperparameters, by default True

class malpolon.models.standard_prediction_systems.GenericPredictionSystem(model: Union[torch.nn.Module, Mapping], loss: torch.nn.modules.loss._Loss, optimizer: torch.optim.Optimizer, metrics: Optional[dict[str, Callable]] = None, save_hyperparameters: Optional[bool] = True)

Bases: LightningModule

Generic prediction system providing standard methods.

Parameters:
  • model (torch.nn.Module) – Model to use.

  • loss (torch.nn.modules.loss._Loss) – Loss used to fit the model.

  • optimizer (torch.optim.Optimizer) – Optimization algorithm used to train the model.

  • metrics (dict) – Dictionary containing the metrics to monitor during the training and to compute at test time.

configure_optimizers() Optimizer
forward(x: Any) Any
predict(datamodule, trainer)

Predict a model’s output on the test dataset.

This method performs inference on the test dataset using only pytorchlightning tools.

Parameters:
  • datamodule (pl.LightningDataModule) – pytorchlightning datamodule handling all train/test/val datasets.

  • trainer (pl.Trainer) – pytorchlightning trainer in charge of running the model on train and inference mode.

Returns:

Predicted tensor values.

Return type:

array

predict_point(checkpoint_path: str, data: Union[Tensor, tuple[Any, Any]], state_dict_replace_key: Optional[list[str, str]] = None, ckpt_transform: Callable = None)

Predict a model’s output on 1 data point.

Performs as predict() but for a single data point and using native pytorch tools.

Parameters:
  • checkpoint_path (str) – path to the model’s checkpoint to load.

  • data (Union[Tensor, tuple[Any, Any]]) – data point to perform inference on.

  • state_dict_replace_key (Optional[list[str, str]], optional) – list of values used to call the static method state_dict_replace_key(). Defaults to None.

  • ckpt_transform (Callable, optional) – callable function applied to the loaded checkpoint object. Use this to modify the structure of the loaded model’s checkpoint on the fly. Defaults to None.

Returns:

Predicted tensor value.

Return type:

array

predict_step(batch, batch_idx, dataloader_idx=0)
static state_dict_replace_key(state_dict: dict, replace: Optional[list[str]] = ['.', ''])

Replace keys in a state_dict dictionnary.

A state_dict usually is an OrderedDict where the keys are the model’s module names. This method allows to change these names by replacing a given string, or token, by another.

Parameters:
  • state_dict (dict) – Model state_dict

  • replace (Optional[list[str]], optional) – Tokens to replace in the state_dict module names. The first element is the token to look for while the second is the replacement value. By default [‘.’, ‘’].

Returns:

State_dict with new keys.

Return type:

dict

Examples

I have loaded a Resnet18 model through a checkpoint after a training session. But the names of the model modules have been altered with a prefix “model.”:

>>> sd = model.state_dict()
>>> print(list(sd)[:2])
(model.conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(model.bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

To remove this prefix:

>>> sd = GenericPredictionSystem.state_dict_replace_key(sd, ['model.', ''])
>>> print(sd[:2])
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
test_step(batch: tuple[Any, Any], batch_idx: int) Union[Tensor, dict[str, Any]]
training_step(batch: tuple[Any, Any], batch_idx: int) Union[Tensor, dict[str, Any]]
validation_step(batch: tuple[Any, Any], batch_idx: int) Union[Tensor, dict[str, Any]]

malpolon.models.multi_modal

This module provides classes for advanced model building.

Author: Titouan Lorieul <titouan.lorieul@gmail.com>

Theo Larcher <theo.larcher@inria.fr>

class malpolon.models.multi_modal.HomogeneousMultiModalModel(modality_names: list, modalities_model: dict, aggregator_model: Union[nn.Module, Mapping])

Bases: MultiModalModel

Straightforward multi-modal model.

__init__(modality_names: list, modalities_model: dict, aggregator_model: Union[nn.Module, Mapping])

Class constructor.

Parameters:
  • modality_names (list) – list of modalities names

  • modalities_model (dict) – dictionary of modality names and their respective models to pass on to the model builder

  • aggregator_model (Union[nn.Module, Mapping]) – Model strategy to aggregate the features from each modality. Can either be a PyTorch module directly (in this case, the module will be directly called), or a mapping in the same fashion as for buiding the modality models, in which case the model builder will be called again.

class malpolon.models.multi_modal.MultiModalModel(modality_models: Union[nn.Module, Mapping], aggregator_model: Union[nn.Module, Mapping])

Bases: Module

Base multi-modal model.

This class builds an aggregation of multiple models from the passed on config file values, one for each modality, splits the training routine per modality and then aggregates the features from each modality after each forward pass.

__init__(modality_models: Union[nn.Module, Mapping], aggregator_model: Union[nn.Module, Mapping])

Class constructor.

Parameters:
  • modality_models (Union[nn.Module, Mapping]) – dictionary of modality names and their respective models to pass on to the model builder

  • aggregator_model (Union[nn.Module, Mapping]) – Model strategy to aggregate the features from each modality. Can either be a PyTorch module directly (in this case, the module will be directly called), or a mapping in the same fashion as for buiding the modality models, in which case the model builder will be called again.

forward(x: list[Any]) Any
class malpolon.models.multi_modal.ParallelMultiModalModelStrategy(accelerator=None, parallel_devices=None, checkpoint_io=None, precision_plugin=None)

Bases: SingleDeviceStrategy

Model parallelism strategy for multi-modal models.

WARNING: STILL UNDER DEVELOPMENT.

batch_to_device(batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) Any

TODO: Docstring.

model_to_device() None

TODO: Docstring.

strategy_name = 'parallel_multi_modal_model'

malpolon.models.utils

This file compiles useful functions related to models.

Author: Theo Larcher <theo.larcher@inria.fr>

Titouan Lorieul <titouan.lorieul@gmail.com>

class malpolon.models.utils.CrashHandler(trainer)

Bases: object

Saves the model in case of unexpected crash or user interruption.

save_checkpoint()

Save the latest checkpoint.

signal_handler(sig, frame)

Attempt to save the latest checkpoint in case of crash.

malpolon.models.utils.check_loss(loss: _Loss) _Loss

Ensure input loss is a pytorch loss.

Args:

loss (nn.modules.loss._Loss): input loss.

Raises:

ValueError: if input loss isn’t a pytorch loss object.

Returns:

nn.modules.loss._Loss: the pytorch input loss itself.

malpolon.models.utils.check_metric(metrics: OmegaConf) OmegaConf

Ensure user’s model metrics are valid.

Users can either choose from a list of predefined metrics or define their own custom metrics. This function binds the user’s metrics with their corresponding callable function from torchmetrics, by reading the values in metrics which is a dict-like structure returned by hydra when reading the config file. If the user chose predefined metrics, the function will automatically bind the corresponding callable function from torchmetrics. If the user chose custom metrics, the function checks that they also provided the callable function to compute the metric.

Parameters:

metrics (OmegaConf) – user’s input metrics, read from the config file via hydra, in a dict-like structure

Returns:

user’s metrics with their corresponding callable function

Return type:

OmegaConf

malpolon.models.utils.check_model(model: Union[Module, Mapping]) Module

Ensure input model is a pytorch model.

Args:

model (Union[nn.Module, Mapping]): input model.

Raises:

ValueError: if input model isn’t a pytorch model object.

Returns:

nn.Module: the pytorch input model itself.

malpolon.models.utils.check_optimizer(optimizer: Optimizer) Optimizer

Ensure input optimizer is a pytorch optimizer.

Args:

optimizer (optim.Optimizer): input optimizer.

Raises:

ValueError: if input optimizer isn’t a pytorch optimizer object.

Returns:

optim.Optimizer: the pytorch input optimizer itself.

malpolon.data

malpolon.data.data_module

This module provides a base class for data modules.

Author: Theo Larcher <theo.larcher@inria.fr>

Titouan Lorieul <titouan.lorieul@gmail.com>

class malpolon.data.data_module.BaseDataModule(train_batch_size: int = 32, inference_batch_size: int = 256, num_workers: int = 8)

Bases: LightningDataModule, ABC

Base class for data modules.

This class inherits pytorchlightining’s LightningDataModule class and provides a base class for data modules by re-defining steps methods as well as adding new data manipulation methods.

export_predict_csv(predictions: Union[Tensor, np.ndarray], probas: Union[Tensor, np.ndarray] = None, single_point_query: dict = None, out_name: str = 'predictions', out_dir: str = './', return_csv: bool = False, top_k: int = None, **kwargs: Any) Any

Export predictions to csv file.

This method is used to export predictions to a csv file. It can be used with a single point query or with the whole test dataset. This method is adapted for a classification task with an observations file and multi-modal data. Keys in the csv file match the ones used to inistantiate the RasterTorchGeoDataset class, that is to say : observation_id, lon, lat, target_species_id. The crs key is also mandatory in the case of singl-point query.

Parameters:
  • predictions (Union[Tensor, np.ndarray]) – model’s predictions.

  • probas (Union[Tensor, np.ndarray], optional) – predictions’ raw logits or logits passed through an activation function, by default None

  • single_point_query (dict, optional) – query dictionnary of the single-point prediction. ‘target_species_id’ key is mandatory expects a list of numpy arrays of species ids. ‘predictions’ and ‘probas’ keys expect numpy arrays of predictions and probabilities. By default None (whole test dataset predictions)

  • out_name (str, optional) – output CSV file name, by default “predictions”

  • out_dir (str, optional) – output directory name, by default “./”

  • return_csv (bool, optional) – if true, the method returns the CSV as a pandas DataFrame, by default False

  • top_k (int, optional) – number of top predictions to return, by default None (max number of predictions)

Returns:

CSV content as a pandas DataFrame if return_csv is True

Return type:

pandas.DataFrame

export_predict_csv_basic(predictions: Union[Tensor, np.ndarray], targets: Union[np.ndarray, list], probas: Union[Tensor, np.ndarray] = None, ids: Union[np.ndarray, list] = None, out_name: str = 'predictions', out_dir: str = './', return_csv: bool = False, top_k: int = None, **kwargs: Any)

Export predictions to csv file.

Exports predictions, probabilities and ids to a csv file.

Parameters:
  • predictions (Union[Tensor, np.ndarray]) – model’s predictions.

  • targets (Union[np.ndarray, list], optional) – target species ids, by default None

  • probas (Union[Tensor, np.ndarray], optional) – predictions’ raw logits or logits passed through an activation function, by default None

  • ids (Union[np.ndarray, list], optional) – ids of the observations, by default None

  • out_name (str, optional) – output CSV file name, by default “predictions”

  • out_dir (str, optional) – output directory name, by default “./”

  • return_csv (bool, optional) – if true, the method returns the CSV as a pandas DataFrame, by default False

  • top_k (int, optional) – number of top predictions to return, by default None (max number of predictions)

Returns:

CSV content as a pandas DataFrame if return_csv is True

Return type:

pandas.DataFrame

abstract get_dataset(split: str, transform: Callable, **kwargs: Any) Dataset

Return the dataset corresponding to the split.

Parameters:
  • split (str) – Type of dataset. Values must be on of [“train”, “val”, “test”]

  • transform (Callable) – data transforms to apply when loading the dataset

Returns:

dataset corresponding to the split

Return type:

Dataset

get_test_dataset() Dataset

Call self.get_dataset to return the test dataset.

Returns:

test dataset

Return type:

Dataset

get_train_dataset() Dataset

Call self.get_dataset to return the train dataset.

Returns:

train dataset

Return type:

Dataset

get_val_dataset() Dataset

Call self.get_dataset to return the validation dataset.

Returns:

validation dataset

Return type:

Dataset

predict_dataloader() DataLoader

Return predict dataloader instantiated with class attributes.

Returns:

predict dataloader

Return type:

DataLoader

predict_logits_to_class(predictions: Tensor, classes: Union[np.ndarray, Tensor], activation_fn: torch.nn.modules.activation = Softmax(dim=1)) Tensor

Convert the model’s predictions to class labels.

This method applies an activation function to the model’s predictions and returns the corresponding class labels.

Parameters:
  • predictions (Tensor) – model’s predictions (raw logits), by default Softmax(dim=1)

  • classes (Union[np.ndarray, Tensor]) – classes labels

  • activation_fn (torch.nn.modules.activation, optional) – activation function to apply to the model’s predictions, by default torch.nn.Softmax(dim=1)

Returns:

class labels and corresponding probabilities

Return type:

tuple[np.ndarray, np.ndarray]

prepare_data() None

Prepare data.

Called once on CPU. Class states defined here are lost afterwards. This method is intended for data downloading, tokenization, permanent transformation…

setup(stage: Optional[str] = None) None

Register the correct datasets to the class attributes.

Depending on the trainer’s stage, this method will retrieve the train, val or test dataset and register it as a class attribute. The “predict” stage calls for the test dataset.

Parameters:

stage (Optional[str], optional) – trainer’s stage, by default None (train)

test_dataloader() DataLoader

Return test dataloader instantiated with class attributes.

Returns:

test dataloader

Return type:

DataLoader

abstract property test_transform: Callable

Return test data transforms.

Returns:

test data transforms

Return type:

Callable

train_dataloader() DataLoader

Return train dataloader instantiated with class attributes.

Returns:

train dataloader

Return type:

DataLoader

abstract property train_transform: Callable

Return train data transforms.

Returns:

train data transforms

Return type:

Callable

val_dataloader() DataLoader

Return validation dataloader instantiated with class attributes.

Returns:

Validation dataloader

Return type:

DataLoader

malpolon.data.environmental_raster

Custom classes to handle environmental rasters without torchgeo.

Author: Titouan Lorieul <titouan.lorieul@gmail.com>

class malpolon.data.environmental_raster.PatchExtractor(root_path: Union[str, Path], size: int = 256)

Bases: object

Handles the loading and extraction of an environmental tensor from multiple rasters given GPS coordinates.

Parameters:
  • root_path (string or pathlib.Path) – Path to the folder containing all the rasters.

  • size (integer) – Size in pixels (size x size) of the patches to extract around each location.

__getitem__(coordinates: Coordinates) npt.NDArray[np.float32]

Extract the patches around the given GPS coordinates for all the previously loaded rasters.

Parameters:

coordinates (tuple containing two floats) – GPS coordinates (latitude, longitude)

Returns:

patch – Extracted patches around the given coordinates.

Return type:

3d array of floats, [n_rasters, size, size], or 1d array of floats, [n_rasters,], if size == 1

__len__() int

Return the number of variables/rasters loaded.

Returns:

n_rasters – Number of loaded rasters

Return type:

integer

add_all_bioclimatic_rasters(**kwargs: Any) None

Add all bioclimatic variables (rasters) available.

Parameters:

kwargs (dict) – Updates the default arguments passed to Raster (nan, out_of_bounds, etc.)

add_all_pedologic_rasters(**kwargs: Any) None

Add all pedologic variables (rasters) available.

Parameters:

kwargs (dict) – Updates the default arguments passed to Raster (nan, out_of_bounds, etc.)

add_all_rasters(**kwargs: Any) None

Add all variables (rasters) available.

Parameters:

kwargs (dict) – Updates the default arguments passed to Raster (nan, out_of_bounds, etc.)

append(raster_name: str, **kwargs: Any) None

Load and append a single raster to the rasters already loaded.

Can be useful to load only a subset of rasters or to pass configurations specific to each raster.

Parameters:
  • raster_name (string) – Name of the raster to load, should be a subfolder of root_path.

  • kwargs (dict) – Updates the default arguments passed to Raster (nan, out_of_bounds, etc.)

clean() None

Remove all rasters from the extractor.

plot(coordinates: Coordinates, return_fig: bool = False, n_cols: int = 5, fig: Optional[plt.Figure] = None, resolution: float = 1.0) Optional[plt.Figure]

Plot an environmental tensor (only works if size > 1).

Parameters:
  • coordinates (tuple containing two floats) – GPS coordinates (latitude, longitude)

  • return_fig (boolean) – If True, returns the created plt.Figure object

  • n_cols (integer) – Number of columns to use

  • fig (plt.Figure or None) – If not None, use the given plt.Figure object instead of creating a new one

  • resolution (float) – Resolution of the created figure

Returns:

fig – If return_fig is True, the used plt.Figure object

Return type:

plt.Figure

class malpolon.data.environmental_raster.Raster(path: Union[str, Path], country: str, size: int = 256, nan: Optional[float] = nan, out_of_bounds: str = 'error')

Bases: object

Loads a GeoTIFF file and extract patches for a single environmental raster.

Parameters:
  • path (string / pathlib.Path) – Path to the folder containing all the rasters.

  • country (string, either "FR" or "USA") – Which country to load raster from.

  • size (integer) – Size in pixels (size x size) of the patch to extract around each location.

  • nan (float or None) – Value to use to replace missing data in original rasters, if None, leaves default values.

  • out_of_bounds (string, either "error", "warn" or "ignore") – If “error”, raises an exception if the location requested is out of bounds of the rasters. Set to “warn” to only produces a warning and to “ignore” to silently ignore it and return a patch filled with missing data.

__getitem__(coordinates: Coordinates) Patch

Extract the patch around the given GPS coordinates.

Parameters:

coordinates (tuple containing two floats) – GPS coordinates (latitude, longitude)

Returns:

patch – Extracted patch around the given coordinates.

Return type:

2d array of floats, [size, size], or single float if size == 1

__len__() int

Return the number of bands in the raster.

Should always be equal to 1.

Returns:

n_bands – Number of bands in the raster

Return type:

integer

malpolon.data.get_jpeg_patches_stats

Script / module used to compute the mean and std on JPEG files.

When dealing with a large amount of files it should be run only once, and the statistics should be stored in a separate .csv for later use.

Author: Theo Larcher <theo.larcher@inria.fr>

malpolon.data.get_jpeg_patches_stats.standardize(root_path: str = 'sample_data/SatelliteImages/', ext: str = ['jpeg', 'jpg'], output: str = 'root_path')

Perform standardization over images.

Returns and stores the mean and standard deviation of an image dataset organized inside a root directory for computation purposes like deep learning.

Args:
root_path (str): root dir. containing the images.

Defaults to ‘./sample_data/SatelliteImages/’.

ext (str, optional): the images extensions to consider.

Defaults to ‘jpeg’.

output (str, optional): output path where to save the csv containing

the mean and std of the dataset. If None: doesn’t output anything. Defaults to root_path.

Returns:

(tuple): tuple of mean and std fo the jpeg images.

malpolon.data.get_jpeg_patches_stats.standardize_by_parts(fps_fp: str, output: str = 'glc23_stats.csv', max_imgs_per_computation: int = 100000)

Perform standardization over images part by part.

With too many images, memory can overflow. This function addresses this problem by performing the computation in parts. Downside: the computed standard deviation is an mean approximation of the true value.

Args:
fps_fp (str): file path to a text file containing the paths to

the images.

output (str, optional): output path where to save the csv containing

the mean and std of the dataset. If None: doesn’t output anything. Defaults to root_path.

max_imgs_per_computation (int, optional): maximum number of images to hold in memory.

Defaults to 100000.

Returns:

(tuple): tuple of mean and std fo the jpeg images.

malpolon.data.utils

This file compiles useful functions related to data and file handling.

Author: Theo Larcher <theo.larcher@inria.fr>

malpolon.data.utils.get_files_path_recursively(path, *args, suffix='') list

Retrieve specific files path recursively from a directory.

Retrieve the path of all files with one of the given extension names, in the given directory and all its subdirectories, recursively. The extension names should be given as a list of strings. The search for extension names is case sensitive.

Parameters:
  • path (str) – root directory from which to search for files recursively

  • *args (list) – list of file extensions to be considered.

Returns:

subdirectories.

Return type:

list list of paths of every file in the directory and all its

malpolon.data.utils.is_bbox_contained(bbox1: Union[Iterable, BoundingBox], bbox2: Union[Iterable, BoundingBox], method: str = 'shapely') bool

Determine if a 2D bbox in included inside of another.

Returns a boolean answering the question “Is bbox1 contained inside bbox2 ?”. With methods ‘shapely’ and ‘manual’, bounding boxes must follow the format: [xmin, ymin, xmax, ymax]. With method ‘torchgeo’, bounding boxes must be of type: torchgeo.datasets.utils.BoundingBox.

Parameters:
  • bbox1 (Union[Iterable, BoundingBox]) – Bounding box n°1.

  • bbox2 (Union[Iterable, BoundingBox]) – Bounding box n°2.

  • method (str) – Method to use for comparison. Can take any value in [‘shapely’, ‘manual’, ‘torchgeo’], by default ‘shapely’.

Returns:

True if bbox1 ⊂ bbox2, False otherwise.

Return type:

boolean

malpolon.data.utils.is_point_in_bbox(point: Iterable, bbox: Iterable, method: str = 'shapely') bool

Determine if a 2D point in included inside of a 2D bounding box.

Returns a boolean answering the question “Is point contained inside bbox ?”. Point must follow the format: [x, y] Bounding box must follow the format: [xmin, ymin, xmax, ymax]

Parameters:
  • point (Iterable) – Point in the format [x, y].

  • bbox (Iterable) – Bounding box in the format [xmin, xmax, ymin, ymax].

  • method (str) – Method to use for comparison. Can take any value in [‘shapely’, ‘manual’], by default ‘shapely’.

Returns:

True if point ⊂ bbox, False otherwise.

Return type:

boolean

malpolon.data.utils.to_one_hot_encoding(labels_predict: int | list, labels_target: list) list

Return a one-hot encoding of class-index predicted labels.

Converts a single label value or a vector of labels into a vector of one-hot encoded labels. The labels order follow that of input labels_target.

Parameters:
  • labels_predict (int | list) – Labels to convert to one-hot encoding.

  • labels_target (list) – All existing labels, in the right order.

Returns:

One-hot encoded labels.

Return type:

list

malpolon.data.datasets

malpolon.data.datasets.torchgeo_datasets

This module provides raster related classes based on torchgeo.

Author: Theo Larcher <theo.larcher@inria.fr>

class malpolon.data.datasets.torchgeo_datasets.RasterBioclim(root: str = 'data', labels_name: str = None, split: str = None, crs: Any = None, res: float = None, bands: Sequence[str] = None, transform: Callable[..., Any] = None, transform_target: Callable[..., Any] = None, patch_size: int | float | tuple = 256, query_units: str = 'pixel', query_crs: int | str | CRS = 'self', obs_data_columns: Dict = {'index': 'surveyId', 'species_id': 'speciesId', 'split': 'subset', 'x': 'lon', 'y': 'lat'}, task: str = 'multiclass', binary_positive_classes: list = [], cache: TYPE_CHECKING = True, **kwargs)

Bases: RasterTorchGeoDataset

Raster dataset adapted for Sentinel-2 data.

Inherits RasterTorchGeoDataset.

all_bands: list[str] = ['bio_1', 'bio_2', 'bio_3', 'bio_4']

Names of all available bands in the dataset

date_format = '%Y%m%dT%H%M%S'

Date format string used to parse date from filename.

Not used if filename_regex does not contain a date group.

filename_glob = 'bio_*.tif'

Glob expression used to search for files.

This expression should be specific enough that it will not pick up files from other datasets. It should not include a file extension, as the dataset may be in a different file format than what it was originally downloaded as.

filename_regex = '(?P<band>bio_[\\d])'

Regular expression used to extract date from filename.

The expression should use named groups. The expression may contain any number of groups. The following groups are specifically searched for by the base class:

  • date: used to calculate mint and maxt for index insertion

When separate_files is True, the following additional groups are searched for to find other files:

  • band: replaced with requested band name

is_image = True

True if dataset contains imagery, False if dataset contains mask

plot(sample: Patches)

Plot all layers of a given patch.

A patch is selected based on a key matching the associated provider’s __get__() method.

Args:

item (dict): provider’s get index.

plot_bands = 'all_bands'
separate_files = True

True if data is stored in a separate file for each band, else False.

class malpolon.data.datasets.torchgeo_datasets.RasterTorchGeoDataset(root: str = 'data', labels_name: Optional[str] = None, split: Optional[str] = None, crs: Optional[Any] = None, res: Optional[float] = None, bands: Optional[Sequence[str]] = None, transform: Optional[Callable] = None, transform_target: Optional[Callable] = None, patch_size: Union[int, float, tuple] = 256, query_units: str = 'pixel', query_crs: Union[int, str, CRS] = 'self', obs_data_columns: dict = {'index': 'surveyId', 'species_id': 'speciesId', 'split': 'subset', 'x': 'lon', 'y': 'lat'}, task: str = 'multiclass', binary_positive_classes: list = [], cache: bool = True)

Bases: RasterDataset

Generic torchgeo based raster datasets.

Datasets based on this class return patches from raster files and can be queried by either a torchgeo BoundingBox object, a tuple of coordinates in the dataset’s CRS or a dictionary specifying coordinates and the wanted CRS. Additionally one can specify the desired size and units of the wanted patch even if they don’t match the dataset’s.

RasterTorchGeoDataset inherits torchgeo’s RasterDataset class.

__getitem__(query: Union[int, dict, tuple, list, set, BoundingBox]) Dict[str, Any]

Query an item from the dataset.

Supports querying the dataset with coordinates in the dataset’s CRS or in another CRS. The dataset is always queried with a torchgeo BoundingBox because it is itself a torchgeo dataset, but the query in this getter method can be passed as a tuple, list, set, dict or BoundingBox.

  • Use case 1: query is a [list, tuple, set] of 2 elements : lon, lat. Here the CRS and Units system are by default those of the dataset’s.

  • Use case 2: query is a torchgeo BoundingBox. Here the CRS and Units system are by default those of the dataset’s.

  • Use case 3: query is a dict containing the following necessary keys: {‘lon’, ‘lat’}, and optional keys: {‘crs’, ‘units’, ‘size’} which values default to those of the dataset’s.

In Use case 3, if the ‘crs’ key is registered and it is different from the dataset’s CRS, the coordinates of the point are projected into the dataset’s CRS and the value of the key is overwritten by said CRS.

Use cases 1 and 3 give the possibility to easily query the dataset using only a point and a bounding box (bbox) size, using the desired input CRS.

The unit of measurement of the bbox can be set to [‘m’, ‘meters’, ‘metres’] even if the dataset’s unit is different as the points will be projected in the nearest meter-based CRS (see self.point_to_bbox()). Note that depending on your dataset’s CRS, querying a meter-based bbox may result in rectangular patches because of deformations.

Parameters:

query (Union[dict, tuple, BoundingBox]) –

item query containing geographical coordinates. It can be of different types for different use. One can query a patch by providing a BoundingBox using torchgeo.datasets.BoundingBox constructor; or by given a center and a size.

— BoundingBox strategy —

Must follow : BoundingBox(minx, maxx, miny, maxy, mint, maxt)

— Point strategy —

If tuple, must follow : (lon, lat) and the CRS of the coordinates will be assumed to be the dataset’s. If dict, must follow : {‘lon’: lon, ‘lat’: lat, <’crs’: crs>} and the coordinates CRS can be specified. If not, it will be assumed that it is equal to the dataset’s. In both cases, a BoundingBox is generated to pursue the query.

Returns:

dataset patch.

Return type:

Dict[str, Any]

__init__(root: str = 'data', labels_name: Optional[str] = None, split: Optional[str] = None, crs: Optional[Any] = None, res: Optional[float] = None, bands: Optional[Sequence[str]] = None, transform: Optional[Callable] = None, transform_target: Optional[Callable] = None, patch_size: Union[int, float, tuple] = 256, query_units: str = 'pixel', query_crs: Union[int, str, CRS] = 'self', obs_data_columns: dict = {'index': 'surveyId', 'species_id': 'speciesId', 'split': 'subset', 'x': 'lon', 'y': 'lat'}, task: str = 'multiclass', binary_positive_classes: list = [], cache: bool = True) None

Class constructor.

Parameters:
  • root (str, optional) – path to the directory containing the data and labels, by default “data”

  • labels_name (str, optional) – labels file name, by default None

  • split (str, optional) – dataset subset desired for labels selection, by default None

  • crs (Any | None, optional) – coordinate reference system (CRS) to warp to (defaults to the CRS of the first file found), by default None

  • res (float | None, optional) – resolution of the dataset in units of CRS (defaults to the resolution of the first file found), by default None

  • bands (Sequence[str] | None, optional) – bands to return (defaults to all bands), by default None

  • transform (Callable | None, optional) – a callable function that takes an input sample and returns a transformed version, by default None

  • transform_target (Callable | None, optional) – a callable function that takes an input target and returns a transformed version, by default None

  • patch_size (int, optional) – size of the 2D extracted patches. Patches can either be square (int/float value) or rectangular (tuple of int/float). Defaults to a square of size 256, by default 256

  • query_units (str, optional) – unit system of the dataset’s queries, by default ‘pixel’

  • query_crs (Union[int, str, CRS], optional) – CRS of the dataset’s queries, by default ‘self’ (same as dataset’s CRS)

  • obs_data_columns (dict, optional) –

    this dictionary allows users to match the dataset attributes with custom column names of their obs data file, by default:

    {'x': 'lon',
     'y': 'lat',
     'index': 'surveyId',
     'species_id': 'speciesId',
     'split': 'subset'}
    

    Here’s a description of the keys:

    • ’x’, ‘y’: coordinates of the obs points (by default ‘lon’, ‘lat’ as per the WGS84 system)

    • ’index’: obs ID over which to iterate during the training loop

    • ’species_id’: species ID (label) associated with the obs points

    • ’split’: dataset split column name

  • task (str, optional) – machine learning task (used to format labels accordingly), by default ‘multiclass’

  • binary_positive_classes (list, optional) – labels’ classes to consider valid in the case of binary classification with multi-class labels (defaults to all 0), by default []

  • cache (bool, optional) – if True, cache file handle to speed up repeated sampling, by default True

__len__() int
coords_transform(lon: Union[int, float], lat: Union[int, float], input_crs: Union[str, int, CRS] = '4326', output_crs: Union[str, int, CRS] = 'self') tuple[float, float]

Transform coordinates from one CRS to another.

Parameters:
  • lon (Union[int, float]) – longitude

  • lat (Union[int, float]) – latitude

  • input_crs (Union[str, int, CRS], optional) – Input CRS, by default “4326”

  • output_crs (Union[str, int, CRS], optional) – Output CRS, by default “self”

Returns:

Transformed coordinates.

Return type:

tuple

get_label(df: DataFrame, query_lon: float, query_lat: float, obs_id: Optional[int] = None) Union[ndarray, int]

Return the label(s) matching the query coordinates.

This method takes into account the fact that several labels can match a single coordinate set. For that reason, the labels are chosen according to the value of the ‘obs_id’ parameter (matching the observation_id column of the labels DataFrame). If no value is given to ‘obs_id’, all matching labels are returned.

Parameters:
  • df (pd.DataFrame) – dataset DataFrame composed of columns: [‘lon’, ‘lat’, ‘observation_id’]

  • query_lon (float) – longitude value of the query point.

  • query_lat (float) – latitude value of the query point.

  • obs_id (int, optional) – observation ID tied to the query point, by default None

Returns:

target label(s).

Return type:

Union[np.ndarray, int]

point_to_bbox(lon: Union[int, float], lat: Union[int, float], size: Optional[Union[tuple, int]] = None, units: str = 'crs', crs: Union[int, str] = 'self') BoundingBox

Convert a geographical point to a torchgeo BoundingBox.

This method converts a 2D point into a 2D torchgeo bounding box (bbox). If ‘size’ is in the CRS’ unit system, the bbox is computed directly from the point’s coordinates. If ‘size’ is in pixels, ‘size’ is multiplied by the resolution of the dataset. If ‘size’ is in meters and the dataset’s unit system isn’t, the point is projected into the nearest meter-based CRS (from a list defined as constant at the begining of this file), the bbox vertices’ min and max are computed in thise reference system, then they are projected back into the input CRS ‘crs’. If the dataset’s CRS doesn’t match en EPSG code but is instead built from a WKT, the nearest meter-based CRS will always be EPSG:3035.

By default, ‘size’ is set to the dataset’s ‘patch_size’ value via None.

Parameters:
  • lon (Union[int, float]) – longitude

  • lat (Union[int, float]) – latitude

  • size (Union[tuple, int], optional) – Patch size, by default None. If passed as an int, the patch will be square. If passed as a tuple (width, height), can be rectangular. By default None.

  • units (str, optional) – The coordinates’ unit system, must have a value in [‘pixel’, ‘crs’]. The size of the bbox will adapt to the unit. If ‘pixel’ is selected, the bbox size will be multiplied by the dataset resolution. Selecting ‘crs’ will not modify the bbox size. In that case the returned bbox will be of size: (size[0], size[1]) <metric_of_the_dataset (usually meters)>. Defaults to ‘crs’.

  • crs (Union[int, str]) – CRS of the point’s lon/lat coordinates, by default ‘self’.

Returns:

Corresponding torchgeo BoundingBox.

Return type:

BoundingBox

malpolon.data.datasets.torchgeo_sentinel2

This module provides Sentinel-2 related classes based on torchgeo.

Sentinel-2 data is queried from Microsoft Planetary Computer (MPC).

Author: Theo Larcher <theo.larcher@inria.fr>

class malpolon.data.datasets.torchgeo_sentinel2.RasterSentinel2(root: str = 'data', labels_name: Optional[str] = None, split: Optional[str] = None, crs: Optional[Any] = None, res: Optional[float] = None, bands: Optional[Sequence[str]] = None, transform: Optional[Callable] = None, transform_target: Optional[Callable] = None, patch_size: Union[int, float, tuple] = 256, query_units: str = 'pixel', query_crs: Union[int, str, CRS] = 'self', obs_data_columns: dict = {'index': 'surveyId', 'species_id': 'speciesId', 'split': 'subset', 'x': 'lon', 'y': 'lat'}, task: str = 'multiclass', binary_positive_classes: list = [], cache: bool = True)

Bases: RasterTorchGeoDataset

Raster dataset adapted for Sentinel-2 data.

Inherits RasterTorchGeoDataset.

all_bands: list[str] = ['B02', 'B03', 'B04', 'B08']

Names of all available bands in the dataset

date_format = '%Y%m%dT%H%M%S'

Date format string used to parse date from filename.

Not used if filename_regex does not contain a date group.

filename_glob = 'T*_B0*_10m.tif'

Glob expression used to search for files.

This expression should be specific enough that it will not pick up files from other datasets. It should not include a file extension, as the dataset may be in a different file format than what it was originally downloaded as.

filename_regex = 'T31TEJ_20190801T104031_(?P<band>B0[\\d])'

Regular expression used to extract date from filename.

The expression should use named groups. The expression may contain any number of groups. The following groups are specifically searched for by the base class:

  • date: used to calculate mint and maxt for index insertion

When separate_files is True, the following additional groups are searched for to find other files:

  • band: replaced with requested band name

is_image = True

True if dataset contains imagery, False if dataset contains mask

plot(sample: Patches) Figure

Plot a 3-bands dataset patch (sample).

Plots a dataset sample by selecting the 3 bands indicated in the plot_bands variable (in the same order). By default, the method plots the RGB bands.

Parameters:

sample (Patches) – dataset’s patch to plot

Returns:

matplotlib figure containing the plot

Return type:

Figure

plot_bands = ['B04', 'B03', 'B02']
separate_files = True

True if data is stored in a separate file for each band, else False.

class malpolon.data.datasets.torchgeo_sentinel2.RasterSentinel2GLC23(root: str = 'data', labels_name: Optional[str] = None, split: Optional[str] = None, crs: Optional[Any] = None, res: Optional[float] = None, bands: Optional[Sequence[str]] = None, transform: Optional[Callable] = None, transform_target: Optional[Callable] = None, patch_size: Union[int, float, tuple] = 256, query_units: str = 'pixel', query_crs: Union[int, str, CRS] = 'self', obs_data_columns: dict = {'index': 'surveyId', 'species_id': 'speciesId', 'split': 'subset', 'x': 'lon', 'y': 'lat'}, task: str = 'multiclass', binary_positive_classes: list = [], cache: bool = True)

Bases: RasterSentinel2

Adaptation of RasterSentinel2 for new GLC23 observations.

all_bands: list[str] = ['red', 'green', 'blue', 'nir']

Names of all available bands in the dataset

filename_glob = '*.tif'

Glob expression used to search for files.

This expression should be specific enough that it will not pick up files from other datasets. It should not include a file extension, as the dataset may be in a different file format than what it was originally downloaded as.

filename_regex = '(?P<band>red|green|blue|nir)_2021'

Regular expression used to extract date from filename.

The expression should use named groups. The expression may contain any number of groups. The following groups are specifically searched for by the base class:

  • date: used to calculate mint and maxt for index insertion

When separate_files is True, the following additional groups are searched for to find other files:

  • band: replaced with requested band name

plot_bands = ['red', 'green', 'blue']
class malpolon.data.datasets.torchgeo_sentinel2.Sentinel2GeoSampler(dataset: GeoDataset, size: Union[Tuple[float, float], float], length: Optional[int] = None, roi: Optional[BoundingBox] = None, units: str = 'pixel', crs: str = 'crs')

Bases: GeoSampler

Custom sampler for RasterSentinel2.

This custom sampler is used by RasterSentinel2 to query the dataset with the fully constructed dictionary. The sampler is passed to and used by PyTorch dataloaders in the training/inference workflow.

Inherits GeoSampler.

NOTE: this sampler is compatible with any class inheriting

RasterTorchGeoDataset’s __getitem__ method so the name of this sampler may become irrelevant when more dataset-specific classes inheriting RasterTorchGeoDataset are created.

__len__() int

Return the number of samples in a single epoch.

Returns:

length of the epoch

class malpolon.data.datasets.torchgeo_sentinel2.Sentinel2TorchGeoDataModule(dataset_path: str, labels_name: str = 'labels.csv', train_batch_size: int = 32, inference_batch_size: int = 256, num_workers: int = 8, size: int = 200, units: str = 'pixel', crs: int = 4326, binary_positive_classes: list = [], task: str = 'classification_multiclass', dataset_kwargs: dict = {}, download_data_sample: bool = False, **kwargs)

Bases: BaseDataModule

Data module for Sentinel-2A dataset.

__init__(dataset_path: str, labels_name: str = 'labels.csv', train_batch_size: int = 32, inference_batch_size: int = 256, num_workers: int = 8, size: int = 200, units: str = 'pixel', crs: int = 4326, binary_positive_classes: list = [], task: str = 'classification_multiclass', dataset_kwargs: dict = {}, download_data_sample: bool = False, **kwargs)

Class constructor.

Parameters:
  • dataset_path (str) – path to the directory containing the data

  • labels_name (str, optional) – labels file name, by default ‘labels.csv’

  • train_batch_size (int, optional) – train batch size, by default 32

  • inference_batch_size (int, optional) – inference batch size, by default 256

  • num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process, by default 8

  • size (int, optional) – size of the 2D extracted patches. Patches can either be square (int/float value) or rectangular (tuple of int/float). Defaults to a square of size 200, by default 200

  • units (Units, optional) – The queries’ unit system, must have a value in [‘pixel’, ‘crs’, ‘m’, ‘meter’, ‘metre]. This sets the unit you want your query to be performed in, even if it doesn’t match the dataset’s unit system, by default Units.CRS

  • crs (int, optional) – The queries’ coordinate reference system (CRS). This argument sets the CRS of the dataset’s queries. The value should be equal to the CRS of your observations. It takes any EPSG integer code, by default 4326

  • binary_positive_classes (list, optional) – labels’ classes to consider valid in the case of binary classification with multi-class labels (defaults to all 0), by default []

  • task (str, optional) – machine learning task (used to format labels accordingly), by default ‘classification_multiclass’

  • dataset_kwargs (dict, optional) – additional keyword arguments for the dataset, by default {}

  • download_data_sample (bool, optional) – whether to download a sample of Sentinel-2 data, by default False

download_data_sample()

Download 4 Sentinel-2A tiles from MPC.

This method is useful to quickly download a sample of Sentinel-2A tiles via Microsoft Planetary Computer (MPC). The referenced of the tile downloaded are specified by the tile_id and timestamp variables. Tiles are not downloaded if they already have been and are found locally.

get_dataset(split, transform, **kwargs)
predict_dataloader() DataLoader
test_dataloader() DataLoader
property test_transform
train_dataloader() DataLoader
property train_transform
val_dataloader() DataLoader

malpolon.data.datasets.torchgeo_concat

This module provides Sentinel-2 related classes based on torchgeo.

Sentinel-2 data is queried from Microsoft Planetary Computer (MPC).

NOTE: “unused” imports are necessary because they are evaluated in the eval() function. These classes are passed by the user in the config file along with their arguments.

Author: Theo Larcher <theo.larcher@inria.fr>

class malpolon.data.datasets.torchgeo_concat.ConcatPatchRasterDataset(datasets: list[dict[torch.utils.data.dataset.Dataset]], split: str, transform: Callable, task: str)

Bases: Dataset

Concatenation dataset.

This class concatenates multiple datasets into a single one with a single sampler. It is useful when you want to train a model on multiple datasets at the same time _(e.g.: to train both rasters and pre-extracted jpeg patches)_. In the case of RasterTorchgeDataset, the __getitem__ method calls a private method _default_sample_to_getitem to convert the iterating index to the correct index for the dataset. This is necessary because the RasterTorchGeoDataset class uses a custom dict-based sampler but the other classes don’t.

The minimum required class arguments _(i.e. observation_ids, targets, coordinates)_ are taken from the first dataset in the list.

Target labels are taken from the first dataset in the list.

All datasets must return tensors of the same shape.

__getitem__(idx: int) Tuple[Patches, Targets]

Query an item from the dataset.

Parameters:

idx (int) – item index (standard int).

Returns:

concatenated data and corresponding label(s).

Return type:

Tuple[Patches, Targets]

__init__(datasets: list[dict[torch.utils.data.dataset.Dataset]], split: str, transform: Callable, task: str) None

Class constructor.

Parameters:
  • datasets (list[dict[Dataset]]) – list of dictionaries with keys ‘callable’ and ‘kwargs’ on which to call the datasets.

  • transform (Callable) – data transform callable function.

  • task (str) – deep learning task.

__len__() int
class malpolon.data.datasets.torchgeo_concat.ConcatTorchGeoDataModule(dataset_kwargs: list[dict[torch.utils.data.dataset.Dataset, Any]], dataset_path: str = 'dataset/', labels_name: str = 'labels.csv', train_batch_size: int = 32, inference_batch_size: int = 256, num_workers: int = 8, binary_positive_classes: list = [], task: str = 'classification_multiclass', **kwargs)

Bases: BaseDataModule

__init__(dataset_kwargs: list[dict[torch.utils.data.dataset.Dataset, Any]], dataset_path: str = 'dataset/', labels_name: str = 'labels.csv', train_batch_size: int = 32, inference_batch_size: int = 256, num_workers: int = 8, binary_positive_classes: list = [], task: str = 'classification_multiclass', **kwargs)

Class constructor.

Parameters:
  • concat_datasets (list[dict[Dataset, Any]]) – list of dictionaries with keys ‘callable’ and ‘kwargs’ on which to call the datasets.

  • dataset_path (str) – path to the directory containing the data

  • labels_name (str, optional) – labels file name, by default ‘labels.csv’

  • train_batch_size (int, optional) – train batch size, by default 32

  • inference_batch_size (int, optional) – inference batch size, by default 256

  • num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process, by default 8

  • binary_positive_classes (list, optional) – labels’ classes to consider valid in the case of binary classification with multi-class labels (defaults to all 0), by default []

  • task (str, optional) – machine learning task (used to format labels accordingly), by default ‘classification_multiclass’

get_dataset(split, transform=None, **kwargs) Dataset
predict_dataloader() DataLoader
test_dataloader() DataLoader
property test_transform
train_dataloader() DataLoader
property train_transform
val_dataloader() DataLoader

malpolon.data.datasets.geolifeclef2022

GeoLifeCLEF2022Dataset
class malpolon.data.datasets.geolifeclef2022.GeoLifeCLEF2022Dataset(root: Union[str, Path], subset: str, *, region: str = 'both', patch_data: str = 'all', use_rasters: bool = True, patch_extractor: Optional[PatchExtractor] = None, use_localisation: bool = False, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, **kwargs)

Bases: Dataset

Pytorch dataset handler for GeoLifeCLEF 2022 dataset.

Parameters:
  • root (string or pathlib.Path) – Root directory of dataset.

  • subset (string, either "train", "val", "train+val" or "test") – Use the given subset (“train+val” is the complete training data).

  • region (string, either "both", "fr" or "us") – Load the observations of both France and US or only a single region.

  • patch_data (string or list of string) – Specifies what type of patch data to load, possible values: ‘all’, ‘rgb’, ‘near_ir’, ‘landcover’ or ‘altitude’.

  • use_rasters (boolean (optional)) – If True, extracts patches from environmental rasters.

  • patch_extractor (PatchExtractor object (optional)) – Patch extractor to use if rasters are used.

  • use_localisation (boolean) – If True, returns also the localisation as a tuple (latitude, longitude).

  • transform (callable (optional)) – A function/transform that takes a list of arrays and returns a transformed version.

  • target_transform (callable (optional)) – A function/transform that takes in the target and transforms it.

__getitem__(index: int) Union[dict[str, Patches], tuple[dict[str, Patches], Targets]]

Return a dataset item.

Args:

index (int): dataset id.

Returns:
Union[dict[str, Patches], tuple[dict[str, Patches], Targets]]:

data and labels corresponding to the dataset id.

__len__() int

Return the number of observations in the dataset.

MiniGeoLifeCLEF2022Dataset
class malpolon.data.datasets.geolifeclef2022.MiniGeoLifeCLEF2022Dataset(root: Union[str, Path], subset: str, *, patch_data: str = 'all', use_rasters: bool = True, patch_extractor: Optional[PatchExtractor] = None, use_localisation: bool = False, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, **kwargs)

Bases: GeoLifeCLEF2022Dataset

Pytorch dataset handler for a subset of GeoLifeCLEF 2022 dataset.

It consists in a restriction to France and to the 100 most present plant species.

Parameters:
  • root (string or pathlib.Path) – Root directory of dataset.

  • subset (string, either "train", "val", "train+val" or "test") – Use the given subset (“train+val” is the complete training data).

  • patch_data (string or list of string) – Specifies what type of patch data to load, possible values: ‘all’, ‘rgb’, ‘near_ir’, ‘landcover’ or ‘altitude’.

  • use_rasters (boolean (optional)) – If True, extracts patches from environmental rasters.

  • patch_extractor (PatchExtractor object (optional)) – Patch extractor to use if rasters are used.

  • use_localisation (boolean) – If True, returns also the localisation as a tuple (latitude, longitude).

  • transform (callable (optional)) – A function/transform that takes a list of arrays and returns a transformed version.

  • target_transform (callable (optional)) – A function/transform that takes in the target and transforms it.

__getitem__(index: int) Union[dict[str, Patches], tuple[dict[str, Patches], Targets]]

Return a dataset item.

Args:

index (int): dataset id.

Returns:
Union[dict[str, Patches], tuple[dict[str, Patches], Targets]]:

data and labels corresponding to the dataset id.

__len__() int

Return the number of observations in the dataset.

MicroGeoLifeCLEF2022Dataset

Pytorch dataset handler for a subset of GeoLifeCLEF 2022 dataset.

It consists in a restriction to France and to the 100 most present plant species.

param root:

Root directory of dataset.

type root:

string or pathlib.Path

param subset:

Use the given subset (“train+val” is the complete training data).

type subset:

string, either “train”, “val”, “train+val” or “test”

param patch_data:

Specifies what type of patch data to load, possible values: ‘all’, ‘rgb’, ‘near_ir’, ‘landcover’ or ‘altitude’.

type patch_data:

string or list of string

param use_rasters:

If True, extracts patches from environmental rasters.

type use_rasters:

boolean (optional)

param patch_extractor:

Patch extractor to use if rasters are used.

type patch_extractor:

PatchExtractor object (optional)

param use_localisation:

If True, returns also the localisation as a tuple (latitude, longitude).

type use_localisation:

boolean

param transform:

A function/transform that takes a list of arrays and returns a transformed version.

type transform:

callable (optional)

param target_transform:

A function/transform that takes in the target and transforms it.

type target_transform:

callable (optional)

param download:

If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

type download:

boolean (optional)

malpolon.data.datasets.geolifeclef2022.MicroGeoLifeCLEF2022Dataset.__getitem__(self, index)

Return a MicroGeolifeClef dataset item.

Args:

index (int): dataset id.

Returns:

(tuple): data and labels.

malpolon.data.datasets.geolifeclef2022.MicroGeoLifeCLEF2022Dataset.__len__(self)

Return the number of observations in the dataset.

malpolon.data.datasets.geolifeclef2023

GeoLifeCLEF2023: datasets
class malpolon.data.datasets.geolifeclef2023.PatchesDataset(occurrences: str, providers: Iterable, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, id_name: str = 'glcID', label_name: str = 'speciesId', item_columns: Iterable = ['lat', 'lon', 'patchID'])

Bases: Dataset

Patches dataset class.

This class provides a PyTorch-friendly dataset to handle patch data and labels. The data can be .jpeg or .tif files of variable depth (i.e. multi-spectral data). Each __getitem__ call returns a patch extracted from the dataset.

Args:

Dataset (Dataset): PyTorch Dataset class.

__getitem__(index)

Return a dataset element.

Returns an element from a dataset id (0 to n) with its label.

Args:

index (int): dataset id.

Returns:

(tuple): tuple of data patch (tensor) and label (int).

__len__()

Return the size of the dataset.

Returns:

(int): number of occurrences.

class malpolon.data.datasets.geolifeclef2023.PatchesDatasetMultiLabel(occurrences: str, providers: Iterable, n_classes: Union[int, str] = 'max', id_getitem: str = 'patchId', **kwargs)

Bases: PatchesDataset

Multilabel patches dataset.

Like PatchesDataset but provides one-hot encoded labels.

Args:

PatchesDataset (PatchesDataset): pytorch friendly dataset class.

__getitem__(index)

Return a dataset element.

Returns an element from a dataset id (0 to n) with the labels in one-hot encoding.

Args:

index (int): dataset id.

Returns:

(tuple): tuple of data patch (tensor) and labels (list).

__len__()

Return the size of the dataset.

Returns:

(int): number of occurrences.

class malpolon.data.datasets.geolifeclef2023.TimeSeriesDataset(occurrences: str, providers: Iterable, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, id_name: str = 'glcID', label_name: str = 'speciesId', item_columns: Iterable = ['timeSerieID'])

Bases: Dataset

Time series dataset.

Like PatchesDataset but adapted to time series data which is formatted as .csv files where each row is an occurrence and each column is a timestamp. Timeseries present one additional dimension compared to visual patches. Timeseries do not all have the same size and have a parametrable no_data_value argument.

Args:

Dataset (Dataset): pytorch Dataset class.

__getitem__(index)

Return a time series.

Returns a tuple of time series data and its label. The data in in the following shape : [1, n_bands, max_length_ts]

Args:

index (int): dataset id.

Returns:

(tuple): tuple of time series data (tensor) and labels (list).

__len__()

Return the size of the dataset.

Returns:

(int): number of occurrences.

GeoLifeCLEF2023: providers
class malpolon.data.datasets.geolifeclef2023.PatchProvider(size: int = 128, normalize: bool = False)

Bases: object

Parent class for all GLC23 patch data providers.

This class implements common implemented & abstract methods used by all patch providers. Particularly, the plot method is designed to accommodate all cases of the GLC23 patch datasets.

abstract __getitem__(item)

Return a patch (parent class).

__len__()

Return the size of the provider.

Returns:

(int): number of layers (depth).

class malpolon.data.datasets.geolifeclef2023.MetaPatchProvider(providers: Iterable[Callable], transform: Optional[Callable] = None)

Bases: PatchProvider

Parent class for patch providers.

This class interfaces patch providers with patch datasets.

Args:

(PatchProvider): inherits PatchProvider.

__getitem__(item)

Return a patch.

This getter is used by a patch dataset class and calls each provider’s getter method to return concatenated patches.

Args:

item (dict): provider index.

Returns:

(array): concatenaned patch from all providers.

__len__()

Return the size of the provider.

Returns:

(int): number of layers (depth).

class malpolon.data.datasets.geolifeclef2023.RasterPatchProvider(raster_path: str, size: int = 128, spatial_noise: int = 0, normalize: bool = True, fill_zero_if_error: bool = False, nan_value: Union[int, float] = 0)

Bases: PatchProvider

Patch provider for .tif raster files.

This class handles rasters stored as .tif files and returns patches from them based on a dict key containing (lon, lat) coordinates.

Args:

(PatchProvider): inherits PatchProvider.

__getitem__(item)

Return a patch from a .tif raster.

This getter returns a patch of size self.size from a .tif raster loaded in self.data using GPS coordinates projected in EPSG:4326.

Args:
item (dict): dictionary containing at least latitude and longitude

keys ({‘lat’: lat, ‘lon’:lon})

Returns:

(array): the environmental vector (size>1 or size=1).

__len__()

Return the size of the provider.

Returns:

(int): number of layers (depth).

class malpolon.data.datasets.geolifeclef2023.MultipleRasterPatchProvider(rasters_folder: str, select: Optional[Iterable[str]] = None, **kwargs)

Bases: PatchProvider

Patch provider for multiple PatchProvider.

This provider is useful when having to load several patch modalities through RasterPatchProvider by selecting to desired data in the ‘select’ argument of the constructor.

Args:

PatchProvider (_type_): _description_

__getitem__(item)

Return multiple patches.

Returns multiple patches from multiple raster providers in a 3-dimensional numpy array.

Args:

item (dict): providers index.

Returns:

(array): array of patches.

__len__()

Return the size of the provider.

Returns:

(int): number of layers (depth).

class malpolon.data.datasets.geolifeclef2023.JpegPatchProvider(root_path: str, select: Optional[Iterable[str]] = None, normalize: bool = False, patch_transform: Optional[Callable] = None, size: int = 128, dataset_stats: str = 'jpeg_patches_stats.csv', id_getitem: str = 'patchID')

Bases: PatchProvider

JPEG patches provider for GLC23.

Provides tensors of multi-modal patches from JPEG patch files of rasters of the GLC23 challenge.

Attributes:

(PatchProvider): inherits PatchProvider.

__getitem__(item)

Return a tensor composed of every channels of a jpeg patch.

Looks for every spectral bands listed in self.channels and returns a 3-dimensionnal patch concatenated in 1 tensor. The index used to query the right patch is a dictionnary with at least one key/value pair : {‘patchID’, <patchID_value>}.

Args:
item (dict): dictionnary containing the patchID necessary to

identify the jpeg patch to return.

Raises:

KeyError: the ‘patchID’ key is missing from item Exception: item is not a dictionnary as expected

Returns:

(tensor): multi-channel patch tensor.

__len__()

Return the size of the provider.

Returns:

(int): number of layers (depth).

class malpolon.data.datasets.geolifeclef2023.TimeSeriesProvider(root_path: str, eos_replace_value: Union[int, float] = -1, transform: Optional[Iterable[Callable]] = None)

Bases: object

Provide time series data.

This provider is the parent class of time series providers. It handles time series data stored as .csv files where each file has values for a single spectral band (red, green, infra-red etc…).

abstract __getitem__(item)

Return a time series (parent class).

__len__()

Return the size of the provider.

Returns:

(int): number of layers (depth).

class malpolon.data.datasets.geolifeclef2023.MetaTimeSeriesProvider(providers: Iterable[Callable], transform: Optional[Iterable[Callable]] = None)

Bases: TimeSeriesProvider

Time Series provider called by TimeSeriesDataset to handle TS providers.

This TS provider handles all TS providers passed in TimeSeriesDataset to provide multi-modal time series objects.

Args:

(TimeSeriesProvider) : inherits TimeSeriesProvider.

__getitem__(item)

Return the time series from all TS providers.

This getter is called by a TimeSeriesDataset and returns a 3-dimensional array containing time series from all providers. The time series is select from the TS index item which is a dictionary containing at least 1 key/value pair : {‘timeSerieID’: <timeSerieID_value>}.

Args:

item (dict): time series index.

Returns:

(array): array containing the time series from all TS providers.

__len__()

Return the size of the provider.

Returns:

(int): number of layers (depth).

class malpolon.data.datasets.geolifeclef2023.CSVTimeSeriesProvider(ts_data_path: str, normalize: bool = False, ts_id: str = 'timeSerieID', features_col: list = [], eos_replace_value: Union[int, float] = -1, transform: Optional[Iterable[Callable]] = None)

Bases: TimeSeriesProvider

Implement TimeSeriesProvider for .csv time series.

Only loads time series from a single .csv file. If the time series of an occurrence is smaller than the longest one in the .csv file, the remaining columns are filled with the ‘eos’ string.

Args:

(TimeSeriesProvider) : inherits TimeSeriesProvider.

__getitem__(item)

Return a time series.

This getter returns a time series occurrence in tensor fashionned array, based on the item index which is a dictionnary containing at least 1 key/value pair : {‘timeSeriesID’: <timeSeriesID_value>}.

Args:

item (dict): time series index.

Returns:

(array): time series occurrence.

__len__()

Return the size of the time series.

Returns:

(int): length of biggest sequence.

class malpolon.data.datasets.geolifeclef2023.MultipleCSVTimeSeriesProvider(root_path: str, select: list = [], normalize: bool = False, ts_id: str = 'timeSerieID', features_col: list = [], eos_replace_value: Union[int, float] = -1, transform: Optional[Iterable[Callable]] = None)

Bases: TimeSeriesProvider

Like CSVTimeSeriesProvider but with several .csv files.

Args:

(TimeSeriesProvider) : inherits TimeSeriesProvider

__getitem__(item)

Return multiple time series.

Like CSVTimeSeriesProvider but concatenantes the time series in a tensor fashioned 3-dimensional array to provide a multi-modal time series array. This array is of the shape : (1, n_modalities, max_sequence_length)

Args:

item (_type_): _description_

Returns:

_type_: _description_

__len__()

Return the size of the provider.

Returns:

(int): number of layers (depth).

malpolon.data.datasets.geolifeclef2024

GeoLifeCLEF2024: datasets
class malpolon.data.datasets.geolifeclef2024.PatchesDataset(occurrences: str, providers: Iterable, transform: Optional[Callable] = None, transform_target: Optional[Callable] = None, id_name: str = 'surveyId', labels_name: str = 'speciesId', item_columns: Iterable = ['lat', 'lon', 'surveyId'], split: Optional[str] = None, **kwargs)

Bases: Dataset

Patches dataset class.

This class provides a PyTorch-friendly dataset to handle patch data and labels. The data can be .jpeg or .tif files of variable depth (i.e. multi-spectral data). Each __getitem__ call returns a patch extracted from the dataset.

Args:

Dataset (Dataset): PyTorch Dataset class.

__getitem__(index)

Return a dataset element.

Returns an element from a dataset id (0 to n) with its label.

Args:

index (int): dataset id.

Returns:

(tuple): tuple of data patch (tensor) and label (int).

__len__()

Return the size of the dataset.

Returns:

(int): number of occurrences.

class malpolon.data.datasets.geolifeclef2024.PatchesDatasetMultiLabel(occurrences: str, providers: Iterable, n_classes: Union[int, str] = 'max', id_getitem: str = 'surveyId', **kwargs)

Bases: PatchesDataset

Multilabel patches dataset.

Like PatchesDataset but provides one-hot encoded labels.

Args:

PatchesDataset (PatchesDataset): pytorch friendly dataset class.

__getitem__(index)

Return a dataset element.

Returns an element from a dataset id (0 to n) with the labels in one-hot encoding.

Args:

index (int): dataset id.

Returns:

(tuple): tuple of data patch (tensor) and labels (list).

__len__()

Return the size of the dataset.

Returns:

(int): number of occurrences.

class malpolon.data.datasets.geolifeclef2024.TimeSeriesDataset(occurrences: str, providers: Iterable, transform: Optional[Callable] = None, transform_target: Optional[Callable] = None, id_name: str = 'surveyId', labels_name: str = 'speciesId', item_columns: Iterable = ['timeSerieID'])

Bases: Dataset

Time series dataset.

Like PatchesDataset but adapted to time series data which is formatted as .csv files where each row is an occurrence and each column is a timestamp. Timeseries present one additional dimension compared to visual patches. Timeseries do not all have the same size and have a parametrable no_data_value argument.

Args:

Dataset (Dataset): pytorch Dataset class.

__getitem__(index)

Return a time series.

Returns a tuple of time series data and its label. The data in in the following shape : [1, n_bands, max_length_ts]

Args:

index (int): dataset id.

Returns:

(tuple): tuple of time series data (tensor) and labels (list).

__len__()

Return the size of the dataset.

Returns:

(int): number of occurrences.

GeoLifeCLEF2024: providers
class malpolon.data.datasets.geolifeclef2024.PatchProvider(size: int = 128, normalize: bool = False)

Bases: object

Parent class for all GLC23 patch data providers.

This class implements common implemented & abstract methods used by all patch providers. Particularly, the plot method is designed to accommodate all cases of the GLC23 patch datasets.

abstract __getitem__(item)

Return a patch (parent class).

__len__()

Return the size of the provider.

Returns:

(int): number of layers (depth).

class malpolon.data.datasets.geolifeclef2024.MetaPatchProvider(providers: Iterable[Callable], transform: Optional[Callable] = None)

Bases: PatchProvider

Parent class for patch providers.

This class interfaces patch providers with patch datasets.

Args:

(PatchProvider): inherits PatchProvider.

__getitem__(item)

Return a patch.

This getter is used by a patch dataset class and calls each provider’s getter method to return concatenated patches.

Args:

item (dict): provider index.

Returns:

(array): concatenaned patch from all providers.

__len__()

Return the size of the provider.

Returns:

(int): number of layers (depth).

class malpolon.data.datasets.geolifeclef2024.RasterPatchProvider(raster_path: str, size: int = 128, spatial_noise: int = 0, normalize: bool = True, fill_zero_if_error: bool = False, nan_value: Union[int, float] = 0)

Bases: PatchProvider

Patch provider for .tif raster files.

This class handles rasters stored as .tif files and returns patches from them based on a dict key containing (lon, lat) coordinates.

Args:

(PatchProvider): inherits PatchProvider.

__getitem__(item)

Return a patch from a .tif raster.

This getter returns a patch of size self.size from a .tif raster loaded in self.data using GPS coordinates projected in EPSG:4326.

Args:
item (dict): dictionary containing at least latitude and longitude

keys ({‘lat’: lat, ‘lon’:lon})

Returns:

(array): the environmental vector (size>1 or size=1).

__len__()

Return the size of the provider.

Returns:

(int): number of layers (depth).

class malpolon.data.datasets.geolifeclef2024.MultipleRasterPatchProvider(rasters_folder: str, select: Optional[Iterable[str]] = None, **kwargs)

Bases: PatchProvider

Patch provider for multiple PatchProvider.

This provider is useful when having to load several patch modalities through RasterPatchProvider by selecting to desired data in the ‘select’ argument of the constructor.

Args:

PatchProvider (_type_): _description_

__getitem__(item)

Return multiple patches.

Returns multiple patches from multiple raster providers in a 3-dimensional numpy array.

Args:

item (dict): providers index.

Returns:

(array): array of patches.

__len__()

Return the size of the provider.

Returns:

(int): number of layers (depth).

class malpolon.data.datasets.geolifeclef2024.JpegPatchProvider(root_path: str, select: Optional[Iterable[str]] = None, normalize: bool = False, transform: Optional[Callable] = None, size: int = 128, dataset_stats: str = 'jpeg_patches_stats.csv', id_getitem: str = 'surveyId')

Bases: PatchProvider

JPEG patches provider for GLC23.

Provides tensors of multi-modal patches from JPEG patch files of rasters of the GLC23 challenge.

Attributes:

(PatchProvider): inherits PatchProvider.

__getitem__(item)

Return a tensor composed of every channels of a jpeg patch.

Looks for every spectral bands listed in self.channels and returns a 3-dimensionnal patch concatenated in 1 tensor. The index used to query the right patch is a dictionnary with at least one key/value pair : {‘surveyId’, <surveyId_value>}.

Args:
item (dict): dictionnary containing the surveyId necessary to

identify the jpeg patch to return.

Raises:

KeyError: the ‘surveyId’ key is missing from item Exception: item is not a dictionnary as expected

Returns:

(tensor): multi-channel patch tensor.

__len__()

Return the size of the provider.

Returns:

(int): number of layers (depth).

class malpolon.data.datasets.geolifeclef2024.TimeSeriesProvider(root_path: str, eos_replace_value: Union[int, float] = -1, transform: Optional[Iterable[Callable]] = None)

Bases: object

Provide time series data.

This provider is the parent class of time series providers. It handles time series data stored as .csv files where each file has values for a single spectral band (red, green, infra-red etc…).

abstract __getitem__(item)

Return a time series (parent class).

__len__()

Return the size of the provider.

Returns:

(int): number of layers (depth).

class malpolon.data.datasets.geolifeclef2024.MetaTimeSeriesProvider(providers: Iterable[Callable], transform: Optional[Iterable[Callable]] = None)

Bases: TimeSeriesProvider

Time Series provider called by TimeSeriesDataset to handle TS providers.

This TS provider handles all TS providers passed in TimeSeriesDataset to provide multi-modal time series objects.

Args:

(TimeSeriesProvider) : inherits TimeSeriesProvider.

__getitem__(item)

Return the time series from all TS providers.

This getter is called by a TimeSeriesDataset and returns a 3-dimensional array containing time series from all providers. The time series is select from the TS index item which is a dictionary containing at least 1 key/value pair : {‘timeSerieID’: <timeSerieID_value>}.

Args:

item (dict): time series index.

Returns:

(array): array containing the time series from all TS providers.

__len__()

Return the size of the provider.

Returns:

(int): number of layers (depth).

class malpolon.data.datasets.geolifeclef2024.CSVTimeSeriesProvider(ts_data_path: str, normalize: bool = False, ts_id: str = 'timeSerieID', features_col: list = [], eos_replace_value: Union[int, float] = -1, transform: Optional[Iterable[Callable]] = None)

Bases: TimeSeriesProvider

Implement TimeSeriesProvider for .csv time series.

Only loads time series from a single .csv file. If the time series of an occurrence is smaller than the longest one in the .csv file, the remaining columns are filled with the ‘eos’ string.

Args:

(TimeSeriesProvider) : inherits TimeSeriesProvider.

__getitem__(item)

Return a time series.

This getter returns a time series occurrence in tensor fashionned array, based on the item index which is a dictionnary containing at least 1 key/value pair : {‘timeSeriesID’: <timeSeriesID_value>}.

Args:

item (dict): time series index.

Returns:

(array): time series occurrence.

__len__()

Return the size of the time series.

Returns:

(int): length of biggest sequence.

class malpolon.data.datasets.geolifeclef2024.MultipleCSVTimeSeriesProvider(root_path: str, select: list = [], normalize: bool = False, ts_id: str = 'timeSerieID', features_col: list = [], eos_replace_value: Union[int, float] = -1, transform: Optional[Iterable[Callable]] = None)

Bases: TimeSeriesProvider

Like CSVTimeSeriesProvider but with several .csv files.

Args:

(TimeSeriesProvider) : inherits TimeSeriesProvider

__getitem__(item)

Return multiple time series.

Like CSVTimeSeriesProvider but concatenantes the time series in a tensor fashioned 3-dimensional array to provide a multi-modal time series array. This array is of the shape : (1, n_modalities, max_sequence_length)

Args:

item (_type_): _description_

Returns:

_type_: _description_

__len__()

Return the size of the provider.

Returns:

(int): number of layers (depth).

malpolon.plot

malpolon.plot.history

Utilities used for plotting purposes.

Author: Titouan Lorieul <titouan.lorieul@gmail.com>

malpolon.plot.history.escape_tex(s: str) str

Escape special characters for LaTeX rendering.

malpolon.plot.history.plot_history(df_metrics: pd.DataFrame, *, fig: Optional[plt.Figure] = None, axes: Optional[list[plt.Axis]] = None) tuple[plt.Figure, list[plt.Axis]]

Plot model training history.

Parameters:
  • df_metrics (pd.DataFrame containing metrics monitored during training) –

  • fig (plt.Figure to use for plotting) –

  • axes (list of plt.Axis to use for plotting) –

Returns:

  • fig (plt.Figure used for plotting)

  • axes (list of plt.Axis used for plotting)

malpolon.plot.history.plot_metric(df_metrics: pd.DataFrame, metric: str, ax: plt.Axis) plt.Axis

Plot specific metric monitored during model training history.

Parameters:
  • df_metrics (pd.DataFrame containing metrics monitored during training) –

  • metrics (name of the metric to plot) –

  • ax (plt.Axis to use for plotting) –

Returns:

ax

Return type:

plt.Axis used for plotting

malpolon.plot.map

Utilities for plotting maps.

Author: Titouan Lorieul <titouan.lorieul@gmail.com>

malpolon.plot.map.plot_map(*, region: Optional[str] = None, extent: Optional[npt.ArrayLike] = None, ax: Optional[plt.Axes] = None) plt.Axes

Plot a map on which to show the observations.

Parameters:
  • region (string, either "fr" or "us") – Region to show, France or US.

  • extent (array-like of form [longitude min, longitude max, latitude min, latitude max]) – Explicit extent of the area to show, e.g., for zooming.

  • ax (plt.Axes) – Provide an Axes to use instead of creating one.

Returns:

Returns the used Axes.

Return type:

plt.Axes

malpolon.plot.map.plot_observation_dataset(*, df: DataFrame, obs_data_columns: dict = {'index': 'surveyId', 'species_id': 'speciesId', 'split': 'subset', 'x': 'lon', 'y': 'lat'}, show_map: bool = False) Axes

Plot observations on a map from an observation dataset.

This method expects a pandas DataFrame with columns containing coordinates, species ids, and dataset split information (‘train’, ‘test’ or ‘val’). Users can specify the names of the columns containing these informations if they do not match the default names.

Parameters:
  • df (pd.DataFrame) – observation dataset

  • obs_data_columns (_type_, optional) – dictionary matching custom dataframe keys with necessary keys, by default {‘x’: ‘lon’, ‘y’: ‘lat’, ‘index’: ‘surveyId’, ‘species_id’: ‘speciesId’, ‘split’: ‘subset’}

  • show_map (bool, optional) – if True, displays the map, by default False

Returns:

map’s ax object

Return type:

plt.Axes

malpolon.plot.map.plot_observation_map(*, longitudes: npt.ArrayLike, latitudes: npt.ArrayLike, ax: Optional[plt.Axes] = None, **kwargs) plt.Axes

Plot observations on a map.

Parameters:
  • longitude (array-like) – Longitudes of the observations.

  • latitude (array-like) – Latitudes of the observations.

  • ax (plt.Axes) – Provide an Axes to use instead of creating one.

  • kwargs – Additional arguments to pass to plt.scatter.

Returns:

Returns the used Axes.

Return type:

plt.Axes

malpolon.logging

class malpolon.logging.Summary

Bases: Callback

Log model summary at the beginning of training.

FIXME handle multi validation data loaders, combined datasets

on_train_start(trainer: pl.Trainer, pl_module: GenericPredictionSystem) None
malpolon.logging.str_object(obj: Any) str

Format an object to string.

Formats an object to printing by returning a string containing the class name and attributes (both name and values)

Parameters:

obj (object to print.) –

Returns:

str

Return type:

string containing class name and attributes.

malpolon.check_install

This module checks the installation of PyTorch and GPU libraries.

Author: Titouan Lorieul <titouan.lorieul@gmail.com>

malpolon.check_install.print_cuda_info()

Print information about the CUDA/PyTorch installation.