diff --git a/README.md b/README.md
index 53f78d7..72742d1 100644
--- a/README.md
+++ b/README.md
@@ -48,6 +48,7 @@
# :star2: News
Project is under development :construction_worker:. Please stay tuned for more :fire: updates.
+* **2023.10.31:** Add instance segmentation and keypoint detection [visualization](#throughput-acceleration) :horse_racing:.
* **2023.10.02:** Efficacy :fast_forward: [study](#kmeans) is added.
* **2023.8.22:** Throughput acceleration :stars: [experiment](#throughput-acceleration) is released :tada:.
* **2023.8.14:** [Poster](assets/poster.pdf) :bar_chart: is released. [Part of the project](https://gretsi.fr/data/colloque/pdf/2023_pham1312.pdf) will be :mega: presented at [GRETSI'23](https://gretsi.fr/colloque2023/) :clap:.
@@ -354,6 +355,7 @@ Time consumption to calculate the similarity matrix on VGG-16-BN. For tail layer
A comprehensive ablation study on CIFAR-10, showcasing comparable final accuracies achieved with the 3 considered distances.
## 5. Throughput acceleration.
++ FasterRCNN for object detection
@@ -365,6 +367,30 @@ A comprehensive ablation study on CIFAR-10, showcasing comparable final accuraci
|
++ MaskRCNN for instance segmentation
+
+
+
+
+ |
+
+
+ |
+
+
+
++ KeypointRCNN for human keypoint detection
+
+
+
+
+ |
+
+
+ |
+
+
+
Baseline (left) vs Pruned (right) model inference.
diff --git a/assets/baseline.gif b/assets/baseline.gif
index 96e6c6b..5ca5d4c 100644
Binary files a/assets/baseline.gif and b/assets/baseline.gif differ
diff --git a/assets/baseline_keypoint.gif b/assets/baseline_keypoint.gif
new file mode 100644
index 0000000..9be7679
Binary files /dev/null and b/assets/baseline_keypoint.gif differ
diff --git a/assets/baseline_mask.gif b/assets/baseline_mask.gif
new file mode 100644
index 0000000..18f1edc
Binary files /dev/null and b/assets/baseline_mask.gif differ
diff --git a/assets/pruned.gif b/assets/pruned.gif
index d7a2f18..28f7c4a 100644
Binary files a/assets/pruned.gif and b/assets/pruned.gif differ
diff --git a/assets/pruned_keypoint.gif b/assets/pruned_keypoint.gif
new file mode 100644
index 0000000..8a535a3
Binary files /dev/null and b/assets/pruned_keypoint.gif differ
diff --git a/assets/pruned_mask.gif b/assets/pruned_mask.gif
new file mode 100644
index 0000000..51cd7f0
Binary files /dev/null and b/assets/pruned_mask.gif differ
diff --git a/detection/model.py b/detection/model.py
index e9d5875..61b51c0 100644
--- a/detection/model.py
+++ b/detection/model.py
@@ -169,9 +169,7 @@ def maskrcnn_resnet50_fpn(
if weights is not None:
weights_backbone = None
- num_classes = _ovewrite_value_param(
- "num_classes", num_classes, len(weights.meta["categories"])
- )
+ num_classes = 91
elif num_classes is None:
num_classes = 91
@@ -276,12 +274,8 @@ def keypointrcnn_resnet50_fpn(
if weights is not None:
weights_backbone = None
- num_classes = _ovewrite_value_param(
- "num_classes", num_classes, len(weights.meta["categories"])
- )
- num_keypoints = _ovewrite_value_param(
- "num_keypoints", num_keypoints, len(weights.meta["keypoint_names"])
- )
+ num_classes = 2
+ num_keypoints = 17
else:
if num_classes is None:
num_classes = 2
diff --git a/detection/visualize.py b/detection/visualize.py
index 4acf3dd..391726e 100644
--- a/detection/visualize.py
+++ b/detection/visualize.py
@@ -1,19 +1,30 @@
import cv2
import time
+import random
+import numpy as np
+import matplotlib as mpl
import torch
import argparse
from torchvision.transforms.functional import to_pil_image
from torchvision.models.detection import (
fasterrcnn_resnet50_fpn,
+ maskrcnn_resnet50_fpn,
+ keypointrcnn_resnet50_fpn,
FasterRCNN_ResNet50_FPN_Weights,
+ MaskRCNN_ResNet50_FPN_Weights,
+ KeypointRCNN_ResNet50_FPN_Weights,
)
-from model import fasterrcnn_resnet50_fpn as custom_fasterrcnn_resnet50_fpn
+import model
from utils import get_cpr
def get_args_parser(add_help=True):
parser = argparse.ArgumentParser(
- description="FasterRCNN_ResNet50_FPN Video Inference", add_help=add_help
+ description="Faster/Mask/KeypointRCNN_ResNet50_FPN Video Inference",
+ add_help=add_help,
+ )
+ parser.add_argument(
+ "--model", default="fasterrcnn_resnet50_fpn", type=str, help="model name"
)
parser.add_argument(
"--custom", dest="custom", action="store_true", help="Use custom model"
@@ -31,6 +42,9 @@ def get_args_parser(add_help=True):
"--output", default="output.mp4", help="Path to output video file"
)
parser.add_argument("--fps", type=int, default=10, help="FPS to write output video")
+ parser.add_argument(
+ "--confidence", type=float, default=0.75, help="confidence threshold"
+ )
parser.add_argument("--device", default="cuda:0", help="Device to use (cuda/cpu)")
return parser
@@ -39,7 +53,12 @@ def get_args_parser(add_help=True):
def main(args):
device = torch.device(args.device)
- weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
+ if "faster" in args.model:
+ weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
+ elif "mask" in args.model:
+ weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT
+ else:
+ weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
preprocess = weights.transforms()
labels = weights.meta["categories"]
# Create a color map for each class
@@ -52,11 +71,11 @@ def main(args):
if args.custom:
compress_rate = get_cpr(args.compress_rate)
- model = custom_fasterrcnn_resnet50_fpn(
+ model = eval(f"model.{args.model}")(
weights=args.weight, compress_rate=compress_rate
)
else:
- model = fasterrcnn_resnet50_fpn(weights=weights)
+ model = eval(args.model)(weights=weights)
model = model.to(device)
model.eval()
@@ -76,6 +95,26 @@ def main(args):
start_time = time.time()
frame_count = 0
+ # Pairs of edges for 17 of the keypoints detected
+ # omit any of the undesired connecting points
+ edges = [
+ (0, 1),
+ (0, 2),
+ (2, 4),
+ (1, 3),
+ (6, 8),
+ (8, 10),
+ (5, 7),
+ (7, 9),
+ (5, 11),
+ (11, 13),
+ (13, 15),
+ (6, 12),
+ (12, 14),
+ (14, 16),
+ (5, 6),
+ ]
+
while True:
ret, frame = cap.read()
if not ret:
@@ -98,14 +137,14 @@ def main(args):
for score, label, box in zip(
prediction["scores"], prediction["labels"], prediction["boxes"]
):
- if score > 0.5: # Adjust the confidence threshold as needed
+ if score > args.confidence: # Adjust the confidence threshold as needed
box = [int(coord) for coord in box.tolist()] # Convert to integers
class_name = labels[label]
color = color_map[label.item()]
frame = cv2.rectangle(
frame, (box[0], box[1]), (box[2], box[3]), color, 2
)
- text = f"{class_name}"
+ text = f"{class_name} {score.item():.2f}"
frame = cv2.putText(
frame,
text,
@@ -116,6 +155,51 @@ def main(args):
2,
)
+ # Draw masks
+ if "mask" in args.model:
+ pred_score = list(prediction["scores"].cpu().numpy())
+ pred_t = [pred_score.index(x) for x in pred_score if x > args.confidence]
+ if len(pred_t) > 0:
+ masks = (prediction["masks"] > 0.5).squeeze().detach().cpu().numpy()
+ masks = masks[: pred_t[-1] + 1]
+ for i in range(len(masks)):
+ rgb_mask = random_colour_masks(masks[i])
+ frame = cv2.addWeighted(frame, 1, rgb_mask, 0.5, 0)
+
+ # Draw keypoints
+ if "keypoint" in args.model:
+ for i in range(len(prediction["keypoints"])):
+ # get the detected keypoints
+ keypoints = prediction["keypoints"][i].cpu().detach().numpy()
+ # proceed to draw the lines
+ if prediction["scores"][i] > args.confidence:
+ keypoints = keypoints[:, :].reshape(-1, 3)
+ for p in range(keypoints.shape[0]):
+ # draw the keypoints
+ cv2.circle(
+ frame,
+ (int(keypoints[p, 0]), int(keypoints[p, 1])),
+ 3,
+ (0, 0, 255),
+ thickness=-1,
+ lineType=cv2.FILLED,
+ )
+ # draw the lines joining the keypoints
+ for ie, e in enumerate(edges):
+ # get different colors for the edges
+ rgb = mpl.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0])
+ rgb = rgb * 255
+ # join the keypoint pairs to draw the skeletal structure
+ cv2.line(
+ frame,
+ (int(keypoints[e, 0][0]), int(keypoints[e, 1][0])),
+ (int(keypoints[e, 0][1]), int(keypoints[e, 1][1])),
+ tuple(rgb),
+ 2,
+ lineType=cv2.LINE_AA,
+ )
+
+ # Write FPS
fps_text = f"FPS {fps:.2f}"
frame = cv2.putText(
frame, fps_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2
@@ -137,6 +221,51 @@ def main(args):
print(f"Processed {frame_count} frames at {fps:.2f} FPS")
+def random_colour_masks(mask):
+ """
+ Apply random colors to mask regions.
+
+ Args:
+ mask (np.ndarray): Binary mask.
+
+ Returns:
+ np.ndarray: Mask with random colors applied.
+ """
+
+ # Define a list of predefined colors
+ colors = [
+ [0, 255, 0],
+ [0, 0, 255],
+ [255, 0, 0],
+ [0, 255, 255],
+ [255, 255, 0],
+ [255, 0, 255],
+ [80, 70, 180],
+ [250, 80, 190],
+ [245, 145, 50],
+ [70, 150, 250],
+ [50, 190, 190],
+ ]
+
+ # Initialize empty color channels
+ r, g, b = (
+ np.zeros_like(mask).astype(np.uint8),
+ np.zeros_like(mask).astype(np.uint8),
+ np.zeros_like(mask).astype(np.uint8),
+ )
+
+ # Assign random colors to mask regions
+ r = np.zeros_like(mask).astype(np.uint8)
+ g = np.zeros_like(mask).astype(np.uint8)
+ b = np.zeros_like(mask).astype(np.uint8)
+ r[mask == 1], g[mask == 1], b[mask == 1] = colors[random.randrange(0, 10)]
+
+ # Create a colored mask by stacking the color channels
+ colored_mask = np.stack([r, g, b], axis=2)
+
+ return colored_mask
+
+
if __name__ == "__main__":
args = get_args_parser().parse_args()
main(args)