Skip to content

Commit

Permalink
Pose notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
Idan-BenAmi committed Jun 5, 2024
1 parent d253932 commit ef1e2ce
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 99 deletions.
15 changes: 2 additions & 13 deletions tutorials/mct_model_garden/models_pytorch/yolov8/yolov8.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import math
import re
from copy import deepcopy
from typing import Dict, List, Tuple, Any, Callable
from typing import Dict, List, Tuple, Any

import numpy as np
import torch
Expand Down Expand Up @@ -284,7 +284,7 @@ class Detect_wo_bb_dec(nn.Module):
def __init__(self, nc: int = 80,
ch: List[int] = ()):
"""
Detection layer for YOLOv8.
Detection layer for YOLOv8. Bounding box decoding was removed.
Args:
nc (int): Number of classes.
ch (List[int]): List of channel values for detection layers.
Expand All @@ -304,17 +304,6 @@ def __init__(self, nc: int = 80,
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3),
nn.Conv2d(c3, self.nc, 1)) for x in ch)
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
anchors, strides = (x.transpose(0, 1) for x in make_anchors(self.feat_sizes,
self.stride, 0.5))
strides = strides / self.img_size
anchors = anchors * strides
self.relu1 = nn.ReLU()
self.relu2 = nn.ReLU()
self.relu3 = nn.ReLU()
self.relu4 = nn.ReLU()

self.register_buffer('anchors', anchors)
self.register_buffer('strides', strides)

def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
shape = x[0].shape # BCHW
Expand Down
105 changes: 19 additions & 86 deletions tutorials/mct_model_garden/models_pytorch/yolov8/yolov8_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,22 @@ def nms(dets: np.ndarray, scores: np.ndarray, iou_thres: float = 0.5, max_out_de

def combined_nms(batch_boxes, batch_scores, iou_thres: float = 0.5, conf: float = 0.001, max_out_dets: int = 300):

"""
Performs combined Non-Maximum Suppression (NMS) on batches of bounding boxes and scores.
Parameters:
batch_boxes (List[np.ndarray]): A list of arrays, where each array contains bounding boxes for a batch.
batch_scores (List[np.ndarray]): A list of arrays, where each array contains scores for the corresponding bounding boxes.
iou_thres (float): Intersection over Union (IoU) threshold for NMS. Defaults to 0.5.
conf (float): Confidence threshold for filtering boxes. Defaults to 0.001.
max_out_dets (int): Maximum number of output detections per image. Defaults to 300.
Returns:
List[Tuple[np.ndarray, np.ndarray, np.ndarray]]: A list of tuples for each batch, where each tuple contains:
- nms_bbox: Array of bounding boxes after NMS.
- nms_scores: Array of scores after NMS.
- nms_classes: Array of class IDs after NMS.
"""
nms_results = []
for boxes, scores in zip(batch_boxes, batch_scores):

Expand Down Expand Up @@ -220,59 +236,25 @@ def clip_coords(kpts: np.ndarray, h: int, w: int) -> np.ndarray:
kpts[..., 1] = np.clip(kpts[..., 1], a_min=0, a_max=w)
return kpts

def postprocess_yolov8_detection(outputs: Tuple[np.ndarray, np.ndarray, np.ndarray],
conf: float = 0.001,
iou_thres: float = 0.7,
max_out_dets: int = 300) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Postprocess the outputs of a YOLOv8 model for object detection and pose estimation.
Args:
outputs (Tuple[np.ndarray, np.ndarray, np.ndarray]): Tuple containing the model outputs for bounding boxes,
class predictions, and keypoint predictions.
conf (float, optional): Confidence threshold for bounding box predictions. Default is 0.001.
iou_thres (float, optional): IoU (Intersection over Union) threshold for Non-Maximum Suppression (NMS).
Default is 0.7.
max_out_dets (int, optional): Maximum number of output detections to keep after NMS. Default is 300.
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]: Tuple containing the post-processed bounding boxes,
their corresponding scores, and keypoints.
"""

feat_sizes = np.array([80, 40, 20])
stride_sizes = np.array([8, 16, 32])
a, s = (x.transpose() for x in make_anchors_yolo_v8(feat_sizes, stride_sizes, 0.5))

y_bb, y_cls = outputs
dbox = dist2bbox_yolo_v8(y_bb, a, xywh=True, dim=1) * s
detect_out = np.concatenate((dbox, y_cls), 1)

xd = detect_out.transpose([0, 2, 1])

return combined_nms(xd[..., :4], xd[..., 4:84], iou_thres, conf, max_out_dets)


def postprocess_yolov8_keypoints(outputs: Tuple[np.ndarray, np.ndarray, np.ndarray],
conf: float = 0.001,
iou_thres: float = 0.7,
max_out_dets: int = 300) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Postprocess the outputs of a YOLOv8 model for object detection and pose estimation.
Postprocess the outputs of a YOLOv8 model for pose estimation.
Args:
outputs (Tuple[np.ndarray, np.ndarray, np.ndarray]): Tuple containing the model outputs for bounding boxes,
class predictions, and keypoint predictions.
scores and keypoint predictions.
conf (float, optional): Confidence threshold for bounding box predictions. Default is 0.001.
iou_thres (float, optional): IoU (Intersection over Union) threshold for Non-Maximum Suppression (NMS).
Default is 0.7.
max_out_dets (int, optional): Maximum number of output detections to keep after NMS. Default is 300.
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]: Tuple containing the post-processed bounding boxes,
their corresponding scores, and keypoints.
their corresponding scores and keypoints.
"""
kpt_shape = (17, 3)
feat_sizes = np.array([80, 40, 20])
Expand Down Expand Up @@ -403,52 +385,3 @@ def clip_coords(kpts: np.ndarray, h: int, w: int) -> np.ndarray:
kpts[..., 0] = np.clip(kpts[..., 0], a_min=0, a_max=h)
kpts[..., 1] = np.clip(kpts[..., 1], a_min=0, a_max=w)
return kpts


class COCODrawer:
def __init__(self, categories_file):
self.categories = self.get_categories(categories_file)

def get_categories(self, filename):
with open(filename, 'r') as f:
return [line.strip() for line in f.readlines()]

def draw_bounding_box(self, img, annotation, class_id, ax):
x_min, y_min = int(annotation[1]), int(annotation[0])
x_max, y_max = int(annotation[3]), int(annotation[2])
text = self.categories[int(class_id)]
cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (0, 0, 255), 2)
ax.text(x_min, y_min, text, style='italic',
bbox={'facecolor': 'white', 'alpha': 0.7, 'pad': 5})

def draw_keypoints(self, img, keypoints):
skeleton = [
[0, 1], [0, 2], [1, 3], [2, 4], # Head
[5, 6], [5, 7], [7, 9], [6, 8], # Arms
[8, 10], [5, 11], [6, 12], [11, 12], # Body
[11, 13], [12, 14], [13, 15], [14, 16] # Legs
]

# Draw skeleton lines
for connection in skeleton:
start_point = (int(keypoints[connection[0]][0]), int(keypoints[connection[0]][1]))
end_point = (int(keypoints[connection[1]][0]), int(keypoints[connection[1]][1]))
cv2.line(img, start_point, end_point, (255, 0, 0), 2)

# Draw keypoints as colored circles
for point in keypoints:
x, y = int(point[0]), int(point[1])
cv2.circle(img, (x, y), 3, (0, 255, 0), -1)
def annotate_image(self, img, b, s, c, k, scale, ax):
for index, row in enumerate(b):
if s[index] > 0.55:
self.draw_bounding_box(img, row * scale, c[index], ax)
if k is not None:
self.draw_keypoints(img, k[index] * scale)

def plot_image(self, img, b, s, c, k=None):
fig, ax = plt.subplots()
self.annotate_image(img, b, s, c, k, scale=1, ax=ax)
ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
# plt.title(img.shape)
return fig, ax

0 comments on commit ef1e2ce

Please sign in to comment.