From 007d6b647ccda8aead590fd24ee17b18ccbcea57 Mon Sep 17 00:00:00 2001 From: Leonardo Rossi Date: Thu, 27 Aug 2020 11:31:50 +0200 Subject: [PATCH] train_retriever: support new version of mmcv Signed-off-by: Leonardo Rossi --- mmfashion/apis/train_retriever.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mmfashion/apis/train_retriever.py b/mmfashion/apis/train_retriever.py index 22fb2c754..480ae3d84 100644 --- a/mmfashion/apis/train_retriever.py +++ b/mmfashion/apis/train_retriever.py @@ -1,4 +1,5 @@ from __future__ import division + from collections import OrderedDict import torch @@ -49,17 +50,17 @@ def train_retriever(model, # start training predictor if distributed: # to do - _dist_train(model, dataset, cfg, validate=validate) + _dist_train(model, dataset, cfg, validate=validate, logger=logger) else: - _non_dist_train(model, dataset, cfg, validate=validate) + _non_dist_train(model, dataset, cfg, validate=validate, logger=logger) -def _dist_train(model, dataset, cfg, validate=False): +def _dist_train(model, dataset, cfg, validate=False, logger=None): """ not implemented yet """ raise NotImplementedError -def _non_dist_train(model, dataset, cfg, validate=False): +def _non_dist_train(model, dataset, cfg, validate=False, logger=None): # prepare data loaders data_loaders = [ build_dataloader( @@ -76,7 +77,7 @@ def _non_dist_train(model, dataset, cfg, validate=False): optimizer = build_optimizer(model, cfg.optimizer) runner = Runner(model, batch_processor, optimizer, cfg.work_dir, - cfg.log_level) + logger) runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config, cfg.checkpoint_config, cfg.log_config)