Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vaishaal/revert fsdp loading #147

Merged
merged 2 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 80 additions & 8 deletions open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,7 @@ def load_model(args, model):
global_step = checkpoint.get("step", None)
if next(iter(sd.items()))[0].startswith("module"):
sd = {k[len("module.") :]: v for k, v in sd.items()}
if args.distributed:
model.module.load_state_dict(sd)
else:
model.load_state_dict(sd)
model.load_state_dict(sd)
logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})")
else:
# loading a bare (model only) checkpoint for fine-tune or evaluation
Expand Down Expand Up @@ -391,10 +388,7 @@ def main(args):
if args.hf_model is not None:
model = create_wrapped_hf_model(args)
else:
with torch.device("meta" if args.fsdp else args.device):
model = create_model(args)
if not args.fsdp:
model.reset_parameters()
model = create_model(args)

args.vocab_size = model.vocab_size
args.seq_len = model.seq_len
Expand All @@ -403,6 +397,8 @@ def main(args):
if args.val_num_samples is not None:
args.val_num_samples //= args.seq_len

model = model.to(device)

random_seed(args.seed, args.rank)

if args.distributed:
Expand Down Expand Up @@ -540,6 +536,82 @@ def main(args):
if samples_seen >= args.train_num_samples * args.epochs:
raise RuntimeError("Loaded a checkpoint which has already seen the desired number of tokens.")

if args.distributed:
if args.fsdp:
transformer_layer_cls = None

if args.hf_model is not None:
# retrive the user specified block class for fsdp
for _, target_cls in model.named_modules():
if args.hf_fsdp_block in type(target_cls).__name__:
transformer_layer_cls = {type(target_cls)}
break

if transformer_layer_cls is None:
print(f"--hf-fsdp-block {args.hf_fsdp_block} not found in --hf-model {args.hf_model}")
return -1

else:
transformer_layer_cls = {Block}
# from https://pytorch.org/blog/efficient-large-scale-training-with-pytorch/
transformer_auto_wrapper_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=transformer_layer_cls,
)
# tries to follow gopher...
mp_policy = None
if args.fsdp_amp:
print("=> using bfloat16 params as part of fsdp amp policy.")
mp_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.bfloat16,
)
elif args.fsdp_pure_bf16:
print("=> using pure bfloat16 params as part of fsdp amp policy.")
mp_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
)

if args.rank == 0:
print(f"Before FSDP parameter num: {sum(p.numel() for p in model.parameters())}")
print(f"Before FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB")

fsdp_kwargs = {}
assert not (
args.fsdp_hybrid and args.fsdp_hybrid_o2
), "Only --fsdp-hybrid or --fsdp-hybrid-o2 should be set."
if args.fsdp_backward_prefetch:
fsdp_kwargs["backward_prefetch"] = BackwardPrefetch.BACKWARD_PRE
if args.fsdp_hybrid:
fsdp_kwargs["sharding_strategy"] = ShardingStrategy.HYBRID_SHARD
if args.fsdp_hybrid_o2:
fsdp_kwargs["sharding_strategy"] = ShardingStrategy._HYBRID_SHARD_ZERO2
print("=> FSDP kwargs: ", fsdp_kwargs)

# init FSDP
model = FSDP(
model,
auto_wrap_policy=transformer_auto_wrapper_policy,
device_id=device,
mixed_precision=mp_policy,
cpu_offload=CPUOffload(offload_params=args.fsdp_cpu_offload),
use_orig_params=args.fsdp_use_orig_params,
limit_all_gathers=args.fsdp_limit_all_gathers,
**fsdp_kwargs,
)

print(f"After FSDP parameter num: {sum(p.numel() for p in model.parameters())} on rank {args.rank}")
print(f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB on rank {args.rank}")
else:
ddp_args = {}
if args.ddp_static_graph:
# this doesn't exist in older PyTorch, arg only added if enabled
ddp_args["static_graph"] = True
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args)

# create optimizer and scaler
optimizer = None
scaler = None
Expand Down
53 changes: 21 additions & 32 deletions open_lm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ def __init__(self, layer_id, args: Params):
self.attn_fn = xformers_attn if torch.cuda.is_available() else torch_attn
self.apply_qk_norm = args.apply_qk_norm

# initialize weights by trunc_normal(1/sqrt(fan_in))
std = 1.0 / math.sqrt(args.dim)
torch.nn.init.trunc_normal_(self.in_proj.weight, std=std, a=-3 * std, b=3 * std)
# scale init by depth as in https://arxiv.org/abs/1908.11365 -- worked slightly better.
std = std / math.sqrt(2 * (layer_id + 1))
torch.nn.init.trunc_normal_(self.out_proj.weight, std=std, a=-3 * std, b=3 * std)

# initialize norm layers for queries and keys if needed
self.q_norm = (
args.norm_type(
Expand All @@ -137,18 +144,6 @@ def __init__(self, layer_id, args: Params):
else nn.Identity()
)

self.layer_id = layer_id
self.dim = args.dim
self.reset_parameters()

def reset_parameters(self):
# initialize weights by trunc_normal(1/sqrt(fan_in))
std = 1.0 / math.sqrt(self.dim)
torch.nn.init.trunc_normal_(self.in_proj.weight, std=std, a=-3 * std, b=3 * std)
# scale init by depth as in https://arxiv.org/abs/1908.11365 -- worked slightly better.
std = std / math.sqrt(2 * (self.layer_id + 1))
torch.nn.init.trunc_normal_(self.out_proj.weight, std=std, a=-3 * std, b=3 * std)

def forward(self, x: torch.Tensor, is_causal=True):
batchsize, seqlen, _ = x.shape
queries, keys, vals = self.in_proj(x).chunk(3, dim=-1)
Expand Down Expand Up @@ -176,19 +171,17 @@ def __init__(self, layer_id, args: Params):
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = CustomAttn(layer_id, args)
self._ffn_type = args.ffn_type

if args.ffn_type == "swiglu":
# this follows llama / lit llama -- go to multiple of 256
self.hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256)
self.feed_forward = xops.SwiGLU(args.dim, self.hidden_dim, args.dim, bias=False)
hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256)
self.feed_forward = xops.SwiGLU(args.dim, hidden_dim, args.dim, bias=False)
elif args.ffn_type == "gelu":
# Follows mosaic mpt7b, but without a bias.
self.hidden_dim = args.dim * 4
self._ff_w1 = nn.Linear(args.dim, self.hidden_dim, bias=False)
self._ff_w2 = nn.Linear(self.hidden_dim, args.dim, bias=False)
hidden_dim = args.dim * 4
self._ff_w1 = nn.Linear(args.dim, hidden_dim, bias=False)
self._ff_w2 = nn.Linear(hidden_dim, args.dim, bias=False)
self.feed_forward = nn.Sequential(self._ff_w1, nn.GELU(approximate="none"), self._ff_w2)

self.layer_id = layer_id
self.attention_norm = args.norm_type(
args.dim,
Expand All @@ -199,23 +192,21 @@ def __init__(self, layer_id, args: Params):
eps=args.norm_eps,
)
self.attention.seq_len = args.seq_len
self.reset_parameters()

def reset_parameters(self):
if self._ffn_type == "swiglu":
if args.ffn_type == "swiglu":
# initialize weights trunc_normal(1/sqrt(fan_in))
std = 1.0 / math.sqrt(self.dim)
std = 1.0 / math.sqrt(args.dim)
torch.nn.init.trunc_normal_(self.feed_forward.w12.weight, std=std, a=-3 * std, b=3 * std)
# scale init by depth as in https://arxiv.org/abs/1908.11365 -- worked slightly better.
std = 1.0 / math.sqrt(self.hidden_dim)
std = std / math.sqrt(2 * (self.layer_id + 1))
std = 1.0 / math.sqrt(hidden_dim)
std = std / math.sqrt(2 * (layer_id + 1))
torch.nn.init.trunc_normal_(self.feed_forward.w3.weight, std=std, a=-3 * std, b=3 * std)
elif self._ffn_type == "gelu":
std = 1.0 / math.sqrt(self.dim)
elif args.ffn_type == "gelu":
std = 1.0 / math.sqrt(args.dim)
torch.nn.init.trunc_normal_(self._ff_w1.weight, std=std, a=-3 * std, b=3 * std)

std = 1.0 / math.sqrt(self.hidden_dim)
std = std / math.sqrt(2 * (self._layer_id + 1))
std = 1.0 / math.sqrt(hidden_dim)
std = std / math.sqrt(2 * (layer_id + 1))
torch.nn.init.trunc_normal_(self._ff_w2.weight, std=std, a=-3 * std, b=3 * std)

def forward(self, x):
Expand Down Expand Up @@ -257,13 +248,11 @@ def __init__(self, params):
if self.weight_tying:
self.tok_embeddings.weight = self.output.weight
self.grad_checkpointing = False
self.reset_parameters()

def reset_parameters(self):
# initialize weight 1/sqrt(dim)
# this is 1/fan_in for output, as is default, and Maciej Kilian tried another option
# for the embed layer (from RWKV paper) but this was better.
std = 1.0 / math.sqrt(self.params.dim)
std = 1.0 / math.sqrt(params.dim)
torch.nn.init.trunc_normal_(self.output.weight, std=std, a=-3 * std, b=3 * std)
torch.nn.init.trunc_normal_(self.tok_embeddings.weight, std=std, a=-3 * std, b=3 * std)

Expand Down
7 changes: 2 additions & 5 deletions open_lm/positional_embedding/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,13 @@ class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim_model: int, *_, **__):
super().__init__()
# Generate and save the inverse frequency buffer (non trainable)
self.dim_model = dim_model
self.register_buffer("inv_freq", torch.zeros(self.dim_model // 2))
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_model, 2).float() / dim_model))
self.register_buffer("inv_freq", inv_freq)

self._seq_len_cached = None
self._cos_cached = None
self._sin_cached = None

def reset_parameters(self):
self.inv_freq = 1.0 / (10000 ** (torch.arange(0, self.dim_model, 2).float() / self.dim_model))

def _update_cos_sin_tables(self, x, seq_dimension=1):
seq_len = x.shape[seq_dimension]

Expand Down
12 changes: 2 additions & 10 deletions tests/shared.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
from torch import optim
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from open_lm.main import random_seed
from open_lm.model import create_model
Expand Down Expand Up @@ -92,7 +91,7 @@ def __init__(self):
self.ignore_parse_errors = False


def create_train_fixtures(model="open_lm_11m", fsdp=False):
def create_train_fixtures(model="open_lm_11m"):
# Setup data, optimizer, and other basic settings
args = MockTrainArgs(model)

Expand All @@ -105,14 +104,7 @@ def create_train_fixtures(model="open_lm_11m", fsdp=False):

# create base models
random_seed()
if fsdp:
with torch.device("meta"):
model = create_model(args)
model = FSDP(model)
else:
model = create_model(args)
model.reset_parameters()
model = model.to(args.device)
model = create_model(args).to(args.device)

# create dataloader
data = get_data(
Expand Down