Skip to content

Commit

Permalink
Add instance segmentation and keypoint detection visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
pvti committed Nov 2, 2023
1 parent aa9817d commit 202e2c4
Show file tree
Hide file tree
Showing 9 changed files with 165 additions and 16 deletions.
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@

# :star2: News
Project is under development :construction_worker:. Please stay tuned for more :fire: updates.
* **2023.10.31:** Add instance segmentation and keypoint detection [visualization](#throughput-acceleration) :horse_racing:.
* **2023.10.02:** Efficacy :fast_forward: [study](#kmeans) is added.
* **2023.8.22:** Throughput acceleration :stars: [experiment](#throughput-acceleration) is released :tada:.
* **2023.8.14:** [Poster](assets/poster.pdf) :bar_chart: is released. [Part of the project](https://gretsi.fr/data/colloque/pdf/2023_pham1312.pdf) will be :mega: presented at [GRETSI'23](https://gretsi.fr/colloque2023/) :clap:.
Expand Down Expand Up @@ -354,6 +355,7 @@ Time consumption to calculate the similarity matrix on VGG-16-BN. For tail layer
A comprehensive ablation study on CIFAR-10, showcasing comparable final accuracies achieved with the 3 considered distances.
## 5. Throughput acceleration. <a name="throughput-acceleration"></a>
+ FasterRCNN for object detection
<table style="width: 100%; border: none; border-collapse: collapse;">
<tr>
<td style="width: 50%; padding: 10px; border: none;">
Expand All @@ -365,6 +367,30 @@ A comprehensive ablation study on CIFAR-10, showcasing comparable final accuraci
</tr>
</table>
+ MaskRCNN for instance segmentation
<table style="width: 100%; border: none; border-collapse: collapse;">
<tr>
<td style="width: 50%; padding: 10px; border: none;">
<img src="assets/baseline_mask.gif" alt="Baseline" style="width: 100%;">
</td>
<td style="width: 50%; padding: 10px; border: none;">
<img src="assets/pruned_mask.gif" alt="Pruned" style="width: 100%;">
</td>
</tr>
</table>
+ KeypointRCNN for human keypoint detection
<table style="width: 100%; border: none; border-collapse: collapse;">
<tr>
<td style="width: 50%; padding: 10px; border: none;">
<img src="assets/baseline_keypoint.gif" alt="Baseline" style="width: 100%;">
</td>
<td style="width: 50%; padding: 10px; border: none;">
<img src="assets/pruned_keypoint.gif" alt="Pruned" style="width: 100%;">
</td>
</tr>
</table>
<div align="center">
Baseline (<em>left</em>) vs Pruned (<em>right</em>) model inference.
</div>
Expand Down
Binary file modified assets/baseline.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/baseline_keypoint.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/baseline_mask.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified assets/pruned.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/pruned_keypoint.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/pruned_mask.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 3 additions & 9 deletions detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,7 @@ def maskrcnn_resnet50_fpn(

if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(
"num_classes", num_classes, len(weights.meta["categories"])
)
num_classes = 91
elif num_classes is None:
num_classes = 91

Expand Down Expand Up @@ -276,12 +274,8 @@ def keypointrcnn_resnet50_fpn(

if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(
"num_classes", num_classes, len(weights.meta["categories"])
)
num_keypoints = _ovewrite_value_param(
"num_keypoints", num_keypoints, len(weights.meta["keypoint_names"])
)
num_classes = 2
num_keypoints = 17
else:
if num_classes is None:
num_classes = 2
Expand Down
143 changes: 136 additions & 7 deletions detection/visualize.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,30 @@
import cv2
import time
import random
import numpy as np
import matplotlib as mpl
import torch
import argparse
from torchvision.transforms.functional import to_pil_image
from torchvision.models.detection import (
fasterrcnn_resnet50_fpn,
maskrcnn_resnet50_fpn,
keypointrcnn_resnet50_fpn,
FasterRCNN_ResNet50_FPN_Weights,
MaskRCNN_ResNet50_FPN_Weights,
KeypointRCNN_ResNet50_FPN_Weights,
)
from model import fasterrcnn_resnet50_fpn as custom_fasterrcnn_resnet50_fpn
import model
from utils import get_cpr


def get_args_parser(add_help=True):
parser = argparse.ArgumentParser(
description="FasterRCNN_ResNet50_FPN Video Inference", add_help=add_help
description="Faster/Mask/KeypointRCNN_ResNet50_FPN Video Inference",
add_help=add_help,
)
parser.add_argument(
"--model", default="fasterrcnn_resnet50_fpn", type=str, help="model name"
)
parser.add_argument(
"--custom", dest="custom", action="store_true", help="Use custom model"
Expand All @@ -31,6 +42,9 @@ def get_args_parser(add_help=True):
"--output", default="output.mp4", help="Path to output video file"
)
parser.add_argument("--fps", type=int, default=10, help="FPS to write output video")
parser.add_argument(
"--confidence", type=float, default=0.75, help="confidence threshold"
)
parser.add_argument("--device", default="cuda:0", help="Device to use (cuda/cpu)")

return parser
Expand All @@ -39,7 +53,12 @@ def get_args_parser(add_help=True):
def main(args):
device = torch.device(args.device)

weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
if "faster" in args.model:
weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
elif "mask" in args.model:
weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT
else:
weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
preprocess = weights.transforms()
labels = weights.meta["categories"]
# Create a color map for each class
Expand All @@ -52,11 +71,11 @@ def main(args):

if args.custom:
compress_rate = get_cpr(args.compress_rate)
model = custom_fasterrcnn_resnet50_fpn(
model = eval(f"model.{args.model}")(
weights=args.weight, compress_rate=compress_rate
)
else:
model = fasterrcnn_resnet50_fpn(weights=weights)
model = eval(args.model)(weights=weights)
model = model.to(device)
model.eval()

Expand All @@ -76,6 +95,26 @@ def main(args):
start_time = time.time()
frame_count = 0

# Pairs of edges for 17 of the keypoints detected
# omit any of the undesired connecting points
edges = [
(0, 1),
(0, 2),
(2, 4),
(1, 3),
(6, 8),
(8, 10),
(5, 7),
(7, 9),
(5, 11),
(11, 13),
(13, 15),
(6, 12),
(12, 14),
(14, 16),
(5, 6),
]

while True:
ret, frame = cap.read()
if not ret:
Expand All @@ -98,14 +137,14 @@ def main(args):
for score, label, box in zip(
prediction["scores"], prediction["labels"], prediction["boxes"]
):
if score > 0.5: # Adjust the confidence threshold as needed
if score > args.confidence: # Adjust the confidence threshold as needed
box = [int(coord) for coord in box.tolist()] # Convert to integers
class_name = labels[label]
color = color_map[label.item()]
frame = cv2.rectangle(
frame, (box[0], box[1]), (box[2], box[3]), color, 2
)
text = f"{class_name}"
text = f"{class_name} {score.item():.2f}"
frame = cv2.putText(
frame,
text,
Expand All @@ -116,6 +155,51 @@ def main(args):
2,
)

# Draw masks
if "mask" in args.model:
pred_score = list(prediction["scores"].cpu().numpy())
pred_t = [pred_score.index(x) for x in pred_score if x > args.confidence]
if len(pred_t) > 0:
masks = (prediction["masks"] > 0.5).squeeze().detach().cpu().numpy()
masks = masks[: pred_t[-1] + 1]
for i in range(len(masks)):
rgb_mask = random_colour_masks(masks[i])
frame = cv2.addWeighted(frame, 1, rgb_mask, 0.5, 0)

# Draw keypoints
if "keypoint" in args.model:
for i in range(len(prediction["keypoints"])):
# get the detected keypoints
keypoints = prediction["keypoints"][i].cpu().detach().numpy()
# proceed to draw the lines
if prediction["scores"][i] > args.confidence:
keypoints = keypoints[:, :].reshape(-1, 3)
for p in range(keypoints.shape[0]):
# draw the keypoints
cv2.circle(
frame,
(int(keypoints[p, 0]), int(keypoints[p, 1])),
3,
(0, 0, 255),
thickness=-1,
lineType=cv2.FILLED,
)
# draw the lines joining the keypoints
for ie, e in enumerate(edges):
# get different colors for the edges
rgb = mpl.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0])
rgb = rgb * 255
# join the keypoint pairs to draw the skeletal structure
cv2.line(
frame,
(int(keypoints[e, 0][0]), int(keypoints[e, 1][0])),
(int(keypoints[e, 0][1]), int(keypoints[e, 1][1])),
tuple(rgb),
2,
lineType=cv2.LINE_AA,
)

# Write FPS
fps_text = f"FPS {fps:.2f}"
frame = cv2.putText(
frame, fps_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2
Expand All @@ -137,6 +221,51 @@ def main(args):
print(f"Processed {frame_count} frames at {fps:.2f} FPS")


def random_colour_masks(mask):
"""
Apply random colors to mask regions.
Args:
mask (np.ndarray): Binary mask.
Returns:
np.ndarray: Mask with random colors applied.
"""

# Define a list of predefined colors
colors = [
[0, 255, 0],
[0, 0, 255],
[255, 0, 0],
[0, 255, 255],
[255, 255, 0],
[255, 0, 255],
[80, 70, 180],
[250, 80, 190],
[245, 145, 50],
[70, 150, 250],
[50, 190, 190],
]

# Initialize empty color channels
r, g, b = (
np.zeros_like(mask).astype(np.uint8),
np.zeros_like(mask).astype(np.uint8),
np.zeros_like(mask).astype(np.uint8),
)

# Assign random colors to mask regions
r = np.zeros_like(mask).astype(np.uint8)
g = np.zeros_like(mask).astype(np.uint8)
b = np.zeros_like(mask).astype(np.uint8)
r[mask == 1], g[mask == 1], b[mask == 1] = colors[random.randrange(0, 10)]

# Create a colored mask by stacking the color channels
colored_mask = np.stack([r, g, b], axis=2)

return colored_mask


if __name__ == "__main__":
args = get_args_parser().parse_args()
main(args)

0 comments on commit 202e2c4

Please sign in to comment.