diff --git a/.gitignore b/.gitignore index af8bb7b..b4021ff 100644 --- a/.gitignore +++ b/.gitignore @@ -165,3 +165,5 @@ runs/ # data data/ + +unet_segmetation/model/* diff --git a/unet_segmentation/CMakeLists.txt b/unet_segmentation/CMakeLists.txt new file mode 100644 index 0000000..df9b0ed --- /dev/null +++ b/unet_segmentation/CMakeLists.txt @@ -0,0 +1,22 @@ +cmake_minimum_required(VERSION 3.8) +project(unet_segmentation) + +find_package(ament_cmake REQUIRED) +find_package(ament_cmake_python REQUIRED) +find_package(rclpy REQUIRED) + +ament_python_install_package(${PROJECT_NAME}) + +install(DIRECTORY + launch + config + model + DESTINATION share/${PROJECT_NAME} +) + +install(PROGRAMS + ${PROJECT_NAME}/unet_segmentation_node.py + DESTINATION lib/${PROJECT_NAME} +) + +ament_package() diff --git a/unet_segmentation/config/unet_segmentation.yaml b/unet_segmentation/config/unet_segmentation.yaml new file mode 100644 index 0000000..ab14c77 --- /dev/null +++ b/unet_segmentation/config/unet_segmentation.yaml @@ -0,0 +1,17 @@ +/**: + ros__parameters: + model_file: "unet-simple-320-240-l-5-e10-b16(1).pth" + input_topic: "/gripper_camera/image_raw" + overlay_topic: "/segmentation/overlay" + mask_topic: "/segmentation/mask" + resize_width: 320 + resize_height: 240 + keep_original_size: true # upsample mask/overlay back to source image size + mask_threshold: 0.5 + bilinear: false + simple: true + classes: 1 + device: "cuda" + pred_color: [255, 0, 0] + overlay_alpha: 0.4 + qos_depth: 3 diff --git a/unet_segmentation/launch/unet_segmentation.launch.py b/unet_segmentation/launch/unet_segmentation.launch.py new file mode 100644 index 0000000..a924f45 --- /dev/null +++ b/unet_segmentation/launch/unet_segmentation.launch.py @@ -0,0 +1,53 @@ +import os + +from ament_index_python.packages import get_package_share_directory +from launch import LaunchDescription +from launch.actions import DeclareLaunchArgument, OpaqueFunction +from launch.substitutions import LaunchConfiguration +from launch_ros.actions import Node + +ALLOWED_DEVICES = ['cpu', '0'] + + +def validate_device(device: str): + if device not in ALLOWED_DEVICES: + raise RuntimeError( + f"Invalid device '{device}'. Choose one of: {', '.join(ALLOWED_DEVICES)}" + ) + + +def launch_setup(context, *args, **kwargs): + device = LaunchConfiguration('device').perform(context) + validate_device(device) + + unet_params = os.path.join( + get_package_share_directory('unet_segmentation'), + 'config/unet_segmentation.yaml', + ) + + unet_node = Node( + package='unet_segmentation', + executable='unet_segmentation_node.py', + name='unet_segmentation', + namespace='unet', + output='screen', + parameters=[ + unet_params, + {'device': device}, + ], + ) + + return [unet_node] + + +def generate_launch_description(): + return LaunchDescription( + [ + DeclareLaunchArgument( + 'device', + default_value='0', + description='run unet segmentation', + ), + OpaqueFunction(function=launch_setup), + ] + ) diff --git a/unet_segmentation/model/.gitkeep b/unet_segmentation/model/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/unet_segmentation/package.xml b/unet_segmentation/package.xml new file mode 100644 index 0000000..3225488 --- /dev/null +++ b/unet_segmentation/package.xml @@ -0,0 +1,20 @@ + + + + unet_segmentation + 0.0.0 + unet segmentation on raw image + gardeg + MIT + + ament_cmake + + rclpy + sensor_msgs + cv_bridge + vision_msgs + + + ament_cmake + + diff --git a/unet_segmentation/resource/unet_segmetation b/unet_segmentation/resource/unet_segmetation new file mode 100644 index 0000000..e69de29 diff --git a/unet_segmentation/unet_segmentation/__init__.py b/unet_segmentation/unet_segmentation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/unet_segmentation/unet_segmentation/unet/__init__.py b/unet_segmentation/unet_segmentation/unet/__init__.py new file mode 100644 index 0000000..2af036e --- /dev/null +++ b/unet_segmentation/unet_segmentation/unet/__init__.py @@ -0,0 +1,3 @@ +from .unet_model import UNet as UNet + +__all__ = ["UNet"] diff --git a/unet_segmentation/unet_segmentation/unet/unet_model.py b/unet_segmentation/unet_segmentation/unet/unet_model.py new file mode 100644 index 0000000..b287667 --- /dev/null +++ b/unet_segmentation/unet_segmentation/unet/unet_model.py @@ -0,0 +1,84 @@ +import torch +from torch import nn + +from .unet_parts import DoubleConv, Down, OutConv, Up + + +class UNet(nn.Module): + def __init__(self, n_channels, n_classes, simple=False, bilinear=False): + """The U-Net architecture. + + :param n_channels: Number of input channels (e.g., 3 for RGB images) + :param n_classes: Number of output classes (e.g., 1 for binary segmentation) + :param simple: If True, creates a smaller U-Net with fewer layers. + :param bilinear: If True, use bilinear upsampling instead of transposed convolutions. + """ + super().__init__() + self.n_channels = n_channels + self.n_classes = n_classes + self.bilinear = bilinear + self.simple = simple + + factor = 2 if bilinear else 1 + + if not self.simple: + self.inc = DoubleConv(n_channels, 64) + self.down1 = Down(64, 128) + self.down2 = Down(128, 256) + self.down3 = Down(256, 512) + self.down4 = Down(512, 1024 // factor) + self.up1 = Up(1024, 512 // factor, bilinear) + self.up2 = Up(512, 256 // factor, bilinear) + self.up3 = Up(256, 128 // factor, bilinear) + self.up4 = Up(128, 64, bilinear) + self.outc = OutConv(64, n_classes) + else: + self.inc = DoubleConv(n_channels, 64) + self.down1 = Down(64, 128) + self.down2 = Down(128, 256 // factor) + self.up1 = Up(256, 128 // factor, bilinear) + self.up2 = Up(128, 64, bilinear) + self.outc = OutConv(64, n_classes) + + def forward(self, x): + if not self.simple: + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + x = self.up1(x5, x4) + x = self.up2(x, x3) + x = self.up3(x, x2) + x = self.up4(x, x1) + logits = self.outc(x) + else: + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x = self.up1(x3, x2) + x = self.up2(x, x1) + logits = self.outc(x) + + return logits + + def use_checkpointing(self): + """Enable gradient checkpointing to save memory, but at the cost of additional computation during backpropagation.""" + if not self.simple: + self.inc = torch.utils.checkpoint.checkpoint(self.inc) + self.down1 = torch.utils.checkpoint.checkpoint(self.down1) + self.down2 = torch.utils.checkpoint.checkpoint(self.down2) + self.down3 = torch.utils.checkpoint.checkpoint(self.down3) + self.down4 = torch.utils.checkpoint.checkpoint(self.down4) + self.up1 = torch.utils.checkpoint.checkpoint(self.up1) + self.up2 = torch.utils.checkpoint.checkpoint(self.up2) + self.up3 = torch.utils.checkpoint.checkpoint(self.up3) + self.up4 = torch.utils.checkpoint.checkpoint(self.up4) + self.outc = torch.utils.checkpoint.checkpoint(self.outc) + else: + self.inc = torch.utils.checkpoint.checkpoint(self.inc) + self.down1 = torch.utils.checkpoint.checkpoint(self.down1) + self.down2 = torch.utils.checkpoint.checkpoint(self.down2) + self.up1 = torch.utils.checkpoint.checkpoint(self.up1) + self.up2 = torch.utils.checkpoint.checkpoint(self.up2) + self.outc = torch.utils.checkpoint.checkpoint(self.outc) diff --git a/unet_segmentation/unet_segmentation/unet/unet_parts.py b/unet_segmentation/unet_segmentation/unet/unet_parts.py new file mode 100644 index 0000000..062d69f --- /dev/null +++ b/unet_segmentation/unet_segmentation/unet/unet_parts.py @@ -0,0 +1,77 @@ +"""Parts of the U-Net model.""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class DoubleConv(nn.Module): + """(convolution => [BN] => ReLU) * 2.""" + + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + if not mid_channels: + mid_channels = out_channels + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(mid_channels), + nn.ReLU(inplace=True), + nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + return self.double_conv(x) + + +class Down(nn.Module): + """Downscaling with maxpool then double conv.""" + + def __init__(self, in_channels, out_channels): + super().__init__() + self.maxpool_conv = nn.Sequential( + nn.MaxPool2d(2), DoubleConv(in_channels, out_channels) + ) + + def forward(self, x): + return self.maxpool_conv(x) + + +class Up(nn.Module): + """Upscaling then double conv.""" + + def __init__(self, in_channels, out_channels, bilinear=True): + super().__init__() + + # if bilinear, use the normal convolutions to reduce the number of channels + if bilinear: + self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) + else: + self.up = nn.ConvTranspose2d( + in_channels, in_channels // 2, kernel_size=2, stride=2 + ) + self.conv = DoubleConv(in_channels, out_channels) + + def forward(self, x1, x2): + x1 = self.up(x1) + # input is CHW + diffy = x2.size()[2] - x1.size()[2] + diffx = x2.size()[3] - x1.size()[3] + + x1 = F.pad(x1, [diffx // 2, diffx - diffx // 2, diffy // 2, diffy - diffy // 2]) + # if you have padding issues, see + # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a + # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd + x = torch.cat([x2, x1], dim=1) + return self.conv(x) + + +class OutConv(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) + + def forward(self, x): + return self.conv(x) diff --git a/unet_segmentation/unet_segmentation/unet_segmentation_node.py b/unet_segmentation/unet_segmentation/unet_segmentation_node.py new file mode 100644 index 0000000..466d3e6 --- /dev/null +++ b/unet_segmentation/unet_segmentation/unet_segmentation_node.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python3 +from pathlib import Path + +import cv2 +import numpy as np +import rclpy +import torch +from ament_index_python.packages import get_package_share_directory +from cv_bridge import CvBridge +from PIL import Image as PILImage +from rclpy.node import Node +from rclpy.qos import QoSProfile +from sensor_msgs.msg import Image + +from unet_segmentation.utils import ( + ResizeIfLargerKeepAspect, + build_image_transforms, + load_unet, + make_overlay, + mask_to_mono8, + predict_mask, + upsample_mask_nearest, +) + + +def default_config_path() -> str: + share_dir = Path(get_package_share_directory('unet_segmentation')) + return str(share_dir / 'config' / 'unet_segmentation.yaml') + + +class UnetSegmentationNode(Node): + def __init__(self, config_path: str | None = None): + super().__init__("unet_segmentation") + + # Optional: ensure a params file exists when started directly + cfg_path = Path(config_path) if config_path else Path(default_config_path()) + if not cfg_path.is_file(): + self.get_logger().warn( + f"Params file not found at {cfg_path}. " + "This is fine if parameters are provided via --params-file in the launch." + ) + + self._declare_and_load_parameters() + + # Resolve device + self.device = self._make_device(self.device_param) + self.bridge = CvBridge() + + # Validate model path + model_f = Path(self.model_file).expanduser() + if not model_f.exists(): + self.get_logger().fatal(f"Model file not found: {model_f}") + raise SystemExit(1) + + # Load network + self.net = load_unet( + model_file=str(model_f), + n_classes=self.classes, + device=self.device, + bilinear=bool(self.bilinear), + simple=bool(self.simple), + logger=self.get_logger(), + ) + + # Build transforms (downscale-only) + self.image_transforms = build_image_transforms( + int(self.resize_width), int(self.resize_height) + ) + + # I/O + qos = QoSProfile(depth=int(self.qos_depth)) + self.subscription = self.create_subscription( + Image, self.input_topic, self.image_callback, qos + ) + self.overlay_pub = self.create_publisher(Image, self.overlay_topic, qos) + self.mask_pub = self.create_publisher(Image, self.mask_topic, qos) + + self.get_logger().info( + f"Subscribing to '{self.input_topic}', " + f"publishing overlay to '{self.overlay_topic}' and mask to '{self.mask_topic}'." + ) + self.get_logger().info(f"Running on device: {self.device}") + + # --- helpers ------------------------------------------------------------- + + def _declare_and_load_parameters(self): + """Declare parameters with defaults and bind them to attributes.""" + defaults = { + 'model_file': 'model/unet.pth', + 'input_topic': 'image_raw', + 'overlay_topic': '/segmentation/overlay', + 'mask_topic': '/segmentation/mask', + 'resize_width': 320, + 'resize_height': 240, + 'keep_original_size': True, + 'mask_threshold': 0.5, + 'bilinear': False, + 'simple': True, + 'classes': 1, # YAML key is 'classes' + 'device': 'cpu', # 'cpu', 'cuda', or CUDA index like '0' + 'pred_color': [255, 0, 0], + 'overlay_alpha': 0.4, + 'qos_depth': 3, + } + + for name, default in defaults.items(): + self.declare_parameter(name, default) + + # Bind as attributes + self.model_flie = self.get_parameter('model_file').value + self.input_topic = self.get_parameter('input_topic').value + self.overlay_topic = self.get_parameter('overlay_topic').value + self.mask_topic = self.get_parameter('mask_topic').value + self.resize_width = self.get_parameter('resize_width').value + self.resize_height = self.get_parameter('resize_height').value + self.keep_original_size = self.get_parameter('keep_original_size').value + self.mask_threshold = float(self.get_parameter('mask_threshold').value) + self.bilinear = self.get_parameter('bilinear').value + self.simple = self.get_parameter('simple').value + self.classes = int(self.get_parameter('classes').value) + self.device_param = self.get_parameter('device').value + self.pred_color = tuple(self.get_parameter('pred_color').value) + self.overlay_alpha = float(self.get_parameter('overlay_alpha').value) + self.qos_depth = int(self.get_parameter('qos_depth').value) + + @staticmethod + def _make_device(device_param: str) -> torch.device: + if device_param == 'cpu': + return torch.device('cpu') + if device_param == 'cuda': + return torch.device('cuda') + # allow CUDA index like '0' + if device_param.isdigit(): + idx = int(device_param) + return torch.device(f'cuda:{idx}') + # fallback + return torch.device(device_param) + + # --- ROS callbacks ------------------------------------------------------- + + def image_callback(self, msg: Image): + try: + cv_bgr = self.bridge.imgmsg_to_cv2(msg, "bgr8") + base_rgb = cv2.cvtColor(cv_bgr, cv2.COLOR_BGR2RGB) + orig_h, orig_w = base_rgb.shape[:2] + pil_img = PILImage.fromarray(base_rgb) + + # Resize (downscale only) for inference + resized_pil = ResizeIfLargerKeepAspect( + int(self.resize_width), int(self.resize_height) + )(pil_img) + resized_w, resized_h = resized_pil.size + + image_tensor = self.image_transforms(resized_pil) + + # Predict (in resized space) + pred_mask_np = predict_mask( + self.net, + image_tensor, + self.device, + out_threshold=self.mask_threshold, + ) + + # Optionally upsample mask to original size + if self.keep_original_size: + mask_out = upsample_mask_nearest(pred_mask_np, orig_w, orig_h) + base_for_overlay = base_rgb + else: + mask_out = pred_mask_np + base_for_overlay = np.array(resized_pil) + + # Build overlay + overlay_np = make_overlay( + base_for_overlay, + mask_out, + color=self.pred_color, + alpha=self.overlay_alpha, + ) + + # Publish mask (mono8) and overlay (rgb8) + mask_mono8 = mask_to_mono8(mask_out) + mask_msg = self.bridge.cv2_to_imgmsg(mask_mono8, encoding="mono8") + mask_msg.header = msg.header + self.mask_pub.publish(mask_msg) + + overlay_msg = self.bridge.cv2_to_imgmsg(overlay_np, encoding="rgb8") + overlay_msg.header = msg.header + self.overlay_pub.publish(overlay_msg) + + except Exception as e: + self.get_logger().error(f'Failed to process image: {e}') + + +def main(): + rclpy.init() + node = UnetSegmentationNode(default_config_path()) + rclpy.spin(node) + node.destroy_node() + rclpy.shutdown() + + +if __name__ == '__main__': + main() diff --git a/unet_segmentation/unet_segmentation/utils.py b/unet_segmentation/unet_segmentation/utils.py new file mode 100644 index 0000000..b7fa382 --- /dev/null +++ b/unet_segmentation/unet_segmentation/utils.py @@ -0,0 +1,152 @@ +import os + +import cv2 +import numpy as np +import torch +import torchvision.transforms.functional as F +from PIL import Image as PILImage +from torchvision import transforms + +from .unet import UNet + + +def predict_mask( + net: torch.nn.Module, + image_tensor: torch.Tensor, + device: torch.device, + out_threshold: float = 0.5, +) -> np.ndarray: + """Performs inference on a single image tensor. + + Returns a HxW mask (int64) with values in {0,1,...,n_classes-1} for multi-class + or {0,1} for binary models. + """ + net.eval() + img = image_tensor.unsqueeze(0).to(device=device, dtype=torch.float32) + + with torch.no_grad(): + output = net(img).cpu() + if getattr(net, "n_classes", 1) > 1: + mask = output.argmax(dim=1) + else: + mask = (torch.sigmoid(output) > out_threshold).long() + + return mask[0].long().squeeze().numpy() + + +def blend_image_and_mask( + original_image: PILImage.Image, + mask_array: np.ndarray, + color: tuple[int, int, int], + alpha: float = 0.4, +) -> PILImage.Image: + """Blends a mask over a PIL image using RGBA compositing.""" + original_image = original_image.convert("RGBA") + overlay = PILImage.new("RGBA", original_image.size, (0, 0, 0, 0)) + overlay_np = np.array(overlay) + overlay_np[mask_array == 1] = (*color, int(255 * alpha)) + overlay = PILImage.fromarray(overlay_np) + blended = PILImage.alpha_composite(original_image, overlay) + return blended.convert("RGB") + + +class ResizeIfLargerKeepAspect: + """Downscale a PIL image only if it's larger than the target size, preserving aspect ratio. Never upscales.""" + + def __init__( + self, + max_width: int, + max_height: int, + interpolation=transforms.InterpolationMode.BILINEAR, + ): + self.max_width = max_width + self.max_height = max_height + self.interpolation = interpolation + + def __call__(self, img: PILImage.Image) -> PILImage.Image: + w, h = img.size + if w > self.max_width or h > self.max_height: + scale = min(self.max_width / w, self.max_height / h) + new_w, new_h = int(w * scale), int(h * scale) + img = F.resize(img, (new_h, new_w), interpolation=self.interpolation) + return img + + +def build_image_transforms(max_w: int, max_h: int) -> transforms.Compose: + """Returns a torchvision Compose that resizes (downscale only), converts to tensor, and normalizes.""" + return transforms.Compose( + [ + ResizeIfLargerKeepAspect(max_width=max_w, max_height=max_h), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + + +def load_unet( + model_path: str, + n_classes: int, + device: torch.device, + bilinear: bool, + simple: bool, + logger=None, +) -> "UNet": + # Resolve relative paths to absolute paths + model_path = os.path.abspath(os.path.expanduser(model_path)) + + if logger: + logger.info(f'Loading model from {model_path}') + logger.info(f'Using device {device}') + logger.info(f'simple={simple}, bilinear={bilinear}, n_classes={n_classes}') + + # Create network with named args + net = UNet(n_channels=3, n_classes=n_classes, simple=simple, bilinear=bilinear) + + try: + state_dict = torch.load(model_path, map_location=device) + _ = state_dict.pop('mask_values', None) + + missing, unexpected = net.load_state_dict(state_dict, strict=False) + + if logger: + if missing: + logger.warning(f'Missing keys when loading: {missing}') + if unexpected: + logger.warning(f'Unexpected keys when loading: {unexpected}') + logger.info('Model loaded successfully!') + except FileNotFoundError: + if logger: + logger.fatal( + f"Model file not found at {model_path}. Please check the path." + ) + raise + + net.to(device) + net.eval() + return net + + +def upsample_mask_nearest( + mask_np: np.ndarray, target_w: int, target_h: int +) -> np.ndarray: + """Upsample a HxW mask (int) to (target_h, target_w) via nearest-neighbor.""" + return cv2.resize( + mask_np.astype(np.uint8), (target_w, target_h), interpolation=cv2.INTER_NEAREST + ) + + +def mask_to_mono8(mask_np: np.ndarray) -> np.ndarray: + mask_np = mask_np.astype(np.uint8) + # If it’s binary (only 0/1), scale to 0/255 so it’s visible + if mask_np.max() <= 1: + return (mask_np * 255).astype(np.uint8) + return mask_np + + +def make_overlay( + base_rgb_np: np.ndarray, mask_np: np.ndarray, color=(255, 0, 0), alpha=0.4 +) -> np.ndarray: + mask_bin = (mask_np > 0).astype(np.uint8) + overlay = np.zeros_like(base_rgb_np) + overlay[mask_bin == 1] = np.array(color, dtype=np.uint8) + return cv2.addWeighted(base_rgb_np, 1.0, overlay, alpha, 0)