Skip to content

Commit 1171cf2

Browse files
authored
Mamba tiktoken extend (#150)
Pull in merged mamba-tiktoken / mamba-500k-cp-fullslice
1 parent 503da7e commit 1171cf2

File tree

7 files changed

+880
-124
lines changed

7 files changed

+880
-124
lines changed

fms_fsdp/config/training.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,18 @@ class train_config:
2626
strip_tokens: str = ""
2727
logical_shards: int = 1024
2828
num_workers: int = 1
29+
doc_cutoff: int = 1_000_000
30+
doc_breakpoint: int = 65_536
31+
filter_exp: int = 2
32+
target_doclen: int = 8192
33+
slice_rate: float = 0.0
34+
35+
# FIM training
36+
psm_rate: float = 0.0
37+
spm_rate: float = 0.0
38+
fim_pre: int = 1
39+
fim_mid: int = 2
40+
fim_suf: int = 3
2941

3042
# fsdp policies
3143
sharding_strategy: str = "hsdp"
@@ -72,3 +84,9 @@ class train_config:
7284
stage2_prompt_length: int = 64
7385
stage2_batch_size: int = 96
7486
stage2_seq_length: int = 256
87+
88+
# context parallel
89+
cp: bool = False
90+
cp_mamba_impl: str = "allgather" # "allgather" or "serial"
91+
cp_attn_impl: str = "zigzag" # "zigzag" or "ring"
92+
cp_over_world: bool = False

fms_fsdp/utils/config_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ def get_model_config(model_variant):
126126
nlayers=24,
127127
hidden_grow_factor=8 / 3,
128128
max_expected_seq_len=4096,
129-
rope_theta=500000.0,
130129
)
131130
elif model_variant == "llama3_70b":
132131
model_config = LLaMAConfig(
@@ -175,7 +174,7 @@ def get_model_config(model_variant):
175174
"num_heads_kv": 8,
176175
"out_proj_bias": False,
177176
"qkv_proj_bias": False,
178-
"rotary_emb_dim": 64,
177+
"rotary_emb_dim": 0,
179178
},
180179
"rms_norm": True,
181180
"residual_in_fp32": True,

fms_fsdp/utils/dataloader_utils.py

Lines changed: 54 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55
AutoHandler,
66
BufferDataset,
77
CheckpointDataset,
8+
DocSliceDataset,
9+
FIMDataset,
810
ParquetHandler,
911
PreloadBufferDataset,
1012
PreprocessDataset,
1113
SamplingDataset,
1214
ScalableShardDataset,
1315
StreamingDocDataset,
1416
)
17+
from math import ceil
1518

1619

1720
_handler_map = {
@@ -57,9 +60,9 @@ def __iter__(self):
5760
return torch.utils.data.DataLoader(data, batch_size=cfg.batch_size)
5861

5962

60-
def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
63+
def get_data_loader(cfg, rank, world_size, dp_degree):
6164
"""
62-
Pytorch dataloader for stateful, distributed, and rescalable causal language model (CLM) training.
65+
Pytorch dataloader for stateful, distributed, and rescalable language model training.
6366
Assumes underlying data is sequences of integer values.
6467
...
6568
Args
@@ -70,12 +73,21 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
7073
Rank of current distributed worker. Used for handling dataset sharding logic.
7174
world_size : int
7275
Number of distributed workers. Used for handling dataset sharding logic.
73-
postprocess : List[Callable]
74-
Any task-specific postprocessing to apply before handing over data. Steps will apply in
75-
the order provided by the user. For CLM training, use postprocess=[causal_lm].
7676
"""
7777

78-
datasets, weights = parse_data_args(cfg.datasets, cfg.weights)
78+
do_cp = False
79+
if dp_degree != world_size:
80+
do_cp = True
81+
cp_worldsize = world_size // dp_degree
82+
cp_rank = rank % cp_worldsize
83+
world_size = dp_degree
84+
rank = rank // cp_worldsize
85+
86+
fim_training = cfg.psm_rate + cfg.spm_rate > 0
87+
if fim_training:
88+
assert cfg.bos_token is None, "No BOS in FIM training. Did you mean fim_pre?"
89+
90+
datasets, weights, cols = parse_data_args(cfg.datasets, cfg.weights, cfg.col_name)
7991

8092
# Base streaming dataset. Returns doc chunks in sequence.
8193
# Implements dataset sampling and rescalability.
@@ -87,9 +99,9 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
8799
cfg.file_type in _handler_map
88100
), f"File type {cfg.file_type} is not recognized ({list(_handler_map.keys())})"
89101
if cfg.file_type == "hf_parquet" or cfg.file_type == "auto":
90-
filehandler = _handler_map[cfg.file_type](cfg.tokenizer_path, cfg.col_name)
102+
filehandler = _handler_map[cfg.file_type](cfg.tokenizer_path, cols, cfg.doc_cutoff)
91103
else:
92-
filehandler = _handler_map[cfg.file_type]
104+
filehandler = _handler_map[cfg.file_type](cols)
93105
# Base reader layer
94106
data = StreamingDocDataset(
95107
cfg.data_path,
@@ -99,8 +111,10 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
99111
cfg.eos_token,
100112
bos_token=cfg.bos_token,
101113
strip_tokens=set(droplist),
102-
min_length=3,
114+
min_length=cfg.target_doclen,
103115
seed=cfg.seed,
116+
filter_exp=cfg.filter_exp,
117+
max_consecutive_chunks=ceil(cfg.doc_breakpoint/1024),
104118
)
105119
# Add rescaling/resharding
106120
data = ScalableShardDataset(
@@ -120,18 +134,40 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
120134
# Wrap above dataset in packing logic to form constant-length lines.
121135
data = BufferDataset(
122136
data,
123-
cfg.seq_length if causal_lm not in postprocess else cfg.seq_length + 1,
137+
cfg.seq_length + 1,
124138
bos_token=cfg.bol_token,
125139
eos_token=cfg.eol_token,
126140
pack_hard=True,
127141
)
128142
# Shuffle outputs in length 10k buffer. Consecutive lines appear 10k steps apart on average.
129-
data = PreloadBufferDataset(data, 10000)
130-
131-
# Apply desired postprocessing steps in sequence
143+
data = PreloadBufferDataset(data, 1000)
144+
# Slice and rearrange docs to force long-context retrieval
145+
if cfg.slice_rate > 0:
146+
data = DocSliceDataset(
147+
data,
148+
cfg.eos_token,
149+
slice_rate=cfg.slice_rate,
150+
)
151+
# Apply FIM transformation if needed
152+
if fim_training:
153+
data = FIMDataset(
154+
data,
155+
cfg.eos_token,
156+
cfg.psm_rate,
157+
cfg.spm_rate,
158+
pre_token=cfg.fim_pre,
159+
mid_token=cfg.fim_mid,
160+
suf_token=cfg.fim_suf,
161+
)
162+
# Transform to tensors
132163
data = PreprocessDataset(data, torch.IntTensor)
133-
for p in postprocess:
134-
data = PreprocessDataset(data, p)
164+
# Apply CLM transformation
165+
data = PreprocessDataset(data, causal_lm)
166+
# Apply CP chunking if using CP
167+
if do_cp:
168+
def chunk(x):
169+
return x[(cp_rank*x.size(0))//cp_worldsize : ((cp_rank+1)*x.size(0))//cp_worldsize]
170+
data = PreprocessDataset(data, lambda x: (chunk(x[0]), chunk(x[1])))
135171

136172
# Enable auto-saving
137173
data = CheckpointDataset(
@@ -146,7 +182,7 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
146182
)
147183

148184

149-
def parse_data_args(datas, weights):
185+
def parse_data_args(datas, weights, cols):
150186
# Convert csv inputs into corresponding lists of values
151187
def splitstrip(x):
152188
if isinstance(x, str):
@@ -160,4 +196,5 @@ def splitstrip(x):
160196

161197
datas = splitstrip(datas)
162198
weights = [float(x) for x in splitstrip(weights)]
163-
return datas, weights
199+
cols = splitstrip(cols)
200+
return datas, weights, cols

0 commit comments

Comments
 (0)