diff --git a/open_lm/main.py b/open_lm/main.py index 898c9c53..10508637 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -107,10 +107,7 @@ def load_model(args, model): global_step = checkpoint.get("step", None) if next(iter(sd.items()))[0].startswith("module"): sd = {k[len("module.") :]: v for k, v in sd.items()} - if args.distributed: - model.module.load_state_dict(sd) - else: - model.load_state_dict(sd) + model.load_state_dict(sd) logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})") else: # loading a bare (model only) checkpoint for fine-tune or evaluation @@ -391,10 +388,7 @@ def main(args): if args.hf_model is not None: model = create_wrapped_hf_model(args) else: - with torch.device("meta" if args.fsdp else args.device): - model = create_model(args) - if not args.fsdp: - model.reset_parameters() + model = create_model(args) args.vocab_size = model.vocab_size args.seq_len = model.seq_len @@ -403,6 +397,8 @@ def main(args): if args.val_num_samples is not None: args.val_num_samples //= args.seq_len + model = model.to(device) + random_seed(args.seed, args.rank) if args.distributed: @@ -540,6 +536,82 @@ def main(args): if samples_seen >= args.train_num_samples * args.epochs: raise RuntimeError("Loaded a checkpoint which has already seen the desired number of tokens.") + if args.distributed: + if args.fsdp: + transformer_layer_cls = None + + if args.hf_model is not None: + # retrive the user specified block class for fsdp + for _, target_cls in model.named_modules(): + if args.hf_fsdp_block in type(target_cls).__name__: + transformer_layer_cls = {type(target_cls)} + break + + if transformer_layer_cls is None: + print(f"--hf-fsdp-block {args.hf_fsdp_block} not found in --hf-model {args.hf_model}") + return -1 + + else: + transformer_layer_cls = {Block} + # from https://pytorch.org/blog/efficient-large-scale-training-with-pytorch/ + transformer_auto_wrapper_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls=transformer_layer_cls, + ) + # tries to follow gopher... + mp_policy = None + if args.fsdp_amp: + print("=> using bfloat16 params as part of fsdp amp policy.") + mp_policy = MixedPrecision( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + buffer_dtype=torch.bfloat16, + ) + elif args.fsdp_pure_bf16: + print("=> using pure bfloat16 params as part of fsdp amp policy.") + mp_policy = MixedPrecision( + param_dtype=torch.bfloat16, + reduce_dtype=torch.bfloat16, + buffer_dtype=torch.bfloat16, + ) + + if args.rank == 0: + print(f"Before FSDP parameter num: {sum(p.numel() for p in model.parameters())}") + print(f"Before FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB") + + fsdp_kwargs = {} + assert not ( + args.fsdp_hybrid and args.fsdp_hybrid_o2 + ), "Only --fsdp-hybrid or --fsdp-hybrid-o2 should be set." + if args.fsdp_backward_prefetch: + fsdp_kwargs["backward_prefetch"] = BackwardPrefetch.BACKWARD_PRE + if args.fsdp_hybrid: + fsdp_kwargs["sharding_strategy"] = ShardingStrategy.HYBRID_SHARD + if args.fsdp_hybrid_o2: + fsdp_kwargs["sharding_strategy"] = ShardingStrategy._HYBRID_SHARD_ZERO2 + print("=> FSDP kwargs: ", fsdp_kwargs) + + # init FSDP + model = FSDP( + model, + auto_wrap_policy=transformer_auto_wrapper_policy, + device_id=device, + mixed_precision=mp_policy, + cpu_offload=CPUOffload(offload_params=args.fsdp_cpu_offload), + use_orig_params=args.fsdp_use_orig_params, + limit_all_gathers=args.fsdp_limit_all_gathers, + **fsdp_kwargs, + ) + + print(f"After FSDP parameter num: {sum(p.numel() for p in model.parameters())} on rank {args.rank}") + print(f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB on rank {args.rank}") + else: + ddp_args = {} + if args.ddp_static_graph: + # this doesn't exist in older PyTorch, arg only added if enabled + ddp_args["static_graph"] = True + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args) + # create optimizer and scaler optimizer = None scaler = None diff --git a/open_lm/model.py b/open_lm/model.py index d0750af0..ad135892 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -119,6 +119,13 @@ def __init__(self, layer_id, args: Params): self.attn_fn = xformers_attn if torch.cuda.is_available() else torch_attn self.apply_qk_norm = args.apply_qk_norm + # initialize weights by trunc_normal(1/sqrt(fan_in)) + std = 1.0 / math.sqrt(args.dim) + torch.nn.init.trunc_normal_(self.in_proj.weight, std=std, a=-3 * std, b=3 * std) + # scale init by depth as in https://arxiv.org/abs/1908.11365 -- worked slightly better. + std = std / math.sqrt(2 * (layer_id + 1)) + torch.nn.init.trunc_normal_(self.out_proj.weight, std=std, a=-3 * std, b=3 * std) + # initialize norm layers for queries and keys if needed self.q_norm = ( args.norm_type( @@ -137,18 +144,6 @@ def __init__(self, layer_id, args: Params): else nn.Identity() ) - self.layer_id = layer_id - self.dim = args.dim - self.reset_parameters() - - def reset_parameters(self): - # initialize weights by trunc_normal(1/sqrt(fan_in)) - std = 1.0 / math.sqrt(self.dim) - torch.nn.init.trunc_normal_(self.in_proj.weight, std=std, a=-3 * std, b=3 * std) - # scale init by depth as in https://arxiv.org/abs/1908.11365 -- worked slightly better. - std = std / math.sqrt(2 * (self.layer_id + 1)) - torch.nn.init.trunc_normal_(self.out_proj.weight, std=std, a=-3 * std, b=3 * std) - def forward(self, x: torch.Tensor, is_causal=True): batchsize, seqlen, _ = x.shape queries, keys, vals = self.in_proj(x).chunk(3, dim=-1) @@ -176,19 +171,17 @@ def __init__(self, layer_id, args: Params): self.dim = args.dim self.head_dim = args.dim // args.n_heads self.attention = CustomAttn(layer_id, args) - self._ffn_type = args.ffn_type if args.ffn_type == "swiglu": # this follows llama / lit llama -- go to multiple of 256 - self.hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256) - self.feed_forward = xops.SwiGLU(args.dim, self.hidden_dim, args.dim, bias=False) + hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256) + self.feed_forward = xops.SwiGLU(args.dim, hidden_dim, args.dim, bias=False) elif args.ffn_type == "gelu": # Follows mosaic mpt7b, but without a bias. - self.hidden_dim = args.dim * 4 - self._ff_w1 = nn.Linear(args.dim, self.hidden_dim, bias=False) - self._ff_w2 = nn.Linear(self.hidden_dim, args.dim, bias=False) + hidden_dim = args.dim * 4 + self._ff_w1 = nn.Linear(args.dim, hidden_dim, bias=False) + self._ff_w2 = nn.Linear(hidden_dim, args.dim, bias=False) self.feed_forward = nn.Sequential(self._ff_w1, nn.GELU(approximate="none"), self._ff_w2) - self.layer_id = layer_id self.attention_norm = args.norm_type( args.dim, @@ -199,23 +192,21 @@ def __init__(self, layer_id, args: Params): eps=args.norm_eps, ) self.attention.seq_len = args.seq_len - self.reset_parameters() - def reset_parameters(self): - if self._ffn_type == "swiglu": + if args.ffn_type == "swiglu": # initialize weights trunc_normal(1/sqrt(fan_in)) - std = 1.0 / math.sqrt(self.dim) + std = 1.0 / math.sqrt(args.dim) torch.nn.init.trunc_normal_(self.feed_forward.w12.weight, std=std, a=-3 * std, b=3 * std) # scale init by depth as in https://arxiv.org/abs/1908.11365 -- worked slightly better. - std = 1.0 / math.sqrt(self.hidden_dim) - std = std / math.sqrt(2 * (self.layer_id + 1)) + std = 1.0 / math.sqrt(hidden_dim) + std = std / math.sqrt(2 * (layer_id + 1)) torch.nn.init.trunc_normal_(self.feed_forward.w3.weight, std=std, a=-3 * std, b=3 * std) - elif self._ffn_type == "gelu": - std = 1.0 / math.sqrt(self.dim) + elif args.ffn_type == "gelu": + std = 1.0 / math.sqrt(args.dim) torch.nn.init.trunc_normal_(self._ff_w1.weight, std=std, a=-3 * std, b=3 * std) - std = 1.0 / math.sqrt(self.hidden_dim) - std = std / math.sqrt(2 * (self._layer_id + 1)) + std = 1.0 / math.sqrt(hidden_dim) + std = std / math.sqrt(2 * (layer_id + 1)) torch.nn.init.trunc_normal_(self._ff_w2.weight, std=std, a=-3 * std, b=3 * std) def forward(self, x): @@ -257,13 +248,11 @@ def __init__(self, params): if self.weight_tying: self.tok_embeddings.weight = self.output.weight self.grad_checkpointing = False - self.reset_parameters() - def reset_parameters(self): # initialize weight 1/sqrt(dim) # this is 1/fan_in for output, as is default, and Maciej Kilian tried another option # for the embed layer (from RWKV paper) but this was better. - std = 1.0 / math.sqrt(self.params.dim) + std = 1.0 / math.sqrt(params.dim) torch.nn.init.trunc_normal_(self.output.weight, std=std, a=-3 * std, b=3 * std) torch.nn.init.trunc_normal_(self.tok_embeddings.weight, std=std, a=-3 * std, b=3 * std) diff --git a/open_lm/positional_embedding/rotary.py b/open_lm/positional_embedding/rotary.py index c3043746..d247f78b 100644 --- a/open_lm/positional_embedding/rotary.py +++ b/open_lm/positional_embedding/rotary.py @@ -44,16 +44,13 @@ class RotaryEmbedding(torch.nn.Module): def __init__(self, dim_model: int, *_, **__): super().__init__() # Generate and save the inverse frequency buffer (non trainable) - self.dim_model = dim_model - self.register_buffer("inv_freq", torch.zeros(self.dim_model // 2)) + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_model, 2).float() / dim_model)) + self.register_buffer("inv_freq", inv_freq) self._seq_len_cached = None self._cos_cached = None self._sin_cached = None - def reset_parameters(self): - self.inv_freq = 1.0 / (10000 ** (torch.arange(0, self.dim_model, 2).float() / self.dim_model)) - def _update_cos_sin_tables(self, x, seq_dimension=1): seq_len = x.shape[seq_dimension] diff --git a/tests/shared.py b/tests/shared.py index fb7bdbff..e0f6967f 100644 --- a/tests/shared.py +++ b/tests/shared.py @@ -1,6 +1,5 @@ import torch from torch import optim -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from open_lm.main import random_seed from open_lm.model import create_model @@ -92,7 +91,7 @@ def __init__(self): self.ignore_parse_errors = False -def create_train_fixtures(model="open_lm_11m", fsdp=False): +def create_train_fixtures(model="open_lm_11m"): # Setup data, optimizer, and other basic settings args = MockTrainArgs(model) @@ -105,14 +104,7 @@ def create_train_fixtures(model="open_lm_11m", fsdp=False): # create base models random_seed() - if fsdp: - with torch.device("meta"): - model = create_model(args) - model = FSDP(model) - else: - model = create_model(args) - model.reset_parameters() - model = model.to(args.device) + model = create_model(args).to(args.device) # create dataloader data = get_data(