Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/bert/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def main(cfg: DictConfig,
# Get batch size info
cfg = update_batch_size_info(cfg)

# Read FSDP Config as a dict
fsdp_config = cfg.get('fsdp_config', None)
fsdp_config = om.to_container(fsdp_config,
resolve=True) if fsdp_config else None

# Build Model
print('Initializing model...')
model = build_model(cfg.model)
Expand Down Expand Up @@ -112,6 +117,7 @@ def main(cfg: DictConfig,
device=cfg.get('device', None),
device_train_microbatch_size=cfg.get('device_train_microbatch_size',
'auto'),
fsdp_config=fsdp_config, # type: ignore
save_folder=cfg.get('save_folder', None),
save_interval=cfg.get('save_interval', '1000ba'),
save_num_checkpoints_to_keep=cfg.get('save_num_checkpoints_to_keep',
Expand Down
65 changes: 65 additions & 0 deletions examples/bert/src/bert_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import logging
import math
import warnings
from functools import partial
from typing import List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -545,6 +546,70 @@ def forward(
all_encoder_layers.append(hidden_states)
return all_encoder_layers

# Param Initialization, needed for device='meta' fast initialization
def param_init_fn(self, module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jacobfulano a few comments:

  • param_init_fn is not needed unless you plan to make your models meta initializable. For a first draft, I would remove entirely
  • fsdp_wrap_fn and activation_checkpointing_fn need to be defined on a root module underneath the ComposerModel. So they have to be defined for say, BertForMaskedLM. I would recommend moving the definition there, and I think it will work!

init_fn = partial(torch.nn.init.normal_,
mean=0.0,
std=self.cfg.init_std)
# Linear
if isinstance(module, nn.Linear):
init_fn(module.weight)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)

if getattr(module, '_is_residual', False):
module.weight.data.normal_(
mean=0.0,
std=(self.cfg.init_std / math.sqrt(2 * self.cfg.n_layers)))

# Embedding
if isinstance(module, nn.Embedding):
init_fn(module.weight)

# LayerNorm
if isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)

# torch's MultiheadAttention
if isinstance(module, nn.MultiheadAttention):
if module._qkv_same_embed_dim:
assert module.in_proj_weight is not None
assert module.q_proj_weight is None and module.k_proj_weight is None and module.v_proj_weight is None
init_fn(module.in_proj_weight)
else:
assert module.q_proj_weight is not None and module.k_proj_weight is not None and module.v_proj_weight is not None
assert module.in_proj_weight is None
init_fn(module.q_proj_weight)
init_fn(module.k_proj_weight)
init_fn(module.v_proj_weight)

# bias
if module.in_proj_bias is not None:
torch.nn.init.zeros_(module.in_proj_bias)
if module.bias_k is not None:
torch.nn.init.zeros_(module.bias_k)
if module.bias_v is not None:
torch.nn.init.zeros_(module.bias_v)

# out proj
if module.out_proj._is_residual:
module.out_proj.weight.data.normal_(
mean=0.0,
std=(self.cfg.init_std / math.sqrt(2 * self.cfg.n_layers)))
else:
init_fn(module.out_proj.weight)
if module.out_proj.bias is not None:
torch.nn.init.zeros_(module.out_proj.bias)

# FSDP Wrap function
def fsdp_wrap_fn(self, module):
return isinstance(module, BertLayer)

# Activation Checkpointing
def activation_checkpointing_fn(self, module):
return isinstance(module, BertLayer)


class BertPooler(nn.Module):

Expand Down