-
Notifications
You must be signed in to change notification settings - Fork 123
/
Copy pathtrain.py
84 lines (68 loc) · 4.26 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import argparse
import ast
import mindspore as ms
from mindspore import amp
from segment_anything.build_sam import create_model
from segment_anything.dataset.dataset import create_dataloader
from segment_anything.modeling.loss import create_loss_fn
from segment_anything.optim.loss_scaler import create_loss_scaler
from segment_anything.optim.optimizer import create_optimizer
from segment_anything.utils import logger
from segment_anything.utils.callbacks import create_callback
from segment_anything.utils.config import parse_args
from segment_anything.utils.model_wrapper import NetWithLossWrapper, TrainOneStepCellWrapper, SamIterativeSegModel
from segment_anything.utils.utils import set_distributed, set_directory_and_log, update_rank_to_dataloader_config, set_env
def main(args) -> None:
# Step1: initialize environment
set_env(args)
rank_id, rank_size, main_device = set_distributed(args.distributed)
update_rank_to_dataloader_config(rank_id, rank_size, args.train_loader, args.eval_loader, args.callback)
set_directory_and_log(main_device, rank_id, rank_size, args.work_root, args.log_level, args.callback)
logger.info(args.pretty())
# Step2: create dataset
train_dataloader = create_dataloader(args.train_loader)
# create model, load pretrained ckpt, set amp level, also freeze layer if specified
network = create_model(args.network.model)
loss_fn = create_loss_fn(args.network.loss)
network.set_train()
network = amp.auto_mixed_precision(network, args.get('amp_level', 'O0'))
loss_fn = amp.auto_mixed_precision(loss_fn, args.get('amp_level', 'O0'))
# Step3: create optimizer, including learning rate scheduler and group parameter settings
optimizer = create_optimizer(params=network.trainable_params(), args=args.optimizer,
step_per_epoch=train_dataloader.get_dataset_size(),
epoch_size=args.train_loader.epoch_size)
# Step4: wrap model and optimizer for training
with_loss_model = NetWithLossWrapper(network, loss_fn=loss_fn,
input_columns=[args.train_loader.model_column, args.train_loader.loss_column],
all_columns=args.train_loader.dataset.output_column,
)
loss_scaler = create_loss_scaler(args.loss_manager.loss_scaler)
model = TrainOneStepCellWrapper(with_loss_model, optimizer=optimizer, scale_sense=loss_scaler,
drop_overflow_update=args.loss_manager.drop_overflow_update)
# Step5: train model
callbacks = create_callback(args.callback)
if not args.get('iterative_training'):
model = ms.Model(model)
else:
args.loss_manager.loss_scaler.use_amp_scale = True
amp_loss_scaler = create_loss_scaler(args.loss_manager.loss_scaler) # a workaround to use amp scaler
model = SamIterativeSegModel(model, num_iter=args.network.num_iter, mask_only_iter=args.network.mask_only_iter,
loss_scaler=amp_loss_scaler)
model.train(epoch=args.train_loader.epoch_size, train_dataset=train_dataloader, callbacks=callbacks)
if __name__ == "__main__":
parser_config = argparse.ArgumentParser(description="SAM Config", add_help=False)
parser_config.add_argument(
"-c", "--config", type=str, default="configs/coco_box_finetune.yaml",
help="YAML config file specifying default arguments."
)
parser_config.add_argument('-o', '--override-cfg', nargs='+',
help="command line to override configuration in config file."
"For dict, use key=value format, eg: device=False. "
"For nested dict, use '.' to denote hierarchy, eg: optimizer.weight_decay=1e-3."
"For list, use number to denote position, eg: callback.1.interval=100.")
# model arts
parser_config.add_argument("--enable-modelarts", type=ast.literal_eval, default=False)
parser_config.add_argument("--train-url", type=str, default="", help="obs path to output folder")
parser_config.add_argument("--data-url", type=str, default="", help="obs path to dataset folder")
args = parse_args(parser_config)
main(args)