|
| 1 | +# ***************************************************************************** |
| 2 | +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. |
| 3 | +# |
| 4 | +# Redistribution and use in source and binary forms, with or without |
| 5 | +# modification, are permitted provided that the following conditions are met: |
| 6 | +# * Redistributions of source code must retain the above copyright |
| 7 | +# notice, this list of conditions and the following disclaimer. |
| 8 | +# * Redistributions in binary form must reproduce the above copyright |
| 9 | +# notice, this list of conditions and the following disclaimer in the |
| 10 | +# documentation and/or other materials provided with the distribution. |
| 11 | +# * Neither the name of the NVIDIA CORPORATION nor the |
| 12 | +# names of its contributors may be used to endorse or promote products |
| 13 | +# derived from this software without specific prior written permission. |
| 14 | +# |
| 15 | +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND |
| 16 | +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED |
| 17 | +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE |
| 18 | +# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY |
| 19 | +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES |
| 20 | +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; |
| 21 | +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND |
| 22 | +# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
| 23 | +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS |
| 24 | +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| 25 | +# |
| 26 | +# ***************************************************************************** |
| 27 | +import os |
| 28 | +import sys |
| 29 | +import time |
| 30 | +import subprocess |
| 31 | +import argparse |
| 32 | + |
| 33 | +import torch |
| 34 | +import torch.distributed as dist |
| 35 | +from torch.autograd import Variable |
| 36 | + |
| 37 | +def reduce_tensor(tensor, num_gpus): |
| 38 | + rt = tensor.clone() |
| 39 | + dist.all_reduce(rt, op=dist.reduce_op.SUM) |
| 40 | + rt /= num_gpus |
| 41 | + return rt |
| 42 | + |
| 43 | +def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url): |
| 44 | + assert torch.cuda.is_available(), "Distributed mode requires CUDA." |
| 45 | + print("Initializing Distributed") |
| 46 | + |
| 47 | + # Set cuda device so everything is done on the right GPU. |
| 48 | + torch.cuda.set_device(rank % torch.cuda.device_count()) |
| 49 | + |
| 50 | + # Initialize distributed communication |
| 51 | + dist.init_process_group(dist_backend, init_method=dist_url, |
| 52 | + world_size=num_gpus, rank=rank, |
| 53 | + group_name=group_name) |
| 54 | + |
| 55 | +def _flatten_dense_tensors(tensors): |
| 56 | + """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of |
| 57 | + same dense type. |
| 58 | + Since inputs are dense, the resulting tensor will be a concatenated 1D |
| 59 | + buffer. Element-wise operation on this buffer will be equivalent to |
| 60 | + operating individually. |
| 61 | + Arguments: |
| 62 | + tensors (Iterable[Tensor]): dense tensors to flatten. |
| 63 | + Returns: |
| 64 | + A contiguous 1D buffer containing input tensors. |
| 65 | + """ |
| 66 | + if len(tensors) == 1: |
| 67 | + return tensors[0].contiguous().view(-1) |
| 68 | + flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0) |
| 69 | + return flat |
| 70 | + |
| 71 | +def _unflatten_dense_tensors(flat, tensors): |
| 72 | + """View a flat buffer using the sizes of tensors. Assume that tensors are of |
| 73 | + same dense type, and that flat is given by _flatten_dense_tensors. |
| 74 | + Arguments: |
| 75 | + flat (Tensor): flattened dense tensors to unflatten. |
| 76 | + tensors (Iterable[Tensor]): dense tensors whose sizes will be used to |
| 77 | + unflatten flat. |
| 78 | + Returns: |
| 79 | + Unflattened dense tensors with sizes same as tensors and values from |
| 80 | + flat. |
| 81 | + """ |
| 82 | + outputs = [] |
| 83 | + offset = 0 |
| 84 | + for tensor in tensors: |
| 85 | + numel = tensor.numel() |
| 86 | + outputs.append(flat.narrow(0, offset, numel).view_as(tensor)) |
| 87 | + offset += numel |
| 88 | + return tuple(outputs) |
| 89 | + |
| 90 | +def apply_gradient_allreduce(module): |
| 91 | + """ |
| 92 | + Modifies existing model to do gradient allreduce, but doesn't change class |
| 93 | + so you don't need "module" |
| 94 | + """ |
| 95 | + if not hasattr(dist, '_backend'): |
| 96 | + module.warn_on_half = True |
| 97 | + else: |
| 98 | + module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False |
| 99 | + |
| 100 | + for p in module.state_dict().values(): |
| 101 | + if not torch.is_tensor(p): |
| 102 | + continue |
| 103 | + dist.broadcast(p, 0) |
| 104 | + |
| 105 | + def allreduce_params(): |
| 106 | + if(module.needs_reduction): |
| 107 | + module.needs_reduction = False |
| 108 | + buckets = {} |
| 109 | + for param in module.parameters(): |
| 110 | + if param.requires_grad and param.grad is not None: |
| 111 | + tp = type(param.data) |
| 112 | + if tp not in buckets: |
| 113 | + buckets[tp] = [] |
| 114 | + buckets[tp].append(param) |
| 115 | + if module.warn_on_half: |
| 116 | + if torch.cuda.HalfTensor in buckets: |
| 117 | + print("WARNING: gloo dist backend for half parameters may be extremely slow." + |
| 118 | + " It is recommended to use the NCCL backend in this case. This currently requires" + |
| 119 | + "PyTorch built from top of tree master.") |
| 120 | + module.warn_on_half = False |
| 121 | + |
| 122 | + for tp in buckets: |
| 123 | + bucket = buckets[tp] |
| 124 | + grads = [param.grad.data for param in bucket] |
| 125 | + coalesced = _flatten_dense_tensors(grads) |
| 126 | + dist.all_reduce(coalesced) |
| 127 | + coalesced /= dist.get_world_size() |
| 128 | + for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): |
| 129 | + buf.copy_(synced) |
| 130 | + |
| 131 | + for param in list(module.parameters()): |
| 132 | + def allreduce_hook(*unused): |
| 133 | + Variable._execution_engine.queue_callback(allreduce_params) |
| 134 | + if param.requires_grad: |
| 135 | + param.register_hook(allreduce_hook) |
| 136 | + dir(param) |
| 137 | + |
| 138 | + def set_needs_reduction(self, input, output): |
| 139 | + self.needs_reduction = True |
| 140 | + |
| 141 | + module.register_forward_hook(set_needs_reduction) |
| 142 | + return module |
| 143 | + |
| 144 | + |
| 145 | +def main(config, stdout_dir, args_str): |
| 146 | + args_list = ['train.py'] |
| 147 | + args_list += args_str.split(' ') if len(args_str) > 0 else [] |
| 148 | + |
| 149 | + args_list.append('--config={}'.format(config)) |
| 150 | + |
| 151 | + num_gpus = torch.cuda.device_count() |
| 152 | + args_list.append('--num_gpus={}'.format(num_gpus)) |
| 153 | + args_list.append("--group_name=group_{}".format(time.strftime("%Y_%m_%d-%H%M%S"))) |
| 154 | + |
| 155 | + if not os.path.isdir(stdout_dir): |
| 156 | + os.makedirs(stdout_dir) |
| 157 | + os.chmod(stdout_dir, 0o775) |
| 158 | + |
| 159 | + workers = [] |
| 160 | + |
| 161 | + for i in range(num_gpus): |
| 162 | + args_list[-2] = '--rank={}'.format(i) |
| 163 | + stdout = None if i == 0 else open( |
| 164 | + os.path.join(stdout_dir, "GPU_{}.log".format(i)), "w") |
| 165 | + print(args_list) |
| 166 | + p = subprocess.Popen([str(sys.executable)]+args_list, stdout=stdout) |
| 167 | + workers.append(p) |
| 168 | + |
| 169 | + for p in workers: |
| 170 | + p.wait() |
| 171 | + |
| 172 | + |
| 173 | +if __name__ == '__main__': |
| 174 | + parser = argparse.ArgumentParser() |
| 175 | + parser.add_argument('-c', '--config', type=str, required=True, |
| 176 | + help='JSON file for configuration') |
| 177 | + parser.add_argument('-s', '--stdout_dir', type=str, default=".", |
| 178 | + help='directory to save stoud logs') |
| 179 | + parser.add_argument( |
| 180 | + '-a', '--args_str', type=str, default='', |
| 181 | + help='double quoted string with space separated key value pairs') |
| 182 | + |
| 183 | + args = parser.parse_args() |
| 184 | + main(args.config, args.stdout_dir, args.args_str) |
0 commit comments