diff --git a/HISTORY.md b/HISTORY.md index ab6f857a5..f563f0c30 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -2,8 +2,27 @@ ## Version 2.0.0 (Date: TBD) +The major innovations are: + +1. **Migration from albumentations to kornia for data augmentations** - Replaced albumentations with kornia for better PyTorch integration and GPU acceleration + +Additional features and enhancements include: + +- **Enhancement:** Better PyTorch integration with kornia transforms +- **Enhancement:** Simplified API without bbox parameter complexity +- **Enhancement:** GPU acceleration support for augmentation transforms +- **Enhancement:** More consistent with PyTorch ecosystem +- **Documentation:** Updated augmentation documentation with kornia examples + ### Breaking Changes - Deprecated Items Removed: +**Augmentation Changes:** +- **Migration from albumentations to kornia** - All augmentation transforms now use kornia instead of albumentations +- Some augmentation parameter names have changed (e.g., `scale_range` → `scale`, `height/width` → `size`) +- Custom transforms now use `torch.nn.Sequential` instead of `A.Compose` +- No longer requires bbox parameter configuration +- See migration guide in documentation for detailed parameter changes + **Removed Functions:** - `xml_to_annotations()` - Use `utilities.read_pascal_voc(path)` or the general `utilities.read_file(path)`. - `boxes_to_shapefile()` - Use `image_to_geo_coordinates()`. @@ -23,6 +42,7 @@ - `raster_path` parameter from predict_tile() - Use `path` parameter instead **Migration Guide:** +- **Augmentations:** Update parameter names and use kornia transforms (see documentation) - Replace `xml_to_annotations(xml_path)` with `read_pascal_voc(xml_path)` - Replace `boxes_to_shapefile(df, root_dir)` with `image_to_geo_coordinates(df, root_dir)` - Replace `plot_points(image, points)` with `plot_results(results)` diff --git a/docs/user_guide/11_training.md b/docs/user_guide/11_training.md index 8d3317250..881a3212a 100644 --- a/docs/user_guide/11_training.md +++ b/docs/user_guide/11_training.md @@ -271,7 +271,7 @@ Note that if you trained on GPU and restore on cpu, you will need the map_locati ### Data Augmentations -DeepForest supports configurable data augmentations using [Albumentations](https://albumentations.ai/docs/3-basic-usage/bounding-boxes-augmentations/) to improve model generalization across different sensors and acquisition conditions. Augmentations can be specified through the configuration file or passed directly to the model. +DeepForest supports configurable data augmentations using [Kornia](https://kornia.readthedocs.io/en/latest/augmentation.html) to improve model generalization across different sensors and acquisition conditions. Augmentations can be specified through the configuration file or passed directly to the model. #### Configuration-based Augmentations @@ -285,9 +285,9 @@ train: # Or as a list of custom parameters augmentations: - HorizontalFlip: {p: 0.5} - - Downscale: {scale_range: [0.25, 0.75], p: 0.5} - - RandomSizedBBoxSafeCrop: {height: 400, width: 400, p: 0.3} - - PadIfNeeded: {min_height: 400, min_width: 400, p: 1.0} + - Downscale: {scale: [0.25, 0.75], p: 0.5} + - RandomSizedBBoxSafeCrop: {size: [400, 400], scale: [0.5, 1.0], p: 0.3} + - PadIfNeeded: {size: [400, 400], p: 1.0} ``` Note that augmentations are provided as a list (prepended with a `-` in YAML). If you omit this, the parameter will be interpreted as a dictionary and the config parser may fail. If you provide only the augmentation name, default settings will be used. These have been chosen to reflect sensible parameters for different transformations, as it's quite easy to "over augment" which can make models harder to train. By default, if you enable augmentation and do not specify a transform explicitly, only `HorizontalFlip` will be used. @@ -310,7 +310,7 @@ config_args = { "train": { "augmentations": [ "HorizontalFlip": {"p": 0.8}, - "Downscale": {"scale_range": (0.5, 0.9), "p": 0.3} + "Downscale": {"scale": (0.5, 0.9), "p": 0.3} ] } } @@ -321,16 +321,16 @@ model = main.deepforest(config_args=config_args) DeepForest supports the following augmentations optimized for object detection: -- **[HorizontalFlip](https://albumentations.ai/docs/api-reference/albumentations/augmentations/geometric/flip/#HorizontalFlip)**: Randomly flip images horizontally -- **[VerticalFlip](https://albumentations.ai/docs/api-reference/albumentations/augmentations/geometric/flip/#VerticalFlip)**: Randomly flip images vertically -- **[Downscale](https://albumentations.ai/docs/api-reference/albumentations/augmentations/pixel/transforms/#Downscale)**: Randomly downscale images to simulate different resolutions -- **[RandomSizedBBoxSafeCrop](https://albumentations.ai/docs/api-reference/albumentations/augmentations/crops/transforms/#RandomSizedBBoxSafeCrop)**: Crop image while preserving bounding boxes -- **[PadIfNeeded](https://albumentations.ai/docs/api-reference/albumentations/augmentations/geometric/pad/#PadIfNeeded)**: Pad images to minimum size -- **[Rotate](https://albumentations.ai/docs/api-reference/albumentations/augmentations/geometric/rotate/#Rotate)**: Rotate images by small angles -- **[RandomBrightnessContrast](https://albumentations.ai/docs/api-reference/albumentations/augmentations/pixel/transforms/#RandomBrightnessContrast)**: Adjust brightness and contrast -- **[HueSaturationValue](https://albumentations.ai/docs/api-reference/albumentations/augmentations/pixel/transforms/#HueSaturationValue)**: Adjust color properties -- **[GaussNoise](https://albumentations.ai/docs/api-reference/albumentations/augmentations/pixel/transforms/#GaussNoise)**: Add gaussian noise -- **[Blur](https://albumentations.ai/docs/api-reference/albumentations/augmentations/blur/transforms/#Blur)**: Apply blur effect +- **[HorizontalFlip](https://kornia.readthedocs.io/en/latest/augmentation.html#kornia.augmentation.RandomHorizontalFlip)**: Randomly flip images horizontally +- **[VerticalFlip](https://kornia.readthedocs.io/en/latest/augmentation.html#kornia.augmentation.RandomVerticalFlip)**: Randomly flip images vertically +- **[Downscale](https://kornia.readthedocs.io/en/latest/augmentation.html#kornia.augmentation.RandomResizedCrop)**: Randomly downscale images to simulate different resolutions +- **[RandomSizedBBoxSafeCrop](https://kornia.readthedocs.io/en/latest/augmentation.html#kornia.augmentation.RandomResizedCrop)**: Crop image while preserving bounding boxes +- **[PadIfNeeded](https://kornia.readthedocs.io/en/latest/augmentation.html#kornia.augmentation.PadTo)**: Pad images to minimum size +- **[Rotate](https://kornia.readthedocs.io/en/latest/augmentation.html#kornia.augmentation.RandomRotation)**: Rotate images by small angles +- **[RandomBrightnessContrast](https://kornia.readthedocs.io/en/latest/augmentation.html#kornia.augmentation.ColorJitter)**: Adjust brightness and contrast +- **[HueSaturationValue](https://kornia.readthedocs.io/en/latest/augmentation.html#kornia.augmentation.ColorJitter)**: Adjust color properties +- **[GaussNoise](https://kornia.readthedocs.io/en/latest/augmentation.html#kornia.augmentation.RandomGaussianNoise)**: Add gaussian noise +- **[Blur](https://kornia.readthedocs.io/en/latest/augmentation.html#kornia.augmentation.RandomGaussianBlur)**: Apply blur effect #### Zoom Augmentations for Multi-Resolution Training @@ -342,13 +342,13 @@ config_args = { "train": { "augmentations": [ # Simulate different acquisition heights/resolutions - "Downscale": {"scale_range": (0.25, 0.75), "p": 0.5}, + "Downscale": {"scale": (0.25, 0.75), "p": 0.5}, # Crop at different scales while preserving objects - "RandomSizedBBoxSafeCrop": {"height": 400, "width": 400, "p": 0.3}, + "RandomSizedBBoxSafeCrop": {"size": (400, 400), "scale": (0.5, 1.0), "p": 0.3}, # Ensure minimum image size - "PadIfNeeded": {"min_height": 400, "min_width": 400, "p": 1.0}, + "PadIfNeeded": {"size": (400, 400), "p": 1.0}, # Basic data augmentation "HorizontalFlip": {"p": 0.5} @@ -364,26 +364,24 @@ model = main.deepforest(config_args=config_args) For complete control over augmentations, you can still provide custom transforms: ```python -import albumentations as A -from albumentations.pytorch import ToTensorV2 +import torch +import kornia.augmentation as K def get_transform(augment): """Custom transform function""" if augment: - transform = A.Compose([ - A.HorizontalFlip(p=0.5), - A.Downscale(scale_range=(0.25, 0.75), p=0.5), - ToTensorV2() - ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=["category_ids"])) + transform = torch.nn.Sequential([ + K.RandomHorizontalFlip(p=0.5), + K.RandomResizedCrop(size=(200, 200), scale=(0.25, 0.75), p=0.5) + ]) else: - transform = A.Compose([ToTensorV2()], - bbox_params=A.BboxParams(format='pascal_voc', label_fields=["category_ids"])) + transform = torch.nn.Identity() return transform model = main.deepforest(transforms=get_transform) ``` -**Note**: When creating custom transforms, always include `ToTensorV2()` and properly configure `bbox_params` for object detection. If your augmentation pipeline does not contain any geometric transformations, `bbox_params` is not required. Otherwise it's important that you keep the format as `pascal_voc` so that the boxes are correctly interpreted by Albumentations. +**Note**: When creating custom transforms, use PyTorch's `torch.nn.Sequential` to compose multiple augmentations. Kornia transforms work directly with PyTorch tensors and don't require special bbox parameter handling like Albumentations. **How do I make training faster?** diff --git a/pyproject.toml b/pyproject.toml index ff61074f8..9de15a9a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,12 +34,12 @@ authors = [ dependencies = [ "aiohttp", "aiolimiter", - "albumentations>=2.0.0", "faster-coco-eval>=1.6.8", "geopandas", "h5py", "huggingface_hub>=0.25.0", "hydra-core", + "kornia", "matplotlib", "numpy<2.0", "omegaconf", diff --git a/src/deepforest/augmentations.py b/src/deepforest/augmentations.py index 5a65d25ce..3b758f930 100644 --- a/src/deepforest/augmentations.py +++ b/src/deepforest/augmentations.py @@ -1,4 +1,4 @@ -"""Augmentation module for DeepForest using albumentations. +"""Augmentation module for DeepForest using kornia. This module provides configurable augmentations for training and validation that can be specified through configuration files or direct @@ -7,36 +7,48 @@ from typing import Any -import albumentations as A -from albumentations.pytorch import ToTensorV2 +import kornia.augmentation as K +import torch from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig from omegaconf.listconfig import ListConfig _SUPPORTED_TRANSFORMS = { - "HorizontalFlip": (A.HorizontalFlip, {"p": 0.5}), - "VerticalFlip": (A.VerticalFlip, {"p": 0.5}), - "Downscale": (A.Downscale, {"scale_range": (0.25, 0.5), "p": 0.5}), - "RandomCrop": (A.RandomCrop, {"height": 200, "width": 200, "p": 0.5}), + "HorizontalFlip": (K.RandomHorizontalFlip, {"p": 0.5}), + "VerticalFlip": (K.RandomVerticalFlip, {"p": 0.5}), + "Downscale": ( + K.RandomResizedCrop, + {"size": (200, 200), "scale": (0.25, 0.5), "p": 0.5}, + ), + "RandomCrop": (K.RandomCrop, {"size": (200, 200), "p": 0.5}), "RandomSizedBBoxSafeCrop": ( - A.RandomSizedBBoxSafeCrop, - {"height": 200, "width": 200, "p": 0.5}, + K.RandomResizedCrop, + {"size": (200, 200), "scale": (0.5, 1.0), "p": 0.5}, ), - "PadIfNeeded": (A.PadIfNeeded, {"min_height": 800, "min_width": 800, "p": 1.0}), - "Rotate": (A.Rotate, {"limit": 15, "p": 0.5}), + "PadIfNeeded": (K.PadTo, {"size": (800, 800), "p": 1.0}), + "Rotate": (K.RandomRotation, {"degrees": 15, "p": 0.5}), "RandomBrightnessContrast": ( - A.RandomBrightnessContrast, - {"brightness_limit": 0.2, "contrast_limit": 0.2, "p": 0.5}, + K.ColorJitter, + {"brightness": 0.2, "contrast": 0.2, "p": 0.5}, ), "HueSaturationValue": ( - A.HueSaturationValue, - {"hue_shift_limit": 10, "sat_shift_limit": 10, "val_shift_limit": 10, "p": 0.5}, + K.ColorJitter, + {"hue": 0.1, "saturation": 0.1, "p": 0.5}, + ), + "GaussNoise": (K.RandomGaussianNoise, {"std": 0.1, "p": 0.3}), + "Blur": ( + K.RandomGaussianBlur, + {"kernel_size": (3, 3), "sigma": (0.1, 2.0), "p": 0.3}, + ), + "GaussianBlur": ( + K.RandomGaussianBlur, + {"kernel_size": (3, 3), "sigma": (0.1, 2.0), "p": 0.3}, + ), + "MotionBlur": ( + K.RandomMotionBlur, + {"kernel_size": 3, "angle": 45, "direction": 0.0, "p": 0.3}, ), - "GaussNoise": (A.GaussNoise, {"var_limit": (5.0, 20.0), "p": 0.3}), - "Blur": (A.Blur, {"blur_limit": 2, "p": 0.3}), - "GaussianBlur": (A.GaussianBlur, {"blur_limit": 2, "p": 0.3}), - "MotionBlur": (A.MotionBlur, {"blur_limit": 2, "p": 0.3}), - "ZoomBlur": (A.ZoomBlur, {"max_factor": 1.05, "p": 0.3}), + "ZoomBlur": (K.RandomAffine, {"degrees": 0, "scale": (1.0, 1.05), "p": 0.3}), } @@ -51,8 +63,8 @@ def get_available_augmentations() -> list[str]: def get_transform( augmentations: str | list[str] | dict[str, Any] | None = None, -) -> A.Compose: - """Create Albumentations transform for bounding boxes. +) -> torch.nn.Module: + """Create Kornia transform for bounding boxes. Args: augmentations: Augmentation configuration: @@ -62,10 +74,10 @@ def get_transform( - None: No augmentations Returns: - Composed albumentations transform + Composed kornia transform Examples: - >>> # Default behavior, returns a ToTensorV2 transform + >>> # Default behavior, returns a basic transform >>> transform = get_transform() >>> # Single augmentation @@ -77,11 +89,10 @@ def get_transform( >>> # Augmentations with parameters >>> transform = get_transform(augmentations={ ... "HorizontalFlip": {"p": 0.5}, - ... "Downscale": {"scale_min": 0.25, "scale_max": 0.75} + ... "Downscale": {"scale": (0.25, 0.75)} ... }) """ transforms_list = [] - bbox_params = None if augmentations is not None: augment_configs = _parse_augmentations(augmentations) @@ -90,12 +101,12 @@ def get_transform( aug_transform = _create_augmentation(aug_name, aug_params) transforms_list.append(aug_transform) - bbox_params = A.BboxParams(format="pascal_voc", label_fields=["category_ids"]) - - # Always add ToTensorV2 at the end - transforms_list.append(ToTensorV2()) - - return A.Compose(transforms_list, bbox_params=bbox_params) + # Create a sequential container for all transforms + if transforms_list: + return torch.nn.Sequential(*transforms_list) + else: + # Return identity transform if no augmentations + return torch.nn.Identity() def _parse_augmentations( @@ -151,15 +162,15 @@ def _parse_augmentations( raise ValueError(f"Unable to parse augmentation parameters: {augmentations}") -def _create_augmentation(name: str, params: dict[str, Any]) -> A.BasicTransform | None: - """Create an albumentations transform by name with given parameters. +def _create_augmentation(name: str, params: dict[str, Any]) -> torch.nn.Module | None: + """Create a kornia transform by name with given parameters. Args: name: Name of the augmentation params: Parameters to pass to the augmentation Returns: - Albumentations transform or None if name not recognized + Kornia transform or None if name not recognized """ if name not in get_available_augmentations(): diff --git a/tests/test_augmentations.py b/tests/test_augmentations.py index 8ec1f41d1..032429729 100644 --- a/tests/test_augmentations.py +++ b/tests/test_augmentations.py @@ -2,9 +2,9 @@ import io import os -import albumentations as A +import torch +import kornia.augmentation as K import pytest -from albumentations.pytorch import ToTensorV2 from deepforest import main, get_data from deepforest.augmentations import _create_augmentation @@ -17,44 +17,40 @@ def test_get_transform_default(): """Test default behavior (backward compatibility).""" # Test without augmentations transform = get_transform() - assert isinstance(transform, A.Compose) - assert len(transform.transforms) == 1 - assert isinstance(transform.transforms[0], ToTensorV2) + assert isinstance(transform, torch.nn.Identity) def test_get_transform_single_augmentation(): """Test with single augmentation name.""" transform = get_transform(augmentations="Downscale") - assert isinstance(transform, A.Compose) - assert len(transform.transforms) == 2 # Downscale + ToTensorV2 - assert isinstance(transform.transforms[0], A.Downscale) - assert isinstance(transform.transforms[1], ToTensorV2) + assert isinstance(transform, torch.nn.Sequential) + assert len(transform) == 1 # Just Downscale + assert isinstance(transform[0], K.RandomResizedCrop) def test_get_transform_multiple_augmentations(): """Test with list of augmentation names (strings).""" transform = get_transform(augmentations=["HorizontalFlip", "Downscale"]) - assert isinstance(transform, A.Compose) - assert len(transform.transforms) == 3 # HorizontalFlip + Downscale + ToTensorV2 - assert isinstance(transform.transforms[0], A.HorizontalFlip) - assert isinstance(transform.transforms[1], A.Downscale) - assert isinstance(transform.transforms[2], ToTensorV2) + assert isinstance(transform, torch.nn.Sequential) + assert len(transform) == 2 # HorizontalFlip + Downscale + assert isinstance(transform[0], K.RandomHorizontalFlip) + assert isinstance(transform[1], K.RandomResizedCrop) def test_get_transform_with_parameters(): """Test with augmentation parameters.""" augmentations = { "HorizontalFlip": {"p": 0.8}, - "Downscale": {"scale_range": (0.5, 0.9), "p": 0.3} + "Downscale": {"scale": (0.5, 0.9), "p": 0.3} } transform = get_transform(augmentations=augmentations) - assert isinstance(transform, A.Compose) - assert len(transform.transforms) == 3 # HorizontalFlip + Downscale + ToTensorV2 + assert isinstance(transform, torch.nn.Sequential) + assert len(transform) == 2 # HorizontalFlip + Downscale # Check parameters were applied - assert transform.transforms[0].p == 0.8 # HorizontalFlip - assert transform.transforms[1].scale_range == (0.5, 0.9) # Downscale - assert transform.transforms[1].p == 0.3 + assert transform[0].p == 0.8 # HorizontalFlip + assert transform[1].scale == (0.5, 0.9) # Downscale + assert transform[1].p == 0.3 def test_parse_augmentations_string(): @@ -146,7 +142,7 @@ def test_create_augmentation(): """Test _create_augmentation function.""" # Valid augmentation aug = _create_augmentation("HorizontalFlip", {"p": 0.7}) - assert isinstance(aug, A.HorizontalFlip) + assert isinstance(aug, K.RandomHorizontalFlip) assert aug.p == 0.7 # Invalid augmentation should raise ValueError @@ -165,14 +161,13 @@ def test_get_available_augmentations(): def test_bbox_params(): - """Test that bbox_params are properly set.""" + """Test that transforms are properly created.""" transform = get_transform(augmentations="HorizontalFlip") - # Check that bbox_params is configured in the transform repr - transform_repr = repr(transform) - assert "bbox_params" in transform_repr - assert "'format': 'pascal_voc'" in transform_repr - assert "'label_fields': ['category_ids']" in transform_repr + # Check that transform is properly created + assert isinstance(transform, torch.nn.Sequential) + assert len(transform) == 1 + assert isinstance(transform[0], K.RandomHorizontalFlip) def test_blur_augmentations(): @@ -181,37 +176,35 @@ def test_blur_augmentations(): for blur_aug in blur_augmentations: transform = get_transform(augmentations=[{blur_aug: {}}]) - assert isinstance(transform, A.Compose) - assert len(transform.transforms) == 2 # Blur augmentation + ToTensorV2 - assert isinstance(transform.transforms[1], ToTensorV2) + assert isinstance(transform, torch.nn.Sequential) + assert len(transform) == 1 # Just blur augmentation + assert isinstance(transform[0], K.RandomGaussianBlur) def test_blur_augmentations_with_parameters(): """Test blur augmentations with custom parameters.""" blur_configs = { - "GaussianBlur": {"blur_limit": 5, "p": 0.8}, - "MotionBlur": {"blur_limit": 7, "p": 0.6}, - "ZoomBlur": {"max_factor": 1.3, "p": 0.4} + "GaussianBlur": {"kernel_size": (5, 5), "p": 0.8}, + "MotionBlur": {"kernel_size": 7, "p": 0.6}, + "ZoomBlur": {"scale": (1.0, 1.3), "p": 0.4} } transform = get_transform(augmentations=blur_configs) - assert isinstance(transform, A.Compose) - assert len(transform.transforms) == 4 # 3 blur augmentations + ToTensorV2 - assert isinstance(transform.transforms[3], ToTensorV2) + assert isinstance(transform, torch.nn.Sequential) + assert len(transform) == 3 # 3 blur augmentations def test_mixed_blur_and_other_augmentations(): """Test combining blur augmentations with other augmentations using mixed format.""" - mixed_augmentations = ["HorizontalFlip", {"GaussianBlur": {"blur_limit": 3}}, "Downscale", {"MotionBlur": {"blur_limit": 5}}] + mixed_augmentations = ["HorizontalFlip", {"GaussianBlur": {"kernel_size": (3, 3)}}, "Downscale", {"MotionBlur": {"kernel_size": 5}}] transform = get_transform(augmentations=mixed_augmentations) - assert isinstance(transform, A.Compose) - assert len(transform.transforms) == 5 # 4 augmentations + ToTensorV2 - assert isinstance(transform.transforms[0], A.HorizontalFlip) - assert isinstance(transform.transforms[1], A.GaussianBlur) - assert isinstance(transform.transforms[2], A.Downscale) - assert isinstance(transform.transforms[3], A.MotionBlur) - assert isinstance(transform.transforms[4], ToTensorV2) + assert isinstance(transform, torch.nn.Sequential) + assert len(transform) == 4 # 4 augmentations + assert isinstance(transform[0], K.RandomHorizontalFlip) + assert isinstance(transform[1], K.RandomGaussianBlur) + assert isinstance(transform[2], K.RandomResizedCrop) + assert isinstance(transform[3], K.RandomMotionBlur) def test_unknown_augmentation_error(): @@ -226,13 +219,11 @@ def get_transform(augment): """This is the new transform""" if augment: print("I'm a new augmentation!") - transform = A.Compose( - [A.HorizontalFlip(p=0.5), ToTensorV2()], - bbox_params=A.BboxParams(format='pascal_voc', - label_fields=["category_ids"])) - + transform = torch.nn.Sequential( + K.RandomHorizontalFlip(p=0.5) + ) else: - transform = ToTensorV2() + transform = torch.nn.Identity() return transform m = main.deepforest(transforms=get_transform) @@ -272,7 +263,7 @@ def test_config_augmentations_with_params(): "train": { "augmentations": [ {"HorizontalFlip": {"p": 0.8}}, - {"Downscale": {"scale_range": (0.5, 0.9), "p": 0.3}} + {"Downscale": {"scale": (0.5, 0.9), "p": 0.3}} ] }, "validation": {