@@ -51,15 +51,17 @@ def get_available_augmentations() -> list[str]:
5151
5252def get_transform (
5353 augmentations : str | list [str ] | dict [str , Any ] | None = None ,
54+ task : str = "box" ,
5455) -> A .Compose :
55- """Create Albumentations transform for bounding boxes.
56+ """Create Albumentations transform for boxes or keypoints .
5657
5758 Args:
5859 augmentations: Augmentation configuration:
5960 - str: Single augmentation name
6061 - list: List of augmentation names
6162 - dict: Dict with names as keys and params as values
6263 - None: No augmentations
64+ task: Task type - "box" for bounding boxes or "keypoint" for keypoints
6365
6466 Returns:
6567 Composed albumentations transform
@@ -79,9 +81,13 @@ def get_transform(
7981 ... "HorizontalFlip": {"p": 0.5},
8082 ... "Downscale": {"scale_min": 0.25, "scale_max": 0.75}
8183 ... })
84+
85+ >>> # Keypoint augmentations
86+ >>> transform = get_transform(augmentations=["HorizontalFlip"], task="keypoint")
8287 """
8388 transforms_list = []
8489 bbox_params = None
90+ keypoint_params = None
8591
8692 if augmentations is not None :
8793 augment_configs = _parse_augmentations (augmentations )
@@ -90,12 +96,19 @@ def get_transform(
9096 aug_transform = _create_augmentation (aug_name , aug_params )
9197 transforms_list .append (aug_transform )
9298
93- bbox_params = A .BboxParams (format = "pascal_voc" , label_fields = ["category_ids" ])
99+ if task == "box" :
100+ bbox_params = A .BboxParams (format = "pascal_voc" , label_fields = ["labels" ])
101+ elif task == "keypoint" :
102+ keypoint_params = A .KeypointParams (format = "xy" , label_fields = ["labels" ])
103+ else :
104+ raise ValueError (f"Unsupported task: { task } . Must be 'box' or 'keypoint'." )
94105
95106 # Always add ToTensorV2 at the end
96107 transforms_list .append (ToTensorV2 ())
97108
98- return A .Compose (transforms_list , bbox_params = bbox_params )
109+ return A .Compose (
110+ transforms_list , bbox_params = bbox_params , keypoint_params = keypoint_params
111+ )
99112
100113
101114def _parse_augmentations (
0 commit comments