From 3e19c365736402158364c871915f9596a449d5bb Mon Sep 17 00:00:00 2001 From: Shadab Date: Sun, 25 Aug 2024 00:01:22 +0100 Subject: [PATCH 1/2] feat : added support for keypoints dataset --- supervision/dataset/core.py | 297 ++++++++++++++++++++++++++++ supervision/dataset/formats/yolo.py | 134 ++++++++++++- supervision/dataset/utils.py | 9 +- 3 files changed, 427 insertions(+), 13 deletions(-) diff --git a/supervision/dataset/core.py b/supervision/dataset/core.py index 7d320bbc9..0af08b5a3 100644 --- a/supervision/dataset/core.py +++ b/supervision/dataset/core.py @@ -21,6 +21,7 @@ ) from supervision.dataset.formats.yolo import ( load_yolo_annotations, + load_yolo_keypoint_annotations, save_data_yaml, save_yolo_annotations, ) @@ -32,6 +33,7 @@ train_test_split, ) from supervision.detection.core import Detections +from supervision.keypoint.core import KeyPoints from supervision.utils.internal import deprecated, warn_deprecated from supervision.utils.iterables import find_duplicates @@ -908,3 +910,298 @@ def from_folder_structure(cls, root_directory_path: str) -> ClassificationDatase images=image_paths, annotations=annotations, ) + + +class KeyPointDataset(BaseDataset): + """ + Contains information about a keypoint dataset. Handles lazy image loading + and annotation retrieval, dataset splitting, conversion into yolo + formats. + + Attributes: + classes (List[str]): List containing dataset class names. + images (List[str]): Accepts a list of image paths. + If you pass a list of paths, the dataset will + lazily load images on demand, which is much more memory-efficient. + annotations (Dict[str, Keypoints]): Dictionary mapping + image path to annotations. The dictionary keys match + match the keys in `images` or entries in the list of + image paths. + """ + + def __init__( + self, + classes: List[str], + images: List[str], + annotations: Dict[str, KeyPoints], + ) -> None: + self.classes = classes + + if set(images) != set(annotations): + raise ValueError( + "The keys of the images and annotations dictionaries must match." + ) + self.annotations = annotations + + # Eliminate duplicates while preserving order + self.image_paths = list(dict.fromkeys(images)) + + def _get_image(self, image_path: str) -> np.ndarray: + """Assumes that image is in dataset""" + return cv2.imread(image_path) + + def __len__(self) -> int: + return len(self.image_paths) + + def __getitem__(self, i: int) -> Tuple[str, np.ndarray, KeyPoints]: + """ + Returns: + Tuple[str, np.ndarray, KeyPoints]: The image path, image data, + and its corresponding annotation at index i. + """ + image_path = self.image_paths[i] + image = self._get_image(image_path) + annotation = self.annotations[image_path] + return image_path, image, annotation + + def __iter__(self) -> Iterator[Tuple[str, np.ndarray, KeyPoints]]: + """ + Iterate over the images and annotations in the dataset. + + Yields: + Iterator[Tuple[str, np.ndarray, KeyPoints]]: + An iterator that yields tuples containing the image path, + the image data, and its corresponding KeyPoint annotation. + """ + for i in range(len(self)): + image_path, image, annotation = self[i] + yield image_path, image, annotation + + def __eq__(self, other) -> bool: + if not isinstance(other, KeyPointDataset): + return False + + if set(self.classes) != set(other.classes): + return False + + if self.image_paths != other.image_paths: + return False + + if self.annotations != other.annotations: + return False + + return True + + def split( + self, split_ratio=0.8, random_state=None, shuffle: bool = True + ) -> Tuple[KeyPointDataset, KeyPointDataset]: + """ + Splits the dataset into two parts (training and testing) + using the provided split_ratio. + + Args: + split_ratio (float, optional): The ratio of the training + set to the entire dataset. + random_state (int, optional): The seed for the random number generator. + This is used for reproducibility. + shuffle (bool, optional): Whether to shuffle the data before splitting. + + Returns: + Tuple[KeyPointDataset, KeyPointDataset]: A tuple containing + the training and testing datasets. + + Examples: + ```python + import supervision as sv + + ds = sv.KeyPointDataset(...) + train_ds, test_ds = ds.split(split_ratio=0.7, random_state=42, shuffle=True) + len(train_ds), len(test_ds) + # (700, 300) + ``` + """ + + train_paths, test_paths = train_test_split( + data=self.image_paths, + train_ratio=split_ratio, + random_state=random_state, + shuffle=shuffle, + ) + + train_annotations = {path: self.annotations[path] for path in train_paths} + test_annotations = {path: self.annotations[path] for path in test_paths} + + train_dataset = KeyPointDataset( + classes=self.classes, + images=train_paths, + annotations=train_annotations, + ) + test_dataset = KeyPointDataset( + classes=self.classes, + images=test_paths, + annotations=test_annotations, + ) + return train_dataset, test_dataset + + @classmethod + def merge(cls, dataset_list: List[KeyPointDataset]) -> KeyPointDataset: + """ + Merge a list of `KeyPointDataset` objects into a single + `KeyPointDataset` object. + + This method takes a list of `KeyPointDataset` objects and combines + their respective fields (`classes`, `images`, + `annotations`) into a single `KeyPointDataset` object. + + Args: + dataset_list (List[KeyPointDataset]): A list of `KeyPointDataset` + objects to merge. + + Returns: + (KeyPointDataset): A single `KeyPointDataset` object containing + the merged data from the input list. + + Examples: + ```python + import supervision as sv + + ds_1 = sv.KeyPointDataset(...) + len(ds_1) + # 100 + ds_1.classes + # ['dog', 'person'] + + ds_2 = sv.KeyPointDataset(...) + len(ds_2) + # 200 + ds_2.classes + # ['cat'] + + ds_merged = sv.KeyPointDataset.merge([ds_1, ds_2]) + len(ds_merged) + # 300 + ds_merged.classes + # ['cat', 'dog', 'person'] + ``` + """ + + image_paths = list( + chain.from_iterable(dataset.image_paths for dataset in dataset_list) + ) + image_paths_unique = list(dict.fromkeys(image_paths)) + if len(image_paths) != len(image_paths_unique): + duplicates = find_duplicates(image_paths) + raise ValueError( + f"Image paths {duplicates} are not unique across datasets." + ) + image_paths = image_paths_unique + + classes = merge_class_lists( + class_lists=[dataset.classes for dataset in dataset_list] + ) + + annotations = {} + for dataset in dataset_list: + annotations.update(dataset.annotations) + for dataset in dataset_list: + class_index_mapping = build_class_index_mapping( + source_classes=dataset.classes, target_classes=classes + ) + for image_path in dataset.image_paths: + annotations[image_path] = map_detections_class_id( + source_to_target_mapping=class_index_mapping, + detections=annotations[image_path], + ) + + return cls( + classes=classes, + images=image_paths, + annotations=annotations, + ) + + @classmethod + def from_yolo( + cls, + images_directory_path: str, + annotations_directory_path: str, + data_yaml_path: str, + force_masks: bool = False, + ) -> KeyPointDataset: + """ + Creates a Dataset instance from YOLO formatted data. + + Args: + images_directory_path (str): The path to the + directory containing the images. + annotations_directory_path (str): The path to the directory + containing the YOLO annotation files. + data_yaml_path (str): The path to the data + YAML file containing class information. + + Returns: + KeyPointDataset: A KeyPointDataset instance + containing the loaded images and annotations. + + Examples: + ```python + import roboflow + from roboflow import Roboflow + import supervision as sv + + roboflow.login() + rf = Roboflow() + + project = rf.workspace(WORKSPACE_ID).project(PROJECT_ID) + dataset = project.version(PROJECT_VERSION).download("yolov8") + + ds = sv.KeyPointDataset.from_yolo( + images_directory_path=f"{dataset.location}/train/images", + annotations_directory_path=f"{dataset.location}/train/labels", + data_yaml_path=f"{dataset.location}/data.yaml" + ) + + ds.classes + # ['dog', 'person'] + ``` + """ + classes, image_paths, annotations = load_yolo_keypoint_annotations( + images_directory_path=images_directory_path, + annotations_directory_path=annotations_directory_path, + data_yaml_path=data_yaml_path, + ) + return KeyPointDataset( + classes=classes, images=image_paths, annotations=annotations + ) + + def as_yolo( + self, + images_directory_path: Optional[str] = None, + annotations_directory_path: Optional[str] = None, + data_yaml_path: Optional[str] = None, + ) -> None: + """ + Exports the dataset to YOLO format. This method saves the + images and their corresponding annotations in YOLO format. + + Args: + images_directory_path (Optional[str]): The path to the + directory where the images should be saved. + If not provided, images will not be saved. + annotations_directory_path (Optional[str]): The path to the + directory where the annotations in + YOLO format should be saved. If not provided, + annotations will not be saved. + data_yaml_path (Optional[str]): The path where the data.yaml + file should be saved. + If not provided, the file will not be saved. + """ + if images_directory_path is not None: + save_dataset_images( + dataset=self, images_directory_path=images_directory_path + ) + if annotations_directory_path is not None: + save_yolo_annotations( + dataset=self, annotations_directory_path=annotations_directory_path + ) + if data_yaml_path is not None: + save_data_yaml(data_yaml_path=data_yaml_path, classes=self.classes) diff --git a/supervision/dataset/formats/yolo.py b/supervision/dataset/formats/yolo.py index 0ecbac4b5..7ae193e83 100644 --- a/supervision/dataset/formats/yolo.py +++ b/supervision/dataset/formats/yolo.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import cv2 import numpy as np @@ -9,6 +9,7 @@ from supervision.dataset.utils import approximate_mask_with_polygons from supervision.detection.core import Detections from supervision.detection.utils import polygon_to_mask, polygon_to_xyxy +from supervision.keypoint.core import KeyPoints from supervision.utils.file import ( list_files_with_extensions, read_txt_file, @@ -18,7 +19,7 @@ ) if TYPE_CHECKING: - from supervision.dataset.core import DetectionDataset + from supervision.dataset.core import DetectionDataset, KeyPointDataset def _parse_box(values: List[str]) -> np.ndarray: @@ -243,7 +244,7 @@ def detections_to_yolo_annotations( def save_yolo_annotations( - dataset: "DetectionDataset", + dataset: Union["DetectionDataset", "KeyPointDataset"], annotations_directory_path: str, min_image_area_percentage: float = 0.0, max_image_area_percentage: float = 1.0, @@ -256,13 +257,19 @@ def save_yolo_annotations( yolo_annotations_path = os.path.join( annotations_directory_path, yolo_annotations_name ) - lines = detections_to_yolo_annotations( - detections=annotation, - image_shape=image.shape, # type: ignore - min_image_area_percentage=min_image_area_percentage, - max_image_area_percentage=max_image_area_percentage, - approximation_percentage=approximation_percentage, - ) + + if isinstance(dataset, DetectionDataset): + lines = detections_to_yolo_annotations( + detections=annotation, + image_shape=image.shape, # type: ignore + min_image_area_percentage=min_image_area_percentage, + max_image_area_percentage=max_image_area_percentage, + approximation_percentage=approximation_percentage, + ) + else: + lines = keypoints_to_yolo_annotations( + keypoints=annotation, image_shape=image.shape + ) save_text_file(lines=lines, file_path=yolo_annotations_path) @@ -270,3 +277,110 @@ def save_data_yaml(data_yaml_path: str, classes: List[str]) -> None: data = {"nc": len(classes), "names": classes} Path(data_yaml_path).parent.mkdir(parents=True, exist_ok=True) save_yaml_file(data=data, file_path=data_yaml_path) + + +def load_yolo_keypoint_annotations( + images_directory_path: str, + annotations_directory_path: str, + data_yaml_path: str, +) -> Tuple[List[str], List[str], Dict[str, KeyPoints]]: + """ + Loads YOLO annotations and returns class names, images, + and their corresponding KeyPoints annotations. + + Args: + images_directory_path (str): The path to the directory containing the images. + annotations_directory_path (str): The path to the directory + containing the YOLO annotation files. + data_yaml_path (str): The path to the data + YAML file containing class information. + + Returns: + Tuple[List[str], List[str], Dict[str, KeyPoints]]: A tuple containing a list + of class names, a list of image paths, and a dictionary with image names + as keys and corresponding KeyPoints annotations as values. + """ + image_paths = [ + str(path) + for path in list_files_with_extensions( + directory=images_directory_path, extensions=["jpg", "jpeg", "png"] + ) + ] + + classes = _extract_class_names(file_path=data_yaml_path) + annotations = {} + + for image_path in image_paths: + image_stem = Path(image_path).stem + annotation_path = os.path.join(annotations_directory_path, f"{image_stem}.txt") + if not os.path.exists(annotation_path): + annotations[image_path] = KeyPoints.empty() + continue + + image = cv2.imread(image_path) + lines = read_txt_file(file_path=annotation_path, skip_empty=True) + h, w, _ = image.shape + resolution_wh = (w, h) + + annotation = yolo_annotations_to_keypoints( + lines=lines, resolution_wh=resolution_wh + ) + annotations[image_path] = annotation + return classes, image_paths, annotations + + +def yolo_annotations_to_keypoints( + lines: List[str], + resolution_wh: Tuple[int, int], +) -> KeyPoints: + if len(lines) == 0: + return KeyPoints.empty() + + class_ids, keypoints_list = [], [] + w, h = resolution_wh + for line in lines: + values = line.split() + class_ids.append(int(values[0])) + if len(values) > 5: + keypoints = np.array([float(value) for value in values[5:]]).reshape(-1, 2) + keypoints *= np.array([w, h], dtype=np.float32) + keypoints_list.append(keypoints) + + class_ids = np.array(class_ids, dtype=np.int_) + keypoints = np.array(keypoints_list, dtype=np.float32) + data = {} + + return KeyPoints(xy=keypoints, class_id=class_ids, data=data) + + +def keypoints_to_yolo_annotations( + keypoints: KeyPoints, image_shape: Tuple[int, int, int] +) -> List[str]: + """ + Converts keypoints data into YOLO format annotations. + + Args: + keypoints (KeyPoints): The keypoints object containing class IDs + and keypoints coordinates. + image_shape (Tuple[int, int, int]): The shape of the image + as (height, width, channels), used to normalize the coordinates. + + Returns: + List[str]: + A list of YOLO-formatted annotations where each line corresponds + to an object with its class ID and keypoints coordinates. + """ + h, w, _ = image_shape + annotations = [] + + for i, (cls_id, points) in enumerate(zip(keypoints.class_id, keypoints.xy)): + normalized_points = points / np.array([w, h]) + annotation = [str(cls_id)] + annotation.extend(map(str, normalized_points.flatten())) + + if keypoints.confidence is not None: + annotation.append(str(keypoints.confidence[i])) + + annotations.append(" ".join(annotation)) + + return annotations diff --git a/supervision/dataset/utils.py b/supervision/dataset/utils.py index 20b80978f..f41376f31 100644 --- a/supervision/dataset/utils.py +++ b/supervision/dataset/utils.py @@ -17,7 +17,7 @@ ) if TYPE_CHECKING: - from supervision.dataset.core import DetectionDataset + from supervision.dataset.core import DetectionDataset, KeyPointDataset T = TypeVar("T") @@ -99,12 +99,15 @@ def map_detections_class_id( def save_dataset_images( - dataset: "DetectionDataset", images_directory_path: str + dataset: Union["DetectionDataset", "KeyPointDataset"], images_directory_path: str ) -> None: Path(images_directory_path).mkdir(parents=True, exist_ok=True) for image_path in dataset.image_paths: final_path = os.path.join(images_directory_path, Path(image_path).name) - if image_path in dataset._images_in_memory: + if ( + hasattr(dataset, "_images_in_memory") + and image_path in dataset._images_in_memory + ): image = dataset._images_in_memory[image_path] cv2.imwrite(final_path, image) else: From b477a618caac57685c642b3e1a196c78f0958f14 Mon Sep 17 00:00:00 2001 From: Shadab Date: Tue, 27 Aug 2024 20:18:03 +0100 Subject: [PATCH 2/2] refactor : added types for split method --- supervision/dataset/core.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/supervision/dataset/core.py b/supervision/dataset/core.py index 77ed4ccba..2b7991720 100644 --- a/supervision/dataset/core.py +++ b/supervision/dataset/core.py @@ -1002,7 +1002,10 @@ def __eq__(self, other) -> bool: return True def split( - self, split_ratio=0.8, random_state=None, shuffle: bool = True + self, + split_ratio: float = 0.8, + random_state: Optional[int] = None, + shuffle: bool = True, ) -> Tuple[KeyPointDataset, KeyPointDataset]: """ Splits the dataset into two parts (training and testing)