Skip to content

Commit f03e930

Browse files
author
Mark-ZhouWX
committed
add lite inference
1 parent bb4a929 commit f03e930

File tree

8 files changed

+220
-18
lines changed

8 files changed

+220
-18
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[acl_init_options]
2+
ge.exec.precision_mode="allow_fp32_to_fp16"

research/segment-anything/export.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import argparse
2+
import os
3+
4+
import mindspore_lite as mslite
5+
import mindspore as ms
6+
from mindspore import ops
7+
8+
from segment_anything import sam_model_registry
9+
10+
11+
def main(args):
12+
# Step0: prepare
13+
ms.set_context(device_target='CPU') # put mindspore on CPU to avoid competing for ascend memory with mindspore_lite
14+
os.makedirs(args.model_path_wo_ext, exist_ok=True)
15+
model_path_wo_ext = os.path.join(args.model_path_wo_ext, 'sam_' + args.model_type)
16+
mindir_path = os.path.join(model_path_wo_ext + '.mindir')
17+
lite_path_wo_ext = os.path.join(model_path_wo_ext + f"_lite")
18+
lite_path = os.path.join(model_path_wo_ext + f"_lite.mindir")
19+
# model
20+
model = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
21+
22+
# Step 1: export mindir
23+
if args.export_mindir:
24+
# input
25+
image = ops.ones(shape=(1, 3, 1024, 1024), dtype=ms.float32) # b, 3, 1024, 1024
26+
boxes = ops.ones(shape=(1, 1, 4), dtype=ms.float32) # b, n, 4
27+
inputs = (image, boxes)
28+
model.set_inputs(*inputs)
29+
print(f"start export mindir")
30+
ms.export(model, *inputs, file_name=model_path_wo_ext, file_format="MINDIR")
31+
print(f"finish export mindir")
32+
33+
print(f'mind ir path: {mindir_path}')
34+
print(f'lite path wo_ext: {lite_path_wo_ext}')
35+
print(f'lite path: {lite_path}')
36+
37+
# Step 2: convert lite
38+
if args.convert_lite:
39+
optimize_dict = {"ascend": "ascend_oriented", "gpu": "gpu_oriented", "cpu": "general"}
40+
converter = mslite.Converter()
41+
converter.save_type = mslite.ModelType.MINDIR
42+
converter.optimize = optimize_dict[args.device.lower()]
43+
44+
print(f"start convert lite")
45+
converter.convert(
46+
fmk_type=mslite.FmkType.MINDIR,
47+
model_file=mindir_path,
48+
output_file=lite_path_wo_ext,
49+
config_file="./configs/export_lite.cfg",
50+
)
51+
print(f"finish convert lite")
52+
53+
54+
if __name__ == '__main__':
55+
parser = argparse.ArgumentParser(
56+
description=(
57+
"Export online ckpt to offline mindir"
58+
)
59+
)
60+
61+
parser.add_argument(
62+
"--model_path_wo_ext",
63+
type=str,
64+
default='./models/',
65+
help=(
66+
"Full path to the directory where the output model is saved, without file extension."
67+
),
68+
)
69+
70+
parser.add_argument(
71+
"--model-type",
72+
type=str,
73+
default='vit_b',
74+
help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']",
75+
)
76+
77+
parser.add_argument(
78+
"--checkpoint",
79+
type=str,
80+
default='models/sam_vit_b-35e4849c.ckpt',
81+
help="online checkpoint file that stores weight",
82+
)
83+
84+
parser.add_argument("--device", type=str, default="Ascend", help="The device to run generation on.")
85+
86+
parser.add_argument(
87+
"--export-mindir",
88+
default=True,
89+
help=(
90+
"Button to enable export mindir."
91+
),
92+
)
93+
94+
parser.add_argument(
95+
"--convert-lite",
96+
default=True,
97+
help=(
98+
"Button to enable convert lite."
99+
),
100+
)
101+
args = parser.parse_args()
102+
print(args)
103+
main(args)

research/segment-anything/inference_one_image.py renamed to research/segment-anything/inference.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,11 @@
99
from segment_anything.build_sam import sam_model_registry
1010
from segment_anything.dataset.transform import TransformPipeline, ImageNorm, ImageResizeAndPad
1111
import matplotlib.pyplot as plt
12-
import time
1312

13+
from segment_anything.utils.utils import Timer
1414
from use_sam_with_promts import show_mask, show_box
1515

1616

17-
class Timer:
18-
def __init__(self, name=''):
19-
self.name = name
20-
self.start = 0.0
21-
self.end = 0.0
22-
23-
def __enter__(self):
24-
self.start = time.time()
25-
26-
def __exit__(self, exc_type, exc_val, exc_tb):
27-
self.end = time.time()
28-
print(f'{self.name} cost time {self.end - self.start:.3f}')
29-
30-
3117
def infer(args):
3218
ms.context.set_context(mode=args.mode, device_target=args.device)
3319

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import argparse
2+
3+
import cv2
4+
import mindspore_lite as mslite
5+
import numpy as np
6+
from matplotlib import pyplot as plt
7+
8+
from segment_anything.dataset.transform import ImageResizeAndPad, ImageNorm, TransformPipeline
9+
from segment_anything.utils.utils import Timer
10+
from use_sam_with_promts import show_box, show_mask
11+
12+
13+
def set_context(device='Ascend', device_id=0):
14+
context = mslite.Context()
15+
context.target = [device.lower()]
16+
if device.lower() == 'ascend':
17+
context.ascend.device_id = device_id
18+
context.ascend.precision_mode = "preferred_fp32" # this line is important for keeping precision
19+
elif device.lower() == 'gpu':
20+
context.gpu.device_id = device_id
21+
else:
22+
raise NotImplementedError
23+
return context
24+
25+
26+
def build_model(lite_mindir_path, context):
27+
print(f'build model from: {lite_mindir_path}')
28+
model = mslite.Model()
29+
model.build_from_file(lite_mindir_path, mslite.ModelType.MINDIR, context)
30+
return model
31+
32+
33+
def infer(args):
34+
# Step0: prepare model
35+
context = set_context(device=args.device, device_id=args.device_id)
36+
model = build_model(args.model_path, context)
37+
38+
# Step1: data preparation
39+
with Timer('preprocess'):
40+
transform_list = [
41+
ImageResizeAndPad(target_size=1024, apply_mask=False),
42+
ImageNorm(),
43+
]
44+
transform_pipeline = TransformPipeline(transform_list)
45+
46+
image_path = args.image_path
47+
image_np = cv2.imread(image_path)
48+
image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
49+
boxes_np = np.array([[425, 600, 700, 875]])
50+
51+
transformed = transform_pipeline(dict(image=image_np, boxes=boxes_np))
52+
image, boxes, origin_hw = transformed['image'], transformed['boxes'], transformed['origin_hw']
53+
# batch_size for speed test
54+
# image = ms.Tensor(np.expand_dims(image, 0).repeat(8, axis=0)) # b, 3, 1023
55+
# boxes = ms.Tensor(np.expand_dims(boxes, 0).repeat(8, axis=0)) # b, n, 4
56+
image = np.expand_dims(image, 0) # b, 3, 1023
57+
boxes = np.expand_dims(boxes, 0) # b, n, 4
58+
59+
inputs = model.get_inputs()
60+
inputs[0].set_data_from_numpy(image.astype(np.float32))
61+
inputs[1].set_data_from_numpy(boxes.astype(np.float32))
62+
63+
64+
# Step2: inference
65+
with Timer('model inference'):
66+
mask_logits = model.predict(inputs)[0] # (1, 1, 1024, 1024)
67+
68+
with Timer('Second time inference'):
69+
mask_logits = model.predict(inputs)[0] # (1, 1, 1024, 1024)
70+
71+
# Step3: post-process
72+
with Timer('post-process'):
73+
mask_logits = mask_logits.get_data_to_numpy()[0, 0] > 0.0 # (1024, 1024)
74+
mask_logits = mask_logits.astype(np.uint8)
75+
final_mask = cv2.resize(mask_logits[:origin_hw[2], :origin_hw[3]], tuple((origin_hw[1], origin_hw[0])),
76+
interpolation=cv2.INTER_CUBIC)
77+
78+
# Step4: visualize
79+
plt.imshow(image_np)
80+
show_box(boxes_np[0], plt.gca())
81+
show_mask(final_mask, plt.gca())
82+
plt.savefig(args.image_path + '_lite_infer.jpg')
83+
plt.show()
84+
85+
86+
if __name__ == '__main__':
87+
parser = argparse.ArgumentParser(description=("Runs inference on one image"))
88+
parser.add_argument("--image_path", type=str, default='./images/truck.jpg', help="Path to an input image.")
89+
parser.add_argument("--model-path", type=str, default='./models/sam_vit_b_lite.mindir', help="mindir model path for lite inference")
90+
parser.add_argument("--device", type=str, default="Ascend", help="The device to run generation on.")
91+
parser.add_argument("--device_id", type=int, default=0, help="The device to run inference on.")
92+
93+
args = parser.parse_args()
94+
print(args)
95+
infer(args)

research/segment-anything/segment_anything/modeling/image_encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def add_decomposed_rel_pos(
358358
r_q = q.reshape(B, q_h, q_w, dim)
359359
dtype = r_q.dtype
360360
# rel_h = ops.einsum("bhwc,hkc->bhwk", r_q, Rh)
361-
rel_h = ops.BatchMatMul(transpose_b=True)(r_q, ops.broadcast_to(ops.unsqueeze(Rh, 0).astype(dtype), (B, -1, -1, -1)))
361+
rel_h = ops.BatchMatMul(transpose_b=True)(r_q, ops.unsqueeze(Rh, 0).astype(dtype).repeat(B, axis=0))
362362
# rel_w = ops.einsum("bhwc,wkc->bhwk", r_q, Rw)
363363
rel_w = ops.mul(ops.unsqueeze(r_q, -2), ops.unsqueeze(ops.unsqueeze(Rw, 0), 0).astype(dtype)).sum(axis=-1)
364364

research/segment-anything/segment_anything/modeling/mask_decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def predict_masks(
114114
"""Predicts masks. See 'forward' for more details."""
115115
# Concatenate output tokens
116116
output_tokens = ops.cat([self.iou_token.embedding_table, self.mask_tokens.embedding_table], axis=0)
117-
output_tokens = output_tokens.unsqueeze(0).broadcast_to((sparse_prompt_embeddings.shape[0], -1, -1))
117+
output_tokens = output_tokens.unsqueeze(0).repeat(sparse_prompt_embeddings.shape[0], axis=0)
118118
tokens = ops.cat((output_tokens, sparse_prompt_embeddings), axis=1)
119119

120120
# Expand per-image data in batch direction to be per-mask

research/segment-anything/segment_anything/modeling/prompt_encoder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ def construct(
167167
dense_embeddings = self._embed_masks(masks)
168168
else:
169169
dense_embeddings = self.no_mask_embed.embedding_table.reshape(1, -1, 1, 1).broadcast_to(
170-
(bs, -1, self.image_embedding_size[0], self.image_embedding_size[1])
170+
(bs, self.no_mask_embed.embedding_table.shape[1],
171+
self.image_embedding_size[0], self.image_embedding_size[1])
171172
)
172173

173174
return sparse_embeddings, dense_embeddings

research/segment-anything/segment_anything/utils/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import time
23
from datetime import datetime
34

45
import mindspore as ms
@@ -137,3 +138,17 @@ def set_directory_and_log(main_device, rank_id, rank_size, work_root, log_level,
137138
hack_list = {'save_dir': save_dir, 'main_device': main_device}
138139
cb.update(hack_list)
139140
return save_dir
141+
142+
143+
class Timer:
144+
def __init__(self, name=''):
145+
self.name = name
146+
self.start = 0.0
147+
self.end = 0.0
148+
149+
def __enter__(self):
150+
self.start = time.time()
151+
152+
def __exit__(self, exc_type, exc_val, exc_tb):
153+
self.end = time.time()
154+
print(f'{self.name} cost time {self.end - self.start:.3f}')

0 commit comments

Comments
 (0)