Skip to content

Commit 6106754

Browse files
author
Mark-ZhouWX
committed
add lite inference
1 parent bb4a929 commit 6106754

File tree

8 files changed

+220
-18
lines changed

8 files changed

+220
-18
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[acl_init_options]
2+
ge.exec.precision_mode="allow_fp32_to_fp16"

research/segment-anything/export.py

+102
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import argparse
2+
import os
3+
import mindspore as ms
4+
from mindspore import ops
5+
6+
from segment_anything import sam_model_registry
7+
8+
9+
def main(args):
10+
# Step0: prepare
11+
os.makedirs(args.model_path_wo_ext, exist_ok=True)
12+
model_path_wo_ext = os.path.join(args.model_path_wo_ext, 'sam_' + args.model_type)
13+
mindir_path = os.path.join(model_path_wo_ext + '.mindir')
14+
lite_path_wo_ext = os.path.join(model_path_wo_ext + f"_lite")
15+
lite_path = os.path.join(model_path_wo_ext + f"_lite.mindir")
16+
# model
17+
model = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
18+
19+
# Step 1: export mindir
20+
if args.export_mindir:
21+
# input
22+
image = ops.ones(shape=(1, 3, 1024, 1024), dtype=ms.float32) # b, 3, 1024, 1024
23+
boxes = ops.ones(shape=(1, 1, 4), dtype=ms.float32) # b, n, 4
24+
inputs = (image, boxes)
25+
model.set_inputs(*inputs)
26+
print(f"start export mindir")
27+
ms.export(model, *inputs, file_name=model_path_wo_ext, file_format="MINDIR")
28+
print(f"finish export mindir")
29+
30+
print(f'mind ir path: {mindir_path}')
31+
print(f'lite path wo_ext: {lite_path_wo_ext}')
32+
print(f'lite path: {lite_path}')
33+
34+
# Step 2: convert lite
35+
if args.convert_lite:
36+
import mindspore_lite as mslite
37+
optimize_dict = {"ascend": "ascend_oriented", "gpu": "gpu_oriented", "cpu": "general"}
38+
converter = mslite.Converter()
39+
converter.save_type = mslite.ModelType.MINDIR
40+
converter.optimize = optimize_dict[args.device.lower()]
41+
42+
print(f"start convert lite")
43+
converter.convert(
44+
fmk_type=mslite.FmkType.MINDIR,
45+
model_file=mindir_path,
46+
output_file=lite_path_wo_ext,
47+
config_file="./configs/export_lite.cfg",
48+
)
49+
print(converter)
50+
print(f"finish convert lite")
51+
52+
53+
if __name__ == '__main__':
54+
parser = argparse.ArgumentParser(
55+
description=(
56+
"Export online ckpt to offline mindir"
57+
)
58+
)
59+
60+
parser.add_argument(
61+
"--model_path_wo_ext",
62+
type=str,
63+
default='./models/',
64+
help=(
65+
"Full path to the directory where the output model is saved, without file extension."
66+
),
67+
)
68+
69+
parser.add_argument(
70+
"--model-type",
71+
type=str,
72+
default='vit_b',
73+
help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']",
74+
)
75+
76+
parser.add_argument(
77+
"--checkpoint",
78+
type=str,
79+
default='models/sam_vit_b-35e4849c.ckpt',
80+
help="online checkpoint file that stores weight",
81+
)
82+
83+
parser.add_argument("--device", type=str, default="Ascend", help="The device to run generation on.")
84+
85+
parser.add_argument(
86+
"--export-mindir",
87+
default=True,
88+
help=(
89+
"Button to enable export mindir."
90+
),
91+
)
92+
93+
parser.add_argument(
94+
"--convert-lite",
95+
default=True,
96+
help=(
97+
"Button to enable convert lite."
98+
),
99+
)
100+
args = parser.parse_args()
101+
print(args)
102+
main(args)

research/segment-anything/inference_one_image.py renamed to research/segment-anything/inference.py

+1-15
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,11 @@
99
from segment_anything.build_sam import sam_model_registry
1010
from segment_anything.dataset.transform import TransformPipeline, ImageNorm, ImageResizeAndPad
1111
import matplotlib.pyplot as plt
12-
import time
1312

13+
from segment_anything.utils.utils import Timer
1414
from use_sam_with_promts import show_mask, show_box
1515

1616

17-
class Timer:
18-
def __init__(self, name=''):
19-
self.name = name
20-
self.start = 0.0
21-
self.end = 0.0
22-
23-
def __enter__(self):
24-
self.start = time.time()
25-
26-
def __exit__(self, exc_type, exc_val, exc_tb):
27-
self.end = time.time()
28-
print(f'{self.name} cost time {self.end - self.start:.3f}')
29-
30-
3117
def infer(args):
3218
ms.context.set_context(mode=args.mode, device_target=args.device)
3319

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import argparse
2+
3+
import cv2
4+
import mindspore_lite as mslite
5+
import numpy as np
6+
from matplotlib import pyplot as plt
7+
8+
from segment_anything.dataset.transform import ImageResizeAndPad, ImageNorm, TransformPipeline
9+
from segment_anything.utils.utils import Timer
10+
from use_sam_with_promts import show_box, show_mask
11+
12+
13+
def set_context(device='Ascend', device_id=0):
14+
context = mslite.Context()
15+
context.target = [device.lower()]
16+
if device.lower() == 'ascend':
17+
context.ascend.device_id = device_id
18+
context.ascend.precision_mode = "preferred_fp32" # this line is important for keeping precision
19+
elif device.lower() == 'gpu':
20+
context.gpu.device_id = device_id
21+
else:
22+
raise NotImplementedError
23+
return context
24+
25+
26+
def build_model(lite_mindir_path, context):
27+
print(f'build model from: {lite_mindir_path}')
28+
model = mslite.Model()
29+
model.build_from_file(lite_mindir_path, mslite.ModelType.MINDIR, context)
30+
return model
31+
32+
33+
def infer(args):
34+
# Step0: prepare model
35+
context = set_context(device=args.device, device_id=args.device_id)
36+
model = build_model(args.model_path, context)
37+
38+
# Step1: data preparation
39+
with Timer('preprocess'):
40+
transform_list = [
41+
ImageResizeAndPad(target_size=1024, apply_mask=False),
42+
ImageNorm(),
43+
]
44+
transform_pipeline = TransformPipeline(transform_list)
45+
46+
image_path = args.image_path
47+
image_np = cv2.imread(image_path)
48+
image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
49+
boxes_np = np.array([[425, 600, 700, 875]])
50+
51+
transformed = transform_pipeline(dict(image=image_np, boxes=boxes_np))
52+
image, boxes, origin_hw = transformed['image'], transformed['boxes'], transformed['origin_hw']
53+
# batch_size for speed test
54+
# image = ms.Tensor(np.expand_dims(image, 0).repeat(8, axis=0)) # b, 3, 1023
55+
# boxes = ms.Tensor(np.expand_dims(boxes, 0).repeat(8, axis=0)) # b, n, 4
56+
image = np.expand_dims(image, 0) # b, 3, 1023
57+
boxes = np.expand_dims(boxes, 0) # b, n, 4
58+
59+
inputs = model.get_inputs()
60+
inputs[0].set_data_from_numpy(image.astype(np.float32))
61+
inputs[1].set_data_from_numpy(boxes.astype(np.float32))
62+
63+
64+
# Step2: inference
65+
with Timer('model inference'):
66+
mask_logits = model.predict(inputs)[0] # (1, 1, 1024, 1024)
67+
68+
with Timer('Second time inference'):
69+
mask_logits = model.predict(inputs)[0] # (1, 1, 1024, 1024)
70+
71+
# Step3: post-process
72+
with Timer('post-process'):
73+
print(f'mask_logits', mask_logits)
74+
mask_logits = mask_logits.get_data_to_numpy()[0, 0] > 0.0 # (1024, 1024)
75+
mask_logits = mask_logits.astype(np.uint8)
76+
final_mask = cv2.resize(mask_logits[:origin_hw[2], :origin_hw[3]], tuple((origin_hw[1], origin_hw[0])),
77+
interpolation=cv2.INTER_CUBIC)
78+
79+
# Step4: visualize
80+
plt.imshow(image_np)
81+
show_box(boxes_np[0], plt.gca())
82+
show_mask(final_mask, plt.gca())
83+
plt.savefig(args.image_path + '_lite_infer.jpg')
84+
plt.show()
85+
86+
87+
if __name__ == '__main__':
88+
parser = argparse.ArgumentParser(description=("Runs inference on one image"))
89+
parser.add_argument("--image_path", type=str, default='./images/truck.jpg', help="Path to an input image.")
90+
parser.add_argument("--model-path", type=str, default='./models/sam_vit_b_lite.mindir', help="mindir model path for lite inference")
91+
parser.add_argument("--device", type=str, default="Ascend", help="The device to run generation on.")
92+
parser.add_argument("--device_id", type=int, default=0, help="The device to run inference on.")
93+
94+
args = parser.parse_args()
95+
print(args)
96+
infer(args)

research/segment-anything/segment_anything/modeling/image_encoder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def add_decomposed_rel_pos(
358358
r_q = q.reshape(B, q_h, q_w, dim)
359359
dtype = r_q.dtype
360360
# rel_h = ops.einsum("bhwc,hkc->bhwk", r_q, Rh)
361-
rel_h = ops.BatchMatMul(transpose_b=True)(r_q, ops.broadcast_to(ops.unsqueeze(Rh, 0).astype(dtype), (B, -1, -1, -1)))
361+
rel_h = ops.BatchMatMul(transpose_b=True)(r_q, ops.unsqueeze(Rh, 0).astype(dtype).repeat(B, axis=0))
362362
# rel_w = ops.einsum("bhwc,wkc->bhwk", r_q, Rw)
363363
rel_w = ops.mul(ops.unsqueeze(r_q, -2), ops.unsqueeze(ops.unsqueeze(Rw, 0), 0).astype(dtype)).sum(axis=-1)
364364

research/segment-anything/segment_anything/modeling/mask_decoder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def predict_masks(
114114
"""Predicts masks. See 'forward' for more details."""
115115
# Concatenate output tokens
116116
output_tokens = ops.cat([self.iou_token.embedding_table, self.mask_tokens.embedding_table], axis=0)
117-
output_tokens = output_tokens.unsqueeze(0).broadcast_to((sparse_prompt_embeddings.shape[0], -1, -1))
117+
output_tokens = output_tokens.unsqueeze(0).repeat(sparse_prompt_embeddings.shape[0], axis=0)
118118
tokens = ops.cat((output_tokens, sparse_prompt_embeddings), axis=1)
119119

120120
# Expand per-image data in batch direction to be per-mask

research/segment-anything/segment_anything/modeling/prompt_encoder.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ def construct(
167167
dense_embeddings = self._embed_masks(masks)
168168
else:
169169
dense_embeddings = self.no_mask_embed.embedding_table.reshape(1, -1, 1, 1).broadcast_to(
170-
(bs, -1, self.image_embedding_size[0], self.image_embedding_size[1])
170+
(bs, self.no_mask_embed.embedding_table.shape[1],
171+
self.image_embedding_size[0], self.image_embedding_size[1])
171172
)
172173

173174
return sparse_embeddings, dense_embeddings

research/segment-anything/segment_anything/utils/utils.py

+15
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import time
23
from datetime import datetime
34

45
import mindspore as ms
@@ -137,3 +138,17 @@ def set_directory_and_log(main_device, rank_id, rank_size, work_root, log_level,
137138
hack_list = {'save_dir': save_dir, 'main_device': main_device}
138139
cb.update(hack_list)
139140
return save_dir
141+
142+
143+
class Timer:
144+
def __init__(self, name=''):
145+
self.name = name
146+
self.start = 0.0
147+
self.end = 0.0
148+
149+
def __enter__(self):
150+
self.start = time.time()
151+
152+
def __exit__(self, exc_type, exc_val, exc_tb):
153+
self.end = time.time()
154+
print(f'{self.name} cost time {self.end - self.start:.3f}')

0 commit comments

Comments
 (0)