Skip to content

Commit 6528014

Browse files
authored
Add multi-gpu support to MonaiAlgo (#5228)
Signed-off-by: Holger Roth <[email protected]> Fixes #5195. ### Description Add support for MonaiAlgo to be run with torchrun for multi-gpu training. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Holger Roth <[email protected]>
1 parent ef68f16 commit 6528014

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

monai/fl/client/monai_algo.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import Optional, Union
1616

1717
import torch
18+
import torch.distributed as dist
1819

1920
import monai
2021
from monai.bundle import ConfigParser
@@ -112,6 +113,9 @@ class MonaiAlgo(ClientAlgo):
112113
benchmark: set benchmark to `False` for full deterministic behavior in cuDNN components.
113114
Note, full determinism in federated learning depends also on deterministic behavior of other FL components,
114115
e.g., the aggregator, which is not controlled by this class.
116+
multi_gpu: whether to run MonaiAlgo in a multi-GPU setting; defaults to `False`.
117+
backend: backend to use for torch.distributed; defaults to "nccl".
118+
init_method: init_method for torch.distributed; defaults to "env://".
115119
"""
116120

117121
def __init__(
@@ -128,6 +132,9 @@ def __init__(
128132
save_dict_key: Optional[str] = "model",
129133
seed: Optional[int] = None,
130134
benchmark: bool = True,
135+
multi_gpu: bool = False,
136+
backend: str = "nccl",
137+
init_method: str = "env://",
131138
):
132139
self.logger = logging.getLogger(self.__class__.__name__)
133140
if config_evaluate_filename == "default":
@@ -144,6 +151,9 @@ def __init__(
144151
self.save_dict_key = save_dict_key
145152
self.seed = seed
146153
self.benchmark = benchmark
154+
self.multi_gpu = multi_gpu
155+
self.backend = backend
156+
self.init_method = init_method
147157

148158
self.app_root = None
149159
self.train_parser = None
@@ -156,6 +166,7 @@ def __init__(
156166
self.post_evaluate_filters = None
157167
self.iter_of_start_time = 0
158168
self.global_weights = None
169+
self.rank = 0
159170

160171
self.phase = FlPhase.IDLE
161172
self.client_name = None
@@ -174,6 +185,15 @@ def initialize(self, extra=None):
174185
self.client_name = extra.get(ExtraItems.CLIENT_NAME, "noname")
175186
self.logger.info(f"Initializing {self.client_name} ...")
176187

188+
if self.multi_gpu:
189+
dist.init_process_group(backend=self.backend, init_method=self.init_method)
190+
self._set_cuda_device()
191+
self.logger.info(
192+
f"Using multi-gpu training on rank {self.rank} (available devices: {torch.cuda.device_count()})"
193+
)
194+
if self.rank > 0:
195+
self.logger.setLevel(logging.WARNING)
196+
177197
if self.seed:
178198
monai.utils.set_determinism(seed=self.seed)
179199
torch.backends.cudnn.benchmark = self.benchmark
@@ -243,6 +263,8 @@ def train(self, data: ExchangeObject, extra=None):
243263
extra: Dict with additional information that can be provided by FL system.
244264
245265
"""
266+
self._set_cuda_device()
267+
246268
if extra is None:
247269
extra = {}
248270
if not isinstance(data, ExchangeObject):
@@ -284,6 +306,8 @@ def get_weights(self, extra=None):
284306
or load requested model type from disk (`ModelType.BEST_MODEL` or `ModelType.FINAL_MODEL`).
285307
286308
"""
309+
self._set_cuda_device()
310+
287311
if extra is None:
288312
extra = {}
289313

@@ -361,6 +385,8 @@ def evaluate(self, data: ExchangeObject, extra=None):
361385
return_metrics: `ExchangeObject` containing evaluation metrics.
362386
363387
"""
388+
self._set_cuda_device()
389+
364390
if extra is None:
365391
extra = {}
366392
if not isinstance(data, ExchangeObject):
@@ -421,6 +447,9 @@ def finalize(self, extra=None):
421447
self.logger.info(f"Terminating {self.client_name} evaluator...")
422448
self.evaluator.terminate()
423449

450+
if self.multi_gpu:
451+
dist.destroy_process_group()
452+
424453
def _check_converted(self, global_weights, local_var_dict, n_converted):
425454
if n_converted == 0:
426455
self.logger.warning(
@@ -447,3 +476,8 @@ def _add_config_files(self, config_files):
447476
f"Expected config files to be of type str or list but got {type(config_files)}: {config_files}"
448477
)
449478
return files
479+
480+
def _set_cuda_device(self):
481+
if self.multi_gpu:
482+
self.rank = int(os.environ["LOCAL_RANK"])
483+
torch.cuda.set_device(self.rank)

monai/networks/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,8 @@ def copy_model_state(
516516
unchanged_keys = sorted(set(all_keys).difference(updated_keys))
517517
logger.info(f"'dst' model updated: {len(updated_keys)} of {len(dst_dict)} variables.")
518518
if inplace and isinstance(dst, torch.nn.Module):
519+
if isinstance(dst, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
520+
dst = dst.module
519521
dst.load_state_dict(dst_dict)
520522
return dst_dict, updated_keys, unchanged_keys
521523

0 commit comments

Comments
 (0)