diff --git a/mmfashion/apis/train_retriever.py b/mmfashion/apis/train_retriever.py index 22fb2c754..480ae3d84 100644 --- a/mmfashion/apis/train_retriever.py +++ b/mmfashion/apis/train_retriever.py @@ -1,4 +1,5 @@ from __future__ import division + from collections import OrderedDict import torch @@ -49,17 +50,17 @@ def train_retriever(model, # start training predictor if distributed: # to do - _dist_train(model, dataset, cfg, validate=validate) + _dist_train(model, dataset, cfg, validate=validate, logger=logger) else: - _non_dist_train(model, dataset, cfg, validate=validate) + _non_dist_train(model, dataset, cfg, validate=validate, logger=logger) -def _dist_train(model, dataset, cfg, validate=False): +def _dist_train(model, dataset, cfg, validate=False, logger=None): """ not implemented yet """ raise NotImplementedError -def _non_dist_train(model, dataset, cfg, validate=False): +def _non_dist_train(model, dataset, cfg, validate=False, logger=None): # prepare data loaders data_loaders = [ build_dataloader( @@ -76,7 +77,7 @@ def _non_dist_train(model, dataset, cfg, validate=False): optimizer = build_optimizer(model, cfg.optimizer) runner = Runner(model, batch_processor, optimizer, cfg.work_dir, - cfg.log_level) + logger) runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config, cfg.checkpoint_config, cfg.log_config)