Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions code-of-conduct.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Foundation Model Stack Community Code of Conduct

Please refer to [Foundation Model Stack Community Code of Conduct](https://github.com/foundation-model-stack/foundation-model-stack/blob/main/code-of-conduct.md).
18 changes: 18 additions & 0 deletions fms_fsdp/config/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ class train_config:
strip_tokens: str = ""
logical_shards: int = 1024
num_workers: int = 1
doc_cutoff: int = 1_000_000
doc_breakpoint: int = 65_536
filter_exp: int = 2
target_doclen: int = 8192
slice_rate: float = 0.0

# FIM training
psm_rate: float = 0.0
spm_rate: float = 0.0
fim_pre: int = 1
fim_mid: int = 2
fim_suf: int = 3

# fsdp policies
sharding_strategy: str = "hsdp"
Expand Down Expand Up @@ -72,3 +84,9 @@ class train_config:
stage2_prompt_length: int = 64
stage2_batch_size: int = 96
stage2_seq_length: int = 256

# context parallel
cp: bool = False
cp_mamba_impl: str = "allgather" # "allgather" or "serial"
cp_attn_impl: str = "zigzag" # "zigzag" or "ring"
cp_over_world: bool = False
3 changes: 1 addition & 2 deletions fms_fsdp/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def get_model_config(model_variant):
nlayers=24,
hidden_grow_factor=8 / 3,
max_expected_seq_len=4096,
rope_theta=500000.0,
)
elif model_variant == "llama3_70b":
model_config = LLaMAConfig(
Expand Down Expand Up @@ -175,7 +174,7 @@ def get_model_config(model_variant):
"num_heads_kv": 8,
"out_proj_bias": False,
"qkv_proj_bias": False,
"rotary_emb_dim": 64,
"rotary_emb_dim": 0,
},
"rms_norm": True,
"residual_in_fp32": True,
Expand Down
71 changes: 54 additions & 17 deletions fms_fsdp/utils/dataloader_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
AutoHandler,
BufferDataset,
CheckpointDataset,
DocSliceDataset,
FIMDataset,
ParquetHandler,
PreloadBufferDataset,
PreprocessDataset,
SamplingDataset,
ScalableShardDataset,
StreamingDocDataset,
)
from math import ceil


_handler_map = {
Expand Down Expand Up @@ -57,9 +60,9 @@ def __iter__(self):
return torch.utils.data.DataLoader(data, batch_size=cfg.batch_size)


def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
def get_data_loader(cfg, rank, world_size, dp_degree):
"""
Pytorch dataloader for stateful, distributed, and rescalable causal language model (CLM) training.
Pytorch dataloader for stateful, distributed, and rescalable language model training.
Assumes underlying data is sequences of integer values.
...
Args
Expand All @@ -70,12 +73,21 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
Rank of current distributed worker. Used for handling dataset sharding logic.
world_size : int
Number of distributed workers. Used for handling dataset sharding logic.
postprocess : List[Callable]
Any task-specific postprocessing to apply before handing over data. Steps will apply in
the order provided by the user. For CLM training, use postprocess=[causal_lm].
"""

datasets, weights = parse_data_args(cfg.datasets, cfg.weights)
do_cp = False
if dp_degree != world_size:
do_cp = True
cp_worldsize = world_size // dp_degree
cp_rank = rank % cp_worldsize
world_size = dp_degree
rank = rank // cp_worldsize

fim_training = cfg.psm_rate + cfg.spm_rate > 0
if fim_training:
assert cfg.bos_token is None, "No BOS in FIM training. Did you mean fim_pre?"

datasets, weights, cols = parse_data_args(cfg.datasets, cfg.weights, cfg.col_name)

# Base streaming dataset. Returns doc chunks in sequence.
# Implements dataset sampling and rescalability.
Expand All @@ -87,9 +99,9 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
cfg.file_type in _handler_map
), f"File type {cfg.file_type} is not recognized ({list(_handler_map.keys())})"
if cfg.file_type == "hf_parquet" or cfg.file_type == "auto":
filehandler = _handler_map[cfg.file_type](cfg.tokenizer_path, cfg.col_name)
filehandler = _handler_map[cfg.file_type](cfg.tokenizer_path, cols, cfg.doc_cutoff)
else:
filehandler = _handler_map[cfg.file_type]
filehandler = _handler_map[cfg.file_type](cols)
# Base reader layer
data = StreamingDocDataset(
cfg.data_path,
Expand All @@ -99,8 +111,10 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
cfg.eos_token,
bos_token=cfg.bos_token,
strip_tokens=set(droplist),
min_length=3,
min_length=cfg.target_doclen,
seed=cfg.seed,
filter_exp=cfg.filter_exp,
max_consecutive_chunks=ceil(cfg.doc_breakpoint/1024),
)
# Add rescaling/resharding
data = ScalableShardDataset(
Expand All @@ -120,18 +134,40 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
# Wrap above dataset in packing logic to form constant-length lines.
data = BufferDataset(
data,
cfg.seq_length if causal_lm not in postprocess else cfg.seq_length + 1,
cfg.seq_length + 1,
bos_token=cfg.bol_token,
eos_token=cfg.eol_token,
pack_hard=True,
)
# Shuffle outputs in length 10k buffer. Consecutive lines appear 10k steps apart on average.
data = PreloadBufferDataset(data, 10000)

# Apply desired postprocessing steps in sequence
data = PreloadBufferDataset(data, 1000)
# Slice and rearrange docs to force long-context retrieval
if cfg.slice_rate > 0:
data = DocSliceDataset(
data,
cfg.eos_token,
slice_rate=cfg.slice_rate,
)
# Apply FIM transformation if needed
if fim_training:
data = FIMDataset(
data,
cfg.eos_token,
cfg.psm_rate,
cfg.spm_rate,
pre_token=cfg.fim_pre,
mid_token=cfg.fim_mid,
suf_token=cfg.fim_suf,
)
# Transform to tensors
data = PreprocessDataset(data, torch.IntTensor)
for p in postprocess:
data = PreprocessDataset(data, p)
# Apply CLM transformation
data = PreprocessDataset(data, causal_lm)
# Apply CP chunking if using CP
if do_cp:
def chunk(x):
return x[(cp_rank*x.size(0))//cp_worldsize : ((cp_rank+1)*x.size(0))//cp_worldsize]
data = PreprocessDataset(data, lambda x: (chunk(x[0]), chunk(x[1])))

# Enable auto-saving
data = CheckpointDataset(
Expand All @@ -146,7 +182,7 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
)


def parse_data_args(datas, weights):
def parse_data_args(datas, weights, cols):
# Convert csv inputs into corresponding lists of values
def splitstrip(x):
if isinstance(x, str):
Expand All @@ -160,4 +196,5 @@ def splitstrip(x):

datas = splitstrip(datas)
weights = [float(x) for x in splitstrip(weights)]
return datas, weights
cols = splitstrip(cols)
return datas, weights, cols
Loading
Loading