diff --git a/tutorials/mct_model_garden/evaluation_metrics/coco_evaluation.py b/tutorials/mct_model_garden/evaluation_metrics/coco_evaluation.py index a5ccb15c8..218fc6cdd 100644 --- a/tutorials/mct_model_garden/evaluation_metrics/coco_evaluation.py +++ b/tutorials/mct_model_garden/evaluation_metrics/coco_evaluation.py @@ -24,11 +24,11 @@ import torch from tqdm import tqdm +from ..models_pytorch.yolov8.yolov8_postprocess import scale_boxes, scale_coords from ..models_pytorch.yolov8.yolov8_preprocess import yolov8_preprocess_chw_transpose from ..models_pytorch.yolov8.postprocess_yolov8_seg import process_masks, postprocess_yolov8_inst_seg - def coco80_to_coco91(x: np.ndarray) -> np.ndarray: """ Converts COCO 80-class indices to COCO 91-class indices. @@ -46,69 +46,9 @@ def coco80_to_coco91(x: np.ndarray) -> np.ndarray: return coco91Indexs[x.astype(np.int32)] -def clip_boxes(boxes: np.ndarray, h: int, w: int) -> np.ndarray: - """ - Clip bounding boxes to stay within the image boundaries. - - Args: - boxes (numpy.ndarray): Array of bounding boxes in format [y_min, x_min, y_max, x_max]. - h (int): Height of the image. - w (int): Width of the image. - - Returns: - numpy.ndarray: Clipped bounding boxes. - """ - boxes[..., 0] = np.clip(boxes[..., 0], a_min=0, a_max=h) - boxes[..., 1] = np.clip(boxes[..., 1], a_min=0, a_max=w) - boxes[..., 2] = np.clip(boxes[..., 2], a_min=0, a_max=h) - boxes[..., 3] = np.clip(boxes[..., 3], a_min=0, a_max=w) - return boxes - - - -def format_results(outputs: List, img_ids: List, orig_img_dims: List, output_resize: Dict) -> List[Dict]: - """ - Format model outputs into a list of detection dictionaries. - - Args: - outputs (list): List of model outputs, typically containing bounding boxes, scores, and labels. - img_ids (list): List of image IDs corresponding to each output. - orig_img_dims (list): List of tuples representing the original image dimensions (h, w) for each output. - output_resize (Dict): Contains the resize information to map between the model's - output and the original image dimensions. - - Returns: - list: A list of detection dictionaries, each containing information about the detected object. - """ - detections = [] - h_model, w_model = output_resize['shape'] - preserve_aspect_ratio = output_resize['aspect_ratio_preservation'] - - # Process model outputs and convert to detection format - for idx, output in enumerate(outputs): - image_id = img_ids[idx] - scores = output[1].numpy().squeeze() # Extract scores - labels = (coco80_to_coco91( - output[2].numpy())).squeeze() # Convert COCO 80-class indices to COCO 91-class indices - boxes = output[0].numpy().squeeze() # Extract bounding boxes - boxes = scale_boxes(boxes, orig_img_dims[idx][0], orig_img_dims[idx][1], h_model, w_model, - preserve_aspect_ratio) - - for score, label, box in zip(scores, labels, boxes): - detection = { - "image_id": image_id, - "category_id": label, - "bbox": [box[1], box[0], box[3] - box[1], box[2] - box[0]], - "score": score - } - detections.append(detection) - - return detections - - # COCO evaluation class class CocoEval: - def __init__(self, path2json: str, output_resize: Dict = None): + def __init__(self, path2json: str, output_resize: Dict = None, task: str = 'Detection'): """ Initialize the CocoEval class. @@ -128,6 +68,9 @@ def __init__(self, path2json: str, output_resize: Dict = None): # Resizing information to map between the model's output and the original image dimensions self.output_resize = output_resize if output_resize else {'shape': (1, 1), 'aspect_ratio_preservation': False} + # Set the task type (Detection/Segmentation/Keypoints) + self.task = task + def add_batch_detections(self, outputs: Tuple[List, List, List, List], targets: List[Dict]): """ Add batch detections to the evaluation. @@ -142,9 +85,9 @@ def add_batch_detections(self, outputs: Tuple[List, List, List, List], targets: if len(t) > 0: img_ids.append(t[0]['image_id']) orig_img_dims.append(t[0]['orig_img_dims']) - _outs.append([outputs[0][idx], outputs[1][idx], outputs[2][idx], outputs[3][idx]]) + _outs.append([o[idx] for o in outputs]) - batch_detections = format_results(_outs, img_ids, orig_img_dims, self.output_resize) + batch_detections = self.format_results(_outs, img_ids, orig_img_dims, self.output_resize) self.all_detections.extend(batch_detections) @@ -157,7 +100,12 @@ def result(self) -> List[float]: """ # Initialize COCO evaluation object self.coco_dt = self.coco_gt.loadRes(self.all_detections) - coco_eval = COCOeval(self.coco_gt, self.coco_dt, 'bbox') + if self.task == 'Detection': + coco_eval = COCOeval(self.coco_gt, self.coco_dt, 'bbox') + elif self.task == 'Keypoints': + coco_eval = COCOeval(self.coco_gt, self.coco_dt, 'keypoints') + else: + raise Exception("Unsupported task type of CocoEval") # Run evaluation coco_eval.evaluate() @@ -175,6 +123,62 @@ def reset(self): """ self.all_detections = [] + def format_results(self, outputs: List, img_ids: List, orig_img_dims: List, output_resize: Dict) -> List[Dict]: + """ + Format model outputs into a list of detection dictionaries. + + Args: + outputs (list): List of model outputs, typically containing bounding boxes, scores, and labels. + img_ids (list): List of image IDs corresponding to each output. + orig_img_dims (list): List of tuples representing the original image dimensions (h, w) for each output. + output_resize (Dict): Contains the resize information to map between the model's + output and the original image dimensions. + + Returns: + list: A list of detection dictionaries, each containing information about the detected object. + """ + detections = [] + h_model, w_model = output_resize['shape'] + preserve_aspect_ratio = output_resize['aspect_ratio_preservation'] + + if self.task == 'Detection': + # Process model outputs and convert to detection format + for idx, output in enumerate(outputs): + image_id = img_ids[idx] + scores = output[1].numpy().squeeze() # Extract scores + labels = (coco80_to_coco91( + output[2].numpy())).squeeze() # Convert COCO 80-class indices to COCO 91-class indices + boxes = output[0].numpy().squeeze() # Extract bounding boxes + boxes = scale_boxes(boxes, orig_img_dims[idx][0], orig_img_dims[idx][1], h_model, w_model, + preserve_aspect_ratio) + + for score, label, box in zip(scores, labels, boxes): + detection = { + "image_id": image_id, + "category_id": label, + "bbox": [box[1], box[0], box[3] - box[1], box[2] - box[0]], + "score": score + } + detections.append(detection) + + elif self.task == 'Keypoints': + for output, image_id, (w_orig, h_orig) in zip(outputs, img_ids, orig_img_dims): + + bbox, scores, kpts = output + + # Add detection results to predicted_keypoints list + if kpts.shape[0]: + kpts = kpts.reshape(-1, 17, 3) + kpts = scale_coords(kpts, h_orig, w_orig, 640, 640, True) + for ind, k in enumerate(kpts): + detections.append({ + 'category_id': 1, + 'image_id': image_id, + 'keypoints': k.reshape(51).tolist(), + 'score': scores.tolist()[ind] if isinstance(scores.tolist(), list) else scores.tolist() + }) + + return detections def load_and_preprocess_image(image_path: str, preprocess: Callable) -> np.ndarray: """ @@ -376,7 +380,7 @@ def model_predict(model: Any, def coco_evaluate(model: Any, preprocess: Callable, dataset_folder: str, annotation_file: str, batch_size: int, - output_resize: tuple, model_inference: Callable = model_predict) -> dict: + output_resize: tuple, model_inference: Callable = model_predict, task: str = 'Detection') -> dict: """ Evaluate a model on the COCO dataset. @@ -400,7 +404,7 @@ def coco_evaluate(model: Any, preprocess: Callable, dataset_folder: str, annotat coco_loader = DataLoader(coco_dataset, batch_size) # Initialize the evaluation metric object - coco_metric = CocoEval(annotation_file, output_resize) + coco_metric = CocoEval(annotation_file, output_resize, task) # Iterate and the evaluation set for batch_idx, (images, targets) in enumerate(coco_loader): @@ -415,44 +419,6 @@ def coco_evaluate(model: Any, preprocess: Callable, dataset_folder: str, annotat return coco_metric.result() -def scale_boxes(boxes: np.ndarray, h_image: int, w_image: int, h_model: int, w_model: int, preserve_aspect_ratio: bool, normalized: bool = True) -> np.ndarray: - """ - Scale and offset bounding boxes based on model output size and original image size. - - Args: - boxes (numpy.ndarray): Array of bounding boxes in format [y_min, x_min, y_max, x_max]. - h_image (int): Original image height. - w_image (int): Original image width. - h_model (int): Model output height. - w_model (int): Model output width. - preserve_aspect_ratio (bool): Whether to preserve image aspect ratio during scaling - - Returns: - numpy.ndarray: Scaled and offset bounding boxes. - """ - deltaH, deltaW = 0, 0 - H, W = h_model, w_model - scale_H, scale_W = h_image / H, w_image / W - - if preserve_aspect_ratio: - scale_H = scale_W = max(h_image / H, w_image / W) - H_tag = int(np.round(h_image / scale_H)) - W_tag = int(np.round(w_image / scale_W)) - deltaH, deltaW = int((H - H_tag) / 2), int((W - W_tag) / 2) - - nh, nw = (H, W) if normalized else (1, 1) - - # Scale and offset boxes - boxes[..., 0] = (boxes[..., 0] * nh - deltaH) * scale_H - boxes[..., 1] = (boxes[..., 1] * nw - deltaW) * scale_W - boxes[..., 2] = (boxes[..., 2] * nh - deltaH) * scale_H - boxes[..., 3] = (boxes[..., 3] * nw - deltaW) * scale_W - - # Clip boxes - boxes = clip_boxes(boxes, h_image, w_image) - - return boxes - def masks_to_coco_rle(masks, boxes, image_id, height, width, scores, classes, mask_threshold): """ Converts masks to COCO RLE format and compiles results including bounding boxes and scores. @@ -590,6 +556,3 @@ def evaluate_yolov8_segmentation(model, data_dir, data_type='val2017', img_ids_l save_results_to_json(results, output_file) evaluate_seg_model(ann_file, output_file) - - - diff --git a/tutorials/mct_model_garden/models_pytorch/yolov8/postprocess_yolov8_seg.py b/tutorials/mct_model_garden/models_pytorch/yolov8/postprocess_yolov8_seg.py index 9512b62ca..b5a590139 100644 --- a/tutorials/mct_model_garden/models_pytorch/yolov8/postprocess_yolov8_seg.py +++ b/tutorials/mct_model_garden/models_pytorch/yolov8/postprocess_yolov8_seg.py @@ -1,50 +1,11 @@ from typing import List import numpy as np import cv2 -import matplotlib.pyplot as plt -import matplotlib.colors as mcolors from typing import Tuple +from tutorials.mct_model_garden.models_pytorch.yolov8.yolov8_postprocess import nms -def nms(dets: np.ndarray, scores: np.ndarray, iou_thres: float = 0.3, max_out_dets: int = 300) -> List[int]: - """ - Perform Non-Maximum Suppression (NMS) on detected bounding boxes. - - Args: - dets (np.ndarray): Array of bounding box coordinates of shape (N, 4) representing [y1, x1, y2, x2]. - scores (np.ndarray): Array of confidence scores associated with each bounding box. - iou_thres (float, optional): IoU threshold for NMS. Default is 0.5. - max_out_dets (int, optional): Maximum number of output detections to keep. Default is 300. - - Returns: - List[int]: List of indices representing the indices of the bounding boxes to keep after NMS. - - """ - y1, x1 = dets[:, 0], dets[:, 1] - y2, x2 = dets[:, 2], dets[:, 3] - areas = (x2 - x1 + 1) * (y2 - y1 + 1) - order = scores.argsort()[::-1] - - keep = [] - while order.size > 0: - i = order[0] - keep.append(i) - xx1 = np.maximum(x1[i], x1[order[1:]]) - yy1 = np.maximum(y1[i], y1[order[1:]]) - xx2 = np.minimum(x2[i], x2[order[1:]]) - yy2 = np.minimum(y2[i], y2[order[1:]]) - - w = np.maximum(0.0, xx2 - xx1 + 1) - h = np.maximum(0.0, yy2 - yy1 + 1) - inter = w * h - ovr = inter / (areas[i] + areas[order[1:]] - inter) - - inds = np.where(ovr <= iou_thres)[0] - order = order[inds + 1] - - return keep[:max_out_dets] - def combined_nms_seg(batch_boxes, batch_scores, batch_masks, iou_thres: float = 0.3, conf: float = 0.1, max_out_dets: int = 300): """ Perform combined Non-Maximum Suppression (NMS) and segmentation mask processing for batched inputs. diff --git a/tutorials/mct_model_garden/models_pytorch/yolov8/yolov8-seg.yaml b/tutorials/mct_model_garden/models_pytorch/yolov8/yolov8-seg.yaml index 9588ac31f..b53fe511e 100644 --- a/tutorials/mct_model_garden/models_pytorch/yolov8/yolov8-seg.yaml +++ b/tutorials/mct_model_garden/models_pytorch/yolov8/yolov8-seg.yaml @@ -1,3 +1,6 @@ +# The following code was mostly duplicated from https://github.com/ultralytics/ultralytics +# ============================================================================== + # Ultralytics YOLO πŸš€, AGPL-3.0 license # YOLOv8-seg instance segmentation model. For Usage examples see https://docs.ultralytics.com/tasks/segment ## diff --git a/tutorials/mct_model_garden/models_pytorch/yolov8/yolov8.py b/tutorials/mct_model_garden/models_pytorch/yolov8/yolov8.py index dce19e0ce..1f4ec2a35 100644 --- a/tutorials/mct_model_garden/models_pytorch/yolov8/yolov8.py +++ b/tutorials/mct_model_garden/models_pytorch/yolov8/yolov8.py @@ -14,7 +14,7 @@ The code is organized as follows: - Classes definitions of Yolov8n building blocks: Conv, Bottleneck, C2f, SPPF, Upsample, Concaat, DFL and Detect -- Detection Model definition: DetectionModelPytorch +- Detection Model definition: ModelPyTorch - PostProcessWrapper Wrapping the Yolov8n model with PostProcess layer (Specifically, sony_custom_layers/multiclass_nms) - A getter function for getting a new instance of the model @@ -38,6 +38,8 @@ from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device from sony_custom_layers.pytorch.object_detection.nms import multiclass_nms +from tutorials.mct_model_garden.models_pytorch.yolov8.yolov8_postprocess import postprocess_yolov8_keypoints + def yaml_load(file: str = 'data.yaml', append_filename: bool = False) -> Dict[str, any]: """ @@ -277,6 +279,71 @@ def bias_init(self): a[-1].bias.data[:] = 1.0 # box b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img) + +class Detect_wo_bb_dec(nn.Module): + def __init__(self, nc: int = 80, + ch: List[int] = ()): + """ + Detection layer for YOLOv8. Bounding box decoding was removed. + Args: + nc (int): Number of classes. + ch (List[int]): List of channel values for detection layers. + """ + super().__init__() + self.nc = nc # number of classes + self.nl = len(ch) # number of detection layers + self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x) + self.no = nc + self.reg_max * 4 # number of outputs per anchor + self.stride = torch.Tensor([8, 16, 32]) + self.feat_sizes = torch.Tensor([80, 40, 20]) + self.img_size = 640 # img size + c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels + self.cv2 = nn.ModuleList( + nn.Sequential(Conv(x, c2, 3), + Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch) + self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), + nn.Conv2d(c3, self.nc, 1)) for x in ch) + self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity() + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: + shape = x[0].shape # BCHW + for i in range(self.nl): + x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1) + box, cls = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2).split( + (self.reg_max * 4, self.nc), 1) + + y_cls = cls.sigmoid() + y_bb = self.dfl(box) + return y_bb, y_cls + + + def bias_init(self): + """Initialize Detect() biases, WARNING: requires stride availability.""" + m = self # self.model[-1] # Detect() module + for a, b, s in zip(m.cv2, m.cv3, m.stride): # from + a[-1].bias.data[:] = 1.0 # box + b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img) + +class Pose(Detect_wo_bb_dec): + """YOLOv8 Pose head for keypoints models.""" + + def __init__(self, nc=80, kpt_shape=(17, 3), ch=()): + """Initialize YOLO network with default parameters and Convolutional Layers.""" + super().__init__(nc, ch) + self.kpt_shape = kpt_shape # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) + self.nk = kpt_shape[0] * kpt_shape[1] # number of keypoints total + self.detect = Detect_wo_bb_dec.forward + + c4 = max(ch[0] // 4, self.nk) + self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch) + + def forward(self, x): + """Perform forward pass through YOLO model and return predictions.""" + bs = x[0].shape[0] # batch size + kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w) + y_bb, y_cls = self.detect(self, x) + return y_bb, y_cls, kpt + def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) """Parse a YOLO model.yaml dictionary into a PyTorch model.""" import ast @@ -321,7 +388,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) args = [ch[f]] elif m is Concat: c2 = sum(ch[x] for x in f) - elif m in [Segment, Detect]: + elif m in [Segment, Detect, Pose]: args.append([ch[x] for x in f]) else: c2 = ch[f] @@ -405,6 +472,31 @@ def forward(self, images): iou_threshold=self.iou_threshold, max_detections=self.max_detections) return nms +def keypoints_model_predict(model: Any, inputs: np.ndarray) -> List: + """ + Perform inference using the provided PyTorch model on the given inputs. + + This function handles moving the inputs to the appropriate torch device and data type, + and detaches and moves the outputs to the CPU. + + Args: + model (Any): The PyTorch model used for inference. + inputs (np.ndarray): Input data to perform inference on. + + Returns: + List: List containing tensors of predictions. + """ + device = get_working_device() + inputs = torch.from_numpy(inputs).to(device=device, dtype=torch.float) + + # Run Pytorch inference on the batch + outputs = model(inputs) + + # Detach outputs and move to cpu + output_np = [o.detach().cpu().numpy() for o in outputs] + + return postprocess_yolov8_keypoints(output_np) + def yolov8_pytorch(model_yaml: str) -> (nn.Module, Dict): """ @@ -419,7 +511,7 @@ def yolov8_pytorch(model_yaml: str) -> (nn.Module, Dict): """ cfg = model_yaml cfg_dict = yaml_load(cfg, append_filename=True) # model dict - model = DetectionModelPyTorch(cfg_dict) # model + model = ModelPyTorch(cfg_dict) # model return model, cfg_dict @@ -442,7 +534,7 @@ def yolov8_pytorch_pp(model_yaml: str, """ cfg = model_yaml cfg_dict = yaml_load(cfg, append_filename=True) # model dict - model = DetectionModelPyTorch(cfg_dict) # model + model = ModelPyTorch(cfg_dict) # model model_pp = PostProcessWrapper(model=model, score_threshold=score_threshold, iou_threshold=iou_threshold, @@ -490,8 +582,8 @@ def forward(self, x): mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients y_bb, y_cls = self.detect(self, x) - - return y_bb, y_cls, mc, p + + return y_bb, y_cls, mc, p class ModelPyTorch(nn.Module, PyTorchModelHubMixin): diff --git a/tutorials/mct_model_garden/models_pytorch/yolov8/yolov8_postprocess.py b/tutorials/mct_model_garden/models_pytorch/yolov8/yolov8_postprocess.py new file mode 100644 index 000000000..0c9e2e7df --- /dev/null +++ b/tutorials/mct_model_garden/models_pytorch/yolov8/yolov8_postprocess.py @@ -0,0 +1,387 @@ +from enum import Enum +from typing import List, Tuple +import numpy as np +import cv2 +import matplotlib.pyplot as plt + +def nms(dets: np.ndarray, scores: np.ndarray, iou_thres: float = 0.5, max_out_dets: int = 300) -> List[int]: + """ + Perform Non-Maximum Suppression (NMS) on detected bounding boxes. + + Args: + dets (np.ndarray): Array of bounding box coordinates of shape (N, 4) representing [y1, x1, y2, x2]. + scores (np.ndarray): Array of confidence scores associated with each bounding box. + iou_thres (float, optional): IoU threshold for NMS. Default is 0.5. + max_out_dets (int, optional): Maximum number of output detections to keep. Default is 300. + + Returns: + List[int]: List of indices representing the indices of the bounding boxes to keep after NMS. + + """ + y1, x1 = dets[:, 0], dets[:, 1] + y2, x2 = dets[:, 2], dets[:, 3] + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= iou_thres)[0] + order = order[inds + 1] + + return keep[:max_out_dets] + +def combined_nms(batch_boxes, batch_scores, iou_thres: float = 0.5, conf: float = 0.001, max_out_dets: int = 300): + + """ + Performs combined Non-Maximum Suppression (NMS) on batches of bounding boxes and scores. + + Parameters: + batch_boxes (List[np.ndarray]): A list of arrays, where each array contains bounding boxes for a batch. + batch_scores (List[np.ndarray]): A list of arrays, where each array contains scores for the corresponding bounding boxes. + iou_thres (float): Intersection over Union (IoU) threshold for NMS. Defaults to 0.5. + conf (float): Confidence threshold for filtering boxes. Defaults to 0.001. + max_out_dets (int): Maximum number of output detections per image. Defaults to 300. + + Returns: + List[Tuple[np.ndarray, np.ndarray, np.ndarray]]: A list of tuples for each batch, where each tuple contains: + - nms_bbox: Array of bounding boxes after NMS. + - nms_scores: Array of scores after NMS. + - nms_classes: Array of class IDs after NMS. + """ + nms_results = [] + for boxes, scores in zip(batch_boxes, batch_scores): + + xc = np.argmax(scores, 1) + xs = np.amax(scores, 1) + x = np.concatenate([boxes, np.expand_dims(xs, 1), np.expand_dims(xc, 1)], 1) + + xi = xs > conf + x = x[xi] + + x = x[np.argsort(-x[:, 4])[:8400]] + scores = x[:, 4] + x[..., :4] = convert_to_ymin_xmin_ymax_xmax_format(x[..., :4], BoxFormat.XC_YC_W_H) + offset = x[:, 5] * 640 + boxes = x[..., :4] + np.expand_dims(offset, 1) + + # Original post-processing part + valid_indexs = nms(boxes, scores, iou_thres=iou_thres, max_out_dets=max_out_dets) + x = x[valid_indexs] + nms_classes = x[:, 5] + nms_bbox = x[:, :4] + nms_scores = x[:, 4] + + nms_results.append((nms_bbox, nms_scores, nms_classes)) + + return nms_results + + +class BoxFormat(Enum): + YMIM_XMIN_YMAX_XMAX = 'ymin_xmin_ymax_xmax' + XMIM_YMIN_XMAX_YMAX = 'xmin_ymin_xmax_ymax' + XMIN_YMIN_W_H = 'xmin_ymin_width_height' + XC_YC_W_H = 'xc_yc_width_height' + + +def convert_to_ymin_xmin_ymax_xmax_format(boxes, orig_format: BoxFormat): + """ + changes the box from one format to another (XMIN_YMIN_W_H --> YMIM_XMIN_YMAX_XMAX ) + also support in same format mode (returns the same format) + + :param boxes: + :param orig_format: + :return: box in format YMIM_XMIN_YMAX_XMAX + """ + if len(boxes) == 0: + return boxes + elif orig_format == BoxFormat.YMIM_XMIN_YMAX_XMAX: + return boxes + elif orig_format == BoxFormat.XMIN_YMIN_W_H: + boxes[:, 2] += boxes[:, 0] # convert width to xmax + boxes[:, 3] += boxes[:, 1] # convert height to ymax + boxes[:, 0], boxes[:, 1] = boxes[:, 1], boxes[:, 0].copy() # swap xmin, ymin columns + boxes[:, 2], boxes[:, 3] = boxes[:, 3], boxes[:, 2].copy() # swap xmax, ymax columns + return boxes + elif orig_format == BoxFormat.XMIM_YMIN_XMAX_YMAX: + boxes[:, 0], boxes[:, 1] = boxes[:, 1], boxes[:, 0].copy() # swap xmin, ymin columns + boxes[:, 2], boxes[:, 3] = boxes[:, 3], boxes[:, 2].copy() # swap xmax, ymax columns + return boxes + elif orig_format == BoxFormat.XC_YC_W_H: + new_boxes = np.copy(boxes) + new_boxes[:, 0] = boxes[:, 1] - boxes[:, 3] / 2 # top left y + new_boxes[:, 1] = boxes[:, 0] - boxes[:, 2] / 2 # top left x + new_boxes[:, 2] = boxes[:, 1] + boxes[:, 3] / 2 # bottom right y + new_boxes[:, 3] = boxes[:, 0] + boxes[:, 2] / 2 # bottom right x + return new_boxes + else: + raise Exception("Unsupported boxes format") + +def clip_boxes(boxes: np.ndarray, h: int, w: int) -> np.ndarray: + """ + Clip bounding boxes to stay within the image boundaries. + + Args: + boxes (numpy.ndarray): Array of bounding boxes in format [y_min, x_min, y_max, x_max]. + h (int): Height of the image. + w (int): Width of the image. + + Returns: + numpy.ndarray: Clipped bounding boxes. + """ + boxes[..., 0] = np.clip(boxes[..., 0], a_min=0, a_max=h) + boxes[..., 1] = np.clip(boxes[..., 1], a_min=0, a_max=w) + boxes[..., 2] = np.clip(boxes[..., 2], a_min=0, a_max=h) + boxes[..., 3] = np.clip(boxes[..., 3], a_min=0, a_max=w) + return boxes + + +def scale_boxes(boxes: np.ndarray, h_image: int, w_image: int, h_model: int, w_model: int, preserve_aspect_ratio: bool, normalized: bool = True) -> np.ndarray: + """ + Scale and offset bounding boxes based on model output size and original image size. + + Args: + boxes (numpy.ndarray): Array of bounding boxes in format [y_min, x_min, y_max, x_max]. + h_image (int): Original image height. + w_image (int): Original image width. + h_model (int): Model output height. + w_model (int): Model output width. + preserve_aspect_ratio (bool): Whether to preserve image aspect ratio during scaling + + Returns: + numpy.ndarray: Scaled and offset bounding boxes. + """ + deltaH, deltaW = 0, 0 + H, W = h_model, w_model + scale_H, scale_W = h_image / H, w_image / W + + if preserve_aspect_ratio: + scale_H = scale_W = max(h_image / H, w_image / W) + H_tag = int(np.round(h_image / scale_H)) + W_tag = int(np.round(w_image / scale_W)) + deltaH, deltaW = int((H - H_tag) / 2), int((W - W_tag) / 2) + + nh, nw = (H, W) if normalized else (1, 1) + + # Scale and offset boxes + boxes[..., 0] = (boxes[..., 0] * nh - deltaH) * scale_H + boxes[..., 1] = (boxes[..., 1] * nw - deltaW) * scale_W + boxes[..., 2] = (boxes[..., 2] * nh - deltaH) * scale_H + boxes[..., 3] = (boxes[..., 3] * nw - deltaW) * scale_W + + # Clip boxes + boxes = clip_boxes(boxes, h_image, w_image) + + return boxes + + +def scale_coords(kpts: np.ndarray, h_image: int, w_image: int, h_model: int, w_model: int, preserve_aspect_ratio: bool) -> np.ndarray: + """ + Scale and offset keypoints based on model output size and original image size. + + Args: + kpts (numpy.ndarray): Array of bounding keypoints in format [..., 17, 3] where the last dim is (x, y, visible). + h_image (int): Original image height. + w_image (int): Original image width. + h_model (int): Model output height. + w_model (int): Model output width. + preserve_aspect_ratio (bool): Whether to preserve image aspect ratio during scaling + + Returns: + numpy.ndarray: Scaled and offset bounding boxes. + """ + deltaH, deltaW = 0, 0 + H, W = h_model, w_model + scale_H, scale_W = h_image / H, w_image / W + + if preserve_aspect_ratio: + scale_H = scale_W = max(h_image / H, w_image / W) + H_tag = int(np.round(h_image / scale_H)) + W_tag = int(np.round(w_image / scale_W)) + deltaH, deltaW = int((H - H_tag) / 2), int((W - W_tag) / 2) + + # Scale and offset boxes + kpts[..., 0] = (kpts[..., 0] - deltaH) * scale_H + kpts[..., 1] = (kpts[..., 1] - deltaW) * scale_W + + # Clip boxes + kpts = clip_coords(kpts, h_image, w_image) + + return kpts + +def clip_coords(kpts: np.ndarray, h: int, w: int) -> np.ndarray: + """ + Clip keypoints to stay within the image boundaries. + + Args: + kpts (numpy.ndarray): Array of bounding keypoints in format [..., 17, 3] where the last dim is (x, y, visible). + h (int): Height of the image. + w (int): Width of the image. + + Returns: + numpy.ndarray: Clipped bounding boxes. + """ + kpts[..., 0] = np.clip(kpts[..., 0], a_min=0, a_max=h) + kpts[..., 1] = np.clip(kpts[..., 1], a_min=0, a_max=w) + return kpts + + +def postprocess_yolov8_keypoints(outputs: Tuple[np.ndarray, np.ndarray, np.ndarray], + conf: float = 0.001, + iou_thres: float = 0.7, + max_out_dets: int = 300) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Postprocess the outputs of a YOLOv8 model for pose estimation. + + Args: + outputs (Tuple[np.ndarray, np.ndarray, np.ndarray]): Tuple containing the model outputs for bounding boxes, + scores and keypoint predictions. + conf (float, optional): Confidence threshold for bounding box predictions. Default is 0.001. + iou_thres (float, optional): IoU (Intersection over Union) threshold for Non-Maximum Suppression (NMS). + Default is 0.7. + max_out_dets (int, optional): Maximum number of output detections to keep after NMS. Default is 300. + + Returns: + Tuple[np.ndarray, np.ndarray, np.ndarray]: Tuple containing the post-processed bounding boxes, + their corresponding scores and keypoints. + """ + kpt_shape = (17, 3) + feat_sizes = np.array([80, 40, 20]) + stride_sizes = np.array([8, 16, 32]) + a, s = (x.transpose() for x in make_anchors_yolo_v8(feat_sizes, stride_sizes, 0.5)) + + y_bb, y_cls, kpts = outputs + dbox = dist2bbox_yolo_v8(y_bb, np.expand_dims(a, 0), xywh=True, dim=1) * s + detect_out = np.concatenate((dbox, y_cls), 1) + # additional part for pose estimation + ndim = kpt_shape[1] + pred_kpt = kpts.copy() + if ndim == 3: + pred_kpt[:, 2::3] = 1 / (1 + np.exp(-pred_kpt[:, 2::3])) # sigmoid (WARNING: inplace .sigmoid_() Apple MPS bug) + pred_kpt[:, 0::ndim] = (pred_kpt[:, 0::ndim] * 2.0 + (a[0] - 0.5)) * s + pred_kpt[:, 1::ndim] = (pred_kpt[:, 1::ndim] * 2.0 + (a[1] - 0.5)) * s + + x_batch = np.concatenate([detect_out.transpose([0, 2, 1]), pred_kpt.transpose([0, 2, 1])], 2) + nms_bbox, nms_scores, nms_kpts = [], [], [] + for x in x_batch: + x = x[(x[:, 4] > conf)] + x = x[np.argsort(-x[:, 4])[:8400]] + x[..., :4] = convert_to_ymin_xmin_ymax_xmax_format(x[..., :4], BoxFormat.XC_YC_W_H) + boxes = x[..., :4] + scores = x[..., 4] + + # Original post-processing part + valid_indexs = nms(boxes, scores, iou_thres=iou_thres, max_out_dets=max_out_dets) + x = x[valid_indexs] + nms_bbox.append(x[:, :4]) + nms_scores.append(x[:, 4]) + nms_kpts.append(x[:, 5:]) + + return nms_bbox, nms_scores, nms_kpts + + +def postprocess_yolov8_inst_seg(outputs: Tuple[np.ndarray, np.ndarray, np.ndarray], + conf: float = 0.001, + iou_thres: float = 0.7, + max_out_dets: int = 300) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + + feat_sizes = np.array([80, 40, 20]) + stride_sizes = np.array([8, 16, 32]) + a, s = (x.transpose() for x in make_anchors_yolo_v8(feat_sizes, stride_sizes, 0.5)) + + y_bb, y_cls, y_masks = outputs + dbox = dist2bbox_yolo_v8(y_bb, a, xywh=True, dim=1) * s + detect_out = np.concatenate((dbox, y_cls), 1) + + xd = detect_out.transpose([0, 2, 1]) + + return combined_nms(xd[..., :4], xd[..., 4:84], iou_thres, conf, max_out_dets) + + +def make_anchors_yolo_v8(feats, strides, grid_cell_offset=0.5): + """Generate anchors from features.""" + anchor_points, stride_tensor = [], [] + assert feats is not None + for i, stride in enumerate(strides): + h, w = feats[i], feats[i] + sx = np.arange(stop=w) + grid_cell_offset # shift x + sy = np.arange(stop=h) + grid_cell_offset # shift y + sy, sx = np.meshgrid(sy, sx, indexing='ij') + anchor_points.append(np.stack((sx, sy), -1).reshape((-1, 2))) + stride_tensor.append(np.full((h * w, 1), stride)) + return np.concatenate(anchor_points), np.concatenate(stride_tensor) + + +def dist2bbox_yolo_v8(distance, anchor_points, xywh=True, dim=-1): + """Transform distance(ltrb) to box(xywh or xyxy).""" + lt, rb = np.split(distance,2,axis=dim) + x1y1 = anchor_points - lt + x2y2 = anchor_points + rb + if xywh: + c_xy = (x1y1 + x2y2) / 2 + wh = x2y2 - x1y1 + return np.concatenate((c_xy, wh), dim) # xywh bbox + return np.concatenate((x1y1, x2y2), dim) # xyxy bbox + + +def scale_coords(kpts: np.ndarray, h_image: int, w_image: int, h_model: int, w_model: int, preserve_aspect_ratio: bool) -> np.ndarray: + """ + Scale and offset keypoints based on model output size and original image size. + + Args: + kpts (numpy.ndarray): Array of bounding keypoints in format [..., 17, 3] where the last dim is (x, y, visible). + h_image (int): Original image height. + w_image (int): Original image width. + h_model (int): Model output height. + w_model (int): Model output width. + preserve_aspect_ratio (bool): Whether to preserve image aspect ratio during scaling + + Returns: + numpy.ndarray: Scaled and offset bounding boxes. + """ + deltaH, deltaW = 0, 0 + H, W = h_model, w_model + scale_H, scale_W = h_image / H, w_image / W + + if preserve_aspect_ratio: + scale_H = scale_W = max(h_image / H, w_image / W) + H_tag = int(np.round(h_image / scale_H)) + W_tag = int(np.round(w_image / scale_W)) + deltaH, deltaW = int((H - H_tag) / 2), int((W - W_tag) / 2) + + # Scale and offset boxes + kpts[..., 0] = (kpts[..., 0] - deltaH) * scale_H + kpts[..., 1] = (kpts[..., 1] - deltaW) * scale_W + + # Clip boxes + kpts = clip_coords(kpts, h_image, w_image) + + return kpts + + +def clip_coords(kpts: np.ndarray, h: int, w: int) -> np.ndarray: + """ + Clip keypoints to stay within the image boundaries. + + Args: + kpts (numpy.ndarray): Array of bounding keypoints in format [..., 17, 3] where the last dim is (x, y, visible). + h (int): Height of the image. + w (int): Width of the image. + + Returns: + numpy.ndarray: Clipped bounding boxes. + """ + kpts[..., 0] = np.clip(kpts[..., 0], a_min=0, a_max=h) + kpts[..., 1] = np.clip(kpts[..., 1], a_min=0, a_max=w) + return kpts diff --git a/tutorials/mct_model_garden/models_pytorch/yolov8/yolov8n-pose.yaml b/tutorials/mct_model_garden/models_pytorch/yolov8/yolov8n-pose.yaml new file mode 100644 index 000000000..24341d857 --- /dev/null +++ b/tutorials/mct_model_garden/models_pytorch/yolov8/yolov8n-pose.yaml @@ -0,0 +1,50 @@ +# The following code was mostly duplicated from https://github.com/ultralytics/ultralytics +# ============================================================================== + +# Ultralytics YOLO πŸš€, AGPL-3.0 license +# YOLOv8-pose keypoints/pose estimation model. For Usage examples see https://docs.ultralytics.com/tasks/pose + +# Parameters +nc: 1 # number of classes +kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) +scales: # model compound scaling constants, i.e. 'model=yolov8n-pose.yaml' will call yolov8-pose.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + s: [0.33, 0.50, 1024] + m: [0.67, 0.75, 768] + l: [1.00, 1.00, 512] + x: [1.00, 1.25, 512] + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 15 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 18 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2f, [1024]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5) \ No newline at end of file diff --git a/tutorials/notebooks/imx500_notebooks/pytorch/pytorch_yolov8n_pose_for_imx500.ipynb b/tutorials/notebooks/imx500_notebooks/pytorch/pytorch_yolov8n_pose_for_imx500.ipynb new file mode 100644 index 000000000..25f3dd5ed --- /dev/null +++ b/tutorials/notebooks/imx500_notebooks/pytorch/pytorch_yolov8n_pose_for_imx500.ipynb @@ -0,0 +1,507 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fab9d9939dc74da4", + "metadata": { + "collapsed": false + }, + "source": [ + "# YOLOv8n Object Detection PyTorch Model - Quantization for IMX500\n", + "\n", + "[Run this tutorial in Google Colab](https://colab.research.google.com/github/sony/model_optimization/blob/main/tutorials/notebooks/imx500_notebooks/pytorch/pytorch_yolov8n_pose_for_imx500.ipynb)\n", + "\n", + "## Overview\n", + "\n", + "In this tutorial, we will illustrate a basic and quick process of preparing a pre-trained model for deployment using MCT. Specifically, we will demonstrate how to download a pre-trained YOLOv8n model from the MCT Models Library, compress it, and make it deployment-ready using MCT's post-training quantization techniques.\n", + "\n", + "We will use an existing pre-trained YOLOv8n pose estimation model based on [Ultralytics](https://github.com/ultralytics/ultralytics). The model was slightly adjusted for model quantization. We will quantize the model using MCT post training quantization and evaluate the performance of the floating point model and the quantized model on COCO dataset.\n", + "\n", + "\n", + "## Summary\n", + "\n", + "In this tutorial we will cover:\n", + "\n", + "1. Post-Training Quantization using MCT of PyTorch pose estimation model.\n", + "2. Data preparation - loading and preprocessing validation and representative datasets from COCO.\n", + "3. Accuracy evaluation of the floating-point and the quantized models." + ] + }, + { + "cell_type": "markdown", + "id": "d74f9c855ec54081", + "metadata": { + "collapsed": false + }, + "source": [ + "## Setup\n", + "### Install the relevant packages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c7fa04c9903736f", + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "!pip install -q torch\n", + "!pip install onnx\n", + "!pip install -q pycocotools\n", + "!pip install 'huggingface-hub>=0.21.0'\n", + "!pip install --pre 'sony-custom-layers-dev>=0.2.0.dev5'" + ] + }, + { + "cell_type": "markdown", + "id": "57717bc8f59a0d85", + "metadata": { + "collapsed": false + }, + "source": [ + "Install MCT (if it’s not already installed). Additionally, in order to use all the necessary utility functions for this tutorial, we also copy [MCT tutorials folder](https://github.com/sony/model_optimization/tree/main/tutorials) and add it to the system path." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9728247bc20d0600", + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import sys\n", + "import os\n", + "import importlib\n", + "\n", + "if not importlib.util.find_spec('model_compression_toolkit'):\n", + " !pip install model_compression_toolkit\n", + "!git clone https://github.com/sony/model_optimization.git temp_mct && mv temp_mct/tutorials . && \\rm -rf temp_mct\n", + "sys.path.insert(0,\"tutorials\")" + ] + }, + { + "cell_type": "markdown", + "id": "7a1038b9fd98bba2", + "metadata": { + "collapsed": false + }, + "source": [ + "### Download COCO evaluation set" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8bea492d71b4060f", + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "if not os.path.isdir('coco'):\n", + " !wget -nc http://images.cocodataset.org/annotations/annotations_trainval2017.zip\n", + " !unzip -q -o annotations_trainval2017.zip -d ./coco\n", + " !echo Done loading annotations\n", + " !wget -nc http://images.cocodataset.org/zips/val2017.zip\n", + " !unzip -q -o val2017.zip -d ./coco\n", + " !echo Done loading val2017 images" + ] + }, + { + "cell_type": "markdown", + "id": "084c2b8b-3175-4d46-a18a-7c4d8b6fcb38", + "metadata": {}, + "source": [ + "## Model Quantization\n", + "\n", + "### Download a Pre-Trained Model \n", + "\n", + "We begin by loading a pre-trained [YOLOv8n](https://huggingface.co/SSI-DNN/pytorch_yolov8n_640x640) model. This implementation is based on [Ultralytics](https://github.com/ultralytics/ultralytics) and includes a slightly modified version of yolov8 pose-head that was adapted for model quantization. For further insights into the model's implementation details, please refer to [MCT Models Garden - yolov8](https://github.com/sony/model_optimization/tree/main/tutorials/mct_model_garden/models_pytorch/yolov8). " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e8395b28-4732-4d18-b081-5d3bdf508691", + "metadata": {}, + "outputs": [], + "source": [ + "from tutorials.mct_model_garden.models_pytorch.yolov8.yolov8 import ModelPyTorch, yaml_load, model_predict\n", + "cfg_dict = yaml_load(\"tutorials/mct_model_garden/models_pytorch/yolov8/yolov8n-pose.yaml\", append_filename=True)\n", + "model = ModelPyTorch.from_pretrained(\"SSI-DNN/pytorch_yolov8n_640x640\", cfg=cfg_dict)" + ] + }, + { + "cell_type": "markdown", + "id": "3cde2f8e-0642-4374-a1f4-df2775fe7767", + "metadata": {}, + "source": [ + "### Post training quantization using Model Compression Toolkit \n", + "\n", + "Now, we're all set to use MCT's post-training quantization. To begin, we'll define a representative dataset and proceed with the model quantization. Please note that, for demonstration purposes, we'll use the evaluation dataset as our representative dataset. We'll calibrate the model using 80 representative images, divided into 20 iterations of 'batch_size' images each. \n", + "\n", + "Additionally, to further compress the model's memory footprint, we will employ the mixed-precision quantization technique. This method allows each layer to be quantized with different precision options: 2, 4, and 8 bits, aligning with the imx500 target platform capabilities." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56393342-cecf-4f64-b9ca-2f515c765942", + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import model_compression_toolkit as mct\n", + "from tutorials.mct_model_garden.evaluation_metrics.coco_evaluation import CocoDataset, DataLoader\n", + "from tutorials.mct_model_garden.models_pytorch.yolov8.yolov8_preprocess import yolov8_preprocess_chw_transpose\n", + "from typing import Iterator, Tuple, List\n", + "\n", + "REPRESENTATIVE_DATASET_FOLDER = './coco/val2017/'\n", + "REPRESENTATIVE_DATASET_ANNOTATION_FILE = './coco/annotations/person_keypoints_val2017.json'\n", + "BATCH_SIZE = 4\n", + "n_iters = 20\n", + "\n", + "representative_dataset = CocoDataset(dataset_folder=REPRESENTATIVE_DATASET_FOLDER,\n", + " annotation_file=REPRESENTATIVE_DATASET_ANNOTATION_FILE,\n", + " preprocess=yolov8_preprocess_chw_transpose)\n", + "\n", + "gptq_representative_dataset = DataLoader(representative_dataset, BATCH_SIZE, shuffle=True)\n", + "\n", + "# Define representative dataset generator\n", + "def get_representative_dataset(n_iter: int, dataset_loader: Iterator[Tuple]):\n", + " \"\"\"\n", + " This function creates a representative dataset generator. The generator yields numpy\n", + " arrays of batches of shape: [Batch, H, W ,C].\n", + " Args:\n", + " n_iter: number of iterations for MCT to calibrate on\n", + " Returns:\n", + " A representative dataset generator\n", + " \"\"\" \n", + " def representative_dataset() -> Iterator[List]:\n", + " ds_iter = iter(dataset_loader)\n", + " for _ in range(n_iter):\n", + " yield [next(ds_iter)[0]]\n", + "\n", + " return representative_dataset\n", + "\n", + "# Get representative dataset generator\n", + "representative_dataset_gen = get_representative_dataset(n_iter=n_iters,\n", + " dataset_loader=representative_dataset)\n", + "\n", + "# Set IMX500-v1 TPC\n", + "tpc = mct.get_target_platform_capabilities(fw_name=\"pytorch\",\n", + " target_platform_name='imx500',\n", + " target_platform_version='v1')\n", + "\n", + "# Specify the necessary configuration for mixed precision quantization. To keep the tutorial brief, we'll use a small set of images and omit the hessian metric for mixed precision calculations. It's important to be aware that this choice may impact the resulting accuracy. \n", + "mp_config = mct.core.MixedPrecisionQuantizationConfig(num_of_images=64)\n", + "config = mct.core.CoreConfig(mixed_precision_config=mp_config,\n", + " quantization_config=mct.core.QuantizationConfig(shift_negative_activation_correction=True,\n", + " concat_threshold_update=True))\n", + "\n", + "# Define target Resource Utilization for mixed precision weights quantization (80% of 'standard' 8bits quantization)\n", + "resource_utilization_data = mct.core.pytorch_resource_utilization_data(in_model=model,\n", + " representative_data_gen=representative_dataset_gen,\n", + " core_config=config,\n", + " target_platform_capabilities=tpc)\n", + "resource_utilization = mct.core.ResourceUtilization(weights_memory=resource_utilization_data.weights_memory * 0.8)\n", + "\n", + "# Perform post training quantization\n", + "quant_model, _ = mct.ptq.pytorch_post_training_quantization(in_module=model,\n", + " representative_data_gen=representative_dataset_gen,\n", + " target_resource_utilization=resource_utilization,\n", + " core_config=config,\n", + " target_platform_capabilities=tpc)\n", + "print('Quantized model is ready')" + ] + }, + { + "cell_type": "markdown", + "id": "3be2016acdc9da60", + "metadata": { + "collapsed": false + }, + "source": [ + "### Model Export\n", + "\n", + "Now, we can export the quantized model, ready for deployment, into a `.onnx` format file. Please ensure that the `save_model_path` has been set correctly. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "72dd885c7b92fa93", + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "mct.exporter.pytorch_export_model(model=quant_model,\n", + " save_model_path='./qmodel.onnx',\n", + " repr_dataset=representative_dataset_gen)" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Gradient-Based Post Training Quantization using Model Compression Toolkit\n", + "Here we demonstrate how to further optimize the quantized model performance using gradient-based PTQ technique.\n", + "**Please note that this section is computationally heavy, and it's recommended to run it on a GPU. For fast deployment, you may choose to skip this step.** \n", + "\n", + "We will start by loading the COCO training set, and re-define the representative dataset accordingly. " + ], + "metadata": { + "collapsed": false + }, + "id": "655d764593af0763" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "!wget -nc http://images.cocodataset.org/zips/train2017.zip\n", + "!unzip -q -o train2017.zip -d ./coco\n", + "!echo Done loading train2017 images\n", + "\n", + "GPTQ_REPRESENTATIVE_DATASET_FOLDER = './coco/train2017/'\n", + "GPTQ_REPRESENTATIVE_DATASET_ANNOTATION_FILE = './coco/annotations/person_keypoints_train2017.json'\n", + "BATCH_SIZE = 4\n", + "n_iters = 20\n", + "\n", + "# Load representative dataset\n", + "representative_dataset = CocoDataset(dataset_folder=GPTQ_REPRESENTATIVE_DATASET_FOLDER,\n", + " annotation_file=GPTQ_REPRESENTATIVE_DATASET_ANNOTATION_FILE,\n", + " preprocess=yolov8_preprocess_chw_transpose)\n", + "\n", + "representative_dataset_gen = DataLoader(representative_dataset, BATCH_SIZE, shuffle=True)" + ], + "metadata": { + "collapsed": false + }, + "id": "20fe96b6cc95d38c" + }, + { + "cell_type": "markdown", + "source": [ + "Next, we'll set up the Gradient-Based PTQ configuration and execute the necessary MCT command. Keep in mind that this step can be time-consuming, depending on your runtime." + ], + "metadata": { + "collapsed": false + }, + "id": "29d54f733139d114" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "# Specify the necessary configuration for Gradient-Based PTQ.\n", + "n_gptq_epochs = 1000\n", + "gptq_config = mct.gptq.get_pytorch_gptq_config(n_epochs=n_gptq_epochs, use_hessian_based_weights=False)\n", + "\n", + "# Perform Gradient-Based Post Training Quantization\n", + "gptq_quant_model, _ = mct.gptq.pytorch_gradient_post_training_quantization(\n", + " model=model,\n", + " representative_data_gen=representative_dataset_gen,\n", + " target_resource_utilization=resource_utilization,\n", + " gptq_config=gptq_config,\n", + " core_config=config,\n", + " target_platform_capabilities=tpc)\n", + "\n", + "print('Quantized-GPTQ model is ready')" + ], + "metadata": { + "collapsed": false + }, + "id": "240421e00f6cce34" + }, + { + "cell_type": "markdown", + "source": [ + "### Model Export\n", + "\n", + "Now, we can export the quantized model, ready for deployment, into a `.onnx` format file. Please ensure that the `save_model_path` has been set correctly. " + ], + "metadata": { + "collapsed": false + }, + "id": "b5d72e8420550101" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "mct.exporter.pytorch_export_model(model=gptq_quant_model,\n", + " save_model_path='./qmodel_gptq.onnx',\n", + " repr_dataset=representative_dataset_gen)" + ], + "metadata": { + "collapsed": false + }, + "id": "546ff946af81702b" + }, + { + "cell_type": "markdown", + "source": [ + "## Evaluation on COCO dataset\n", + "\n", + "### Floating point model evaluation\n", + "Next, we evaluate the floating point model by using `cocoeval` library alongside additional dataset utilities. We can verify the mAP accuracy aligns with that of the original model. \n", + "Note that we set the \"batch_size\" to 4 and the preprocessing according to [Ultralytics](https://github.com/ultralytics/ultralytics).\n", + "Please ensure that the dataset path has been set correctly before running this code cell." + ], + "metadata": { + "collapsed": false + }, + "id": "43a8a6d11d696b09" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "01e90967-594b-480f-b2e6-45e2c9ce9cee", + "metadata": {}, + "outputs": [], + "source": [ + "from tutorials.mct_model_garden.models_pytorch.yolov8.yolov8 import keypoints_model_predict\n", + "from tutorials.mct_model_garden.evaluation_metrics.coco_evaluation import coco_evaluate\n", + "\n", + "EVAL_DATASET_FOLDER = './coco/val2017'\n", + "EVAL_DATASET_ANNOTATION_FILE = './coco/annotations/person_keypoints_val2017.json'\n", + "INPUT_RESOLUTION = 640\n", + "\n", + "# Define resizing information to map between the model's output and the original image dimensions\n", + "output_resize = {'shape': (INPUT_RESOLUTION, INPUT_RESOLUTION), 'aspect_ratio_preservation': True}\n", + "\n", + "# Evaluate the model on coco\n", + "eval_results = coco_evaluate(model=model,\n", + " dataset_folder=EVAL_DATASET_FOLDER,\n", + " annotation_file=EVAL_DATASET_ANNOTATION_FILE,\n", + " preprocess=yolov8_preprocess_chw_transpose,\n", + " output_resize=output_resize,\n", + " batch_size=BATCH_SIZE,\n", + " model_inference=keypoints_model_predict,\n", + " task='Keypoints')\n", + "\n", + "# Print float model mAP results\n", + "print(\"Float model mAP: {:.4f}\".format(eval_results[0]))" + ] + }, + { + "cell_type": "markdown", + "id": "4fb6bffc-23d1-4852-8ec5-9007361c8eeb", + "metadata": {}, + "source": [ + "### Quantized model evaluation\n", + "We can evaluate the performance of the quantized model. There is a slight decrease in performance that can be further mitigated by either expanding the representative dataset or employing MCT's advanced quantization methods, such as GPTQ (Gradient-Based/Enhanced Post Training Quantization)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8dc7b87c-a9f4-4568-885a-fe009c8f4e8f", + "metadata": {}, + "outputs": [], + "source": [ + "# Evaluate the quantized model with PostProcess on coco\n", + "eval_results = coco_evaluate(model=quant_model,\n", + " dataset_folder=EVAL_DATASET_FOLDER,\n", + " annotation_file=EVAL_DATASET_ANNOTATION_FILE,\n", + " preprocess=yolov8_preprocess_chw_transpose,\n", + " output_resize=output_resize,\n", + " batch_size=BATCH_SIZE,\n", + " model_inference=keypoints_model_predict,\n", + " task='Keypoints')\n", + "\n", + "# Print quantized model mAP results\n", + "print(\"Quantized model mAP: {:.4f}\".format(eval_results[0]))" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Finally, we can evaluate the performance of the quantized model through GPTQ (Gradient-Based/Enhanced Post Training Quantization). We anticipate an improvement in performance compare to the quantized model utilizing PTQ." + ], + "metadata": { + "collapsed": false + }, + "id": "3bb5cc7c91dc8f21" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "# Evaluate the quantized using GPTQ model with PostProcess on coco\n", + "eval_results = coco_evaluate(model=gptq_quant_model,\n", + " dataset_folder=EVAL_DATASET_FOLDER,\n", + " annotation_file=EVAL_DATASET_ANNOTATION_FILE,\n", + " preprocess=yolov8_preprocess_chw_transpose,\n", + " output_resize=output_resize,\n", + " batch_size=BATCH_SIZE,\n", + " model_inference=keypoints_model_predict,\n", + " task='Keypoints')\n", + "\n", + "# Print quantized using GPTQ model mAP results\n", + "print(\"Quantized using GPTQ model mAP: {:.4f}\".format(eval_results[0]))" + ], + "metadata": { + "collapsed": false + }, + "id": "168468f17ae8bc59" + }, + { + "cell_type": "markdown", + "id": "6d93352843a27433", + "metadata": { + "collapsed": false + }, + "source": [ + "\\\n", + "Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "you may not use this file except in compliance with the License.\n", + "You may obtain a copy of the License at\n", + "\n", + " http://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "Unless required by applicable law or agreed to in writing, software\n", + "distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "See the License for the specific language governing permissions and\n", + "limitations under the License." + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}