Skip to content

Commit 51abffc

Browse files
author
Mark-ZhouWX
committed
optimize saveckpt callback
1 parent 60fb20d commit 51abffc

File tree

4 files changed

+12
-10
lines changed

4 files changed

+12
-10
lines changed

research/segment-anything/README.md

+7-7
Original file line numberDiff line numberDiff line change
@@ -79,17 +79,17 @@ See `python use_sam_with_amg.py --help` to explore more custom settings.
7979
## Finetune
8080

8181
Finetune is a popular method that adapts large pretrained model to specific downstream tasks. Currently, finetune with box-prompt are supported. The bounding boxes are used as prompt input to predict mask.
82-
Beside fine-tuning our code on COCO2017 dataset which contains common seen objects and lies in the similar distribution of the original [training dataset](https://segment-anything.com/dataset/index.html)) of SAM, We have done further experiments on a medical imaging segmentation dataset [FLARE22](https://flare22.grand-challenge.org/Dataset/). Result shows that the finetune method in this repository is effective.
82+
Beside fine-tuning our code on COCO2017 dataset which contains common seen objects and lies in the similar distribution of the original [training dataset](https://segment-anything.com/dataset/index.html) of SAM, We have done further experiments on a medical imaging segmentation dataset [FLARE22](https://flare22.grand-challenge.org/Dataset/). Result shows that the finetune method in this repository is effective.
8383

8484
The bellowing shows the mask quality before and after finetune.
8585

8686

87-
| pretrained_model | dataset | epochs | mIOU |
88-
| :--------------: | -------- | :-----------: | ---- |
89-
| sam-vit-b | COCO2017 | 0 (zero-shot) | 77.4 |
90-
| sam-vit-b | COCO2017 | 20 | 83.6 |
91-
| sam-vit-b | FLARE22 | 0 (zero-shot) | 79.5 |
92-
| sam-vit-b | FLARE22 | 10 | 88.1 |
87+
| pretrained_model | dataset | epochs | mIOU |
88+
|:----------------:| -------- |:-------------:|------|
89+
| sam-vit-b | COCO2017 | 0 (zero-shot) | 77.4 |
90+
| sam-vit-b | COCO2017 | 20 | 83.5 |
91+
| sam-vit-b | FLARE22 | 0 (zero-shot) | 79.5 |
92+
| sam-vit-b | FLARE22 | 10 | 88.1 |
9393

9494
To finetune COCO dataset, please run:
9595

research/segment-anything/configs/coco_box_finetune.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ eval_metric: &eval_metric
100100
callback:
101101
- type: segment_anything.utils.callbacks.TrainStatusLog
102102
loss_item: ['focal_loss', 'dice_loss', 'mse_loss'] # for log
103+
interval: 100
103104
- type: segment_anything.utils.callbacks.SaveCkpt
104105
work_root: *work_root
105106
interval: 1 # in epoch

research/segment-anything/configs/flare_box_finetune.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ eval_metric: &eval_metric
9595
callback:
9696
- type: segment_anything.utils.callbacks.TrainStatusLog
9797
loss_item: ['focal_loss', 'dice_loss', 'mse_loss'] # for log
98-
interval: 20
98+
interval: 100
9999
- type: segment_anything.utils.callbacks.SaveCkpt
100100
work_root: *work_root
101101
interval: 1 # in epoch

research/segment-anything/segment_anything/utils/callbacks.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class TrainStatusLog(Callback):
8383
"""
8484
Callback to record the status of training, mainly including loss and time performance information.
8585
"""
86-
def __init__(self, interval=10, loss_item=()):
86+
def __init__(self, interval=100, loss_item=()):
8787
self.log_interval = interval
8888
self.loss_item = loss_item
8989
self.step_start_time = 0.0
@@ -166,7 +166,8 @@ def __init__(self, interval=1, work_root='./work_root', save_dir='', main_device
166166
def on_train_epoch_end(self, run_context: RunContext):
167167
cb_params = run_context.original_args()
168168
cur_epoch = cb_params.cur_epoch_num
169-
if self.main_device and cur_epoch % self.interval == 0:
169+
total_epoch_num = cb_params.epoch_num
170+
if self.main_device and (cur_epoch % self.interval == 0 or cur_epoch == total_epoch_num):
170171
save_path = os.path.join(self.full_save_dir, f'sam_{cur_epoch:03d}.ckpt')
171172
logger.info(f'saving ckpt of epoch {cur_epoch} at {save_path}, interval is {self.interval}')
172173
# model without loss function, cb_params.network is train_one_step_cell

0 commit comments

Comments
 (0)