|
| 1 | +import argparse |
| 2 | + |
| 3 | +import cv2 |
| 4 | +import mindspore_lite as mslite |
| 5 | +import numpy as np |
| 6 | +from matplotlib import pyplot as plt |
| 7 | + |
| 8 | +from segment_anything.dataset.transform import ImageResizeAndPad, ImageNorm, TransformPipeline |
| 9 | +from segment_anything.utils.utils import Timer |
| 10 | +from use_sam_with_promts import show_box, show_mask |
| 11 | + |
| 12 | + |
| 13 | +def set_context(device='Ascend', device_id=0): |
| 14 | + context = mslite.Context() |
| 15 | + context.target = [device.lower()] |
| 16 | + if device.lower() == 'ascend': |
| 17 | + context.ascend.device_id = device_id |
| 18 | + context.ascend.precision_mode = "preferred_fp32" # this line is important for keeping precision |
| 19 | + elif device.lower() == 'gpu': |
| 20 | + context.gpu.device_id = device_id |
| 21 | + else: |
| 22 | + raise NotImplementedError |
| 23 | + return context |
| 24 | + |
| 25 | + |
| 26 | +def build_model(lite_mindir_path, context): |
| 27 | + print(f'build model from: {lite_mindir_path}') |
| 28 | + model = mslite.Model() |
| 29 | + model.build_from_file(lite_mindir_path, mslite.ModelType.MINDIR, context) |
| 30 | + return model |
| 31 | + |
| 32 | + |
| 33 | +def infer(args): |
| 34 | + # Step0: prepare model |
| 35 | + context = set_context(device=args.device, device_id=args.device_id) |
| 36 | + model = build_model(args.model_path, context) |
| 37 | + |
| 38 | + # Step1: data preparation |
| 39 | + with Timer('preprocess'): |
| 40 | + transform_list = [ |
| 41 | + ImageResizeAndPad(target_size=1024, apply_mask=False), |
| 42 | + ImageNorm(), |
| 43 | + ] |
| 44 | + transform_pipeline = TransformPipeline(transform_list) |
| 45 | + |
| 46 | + image_path = args.image_path |
| 47 | + image_np = cv2.imread(image_path) |
| 48 | + image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) |
| 49 | + boxes_np = np.array([[425, 600, 700, 875]]) |
| 50 | + |
| 51 | + transformed = transform_pipeline(dict(image=image_np, boxes=boxes_np)) |
| 52 | + image, boxes, origin_hw = transformed['image'], transformed['boxes'], transformed['origin_hw'] |
| 53 | + # batch_size for speed test |
| 54 | + # image = ms.Tensor(np.expand_dims(image, 0).repeat(8, axis=0)) # b, 3, 1023 |
| 55 | + # boxes = ms.Tensor(np.expand_dims(boxes, 0).repeat(8, axis=0)) # b, n, 4 |
| 56 | + image = np.expand_dims(image, 0) # b, 3, 1023 |
| 57 | + boxes = np.expand_dims(boxes, 0) # b, n, 4 |
| 58 | + |
| 59 | + inputs = model.get_inputs() |
| 60 | + inputs[0].set_data_from_numpy(image.astype(np.float32)) |
| 61 | + inputs[1].set_data_from_numpy(boxes.astype(np.float32)) |
| 62 | + |
| 63 | + |
| 64 | + # Step2: inference |
| 65 | + with Timer('model inference'): |
| 66 | + mask_logits = model.predict(inputs)[0] # (1, 1, 1024, 1024) |
| 67 | + |
| 68 | + with Timer('Second time inference'): |
| 69 | + mask_logits = model.predict(inputs)[0] # (1, 1, 1024, 1024) |
| 70 | + |
| 71 | + # Step3: post-process |
| 72 | + with Timer('post-process'): |
| 73 | + print(f'mask_logits', mask_logits) |
| 74 | + mask_logits = mask_logits.get_data_to_numpy()[0, 0] > 0.0 # (1024, 1024) |
| 75 | + mask_logits = mask_logits.astype(np.uint8) |
| 76 | + final_mask = cv2.resize(mask_logits[:origin_hw[2], :origin_hw[3]], tuple((origin_hw[1], origin_hw[0])), |
| 77 | + interpolation=cv2.INTER_CUBIC) |
| 78 | + |
| 79 | + # Step4: visualize |
| 80 | + plt.imshow(image_np) |
| 81 | + show_box(boxes_np[0], plt.gca()) |
| 82 | + show_mask(final_mask, plt.gca()) |
| 83 | + plt.savefig(args.image_path + '_lite_infer.jpg') |
| 84 | + plt.show() |
| 85 | + |
| 86 | + |
| 87 | +if __name__ == '__main__': |
| 88 | + parser = argparse.ArgumentParser(description=("Runs inference on one image")) |
| 89 | + parser.add_argument("--image_path", type=str, default='./images/truck.jpg', help="Path to an input image.") |
| 90 | + parser.add_argument("--model-path", type=str, default='./models/sam_vit_b_lite.mindir', help="mindir model path for lite inference") |
| 91 | + parser.add_argument("--device", type=str, default="Ascend", help="The device to run generation on.") |
| 92 | + parser.add_argument("--device_id", type=int, default=0, help="The device to run inference on.") |
| 93 | + |
| 94 | + args = parser.parse_args() |
| 95 | + print(args) |
| 96 | + infer(args) |
0 commit comments