From 37bceea3557e7ae4ccf6473627579ea6f80f8d98 Mon Sep 17 00:00:00 2001 From: Dominic Yu Date: Mon, 16 Dec 2019 17:16:38 +0800 Subject: [PATCH] let replica and rank equals 1 if no distributed support --- mmfashion/datasets/loader/sampler.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/mmfashion/datasets/loader/sampler.py b/mmfashion/datasets/loader/sampler.py index 11e1c4e9d..c748342df 100644 --- a/mmfashion/datasets/loader/sampler.py +++ b/mmfashion/datasets/loader/sampler.py @@ -2,9 +2,16 @@ import numpy as np import torch -from torch.distributed import get_world_size, get_rank from torch.utils.data import Sampler from torch.utils.data.distributed import DistributedSampler as _DistributedSampler +try: + from torch.distributed import get_world_size, get_rank + + NO_DISTRIBUTED_SUPPORT = False +except ImportError: + NO_DISTRIBUTED_SUPPORT = True + ONLY_ONE_PROCOCESS = 1 + ONLY_ONE_REPLICA = 1 class DistributedSampler(_DistributedSampler): @@ -91,11 +98,17 @@ def __init__(self, dataset, samples_per_gpu=1, num_replicas=None, - rank=None): - if num_replicas is None: - num_replicas = get_world_size() - if rank is None: - rank = get_rank() + rank=None, + no_distributed_support=False): + + if no_distributed_support: + num_replicas = ONLY_ONE_REPLICA + rank = ONLY_ONE_PROCOCESS + else: + if num_replicas is None: + num_replicas = get_world_size() + if rank is None: + rank = get_rank() self.dataset = dataset self.samples_per_gpu = samples_per_gpu self.num_replicas = num_replicas