-
Notifications
You must be signed in to change notification settings - Fork 0
unet node for image segmentation #27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d746a0f
fc191f3
55301df
0a1cfc9
3c34486
9c690dc
8ee5956
468884a
3233797
de618f8
5f16759
c999220
cbe6f84
6c34ddd
81b01b5
beb0784
b8cbe78
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -165,3 +165,5 @@ runs/ | |
|
|
||
| # data | ||
| data/ | ||
|
|
||
| unet_segmetation/model/* | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| cmake_minimum_required(VERSION 3.8) | ||
| project(unet_segmentation) | ||
|
|
||
| find_package(ament_cmake REQUIRED) | ||
| find_package(ament_cmake_python REQUIRED) | ||
| find_package(rclpy REQUIRED) | ||
|
|
||
| ament_python_install_package(${PROJECT_NAME}) | ||
|
|
||
| install(DIRECTORY | ||
| launch | ||
| config | ||
| model | ||
| DESTINATION share/${PROJECT_NAME} | ||
| ) | ||
|
|
||
| install(PROGRAMS | ||
| ${PROJECT_NAME}/unet_segmentation_node.py | ||
| DESTINATION lib/${PROJECT_NAME} | ||
| ) | ||
|
|
||
| ament_package() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| /**: | ||
| ros__parameters: | ||
| model_file: "unet-simple-320-240-l-5-e10-b16(1).pth" | ||
| input_topic: "/gripper_camera/image_raw" | ||
| overlay_topic: "/segmentation/overlay" | ||
| mask_topic: "/segmentation/mask" | ||
| resize_width: 320 | ||
| resize_height: 240 | ||
| keep_original_size: true # upsample mask/overlay back to source image size | ||
| mask_threshold: 0.5 | ||
| bilinear: false | ||
| simple: true | ||
| classes: 1 | ||
| device: "cuda" | ||
| pred_color: [255, 0, 0] | ||
| overlay_alpha: 0.4 | ||
| qos_depth: 3 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| import os | ||
|
|
||
| from ament_index_python.packages import get_package_share_directory | ||
| from launch import LaunchDescription | ||
| from launch.actions import DeclareLaunchArgument, OpaqueFunction | ||
| from launch.substitutions import LaunchConfiguration | ||
| from launch_ros.actions import Node | ||
|
|
||
| ALLOWED_DEVICES = ['cpu', '0'] | ||
|
|
||
|
|
||
| def validate_device(device: str): | ||
| if device not in ALLOWED_DEVICES: | ||
| raise RuntimeError( | ||
| f"Invalid device '{device}'. Choose one of: {', '.join(ALLOWED_DEVICES)}" | ||
| ) | ||
|
|
||
|
|
||
| def launch_setup(context, *args, **kwargs): | ||
| device = LaunchConfiguration('device').perform(context) | ||
| validate_device(device) | ||
|
|
||
| unet_params = os.path.join( | ||
| get_package_share_directory('unet_segmentation'), | ||
| 'config/unet_segmentation.yaml', | ||
| ) | ||
|
|
||
| unet_node = Node( | ||
| package='unet_segmentation', | ||
| executable='unet_segmentation_node.py', | ||
| name='unet_segmentation', | ||
| namespace='unet', | ||
| output='screen', | ||
| parameters=[ | ||
| unet_params, | ||
| {'device': device}, | ||
| ], | ||
| ) | ||
|
|
||
| return [unet_node] | ||
|
|
||
|
|
||
| def generate_launch_description(): | ||
| return LaunchDescription( | ||
| [ | ||
| DeclareLaunchArgument( | ||
| 'device', | ||
| default_value='0', | ||
| description='run unet segmentation', | ||
| ), | ||
| OpaqueFunction(function=launch_setup), | ||
| ] | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| <?xml version="1.0"?> | ||
| <?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?> | ||
| <package format="3"> | ||
| <name>unet_segmentation</name> | ||
| <version>0.0.0</version> | ||
| <description>unet segmentation on raw image</description> | ||
| <maintainer email="[email protected]">gardeg</maintainer> | ||
| <license>MIT</license> | ||
|
|
||
| <buildtool_depend>ament_cmake</buildtool_depend> | ||
|
|
||
| <depend>rclpy</depend> | ||
| <depend>sensor_msgs</depend> | ||
| <depend>cv_bridge</depend> | ||
| <depend>vision_msgs</depend> | ||
|
|
||
| <export> | ||
| <build_type>ament_cmake</build_type> | ||
| </export> | ||
| </package> |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| from .unet_model import UNet as UNet | ||
|
|
||
| __all__ = ["UNet"] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| import torch | ||
| from torch import nn | ||
|
|
||
| from .unet_parts import DoubleConv, Down, OutConv, Up | ||
|
|
||
|
|
||
| class UNet(nn.Module): | ||
| def __init__(self, n_channels, n_classes, simple=False, bilinear=False): | ||
| """The U-Net architecture. | ||
|
|
||
| :param n_channels: Number of input channels (e.g., 3 for RGB images) | ||
| :param n_classes: Number of output classes (e.g., 1 for binary segmentation) | ||
| :param simple: If True, creates a smaller U-Net with fewer layers. | ||
| :param bilinear: If True, use bilinear upsampling instead of transposed convolutions. | ||
| """ | ||
| super().__init__() | ||
| self.n_channels = n_channels | ||
| self.n_classes = n_classes | ||
| self.bilinear = bilinear | ||
| self.simple = simple | ||
|
|
||
| factor = 2 if bilinear else 1 | ||
|
|
||
| if not self.simple: | ||
| self.inc = DoubleConv(n_channels, 64) | ||
| self.down1 = Down(64, 128) | ||
| self.down2 = Down(128, 256) | ||
| self.down3 = Down(256, 512) | ||
| self.down4 = Down(512, 1024 // factor) | ||
| self.up1 = Up(1024, 512 // factor, bilinear) | ||
| self.up2 = Up(512, 256 // factor, bilinear) | ||
| self.up3 = Up(256, 128 // factor, bilinear) | ||
| self.up4 = Up(128, 64, bilinear) | ||
| self.outc = OutConv(64, n_classes) | ||
| else: | ||
| self.inc = DoubleConv(n_channels, 64) | ||
| self.down1 = Down(64, 128) | ||
| self.down2 = Down(128, 256 // factor) | ||
| self.up1 = Up(256, 128 // factor, bilinear) | ||
| self.up2 = Up(128, 64, bilinear) | ||
| self.outc = OutConv(64, n_classes) | ||
|
|
||
| def forward(self, x): | ||
| if not self.simple: | ||
| x1 = self.inc(x) | ||
| x2 = self.down1(x1) | ||
| x3 = self.down2(x2) | ||
| x4 = self.down3(x3) | ||
| x5 = self.down4(x4) | ||
| x = self.up1(x5, x4) | ||
| x = self.up2(x, x3) | ||
| x = self.up3(x, x2) | ||
| x = self.up4(x, x1) | ||
| logits = self.outc(x) | ||
| else: | ||
| x1 = self.inc(x) | ||
| x2 = self.down1(x1) | ||
| x3 = self.down2(x2) | ||
| x = self.up1(x3, x2) | ||
| x = self.up2(x, x1) | ||
| logits = self.outc(x) | ||
|
|
||
| return logits | ||
|
|
||
| def use_checkpointing(self): | ||
| """Enable gradient checkpointing to save memory, but at the cost of additional computation during backpropagation.""" | ||
| if not self.simple: | ||
| self.inc = torch.utils.checkpoint.checkpoint(self.inc) | ||
| self.down1 = torch.utils.checkpoint.checkpoint(self.down1) | ||
| self.down2 = torch.utils.checkpoint.checkpoint(self.down2) | ||
| self.down3 = torch.utils.checkpoint.checkpoint(self.down3) | ||
| self.down4 = torch.utils.checkpoint.checkpoint(self.down4) | ||
| self.up1 = torch.utils.checkpoint.checkpoint(self.up1) | ||
| self.up2 = torch.utils.checkpoint.checkpoint(self.up2) | ||
| self.up3 = torch.utils.checkpoint.checkpoint(self.up3) | ||
| self.up4 = torch.utils.checkpoint.checkpoint(self.up4) | ||
| self.outc = torch.utils.checkpoint.checkpoint(self.outc) | ||
|
Comment on lines
+68
to
+77
|
||
| else: | ||
| self.inc = torch.utils.checkpoint.checkpoint(self.inc) | ||
| self.down1 = torch.utils.checkpoint.checkpoint(self.down1) | ||
| self.down2 = torch.utils.checkpoint.checkpoint(self.down2) | ||
| self.up1 = torch.utils.checkpoint.checkpoint(self.up1) | ||
| self.up2 = torch.utils.checkpoint.checkpoint(self.up2) | ||
| self.outc = torch.utils.checkpoint.checkpoint(self.outc) | ||
kluge7 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| """Parts of the U-Net model.""" | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
|
|
||
|
|
||
| class DoubleConv(nn.Module): | ||
| """(convolution => [BN] => ReLU) * 2.""" | ||
|
|
||
| def __init__(self, in_channels, out_channels, mid_channels=None): | ||
| super().__init__() | ||
| if not mid_channels: | ||
| mid_channels = out_channels | ||
| self.double_conv = nn.Sequential( | ||
| nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), | ||
| nn.BatchNorm2d(mid_channels), | ||
| nn.ReLU(inplace=True), | ||
| nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), | ||
| nn.BatchNorm2d(out_channels), | ||
| nn.ReLU(inplace=True), | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| return self.double_conv(x) | ||
|
|
||
|
|
||
| class Down(nn.Module): | ||
| """Downscaling with maxpool then double conv.""" | ||
|
|
||
| def __init__(self, in_channels, out_channels): | ||
| super().__init__() | ||
| self.maxpool_conv = nn.Sequential( | ||
| nn.MaxPool2d(2), DoubleConv(in_channels, out_channels) | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| return self.maxpool_conv(x) | ||
|
|
||
|
|
||
| class Up(nn.Module): | ||
| """Upscaling then double conv.""" | ||
|
|
||
| def __init__(self, in_channels, out_channels, bilinear=True): | ||
| super().__init__() | ||
|
|
||
| # if bilinear, use the normal convolutions to reduce the number of channels | ||
| if bilinear: | ||
| self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) | ||
| self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) | ||
| else: | ||
| self.up = nn.ConvTranspose2d( | ||
| in_channels, in_channels // 2, kernel_size=2, stride=2 | ||
| ) | ||
| self.conv = DoubleConv(in_channels, out_channels) | ||
|
|
||
| def forward(self, x1, x2): | ||
| x1 = self.up(x1) | ||
| # input is CHW | ||
| diffy = x2.size()[2] - x1.size()[2] | ||
| diffx = x2.size()[3] - x1.size()[3] | ||
|
|
||
| x1 = F.pad(x1, [diffx // 2, diffx - diffx // 2, diffy // 2, diffy - diffy // 2]) | ||
| # if you have padding issues, see | ||
| # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a | ||
| # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd | ||
| x = torch.cat([x2, x1], dim=1) | ||
| return self.conv(x) | ||
|
|
||
|
|
||
| class OutConv(nn.Module): | ||
| def __init__(self, in_channels, out_channels): | ||
| super().__init__() | ||
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) | ||
|
|
||
| def forward(self, x): | ||
| return self.conv(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this needed to run the code?