Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,5 @@ runs/

# data
data/

unet_segmetation/model/*
22 changes: 22 additions & 0 deletions unet_segmentation/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
cmake_minimum_required(VERSION 3.8)
project(unet_segmentation)

find_package(ament_cmake REQUIRED)
find_package(ament_cmake_python REQUIRED)
find_package(rclpy REQUIRED)

ament_python_install_package(${PROJECT_NAME})

install(DIRECTORY
launch
config
model
DESTINATION share/${PROJECT_NAME}
)

install(PROGRAMS
${PROJECT_NAME}/unet_segmentation_node.py
DESTINATION lib/${PROJECT_NAME}
)

ament_package()
17 changes: 17 additions & 0 deletions unet_segmentation/config/unet_segmentation.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/**:
ros__parameters:
model_file: "unet-simple-320-240-l-5-e10-b16(1).pth"
input_topic: "/gripper_camera/image_raw"
overlay_topic: "/segmentation/overlay"
mask_topic: "/segmentation/mask"
resize_width: 320
resize_height: 240
keep_original_size: true # upsample mask/overlay back to source image size
mask_threshold: 0.5
bilinear: false
simple: true
classes: 1
device: "cuda"
pred_color: [255, 0, 0]
overlay_alpha: 0.4
qos_depth: 3
53 changes: 53 additions & 0 deletions unet_segmentation/launch/unet_segmentation.launch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os

from ament_index_python.packages import get_package_share_directory
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument, OpaqueFunction
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node

ALLOWED_DEVICES = ['cpu', '0']


def validate_device(device: str):
if device not in ALLOWED_DEVICES:
raise RuntimeError(
f"Invalid device '{device}'. Choose one of: {', '.join(ALLOWED_DEVICES)}"
)


def launch_setup(context, *args, **kwargs):
device = LaunchConfiguration('device').perform(context)
validate_device(device)

unet_params = os.path.join(
get_package_share_directory('unet_segmentation'),
'config/unet_segmentation.yaml',
)

unet_node = Node(
package='unet_segmentation',
executable='unet_segmentation_node.py',
name='unet_segmentation',
namespace='unet',
output='screen',
parameters=[
unet_params,
{'device': device},
],
)

return [unet_node]


def generate_launch_description():
return LaunchDescription(
[
DeclareLaunchArgument(
'device',
default_value='0',
description='run unet segmentation',
),
OpaqueFunction(function=launch_setup),
]
)
Empty file.
20 changes: 20 additions & 0 deletions unet_segmentation/package.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>unet_segmentation</name>
<version>0.0.0</version>
<description>unet segmentation on raw image</description>
<maintainer email="[email protected]">gardeg</maintainer>
<license>MIT</license>

<buildtool_depend>ament_cmake</buildtool_depend>

<depend>rclpy</depend>
<depend>sensor_msgs</depend>
<depend>cv_bridge</depend>
<depend>vision_msgs</depend>

<export>
<build_type>ament_cmake</build_type>
</export>
</package>
Empty file.
Empty file.
3 changes: 3 additions & 0 deletions unet_segmentation/unet_segmentation/unet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .unet_model import UNet as UNet

__all__ = ["UNet"]
Comment on lines +1 to +3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed to run the code?

84 changes: 84 additions & 0 deletions unet_segmentation/unet_segmentation/unet/unet_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import torch
from torch import nn

from .unet_parts import DoubleConv, Down, OutConv, Up


class UNet(nn.Module):
def __init__(self, n_channels, n_classes, simple=False, bilinear=False):
"""The U-Net architecture.

:param n_channels: Number of input channels (e.g., 3 for RGB images)
:param n_classes: Number of output classes (e.g., 1 for binary segmentation)
:param simple: If True, creates a smaller U-Net with fewer layers.
:param bilinear: If True, use bilinear upsampling instead of transposed convolutions.
"""
super().__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.simple = simple

factor = 2 if bilinear else 1

if not self.simple:
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
self.down4 = Down(512, 1024 // factor)
self.up1 = Up(1024, 512 // factor, bilinear)
self.up2 = Up(512, 256 // factor, bilinear)
self.up3 = Up(256, 128 // factor, bilinear)
self.up4 = Up(128, 64, bilinear)
self.outc = OutConv(64, n_classes)
else:
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256 // factor)
self.up1 = Up(256, 128 // factor, bilinear)
self.up2 = Up(128, 64, bilinear)
self.outc = OutConv(64, n_classes)

def forward(self, x):
if not self.simple:
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
else:
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x = self.up1(x3, x2)
x = self.up2(x, x1)
logits = self.outc(x)

return logits

def use_checkpointing(self):
"""Enable gradient checkpointing to save memory, but at the cost of additional computation during backpropagation."""
if not self.simple:
self.inc = torch.utils.checkpoint.checkpoint(self.inc)
self.down1 = torch.utils.checkpoint.checkpoint(self.down1)
self.down2 = torch.utils.checkpoint.checkpoint(self.down2)
self.down3 = torch.utils.checkpoint.checkpoint(self.down3)
self.down4 = torch.utils.checkpoint.checkpoint(self.down4)
self.up1 = torch.utils.checkpoint.checkpoint(self.up1)
self.up2 = torch.utils.checkpoint.checkpoint(self.up2)
self.up3 = torch.utils.checkpoint.checkpoint(self.up3)
self.up4 = torch.utils.checkpoint.checkpoint(self.up4)
self.outc = torch.utils.checkpoint.checkpoint(self.outc)
Comment on lines +68 to +77
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation is incorrect. torch.utils.checkpoint.checkpoint is a function that performs checkpointing during forward pass, not a wrapper for modules. Assigning it to module attributes will break the model. Gradient checkpointing should be applied in the forward() method by wrapping module calls with torch.utils.checkpoint.checkpoint(module, inputs), or by using the module's .gradient_checkpointing_enable() method if available.

Copilot uses AI. Check for mistakes.
else:
self.inc = torch.utils.checkpoint.checkpoint(self.inc)
self.down1 = torch.utils.checkpoint.checkpoint(self.down1)
self.down2 = torch.utils.checkpoint.checkpoint(self.down2)
self.up1 = torch.utils.checkpoint.checkpoint(self.up1)
self.up2 = torch.utils.checkpoint.checkpoint(self.up2)
self.outc = torch.utils.checkpoint.checkpoint(self.outc)
77 changes: 77 additions & 0 deletions unet_segmentation/unet_segmentation/unet/unet_parts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Parts of the U-Net model."""

import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2."""

def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)

def forward(self, x):
return self.double_conv(x)


class Down(nn.Module):
"""Downscaling with maxpool then double conv."""

def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2), DoubleConv(in_channels, out_channels)
)

def forward(self, x):
return self.maxpool_conv(x)


class Up(nn.Module):
"""Upscaling then double conv."""

def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()

# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(
in_channels, in_channels // 2, kernel_size=2, stride=2
)
self.conv = DoubleConv(in_channels, out_channels)

def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffy = x2.size()[2] - x1.size()[2]
diffx = x2.size()[3] - x1.size()[3]

x1 = F.pad(x1, [diffx // 2, diffx - diffx // 2, diffy // 2, diffy - diffy // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1)
return self.conv(x)


class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

def forward(self, x):
return self.conv(x)
Loading