Skip to content

Commit

Permalink
Yolov8n pose tutorial (#1097)
Browse files Browse the repository at this point in the history
* Add YOLOv8n pose estimation tutorial
  • Loading branch information
Idan-BenAmi authored Jun 5, 2024
1 parent d13319f commit 32235a4
Show file tree
Hide file tree
Showing 7 changed files with 1,117 additions and 154 deletions.
179 changes: 71 additions & 108 deletions tutorials/mct_model_garden/evaluation_metrics/coco_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
import torch
from tqdm import tqdm

from ..models_pytorch.yolov8.yolov8_postprocess import scale_boxes, scale_coords
from ..models_pytorch.yolov8.yolov8_preprocess import yolov8_preprocess_chw_transpose
from ..models_pytorch.yolov8.postprocess_yolov8_seg import process_masks, postprocess_yolov8_inst_seg



def coco80_to_coco91(x: np.ndarray) -> np.ndarray:
"""
Converts COCO 80-class indices to COCO 91-class indices.
Expand All @@ -46,69 +46,9 @@ def coco80_to_coco91(x: np.ndarray) -> np.ndarray:
return coco91Indexs[x.astype(np.int32)]


def clip_boxes(boxes: np.ndarray, h: int, w: int) -> np.ndarray:
"""
Clip bounding boxes to stay within the image boundaries.
Args:
boxes (numpy.ndarray): Array of bounding boxes in format [y_min, x_min, y_max, x_max].
h (int): Height of the image.
w (int): Width of the image.
Returns:
numpy.ndarray: Clipped bounding boxes.
"""
boxes[..., 0] = np.clip(boxes[..., 0], a_min=0, a_max=h)
boxes[..., 1] = np.clip(boxes[..., 1], a_min=0, a_max=w)
boxes[..., 2] = np.clip(boxes[..., 2], a_min=0, a_max=h)
boxes[..., 3] = np.clip(boxes[..., 3], a_min=0, a_max=w)
return boxes



def format_results(outputs: List, img_ids: List, orig_img_dims: List, output_resize: Dict) -> List[Dict]:
"""
Format model outputs into a list of detection dictionaries.
Args:
outputs (list): List of model outputs, typically containing bounding boxes, scores, and labels.
img_ids (list): List of image IDs corresponding to each output.
orig_img_dims (list): List of tuples representing the original image dimensions (h, w) for each output.
output_resize (Dict): Contains the resize information to map between the model's
output and the original image dimensions.
Returns:
list: A list of detection dictionaries, each containing information about the detected object.
"""
detections = []
h_model, w_model = output_resize['shape']
preserve_aspect_ratio = output_resize['aspect_ratio_preservation']

# Process model outputs and convert to detection format
for idx, output in enumerate(outputs):
image_id = img_ids[idx]
scores = output[1].numpy().squeeze() # Extract scores
labels = (coco80_to_coco91(
output[2].numpy())).squeeze() # Convert COCO 80-class indices to COCO 91-class indices
boxes = output[0].numpy().squeeze() # Extract bounding boxes
boxes = scale_boxes(boxes, orig_img_dims[idx][0], orig_img_dims[idx][1], h_model, w_model,
preserve_aspect_ratio)

for score, label, box in zip(scores, labels, boxes):
detection = {
"image_id": image_id,
"category_id": label,
"bbox": [box[1], box[0], box[3] - box[1], box[2] - box[0]],
"score": score
}
detections.append(detection)

return detections


# COCO evaluation class
class CocoEval:
def __init__(self, path2json: str, output_resize: Dict = None):
def __init__(self, path2json: str, output_resize: Dict = None, task: str = 'Detection'):
"""
Initialize the CocoEval class.
Expand All @@ -128,6 +68,9 @@ def __init__(self, path2json: str, output_resize: Dict = None):
# Resizing information to map between the model's output and the original image dimensions
self.output_resize = output_resize if output_resize else {'shape': (1, 1), 'aspect_ratio_preservation': False}

# Set the task type (Detection/Segmentation/Keypoints)
self.task = task

def add_batch_detections(self, outputs: Tuple[List, List, List, List], targets: List[Dict]):
"""
Add batch detections to the evaluation.
Expand All @@ -142,9 +85,9 @@ def add_batch_detections(self, outputs: Tuple[List, List, List, List], targets:
if len(t) > 0:
img_ids.append(t[0]['image_id'])
orig_img_dims.append(t[0]['orig_img_dims'])
_outs.append([outputs[0][idx], outputs[1][idx], outputs[2][idx], outputs[3][idx]])
_outs.append([o[idx] for o in outputs])

batch_detections = format_results(_outs, img_ids, orig_img_dims, self.output_resize)
batch_detections = self.format_results(_outs, img_ids, orig_img_dims, self.output_resize)

self.all_detections.extend(batch_detections)

Expand All @@ -157,7 +100,12 @@ def result(self) -> List[float]:
"""
# Initialize COCO evaluation object
self.coco_dt = self.coco_gt.loadRes(self.all_detections)
coco_eval = COCOeval(self.coco_gt, self.coco_dt, 'bbox')
if self.task == 'Detection':
coco_eval = COCOeval(self.coco_gt, self.coco_dt, 'bbox')
elif self.task == 'Keypoints':
coco_eval = COCOeval(self.coco_gt, self.coco_dt, 'keypoints')
else:
raise Exception("Unsupported task type of CocoEval")

# Run evaluation
coco_eval.evaluate()
Expand All @@ -175,6 +123,62 @@ def reset(self):
"""
self.all_detections = []

def format_results(self, outputs: List, img_ids: List, orig_img_dims: List, output_resize: Dict) -> List[Dict]:
"""
Format model outputs into a list of detection dictionaries.
Args:
outputs (list): List of model outputs, typically containing bounding boxes, scores, and labels.
img_ids (list): List of image IDs corresponding to each output.
orig_img_dims (list): List of tuples representing the original image dimensions (h, w) for each output.
output_resize (Dict): Contains the resize information to map between the model's
output and the original image dimensions.
Returns:
list: A list of detection dictionaries, each containing information about the detected object.
"""
detections = []
h_model, w_model = output_resize['shape']
preserve_aspect_ratio = output_resize['aspect_ratio_preservation']

if self.task == 'Detection':
# Process model outputs and convert to detection format
for idx, output in enumerate(outputs):
image_id = img_ids[idx]
scores = output[1].numpy().squeeze() # Extract scores
labels = (coco80_to_coco91(
output[2].numpy())).squeeze() # Convert COCO 80-class indices to COCO 91-class indices
boxes = output[0].numpy().squeeze() # Extract bounding boxes
boxes = scale_boxes(boxes, orig_img_dims[idx][0], orig_img_dims[idx][1], h_model, w_model,
preserve_aspect_ratio)

for score, label, box in zip(scores, labels, boxes):
detection = {
"image_id": image_id,
"category_id": label,
"bbox": [box[1], box[0], box[3] - box[1], box[2] - box[0]],
"score": score
}
detections.append(detection)

elif self.task == 'Keypoints':
for output, image_id, (w_orig, h_orig) in zip(outputs, img_ids, orig_img_dims):

bbox, scores, kpts = output

# Add detection results to predicted_keypoints list
if kpts.shape[0]:
kpts = kpts.reshape(-1, 17, 3)
kpts = scale_coords(kpts, h_orig, w_orig, 640, 640, True)
for ind, k in enumerate(kpts):
detections.append({
'category_id': 1,
'image_id': image_id,
'keypoints': k.reshape(51).tolist(),
'score': scores.tolist()[ind] if isinstance(scores.tolist(), list) else scores.tolist()
})

return detections

def load_and_preprocess_image(image_path: str, preprocess: Callable) -> np.ndarray:
"""
Expand Down Expand Up @@ -376,7 +380,7 @@ def model_predict(model: Any,


def coco_evaluate(model: Any, preprocess: Callable, dataset_folder: str, annotation_file: str, batch_size: int,
output_resize: tuple, model_inference: Callable = model_predict) -> dict:
output_resize: tuple, model_inference: Callable = model_predict, task: str = 'Detection') -> dict:
"""
Evaluate a model on the COCO dataset.
Expand All @@ -400,7 +404,7 @@ def coco_evaluate(model: Any, preprocess: Callable, dataset_folder: str, annotat
coco_loader = DataLoader(coco_dataset, batch_size)

# Initialize the evaluation metric object
coco_metric = CocoEval(annotation_file, output_resize)
coco_metric = CocoEval(annotation_file, output_resize, task)

# Iterate and the evaluation set
for batch_idx, (images, targets) in enumerate(coco_loader):
Expand All @@ -415,44 +419,6 @@ def coco_evaluate(model: Any, preprocess: Callable, dataset_folder: str, annotat

return coco_metric.result()

def scale_boxes(boxes: np.ndarray, h_image: int, w_image: int, h_model: int, w_model: int, preserve_aspect_ratio: bool, normalized: bool = True) -> np.ndarray:
"""
Scale and offset bounding boxes based on model output size and original image size.
Args:
boxes (numpy.ndarray): Array of bounding boxes in format [y_min, x_min, y_max, x_max].
h_image (int): Original image height.
w_image (int): Original image width.
h_model (int): Model output height.
w_model (int): Model output width.
preserve_aspect_ratio (bool): Whether to preserve image aspect ratio during scaling
Returns:
numpy.ndarray: Scaled and offset bounding boxes.
"""
deltaH, deltaW = 0, 0
H, W = h_model, w_model
scale_H, scale_W = h_image / H, w_image / W

if preserve_aspect_ratio:
scale_H = scale_W = max(h_image / H, w_image / W)
H_tag = int(np.round(h_image / scale_H))
W_tag = int(np.round(w_image / scale_W))
deltaH, deltaW = int((H - H_tag) / 2), int((W - W_tag) / 2)

nh, nw = (H, W) if normalized else (1, 1)

# Scale and offset boxes
boxes[..., 0] = (boxes[..., 0] * nh - deltaH) * scale_H
boxes[..., 1] = (boxes[..., 1] * nw - deltaW) * scale_W
boxes[..., 2] = (boxes[..., 2] * nh - deltaH) * scale_H
boxes[..., 3] = (boxes[..., 3] * nw - deltaW) * scale_W

# Clip boxes
boxes = clip_boxes(boxes, h_image, w_image)

return boxes

def masks_to_coco_rle(masks, boxes, image_id, height, width, scores, classes, mask_threshold):
"""
Converts masks to COCO RLE format and compiles results including bounding boxes and scores.
Expand Down Expand Up @@ -590,6 +556,3 @@ def evaluate_yolov8_segmentation(model, data_dir, data_type='val2017', img_ids_l

save_results_to_json(results, output_file)
evaluate_seg_model(ann_file, output_file)



Original file line number Diff line number Diff line change
@@ -1,50 +1,11 @@
from typing import List
import numpy as np
import cv2
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from typing import Tuple

from tutorials.mct_model_garden.models_pytorch.yolov8.yolov8_postprocess import nms

def nms(dets: np.ndarray, scores: np.ndarray, iou_thres: float = 0.3, max_out_dets: int = 300) -> List[int]:
"""
Perform Non-Maximum Suppression (NMS) on detected bounding boxes.
Args:
dets (np.ndarray): Array of bounding box coordinates of shape (N, 4) representing [y1, x1, y2, x2].
scores (np.ndarray): Array of confidence scores associated with each bounding box.
iou_thres (float, optional): IoU threshold for NMS. Default is 0.5.
max_out_dets (int, optional): Maximum number of output detections to keep. Default is 300.
Returns:
List[int]: List of indices representing the indices of the bounding boxes to keep after NMS.
"""
y1, x1 = dets[:, 0], dets[:, 1]
y2, x2 = dets[:, 2], dets[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]

keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])

w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)

inds = np.where(ovr <= iou_thres)[0]
order = order[inds + 1]

return keep[:max_out_dets]


def combined_nms_seg(batch_boxes, batch_scores, batch_masks, iou_thres: float = 0.3, conf: float = 0.1, max_out_dets: int = 300):
"""
Perform combined Non-Maximum Suppression (NMS) and segmentation mask processing for batched inputs.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# The following code was mostly duplicated from https://github.com/ultralytics/ultralytics
# ==============================================================================

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8-seg instance segmentation model. For Usage examples see https://docs.ultralytics.com/tasks/segment
##
Expand Down
Loading

0 comments on commit 32235a4

Please sign in to comment.