Skip to content

Commit

Permalink
Vaishaal/revert fsdp loading (#147)
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaishaal authored Dec 11, 2023
1 parent 158124d commit 007e3a4
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 55 deletions.
88 changes: 80 additions & 8 deletions open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,7 @@ def load_model(args, model):
global_step = checkpoint.get("step", None)
if next(iter(sd.items()))[0].startswith("module"):
sd = {k[len("module.") :]: v for k, v in sd.items()}
if args.distributed:
model.module.load_state_dict(sd)
else:
model.load_state_dict(sd)
model.load_state_dict(sd)
logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})")
else:
# loading a bare (model only) checkpoint for fine-tune or evaluation
Expand Down Expand Up @@ -391,10 +388,7 @@ def main(args):
if args.hf_model is not None:
model = create_wrapped_hf_model(args)
else:
with torch.device("meta" if args.fsdp else args.device):
model = create_model(args)
if not args.fsdp:
model.reset_parameters()
model = create_model(args)

args.vocab_size = model.vocab_size
args.seq_len = model.seq_len
Expand All @@ -403,6 +397,8 @@ def main(args):
if args.val_num_samples is not None:
args.val_num_samples //= args.seq_len

model = model.to(device)

random_seed(args.seed, args.rank)

if args.distributed:
Expand Down Expand Up @@ -540,6 +536,82 @@ def main(args):
if samples_seen >= args.train_num_samples * args.epochs:
raise RuntimeError("Loaded a checkpoint which has already seen the desired number of tokens.")

if args.distributed:
if args.fsdp:
transformer_layer_cls = None

if args.hf_model is not None:
# retrive the user specified block class for fsdp
for _, target_cls in model.named_modules():
if args.hf_fsdp_block in type(target_cls).__name__:
transformer_layer_cls = {type(target_cls)}
break

if transformer_layer_cls is None:
print(f"--hf-fsdp-block {args.hf_fsdp_block} not found in --hf-model {args.hf_model}")
return -1

else:
transformer_layer_cls = {Block}
# from https://pytorch.org/blog/efficient-large-scale-training-with-pytorch/
transformer_auto_wrapper_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=transformer_layer_cls,
)
# tries to follow gopher...
mp_policy = None
if args.fsdp_amp:
print("=> using bfloat16 params as part of fsdp amp policy.")
mp_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.bfloat16,
)
elif args.fsdp_pure_bf16:
print("=> using pure bfloat16 params as part of fsdp amp policy.")
mp_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
)

if args.rank == 0:
print(f"Before FSDP parameter num: {sum(p.numel() for p in model.parameters())}")
print(f"Before FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB")

fsdp_kwargs = {}
assert not (
args.fsdp_hybrid and args.fsdp_hybrid_o2
), "Only --fsdp-hybrid or --fsdp-hybrid-o2 should be set."
if args.fsdp_backward_prefetch:
fsdp_kwargs["backward_prefetch"] = BackwardPrefetch.BACKWARD_PRE
if args.fsdp_hybrid:
fsdp_kwargs["sharding_strategy"] = ShardingStrategy.HYBRID_SHARD
if args.fsdp_hybrid_o2:
fsdp_kwargs["sharding_strategy"] = ShardingStrategy._HYBRID_SHARD_ZERO2
print("=> FSDP kwargs: ", fsdp_kwargs)

# init FSDP
model = FSDP(
model,
auto_wrap_policy=transformer_auto_wrapper_policy,
device_id=device,
mixed_precision=mp_policy,
cpu_offload=CPUOffload(offload_params=args.fsdp_cpu_offload),
use_orig_params=args.fsdp_use_orig_params,
limit_all_gathers=args.fsdp_limit_all_gathers,
**fsdp_kwargs,
)

print(f"After FSDP parameter num: {sum(p.numel() for p in model.parameters())} on rank {args.rank}")
print(f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB on rank {args.rank}")
else:
ddp_args = {}
if args.ddp_static_graph:
# this doesn't exist in older PyTorch, arg only added if enabled
ddp_args["static_graph"] = True
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args)

# create optimizer and scaler
optimizer = None
scaler = None
Expand Down
53 changes: 21 additions & 32 deletions open_lm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ def __init__(self, layer_id, args: Params):
self.attn_fn = xformers_attn if torch.cuda.is_available() else torch_attn
self.apply_qk_norm = args.apply_qk_norm

# initialize weights by trunc_normal(1/sqrt(fan_in))
std = 1.0 / math.sqrt(args.dim)
torch.nn.init.trunc_normal_(self.in_proj.weight, std=std, a=-3 * std, b=3 * std)
# scale init by depth as in https://arxiv.org/abs/1908.11365 -- worked slightly better.
std = std / math.sqrt(2 * (layer_id + 1))
torch.nn.init.trunc_normal_(self.out_proj.weight, std=std, a=-3 * std, b=3 * std)

# initialize norm layers for queries and keys if needed
self.q_norm = (
args.norm_type(
Expand All @@ -137,18 +144,6 @@ def __init__(self, layer_id, args: Params):
else nn.Identity()
)

self.layer_id = layer_id
self.dim = args.dim
self.reset_parameters()

def reset_parameters(self):
# initialize weights by trunc_normal(1/sqrt(fan_in))
std = 1.0 / math.sqrt(self.dim)
torch.nn.init.trunc_normal_(self.in_proj.weight, std=std, a=-3 * std, b=3 * std)
# scale init by depth as in https://arxiv.org/abs/1908.11365 -- worked slightly better.
std = std / math.sqrt(2 * (self.layer_id + 1))
torch.nn.init.trunc_normal_(self.out_proj.weight, std=std, a=-3 * std, b=3 * std)

def forward(self, x: torch.Tensor, is_causal=True):
batchsize, seqlen, _ = x.shape
queries, keys, vals = self.in_proj(x).chunk(3, dim=-1)
Expand Down Expand Up @@ -176,19 +171,17 @@ def __init__(self, layer_id, args: Params):
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = CustomAttn(layer_id, args)
self._ffn_type = args.ffn_type

if args.ffn_type == "swiglu":
# this follows llama / lit llama -- go to multiple of 256
self.hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256)
self.feed_forward = xops.SwiGLU(args.dim, self.hidden_dim, args.dim, bias=False)
hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256)
self.feed_forward = xops.SwiGLU(args.dim, hidden_dim, args.dim, bias=False)
elif args.ffn_type == "gelu":
# Follows mosaic mpt7b, but without a bias.
self.hidden_dim = args.dim * 4
self._ff_w1 = nn.Linear(args.dim, self.hidden_dim, bias=False)
self._ff_w2 = nn.Linear(self.hidden_dim, args.dim, bias=False)
hidden_dim = args.dim * 4
self._ff_w1 = nn.Linear(args.dim, hidden_dim, bias=False)
self._ff_w2 = nn.Linear(hidden_dim, args.dim, bias=False)
self.feed_forward = nn.Sequential(self._ff_w1, nn.GELU(approximate="none"), self._ff_w2)

self.layer_id = layer_id
self.attention_norm = args.norm_type(
args.dim,
Expand All @@ -199,23 +192,21 @@ def __init__(self, layer_id, args: Params):
eps=args.norm_eps,
)
self.attention.seq_len = args.seq_len
self.reset_parameters()

def reset_parameters(self):
if self._ffn_type == "swiglu":
if args.ffn_type == "swiglu":
# initialize weights trunc_normal(1/sqrt(fan_in))
std = 1.0 / math.sqrt(self.dim)
std = 1.0 / math.sqrt(args.dim)
torch.nn.init.trunc_normal_(self.feed_forward.w12.weight, std=std, a=-3 * std, b=3 * std)
# scale init by depth as in https://arxiv.org/abs/1908.11365 -- worked slightly better.
std = 1.0 / math.sqrt(self.hidden_dim)
std = std / math.sqrt(2 * (self.layer_id + 1))
std = 1.0 / math.sqrt(hidden_dim)
std = std / math.sqrt(2 * (layer_id + 1))
torch.nn.init.trunc_normal_(self.feed_forward.w3.weight, std=std, a=-3 * std, b=3 * std)
elif self._ffn_type == "gelu":
std = 1.0 / math.sqrt(self.dim)
elif args.ffn_type == "gelu":
std = 1.0 / math.sqrt(args.dim)
torch.nn.init.trunc_normal_(self._ff_w1.weight, std=std, a=-3 * std, b=3 * std)

std = 1.0 / math.sqrt(self.hidden_dim)
std = std / math.sqrt(2 * (self._layer_id + 1))
std = 1.0 / math.sqrt(hidden_dim)
std = std / math.sqrt(2 * (layer_id + 1))
torch.nn.init.trunc_normal_(self._ff_w2.weight, std=std, a=-3 * std, b=3 * std)

def forward(self, x):
Expand Down Expand Up @@ -257,13 +248,11 @@ def __init__(self, params):
if self.weight_tying:
self.tok_embeddings.weight = self.output.weight
self.grad_checkpointing = False
self.reset_parameters()

def reset_parameters(self):
# initialize weight 1/sqrt(dim)
# this is 1/fan_in for output, as is default, and Maciej Kilian tried another option
# for the embed layer (from RWKV paper) but this was better.
std = 1.0 / math.sqrt(self.params.dim)
std = 1.0 / math.sqrt(params.dim)
torch.nn.init.trunc_normal_(self.output.weight, std=std, a=-3 * std, b=3 * std)
torch.nn.init.trunc_normal_(self.tok_embeddings.weight, std=std, a=-3 * std, b=3 * std)

Expand Down
7 changes: 2 additions & 5 deletions open_lm/positional_embedding/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,13 @@ class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim_model: int, *_, **__):
super().__init__()
# Generate and save the inverse frequency buffer (non trainable)
self.dim_model = dim_model
self.register_buffer("inv_freq", torch.zeros(self.dim_model // 2))
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_model, 2).float() / dim_model))
self.register_buffer("inv_freq", inv_freq)

self._seq_len_cached = None
self._cos_cached = None
self._sin_cached = None

def reset_parameters(self):
self.inv_freq = 1.0 / (10000 ** (torch.arange(0, self.dim_model, 2).float() / self.dim_model))

def _update_cos_sin_tables(self, x, seq_dimension=1):
seq_len = x.shape[seq_dimension]

Expand Down
12 changes: 2 additions & 10 deletions tests/shared.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
from torch import optim
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from open_lm.main import random_seed
from open_lm.model import create_model
Expand Down Expand Up @@ -92,7 +91,7 @@ def __init__(self):
self.ignore_parse_errors = False


def create_train_fixtures(model="open_lm_11m", fsdp=False):
def create_train_fixtures(model="open_lm_11m"):
# Setup data, optimizer, and other basic settings
args = MockTrainArgs(model)

Expand All @@ -105,14 +104,7 @@ def create_train_fixtures(model="open_lm_11m", fsdp=False):

# create base models
random_seed()
if fsdp:
with torch.device("meta"):
model = create_model(args)
model = FSDP(model)
else:
model = create_model(args)
model.reset_parameters()
model = model.to(args.device)
model = create_model(args).to(args.device)

# create dataloader
data = get_data(
Expand Down

0 comments on commit 007e3a4

Please sign in to comment.