diff --git a/README.md b/README.md index 3820e17b..42725e3c 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,113 @@ +# YOLOX-Bytetrack Algorithm Optimization with CuPy + +#### This repository contains an optimized version of the YOLOX-Bytetrack algorithm. + +## Abstract +The primary enhancement involves the use of CuPy instead of NumPy for the `preproc` function, resulting in significant performance improvements for the preprocessing stage. The changes also include the utilization of multithreading for parallel processing of multiple images. + +## Key Improvements +### 1. CuPy Integration +The original `preproc` function used NumPy for various operations, which are now replaced with CuPy to leverage GPU acceleration. This change drastically reduces the preprocessing time, especially when dealing with large batches of images. + +**Original `preproc` Function** +```python +def preproc(image, input_size, mean, std, swap=(2, 0, 1)): + if len(image.shape) == 3: + padded_img = np.ones((input_size[0], input_size[1], 3)) * 114.0 + else: + padded_img = np.ones(input_size) * 114.0 + img = np.array(image) + r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) + resized_img = cv2.resize( + img, + (int(img.shape[1] * r), int(img.shape[0] * r)), + interpolation=cv2.INTER_LINEAR, + ).astype(np.float32) + padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img + + padded_img = padded_img[:, :, ::-1] + padded_img /= 255.0 + if mean is not None: + padded_img -= mean + if std is not None: + padded_img /= std + padded_img = padded_img.transpose(swap) + padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) + return padded_img, r +``` + +**Optimized preproc Function with CuPy** +```python +def preproc_with_cupy(image, input_size, mean, std, swap=(2, 0, 1)): + device = cp.cuda.Device(0) + device.use() + + if len(image.shape) == 3: + padded_img = cp.ones((input_size[0], input_size[1], 3)) * 114.0 + else: + padded_img = cp.ones(input_size) * 114.0 + + img = cp.array(image) + r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) + + target_height = int(img.shape[0] * r) + target_width = int(img.shape[1] * r) + + if target_height <= 0 or target_width <= 0: + raise ValueError(f"Invalid target size: ({target_width}, {target_height})") + + resized_img = cp.array(cv2.resize( + cp.asnumpy(img), + (target_width, target_height), + interpolation=cv2.INTER_LINEAR, + ).astype(np.float32)) + + if len(image.shape) == 3: + padded_img[:target_height, :target_width, :] = resized_img + else: + padded_img[:target_height, :target_width] = resized_img + + padded_img = padded_img[:, :, ::-1] / 255.0 # BGR to RGB and normalize + + if mean is not None: + mean_array = cp.array(mean).reshape(1, 1, 3) + padded_img -= mean_array + + if std is not None: + std_array = cp.array(std).reshape(1, 1, 3) + padded_img /= std_array + + padded_img = padded_img.transpose(swap) + padded_img = cp.ascontiguousarray(padded_img, dtype=cp.float32) + return padded_img, r +``` +### 2. Multithreading for Image Processing + +To further enhance performance, the process_images method now uses multithreading to preprocess multiple images in parallel. This change utilizes Python's ThreadPoolExecutor to handle image preprocessing concurrently. + +**Added process_images Method** + +```python + def process_images(self, image_list, input_size, mean, std, swap=(2, 0, 1)): + with ThreadPoolExecutor() as executor: + futures= [executor.submit(preproc_with_cupy, img, input_size, mean, std, swap) for img in image_list] + results = [future.result() for future in futures] + if results: + return results +``` +## Result +The integration of CuPy and the use of multithreading have significantly improved the preprocessing time for image batches. Below is a comparison of the preprocessing time before and after the optimization: +The tests resulted in an FPS increase of around 1.5X-2X, depending on the graphics card used and the type of model. + +**Before optimization** +

+ +**After optimization** +

+ +------------------ + + # ByteTrack [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bytetrack-multi-object-tracking-by-1/multi-object-tracking-on-mot17)](https://paperswithcode.com/sota/multi-object-tracking-on-mot17?p=bytetrack-multi-object-tracking-by-1) diff --git a/assets/with_cupy.png b/assets/with_cupy.png new file mode 100644 index 00000000..a123a41c Binary files /dev/null and b/assets/with_cupy.png differ diff --git a/assets/without_cupy.png b/assets/without_cupy.png new file mode 100644 index 00000000..df992f99 Binary files /dev/null and b/assets/without_cupy.png differ diff --git a/tools/demo_track.py b/tools/demo_track.py index 4f4e7dc3..e70c585e 100644 --- a/tools/demo_track.py +++ b/tools/demo_track.py @@ -4,20 +4,22 @@ import time import cv2 import torch - from loguru import logger +import cupy as cp + +import sys +sys.path.append('.') + -from yolox.data.data_augment import preproc +from yolox.data.data_augment import preproc, preproc_with_cupy from yolox.exp import get_exp from yolox.utils import fuse_model, get_model_info, postprocess from yolox.utils.visualize import plot_tracking from yolox.tracker.byte_tracker import BYTETracker from yolox.tracking_utils.timer import Timer - +from concurrent.futures import ThreadPoolExecutor IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"] - - def make_parser(): parser = argparse.ArgumentParser("ByteTrack Demo!") parser.add_argument( @@ -100,7 +102,6 @@ def get_image_list(path): image_names.append(apath) return image_names - def write_results(filename, results): save_format = '{frame},{id},{x1},{y1},{w},{h},{s},-1,-1,-1\n' with open(filename, 'w') as f: @@ -113,7 +114,6 @@ def write_results(filename, results): f.write(line) logger.info('save results to {}'.format(filename)) - class Predictor(object): def __init__( self, @@ -121,7 +121,7 @@ def __init__( exp, trt_file=None, decoder=None, - device=torch.device("cpu"), + device=None, fp16=False ): self.model = model @@ -130,7 +130,7 @@ def __init__( self.confthre = exp.test_conf self.nmsthre = exp.nmsthre self.test_size = exp.test_size - self.device = device + self.device = str(device) self.fp16 = fp16 if trt_file is not None: from torch2trt import TRTModule @@ -157,9 +157,20 @@ def inference(self, img, timer): img_info["width"] = width img_info["raw_img"] = img - img, ratio = preproc(img, self.test_size, self.rgb_means, self.std) - img_info["ratio"] = ratio - img = torch.from_numpy(img).unsqueeze(0).float().to(self.device) + if self.device=='cuda': + img=[img] + processed_images = self.process_images(img, self.test_size, self.rgb_means, self.std) + + if processed_images: + img = processed_images[0][0] + img_info["ratio"] = processed_images[0][1] + + img = torch.from_numpy(cp.asnumpy(img)).unsqueeze(0).float().to(self.device) + else: + img, ratio = preproc(img, self.test_size, self.rgb_means, self.std) + img_info["ratio"] = ratio + img = torch.from_numpy(img).unsqueeze(0).float().to(self.device) + if self.fp16: img = img.half() # to FP16 @@ -174,6 +185,12 @@ def inference(self, img, timer): #logger.info("Infer time: {:.4f}s".format(time.time() - t0)) return outputs, img_info + def process_images(self, image_list, input_size, mean, std, swap=(2, 0, 1)): + with ThreadPoolExecutor() as executor: + futures= [executor.submit(preproc_with_cupy, img, input_size, mean, std, swap) for img in image_list] + results = [future.result() for future in futures] + if results: + return results def image_demo(predictor, vis_folder, current_time, args): if osp.isdir(args.path): diff --git a/yolox/data/data_augment.py b/yolox/data/data_augment.py index 99fb30a2..fd641cc1 100644 --- a/yolox/data/data_augment.py +++ b/yolox/data/data_augment.py @@ -11,14 +11,10 @@ import cv2 import numpy as np - -import torch - from yolox.utils import xyxy2cxcywh - import math import random - +import cupy as cp def augment_hsv(img, hgain=0.015, sgain=0.7, vgain=0.4): r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains @@ -210,6 +206,50 @@ def preproc(image, input_size, mean, std, swap=(2, 0, 1)): padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) return padded_img, r +def preproc_with_cupy(image, input_size, mean, std, swap=(2, 0, 1)): + device = cp.cuda.Device(0) + device.use() + + if len(image.shape) == 3: + padded_img = cp.ones((input_size[0], input_size[1], 3)) * 114.0 + else: + padded_img = cp.ones(input_size) * 114.0 + + img = cp.array(image) + r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) + + # Hedef boyutları hesaplayalım + target_height = int(img.shape[0] * r) + target_width = int(img.shape[1] * r) + + # Hedef boyutların sıfır veya negatif olmadığını kontrol edelim + if target_height <= 0 or target_width <= 0: + raise ValueError(f"Invalid target size: ({target_width}, {target_height})") + + resized_img = cp.array(cv2.resize( + cp.asnumpy(img), + (target_width, target_height), + interpolation=cv2.INTER_LINEAR, + ).astype(np.float32)) + + if len(image.shape) == 3: + padded_img[:target_height, :target_width, :] = resized_img + else: + padded_img[:target_height, :target_width] = resized_img + + padded_img = padded_img[:, :, ::-1] / 255.0 # BGR to RGB and normalize + + if mean is not None: + mean_array = cp.array(mean).reshape(1, 1, 3) + padded_img -= mean_array + + if std is not None: + std_array = cp.array(std).reshape(1, 1, 3) + padded_img /= std_array + + padded_img = padded_img.transpose(swap) + padded_img = cp.ascontiguousarray(padded_img, dtype=cp.float32) + return padded_img, r class TrainTransform: def __init__(self, p=0.5, rgb_means=None, std=None, max_labels=100):