Skip to content

Commit bb4a929

Browse files
author
Mark-ZhouWX
committed
decouple image , box and mask in resize and pad transform
1 parent 3ba859f commit bb4a929

File tree

2 files changed

+23
-66
lines changed

2 files changed

+23
-66
lines changed

research/segment-anything/inference_one_image.py

+5-55
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
import mindspore as ms
88

99
from segment_anything.build_sam import sam_model_registry
10-
from segment_anything.dataset.transform import TransformPipeline, ImageNorm
11-
from segment_anything.utils.transforms import ResizeLongestSide
10+
from segment_anything.dataset.transform import TransformPipeline, ImageNorm, ImageResizeAndPad
1211
import matplotlib.pyplot as plt
1312
import time
1413

@@ -29,65 +28,13 @@ def __exit__(self, exc_type, exc_val, exc_tb):
2928
print(f'{self.name} cost time {self.end - self.start:.3f}')
3029

3130

32-
class ImageResizeAndPad:
33-
34-
def __init__(self, target_size):
35-
"""
36-
Args:
37-
target_size (int): target size of model input (1024 in sam)
38-
"""
39-
self.target_size = target_size
40-
self.transform = ResizeLongestSide(target_size)
41-
42-
def __call__(self, result_dict):
43-
"""
44-
Resize input to the long size and then pad it to the model input size (1024*1024 in sam).
45-
Pad masks and boxes to a fixed length for graph mode
46-
Required keys: image, masks, boxes
47-
Update keys: image, masks, boxes
48-
Add keys:
49-
origin_hw (np.array): array with shape (4), represents original image height, width
50-
and resized height, width, respectively. This array record the trace of image shape transformation
51-
and is used for visualization.
52-
image_pad_area (Tuple): image padding area in h and w direction, in the format of
53-
((pad_h_left, pad_h_right), (pad_w_left, pad_w_right))
54-
"""
55-
56-
image = result_dict['image']
57-
boxes = result_dict['boxes']
58-
59-
og_h, og_w, _ = image.shape
60-
image = self.transform.apply_image(image)
61-
resized_h, resized_w, _ = image.shape
62-
63-
# Pad image and masks to the model input
64-
h, w, c = image.shape
65-
max_dim = max(h, w) # long side length
66-
assert max_dim == self.target_size
67-
# pad 0 to the right and bottom side
68-
pad_h = max_dim - h
69-
pad_w = max_dim - w
70-
img_padding = ((0, pad_h), (0, pad_w), (0, 0))
71-
image = np.pad(image, pad_width=img_padding, constant_values=0) # (h, w, c)
72-
73-
# Adjust bounding boxes
74-
boxes = self.transform.apply_boxes(boxes, (og_h, og_w)).astype(np.float32)
75-
76-
result_dict['origin_hw'] = np.array([og_h, og_w, resized_h, resized_w], np.int32) # record image shape trace for visualization
77-
result_dict['image'] = image
78-
result_dict['boxes'] = boxes
79-
result_dict['image_pad_area'] = img_padding[:2]
80-
81-
return result_dict
82-
83-
8431
def infer(args):
8532
ms.context.set_context(mode=args.mode, device_target=args.device)
8633

8734
# Step1: data preparation
8835
with Timer('preprocess'):
8936
transform_list = [
90-
ImageResizeAndPad(target_size=1024),
37+
ImageResizeAndPad(target_size=1024, apply_mask=False),
9138
ImageNorm(),
9239
]
9340
transform_pipeline = TransformPipeline(transform_list)
@@ -99,6 +46,9 @@ def infer(args):
9946

10047
transformed = transform_pipeline(dict(image=image_np, boxes=boxes_np))
10148
image, boxes, origin_hw = transformed['image'], transformed['boxes'], transformed['origin_hw']
49+
# batch_size for speed test
50+
# image = ms.Tensor(np.expand_dims(image, 0).repeat(8, axis=0)) # b, 3, 1023
51+
# boxes = ms.Tensor(np.expand_dims(boxes, 0).repeat(8, axis=0)) # b, n, 4
10252
image = ms.Tensor(image).unsqueeze(0) # b, 3, 1023
10353
boxes = ms.Tensor(boxes).unsqueeze(0) # b, n, 4
10454

research/segment-anything/segment_anything/dataset/transform.py

+18-11
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,17 @@ def __call__(self, result_dict):
136136
@TRANSFORM_REGISTRY.registry_module()
137137
class ImageResizeAndPad:
138138

139-
def __init__(self, target_size):
139+
def __init__(self, target_size, apply_box=True, apply_mask=True):
140140
"""
141141
Args:
142142
target_size (int): target size of model input (1024 in sam)
143+
apply_box: also resize and pad box accordingly beside image
144+
apply_mask: also resize and pad mask accordingly beside image
143145
"""
144146
self.target_size = target_size
145147
self.transform = ResizeLongestSide(target_size)
148+
self.apply_box = apply_box
149+
self.apply_mask = apply_mask
146150

147151
def __call__(self, result_dict):
148152
"""
@@ -159,13 +163,12 @@ def __call__(self, result_dict):
159163
"""
160164

161165
image = result_dict['image']
162-
masks = result_dict['masks']
163-
boxes = result_dict['boxes']
166+
masks = result_dict.get('masks')
167+
boxes = result_dict.get('boxes')
164168

165169
og_h, og_w, _ = image.shape
166170
image = self.transform.apply_image(image)
167171
resized_h, resized_w, _ = image.shape
168-
masks = np.stack([self.transform.apply_image(mask) for mask in masks])
169172

170173
# Pad image and masks to the model input
171174
h, w, c = image.shape
@@ -176,18 +179,22 @@ def __call__(self, result_dict):
176179
pad_w = max_dim - w
177180
img_padding = ((0, pad_h), (0, pad_w), (0, 0))
178181
image = np.pad(image, pad_width=img_padding, constant_values=0) # (h, w, c)
179-
mask_padding = ((0, 0), (0, pad_h), (0, pad_w)) # (n, h, w)
180-
masks = np.pad(masks, pad_width=mask_padding, constant_values=0)
181-
182-
# Adjust bounding boxes
183-
boxes = self.transform.apply_boxes(boxes, (og_h, og_w)).astype(np.float32)
184182

185183
result_dict['origin_hw'] = np.array([og_h, og_w, resized_h, resized_w], np.int32) # record image shape trace for visualization
186184
result_dict['image'] = image
187-
result_dict['masks'] = masks
188-
result_dict['boxes'] = boxes
189185
result_dict['image_pad_area'] = img_padding[:2]
190186

187+
if self.apply_box:
188+
# Adjust bounding boxes
189+
boxes = self.transform.apply_boxes(boxes, (og_h, og_w)).astype(np.float32)
190+
result_dict['boxes'] = boxes
191+
192+
if self.apply_mask:
193+
masks = np.stack([self.transform.apply_image(mask) for mask in masks])
194+
mask_padding = ((0, 0), (0, pad_h), (0, pad_w)) # (n, h, w)
195+
masks = np.pad(masks, pad_width=mask_padding, constant_values=0)
196+
result_dict['masks'] = masks
197+
191198
return result_dict
192199

193200

0 commit comments

Comments
 (0)