diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index e16ef6a8..d90a0961 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -66,7 +66,7 @@ class CheckpointBlock(torch.nn.Module): >>> y2, ... = transformer_block(x) >>> assert torch.allclose(y1, y2) """ - def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_level=3): + def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, offload_level=0, zero_level=3): super().__init__() self._module = inner_module self._inputs = None @@ -80,7 +80,8 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_lev self._ready = False # sort parameters by name ordered_parameters = list(self._module.named_parameters()) - + use_offload = offload_level in [1,2] + assert not (use_checkpoint and use_offload), "It does not make sense to use offload and checkpointing at the same time" # calc total number of parameters for name, param in ordered_parameters: if not isinstance(param, DistributedParameter): @@ -202,6 +203,11 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_lev self._pre_module = [] #save the pre module of self self._ref_count = 0 #incremental in forward and decreasing in backward self._mode = "BLOCK" #BLOCK or ZERO or PIPE + self.offload_level = offload_level + if use_offload: + self._mode = "OFFLOAD" + self._on_device = False + self.all_input_no_grad = False self.all_param_no_grad = False self._zero_level = zero_level @@ -212,12 +218,16 @@ def set_pre_module(self, pre_module): pre_module._next_module.append(self) def pre_module(self): - assert len(self._pre_module) == self._ref_count, "{} != {}".format(len(self._pre_module), self._ref_count) - return self._pre_module[self._ref_count-1] + if len(self._pre_module) > 0: + return self._pre_module[self._ref_count-1] + else: + return None def next_module(self): - assert len(self._next_module) == self._ref_count, "{} != {}".format(len(self._next_module), self._ref_count) - return self._next_module[self._ref_count-1] + if len(self._next_module) > 0: + return self._next_module[self._ref_count-1] + else: + return None def backward_release(self, flag): if self._ref_count == 1 and self._backward_block_ctx is not None: @@ -536,19 +546,21 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) self._modules = {} pre_module = None + offload = 0 for i, module in enumerate(modules): if not isinstance(module, CheckpointBlock): module = CheckpointBlock(module) - - module._mode = "ZERO" + module._mode = "ZERO" if module._mode == "BLOCK" else module._mode module.set_pre_module(pre_module) pre_module = module module._is_first_layer = False module._is_last_layer = False - + if module._mode == "OFFLOAD": + offload+=1 + module.calc_event = torch.cuda.Event() + module.offload_event = torch.cuda.Event() self._modules[str(i)] = module self.add_module(str(i), module) - self._modules[str(0)]._is_first_layer = True self._modules[str(len(modules)-1)]._is_last_layer = True @@ -575,7 +587,7 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) self.save_list = save_list else: self.save_list = [(i, i) for i in range(len(self))] - + def __len__(self) -> int: return len(self._modules) diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 4d91d1d0..b74c3149 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -1,10 +1,30 @@ import torch from .global_var import config from .checkpointing import CheckpointBlockContext - +from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations +from contextlib import contextmanager +from .utils import round_up, find_pre_module_helper +from .offload import Offload_Dict, offload_wrapper, offload_pre_hook, offload_post_hook def zero_pre_forward(module, inputs): enter = True pipe = False + if module._mode == "OFFLOAD": + if not hasattr(module, "_offload_dict"): + module._offload_dict = Offload_Dict() + pack_hook, unpack_hook = offload_wrapper(module._offload_dict) + if module.offload_level == 1: + for n, m in module.named_modules(): + if m.__class__.__name__ == "Linear" and not hasattr(m, "_offload_hook"): + m._offload_hook = (pack_hook, unpack_hook) + m.register_forward_pre_hook(offload_pre_hook) + m.register_forward_hook(offload_post_hook) + elif module.offload_level == 2: + if not hasattr(module, "_offload_hook"): + module._offload_hook = (pack_hook, unpack_hook) + torch._C._autograd._push_saved_tensors_default_hooks( + pack_hook, unpack_hook + ) + if module._mode == "PIPE": enter = module._micro_idx == 0 pipe = True @@ -25,7 +45,21 @@ def zero_post_forward(module, inputs, outputs): exit = True if module._mode == "PIPE": exit = module._micro_idx == config['micros'] - 1 - + elif module._mode == "OFFLOAD": + torch.cuda.current_stream().record_event(module.calc_event) + pre_offload_module = find_pre_module_helper(module.pre_module()) + if pre_offload_module is not None: + torch.cuda.current_stream().wait_event(pre_offload_module.offload_event) + with torch.cuda.stream(config["offload_stream"]): + config["offload_stream"].wait_event(module.calc_event) + if not hasattr(module._offload_dict, "fp16_storage"): + module._offload_dict.make_cpu_storage() + module._offload_dict.record_stream(config["offload_stream"]) + module._offload_dict.d2h_memcpy() + if len(module._next_module) > 0: + config["offload_stream"].record_event(module.offload_event) + if module.offload_level == 2: + torch._C._autograd._pop_saved_tensors_default_hooks() if exit: module._forward_block_ctx.exit(forward_flag) module._ref_count += 1 @@ -33,6 +67,20 @@ def zero_post_forward(module, inputs, outputs): def zero_pre_backward(module, grad_outputs): backward_flag = 2 if module._zero_level == 2 else 0 if module._mode != "PIPE": + if module._mode == "OFFLOAD" or (len(module._next_module) == 0): + if len(module._next_module) != 0: + current_stream = torch.cuda.current_stream() + current_stream.wait_event(module.offload_event) + pre_module = find_pre_module_helper(module.pre_module()) + if pre_module is not None: + pre_module._on_device = True + with torch.cuda.stream(config["offload_stream"]): + if (len(module._next_module) != 0): + torch.cuda.current_stream().wait_event(module.calc_event) + pre_module._offload_dict.h2d_memcpy() + torch.cuda.current_stream().record_event(pre_module.offload_event) + if (len(module._next_module) != 0): + module._offload_dict.record_stream(current_stream) module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict) module._backward_block_ctx.enter(backward_flag, True) if not module._is_last_layer: @@ -45,6 +93,10 @@ def zero_pre_backward(module, grad_outputs): def zero_post_backward(module, grad_inputs, grad_outputs): backward_flag = 2 if module._zero_level == 2 else 0 if module._mode != "PIPE": + if module._mode == "OFFLOAD": + module._on_device = False + module._offload_dict.pop_all() + torch.cuda.current_stream().record_event(module.calc_event) if module._is_first_layer: module.backward_release(backward_flag) else: diff --git a/bmtrain/init.py b/bmtrain/init.py index a6214d78..e002ae9c 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -72,6 +72,7 @@ def init_distributed( config["rank"] = rank config["world_size"] = world_size config["calc_stream"] = torch.cuda.current_stream() + config["offload_stream"] = torch.cuda.Stream(priority=-1) config["load_stream"] = torch.cuda.Stream(priority=-1) config["tp_comm_stream"] = torch.cuda.Stream(priority=-1) config["pp_comm_stream"] = torch.cuda.Stream(priority=-1) diff --git a/bmtrain/offload.py b/bmtrain/offload.py new file mode 100644 index 00000000..d589d663 --- /dev/null +++ b/bmtrain/offload.py @@ -0,0 +1,121 @@ +import torch +from collections import OrderedDict + +class Offload_Dict: + + def __init__(self): + self._offload_dict = OrderedDict() + + def add(self, tensor): + tensor = tensor.contiguous() + tensor_id = id(tensor) + data_ptr = tensor.storage().data_ptr() + if data_ptr not in self._offload_dict: + self._offload_dict[data_ptr] = {} + self._offload_dict[data_ptr]["stor"] = tensor.storage() + self._offload_dict[data_ptr]["size"] = tensor.storage().size() + self._offload_dict[data_ptr]["dtype"] = tensor.storage().dtype + self._offload_dict[data_ptr]["tensors"] = {} + + self._offload_dict[data_ptr]["tensors"][id(tensor)] = {} + self._offload_dict[data_ptr]["tensors"][id(tensor)]["numel"] = tensor.numel() + self._offload_dict[data_ptr]["tensors"][id(tensor)]['dtype'] = tensor.dtype + self._offload_dict[data_ptr]["tensors"][id(tensor)]['offset'] = tensor.storage_offset() + self._offload_dict[data_ptr]["tensors"][id(tensor)]['tensor'] = tensor + self._offload_dict[data_ptr]["tensors"][id(tensor)]["shape"] = tensor.shape + self._device = "cuda" + return (data_ptr,tensor_id) + + def get_total(self): + fp16_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float16]) + fp32_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float32]) + return fp16_total,fp32_total + + def make_cpu_storage(self): + fp16_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float16]) + fp32_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float32]) + fp16_storage = torch.HalfStorage(fp16_total).pin_memory() + fp32_storage = torch.FloatStorage(fp32_total).pin_memory() + self.fp16_storage = fp16_storage + self.fp32_storage = fp32_storage + self.fp16_total = fp16_total + self.fp32_total = fp32_total + + def get(self, key): + data_ptr, tensor_id = key + return self._offload_dict[data_ptr]['tensors'][tensor_id]["tensor"] + + def pop_all(self): + self._offload_dict.clear() + + def h2d_memcpy(self): + fp16_storage_cuda = self.fp16_storage.cuda(non_blocking=True) + fp32_storage_cuda = self.fp32_storage.cuda(non_blocking=True) + for key,val in self._offload_dict.items(): + for id_val in val['tensors'].values(): + id_val['tensor'] = torch.tensor([], dtype=id_val['dtype'],device=fp16_storage_cuda.device) + if id_val['dtype'] == torch.float16: + id_val['tensor'].set_(fp16_storage_cuda, id_val['abs_offset'], id_val['shape']) + elif id_val['dtype'] == torch.float32: + id_val['tensor'].set_(fp32_storage_cuda, id_val['abs_offset'], id_val['shape']) + + def record_stream(self, stream): + for key, val in self._offload_dict.items(): + for id_val in val['tensors'].values(): + id_val['tensor'].record_stream(stream) + + def d2h_memcpy(self): + fp16_offset = 0 + fp32_offset = 0 + fp16_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float16]) + fp32_total = sum([v['size'] for v in self._offload_dict.values() if v['dtype'] == torch.float32]) + assert fp16_total <= self.fp16_total + assert fp32_total <= self.fp32_total + fp16_storage = self.fp16_storage + fp32_storage = self.fp32_storage + for key,val in self._offload_dict.items(): + assert val['dtype'] in [torch.float16, torch.float32] + storage = fp16_storage if val['dtype'] == torch.float16 else fp32_storage + offset = fp16_offset if val['dtype'] == torch.float16 else fp32_offset + for id_val in val['tensors'].values(): + cpu_tensor = torch.tensor([], dtype=id_val['dtype'], device="cpu") \ + .set_(storage, offset+id_val['offset'], id_val['shape']) + id_val["abs_offset"] = offset+id_val['offset'] + id_val['tensor'] = cpu_tensor.copy_(id_val['tensor'], non_blocking=True) + if val['dtype'] == torch.float16: + fp16_offset += val['size'] + else: + fp32_offset += val['size'] + val['stor'] = None + + +def offload_wrapper(offload_dict): + def pack_hook(tensor): + if isinstance(tensor, torch.nn.Parameter): + return (tensor,) + elif tensor.dtype not in [torch.float16]: + return (tensor,) + else: + key = offload_dict.add(tensor) + return (tensor.device, key) + def unpack_hook(packed): + if len(packed) == 2: + device, key = packed + tensor = offload_dict.get(key) + assert tensor.device == device + return tensor + else: + tensor, = packed + return tensor + return pack_hook, unpack_hook + +def offload_pre_hook(module, input): + if hasattr(module, "_offload_hook"): + pack_hook, unpack_hook = module._offload_hook + torch._C._autograd._push_saved_tensors_default_hooks( + pack_hook, unpack_hook + ) + +def offload_post_hook(module, input, output): + if hasattr(module, "_offload_hook"): + torch._C._autograd._pop_saved_tensors_default_hooks() \ No newline at end of file diff --git a/bmtrain/utils.py b/bmtrain/utils.py index 8cb87808..0d72106b 100644 --- a/bmtrain/utils.py +++ b/bmtrain/utils.py @@ -32,7 +32,14 @@ def load_nccl_pypi(): if file_split[-1] == "so" or (len(file_split)>1 and file_split[-2] == "so"): ctypes.CDLL(os.path.join(path, file_so)) - +def find_pre_module_helper(m): + if m is None: + return m + if m._mode == "OFFLOAD": + return m + else: + return find_pre_module_helper(m.pre_module()) + def round_up(x, d): return (x + d - 1) // d * d @@ -80,6 +87,17 @@ def print_rank(*args, rank=0, **kwargs): if config["rank"] == rank: print(*args, **kwargs) +def print_strategy(model): + print_rank(" "*24+"|"+" Offload Level |" + " ZeRO Level |"+" Activation Recompute |") + for idx,ckpt in enumerate(model): + print_rank(f"CheckpointBlock Layer {idx} |{ckpt.offload_level:^14} | {ckpt._zero_level:^10} | {ckpt.use_checkpoint.__repr__():^20} |") + +def print_inspect(model): + model_inspect = bmt.inspect.inspect_model(model, "*") + print_rank(bmt.inspect.format_summary(model_inspect)) + + + def see_memory(message, detail=False): """ Outputs a message followed by GPU memory status summary on rank 0.