Skip to content

Commit 3ffde59

Browse files
author
Mark-ZhouWX
committed
add SAM training and inference model
1 parent 7547ce8 commit 3ffde59

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+6091
-0
lines changed

research/segment-anything/README.md

+129
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Segment Anything
2+
3+
The **Segment Anything Model (SAM)** produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a [dataset](https://segment-anything.com/dataset/index.html) of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks.
4+
5+
## Installation
6+
7+
The code requires `python>=3.7` and `mindspore>=2.0` and supports both GPU and Ascend platform. Please follow the instructions [here](https://www.mindspore.cn/install) to install mindspore dependencies.
8+
9+
Clone the repository locally and install with
10+
11+
```shell
12+
git clone https://github.com/mindspore-lab/models.git
13+
cd research/segment-anything
14+
pip install -r requirements.txt
15+
```
16+
17+
## Inference
18+
19+
First download the weights ([sam_vit_b](sam_vit_b-35e4849c.ckpt), [sam_vit_l](sam_vit_l-1b460f38.ckpt), [sam_vit_h](sam_vit_h-c72f8ba1.ckpt)) and put them under `${project_root}/models` directory.
20+
There are two recommended ways to use sam.
21+
22+
### Using sam with prompts
23+
24+
SAM predicts object masks given prompts that indicate the desired object. if a point prompt is given, three plausible masks are generated.
25+
26+
```shell
27+
python use_sam_with_promts.py --prompt-type point --model-type vit_h
28+
```
29+
30+
<p float="left">
31+
<img src=images/truck_mask1.png width="400"/><img src=images/truck_mask2.png width="400"/><img src=images/truck_mask3.png width="400"/>
32+
</p>
33+
34+
If a prompt with two points is given, one plausible mask is generated instead of 3 because of less ambiguity compared to one point prompt.
35+
The star in green and red denotes positive and negtive point, respectively.
36+
37+
<div align="center">
38+
<img alt="img.png" src="images/truck_two_point.png" width="600"/>
39+
</div>
40+
41+
If a box prompt is given, one plausible masks is generated.
42+
43+
```shell
44+
python use_sam_with_promts.py --prompt-type box --model-type vit_h
45+
```
46+
47+
<div align="center">
48+
<img alt="img.png" width="600" src="images/truck_box.png"/>
49+
</div>
50+
51+
If a prompt with both a box and a point is given, one plausible mask is generated.
52+
53+
```shell
54+
python use_sam_with_promts.py --prompt-type point_box --model-type vit_h
55+
```
56+
57+
<div align="center">
58+
<img alt="img.png" width="600" src="images/truck_point_box.png"/>
59+
</div>
60+
61+
See `python use_sam_with_promts.py --help` to explore more custom settings.
62+
63+
### Using sam with Automatic Mask Generation(AMG)
64+
65+
Since SAM can efficiently process prompts, masks for the entire image can be generated by sampling a large number of prompts over an image. AMG works by sampling single-point input prompts in a grid over the image, from each of which SAM can predict multiple masks. Then, masks are filtered for quality and deduplicated using non-maximal suppression. Additional options allow for further improvement of mask quality and quantity, such as running prediction on multiple crops of the image or postprocessing masks to remove small disconnected regions and holes.
66+
67+
```shell
68+
python use_sam_with_amg.py --model-type vit_h
69+
```
70+
71+
<div align="center">
72+
<img src="images/dengta.jpg" height="350" />
73+
    
74+
<img src="images/dengta-amg-vith.png" height="350" />
75+
</div>
76+
77+
See `python use_sam_with_amg.py --help` to explore more custom settings.
78+
79+
## Finetune
80+
81+
Finetune is a popular method that adapts large pretrained model to specific downstream tasks. Currently, finetune with box-prompt are supported. The bounding boxes are used as prompt input to predict mask.
82+
Beside fine-tuning our code on COCO2017 dataset which contains common seen objects and lies in the similar distribution of the original [training dataset](https://segment-anything.com/dataset/index.html) of SAM, We have done further experiments on a medical imaging segmentation dataset [FLARE22](https://flare22.grand-challenge.org/Dataset/). Result shows that the finetune method in this repository is effective.
83+
84+
The bellowing shows the mask quality before and after finetune.
85+
86+
87+
| pretrained_model | dataset | epochs | mIOU |
88+
|:----------------:| -------- |:-------------:|------|
89+
| sam-vit-b | COCO2017 | 0 (zero-shot) | 77.4 |
90+
| sam-vit-b | COCO2017 | 20 | 83.5 |
91+
| sam-vit-b | FLARE22 | 0 (zero-shot) | 79.5 |
92+
| sam-vit-b | FLARE22 | 10 | 88.1 |
93+
94+
To finetune COCO dataset, please run:
95+
96+
```shell
97+
mpirun --allow-run-as-root -n 8 python train.py -c configs/coco_box_finetune.yaml
98+
```
99+
100+
The original FLARE22 dataset contains image in 3D format and ground truth labelled as instance segmentation ids. Run
101+
102+
```shell
103+
python scripts/preprocess_CT_MR_dataset.py
104+
```
105+
106+
to preprocess it to the format of 2D RGB image and binary mask
107+
108+
To finetune FLARE22 dataset, please run:
109+
110+
```shell
111+
mpirun --allow-run-as-root -n 8 python train.py -c configs/flare_box_finetune.yaml
112+
```
113+
114+
Here are the examples of segmentation result predicted by fine-tuned SAM:
115+
116+
<div align="center">
117+
<img src="images/coco_bear.jpg" height="350" />
118+
  
119+
<img src="images/flare_organ.jpg" height="350" />
120+
</div>
121+
122+
<p align="center">
123+
<em> COCO2017 image example</em>
124+
                      
125+
                      
126+
<em> FLARE22 image example </em>
127+
</p>
128+
129+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
#---------------------------------------------
2+
# Part 1: system basic config setting
3+
distributed: False
4+
device: Ascend
5+
mode: 0 # 0: graph, 1: pynative
6+
work_root: &work_root ./work_dir/
7+
log_level: info
8+
amp_level: O2
9+
10+
# ---------------------------------------------
11+
# Part2: module setting
12+
loss_manager:
13+
# type: fixed # dynamic or
14+
# scale_sense: 1024
15+
loss_scaler:
16+
type: dynamic
17+
grad_clip: False
18+
drop_overflow_update: False
19+
20+
optimizer:
21+
type: segment_anything.optim.optimizer.AdamW
22+
weight_decay: 1e-4
23+
group_param:
24+
25+
lr_scheduler:
26+
type: segment_anything.optim.scheduler.SAMDynamicDecayLR
27+
learning_rate: 8e-6
28+
warmup_steps: 250
29+
decay_steps: [ 60000, 86666 ]
30+
decay_factor: 10
31+
32+
33+
network:
34+
model:
35+
type: vit_b
36+
checkpoint: ./models/sam_vit_b-35e4849c.ckpt
37+
freeze:
38+
image_encoder: True
39+
prompt_encoder: True
40+
41+
loss:
42+
type: segment_anything.modeling.loss.SAMLoss
43+
44+
45+
train_loader:
46+
dataset:
47+
type: segment_anything.dataset.dataset.COCODataset
48+
data_dir: ./datasets/coco2017/train2017
49+
annotation_path: ./datasets/coco2017/annotations/instances_train2017.json
50+
transform_pipeline:
51+
- type: segment_anything.dataset.transform.ImageResizeAndPad
52+
target_size: 1024
53+
- type: segment_anything.dataset.transform.ImageNorm
54+
hwc2chw: True
55+
- type: segment_anything.dataset.transform.LabelPad
56+
gt_size: 20
57+
output_column: ['image', 'masks', 'boxes', 'valid_boxes']
58+
59+
model_column: ['image', 'boxes'] # columns for model cell input
60+
loss_column: ['masks', 'valid_boxes'] # columns for loss function input
61+
62+
shuffle: True
63+
batch_size: 1
64+
epoch_size: 20
65+
drop_remainder: True
66+
num_workers: 2
67+
max_rowsize: 24 # 24M space for dataloader
68+
69+
70+
eval_loader: &eval_loader
71+
dataset:
72+
type: segment_anything.dataset.dataset.COCODataset
73+
data_dir: ./datasets/coco2017/val2017
74+
annotation_path: ./datasets/coco2017/annotations/instances_val2017.json
75+
transform_pipeline:
76+
- type: segment_anything.dataset.transform.ImageResizeAndPad
77+
target_size: 1024
78+
- type: segment_anything.dataset.transform.ImageNorm
79+
hwc2chw: True
80+
- type: segment_anything.dataset.transform.LabelPad
81+
gt_size: 20
82+
output_column: ['image', 'masks', 'boxes', 'valid_boxes', 'origin_hw']
83+
84+
model_column: &model_column [ 'image', 'boxes' ] # columns for model cell input
85+
eval_column: &eval_column [ 'masks', 'valid_boxes', 'origin_hw'] # columns for evaluation, usually for metric calculation or visualization
86+
87+
shuffle: True
88+
batch_size: 1
89+
drop_remainder: False
90+
num_workers: 1
91+
max_rowsize: 36 # 36M space for dataloader, increase with gt_size
92+
max_eval_iter: null # the max iteration to eval, default to eval all the dataset
93+
94+
95+
eval_metric: &eval_metric
96+
- type: segment_anything.evaluate.metrics.MaskMiou
97+
# - type: MaskVisualization
98+
99+
100+
callback:
101+
- type: segment_anything.utils.callbacks.TrainStatusLog
102+
loss_item: ['focal_loss', 'dice_loss', 'mse_loss'] # for log
103+
interval: 100
104+
- type: segment_anything.utils.callbacks.SaveCkpt
105+
work_root: *work_root
106+
interval: 1 # in epoch
107+
- type: segment_anything.utils.callbacks.EvalWhileTrain
108+
data_loader: *eval_loader
109+
metric: *eval_metric
110+
input_column:
111+
- *model_column
112+
- *eval_column
113+
interval: 1 # in epoch
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
#---------------------------------------------
2+
# Part 1: system basic config setting
3+
distributed: False
4+
device: Ascend
5+
mode: 0 # 0: graph, 1: pynative
6+
work_root: &work_root ./work_dir/
7+
log_level: info
8+
amp_level: O2
9+
10+
# ---------------------------------------------
11+
# Part2: module setting
12+
loss_manager:
13+
# type: fixed # dynamic or
14+
# scale_sense: 1024
15+
loss_scaler:
16+
type: dynamic
17+
grad_clip: False
18+
drop_overflow_update: False
19+
20+
21+
optimizer:
22+
type: segment_anything.optim.optimizer.AdamW
23+
weight_decay: 1e-4
24+
group_param:
25+
26+
lr_scheduler:
27+
type: segment_anything.optim.scheduler.SAMDynamicDecayLR
28+
learning_rate: 8e-6
29+
warmup_steps: 250
30+
decay_steps: [ 60000, 86666 ]
31+
decay_factor: 10
32+
33+
34+
network:
35+
model:
36+
type: vit_b
37+
checkpoint: ./models/sam_vit_b-35e4849c.ckpt
38+
freeze:
39+
image_encoder: True
40+
prompt_encoder: True
41+
42+
loss:
43+
type: segment_anything.modeling.loss.SAMLoss
44+
45+
46+
train_loader:
47+
dataset:
48+
type: segment_anything.dataset.dataset.FLAREDataset
49+
data_dir: ./datasets/FLARE22Train_processed/train/
50+
transform_pipeline:
51+
- type: segment_anything.dataset.transform.BinaryMaskFromInstanceSeg
52+
- type: segment_anything.dataset.transform.BoxFormMask
53+
- type: segment_anything.dataset.transform.LabelPad
54+
gt_size: 20
55+
output_column: ['image', 'masks', 'boxes', 'valid_boxes' ]
56+
57+
model_column: ['image', 'boxes' ] # columns for model cell input
58+
loss_column: ['masks', 'valid_boxes' ] # columns for loss function input
59+
60+
shuffle: True
61+
batch_size: 1
62+
epoch_size: 20
63+
drop_remainder: True
64+
num_workers: 2
65+
max_rowsize: 64 # 24M space for dataloader
66+
67+
68+
eval_loader: &eval_loader
69+
dataset:
70+
type: segment_anything.dataset.dataset.FLAREDataset
71+
data_dir: ./datasets/FLARE22Train_processed/val/
72+
transform_pipeline:
73+
- type: segment_anything.dataset.transform.BinaryMaskFromInstanceSeg
74+
- type: segment_anything.dataset.transform.BoxFormMask
75+
- type: segment_anything.dataset.transform.LabelPad
76+
gt_size: 20
77+
output_column: ['image', 'masks', 'boxes', 'valid_boxes']
78+
79+
model_column: &model_column [ 'image', 'boxes' ] # columns for model cell input
80+
eval_column: &eval_column [ 'masks', 'valid_boxes' ] # columns for evaluation, usually for metric calculation or visualization
81+
82+
shuffle: True
83+
batch_size: 1
84+
drop_remainder: False
85+
num_workers: 1
86+
max_rowsize: 64 # 36M space for dataloader, increase with gt_size
87+
max_eval_iter: null # the max iteration to eval, default to eval all the dataset
88+
89+
90+
eval_metric: &eval_metric
91+
- type: segment_anything.evaluate.metrics.MaskMiou
92+
# - type: MaskVisualization
93+
94+
95+
callback:
96+
- type: segment_anything.utils.callbacks.TrainStatusLog
97+
loss_item: ['focal_loss', 'dice_loss', 'mse_loss'] # for log
98+
interval: 100
99+
- type: segment_anything.utils.callbacks.SaveCkpt
100+
work_root: *work_root
101+
interval: 1 # in epoch
102+
- type: segment_anything.utils.callbacks.EvalWhileTrain
103+
data_loader: *eval_loader
104+
metric: *eval_metric
105+
input_column:
106+
- *model_column
107+
- *eval_column
108+
interval: 1 # in epoch

0 commit comments

Comments
 (0)