7
7
import mindspore as ms
8
8
9
9
from segment_anything .build_sam import sam_model_registry
10
- from segment_anything .dataset .transform import TransformPipeline , ImageNorm
11
- from segment_anything .utils .transforms import ResizeLongestSide
10
+ from segment_anything .dataset .transform import TransformPipeline , ImageNorm , ImageResizeAndPad
12
11
import matplotlib .pyplot as plt
13
12
import time
14
13
@@ -29,65 +28,13 @@ def __exit__(self, exc_type, exc_val, exc_tb):
29
28
print (f'{ self .name } cost time { self .end - self .start :.3f} ' )
30
29
31
30
32
- class ImageResizeAndPad :
33
-
34
- def __init__ (self , target_size ):
35
- """
36
- Args:
37
- target_size (int): target size of model input (1024 in sam)
38
- """
39
- self .target_size = target_size
40
- self .transform = ResizeLongestSide (target_size )
41
-
42
- def __call__ (self , result_dict ):
43
- """
44
- Resize input to the long size and then pad it to the model input size (1024*1024 in sam).
45
- Pad masks and boxes to a fixed length for graph mode
46
- Required keys: image, masks, boxes
47
- Update keys: image, masks, boxes
48
- Add keys:
49
- origin_hw (np.array): array with shape (4), represents original image height, width
50
- and resized height, width, respectively. This array record the trace of image shape transformation
51
- and is used for visualization.
52
- image_pad_area (Tuple): image padding area in h and w direction, in the format of
53
- ((pad_h_left, pad_h_right), (pad_w_left, pad_w_right))
54
- """
55
-
56
- image = result_dict ['image' ]
57
- boxes = result_dict ['boxes' ]
58
-
59
- og_h , og_w , _ = image .shape
60
- image = self .transform .apply_image (image )
61
- resized_h , resized_w , _ = image .shape
62
-
63
- # Pad image and masks to the model input
64
- h , w , c = image .shape
65
- max_dim = max (h , w ) # long side length
66
- assert max_dim == self .target_size
67
- # pad 0 to the right and bottom side
68
- pad_h = max_dim - h
69
- pad_w = max_dim - w
70
- img_padding = ((0 , pad_h ), (0 , pad_w ), (0 , 0 ))
71
- image = np .pad (image , pad_width = img_padding , constant_values = 0 ) # (h, w, c)
72
-
73
- # Adjust bounding boxes
74
- boxes = self .transform .apply_boxes (boxes , (og_h , og_w )).astype (np .float32 )
75
-
76
- result_dict ['origin_hw' ] = np .array ([og_h , og_w , resized_h , resized_w ], np .int32 ) # record image shape trace for visualization
77
- result_dict ['image' ] = image
78
- result_dict ['boxes' ] = boxes
79
- result_dict ['image_pad_area' ] = img_padding [:2 ]
80
-
81
- return result_dict
82
-
83
-
84
31
def infer (args ):
85
32
ms .context .set_context (mode = args .mode , device_target = args .device )
86
33
87
34
# Step1: data preparation
88
35
with Timer ('preprocess' ):
89
36
transform_list = [
90
- ImageResizeAndPad (target_size = 1024 ),
37
+ ImageResizeAndPad (target_size = 1024 , apply_mask = False ),
91
38
ImageNorm (),
92
39
]
93
40
transform_pipeline = TransformPipeline (transform_list )
@@ -99,6 +46,9 @@ def infer(args):
99
46
100
47
transformed = transform_pipeline (dict (image = image_np , boxes = boxes_np ))
101
48
image , boxes , origin_hw = transformed ['image' ], transformed ['boxes' ], transformed ['origin_hw' ]
49
+ # batch_size for speed test
50
+ # image = ms.Tensor(np.expand_dims(image, 0).repeat(8, axis=0)) # b, 3, 1023
51
+ # boxes = ms.Tensor(np.expand_dims(boxes, 0).repeat(8, axis=0)) # b, n, 4
102
52
image = ms .Tensor (image ).unsqueeze (0 ) # b, 3, 1023
103
53
boxes = ms .Tensor (boxes ).unsqueeze (0 ) # b, n, 4
104
54
0 commit comments