Skip to content

Commit c48bf1b

Browse files
daviswerlchu6divya-kumari32
authored
Mamba new data fixes2 (#149)
* make mamba * add quick debug * add quick debug * revert debug verbosity * Learning rate scheduler changed (Constant) * Add AutoHandler * Add Auto cfg option for AutoHAndler * Len gets called before open * path/filepath typo fix * Partitioning fix from mup-search * Cosine 0.01 decay * Warmup interval change * Schedule change * Constant schedule * LR schedule change (cool down and constant lr) * Update dataset_utils.py Added a check for length of doc * LR schedule change (Warmup + constant) * Update dataset_utils.py * Cosine schedule * For constant lr 1.5e5 * Schedule change * Schedule change * Final singlefile checkpoint saves one folder up (#127) * Final singlefile checkpoint saves one folder up Signed-off-by: Davis Wertheimer <[email protected]> * save file under new pth subfolder Signed-off-by: Davis Wertheimer <[email protected]> * Repath for easier consumption/conversion Signed-off-by: Davis Wertheimer <[email protected]> --------- Signed-off-by: Davis Wertheimer <[email protected]> * Added cool down * length of doc check * splitstrip cols and pass to fhandler * fhandler col_names support * Warmup for annealing * Debugging * Debugging II * Empty shard check * Added constant lr schedule with warmup * added print for lenght of doc * added print for lenght of doc II * Update dataset_utils.py * Update dataset_utils.py * Update dataset_utils.py * Update dataset_utils.py * Adding print for debug * Revert "Pulled from data-fixes branch" This reverts commit ac5194b, reversing changes made to 1b50708. reverting changes * Revert all changes made after March 6 (before merge) * Revert all changes made after March 6 (before merge) * removed print --------- Signed-off-by: Davis Wertheimer <[email protected]> Co-authored-by: Linsong Chu <[email protected]> Co-authored-by: divykum2 <[email protected]> Co-authored-by: divya-kumari32 <[email protected]>
1 parent 6d751e5 commit c48bf1b

File tree

8 files changed

+144
-400
lines changed

8 files changed

+144
-400
lines changed

fms_fsdp/config/training.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,3 @@ class train_config:
7676
stage2_prompt_length: int = 64
7777
stage2_batch_size: int = 96
7878
stage2_seq_length: int = 256
79-
80-
# FIM training
81-
psm_rate: float = 0.0
82-
spm_rate: float = 0.0
83-
fim_pre: int = 1
84-
fim_mid: int = 2
85-
fim_suf: int = 3

fms_fsdp/utils/checkpointing_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,9 @@ def save_single_file(
324324
):
325325
# Note: metadata kwargs cannot contain any of:
326326
# (step, model)
327-
save_name = os.path.join(self.ckp_path, "step_" + str(step) + "_ckp.pth")
327+
pth_path = os.path.join(self.ckp_path[:-12], "pth", "step_" + str(step))
328+
os.makedirs(pth_path, exist_ok=True)
329+
save_name = os.path.join(pth_path, "consolidated.00.pth")
328330
save_time = time.time()
329331
with FSDP.state_dict_type(
330332
model,

fms_fsdp/utils/dataloader_utils.py

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
AutoHandler,
77
BufferDataset,
88
CheckpointDataset,
9-
FIMDataset,
109
ParquetHandler,
1110
PreloadBufferDataset,
1211
PreprocessDataset,
@@ -59,9 +58,9 @@ def __iter__(self):
5958
return torch.utils.data.DataLoader(data, batch_size=cfg.batch_size)
6059

6160

62-
def get_data_loader(cfg, rank, world_size):
61+
def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
6362
"""
64-
Pytorch dataloader for stateful, distributed, and rescalable language model training.
63+
Pytorch dataloader for stateful, distributed, and rescalable causal language model (CLM) training.
6564
Assumes underlying data is sequences of integer values.
6665
...
6766
Args
@@ -72,12 +71,11 @@ def get_data_loader(cfg, rank, world_size):
7271
Rank of current distributed worker. Used for handling dataset sharding logic.
7372
world_size : int
7473
Number of distributed workers. Used for handling dataset sharding logic.
74+
postprocess : List[Callable]
75+
Any task-specific postprocessing to apply before handing over data. Steps will apply in
76+
the order provided by the user. For CLM training, use postprocess=[causal_lm].
7577
"""
7678

77-
fim_training = cfg.psm_rate + cfg.spm_rate > 0
78-
if fim_training:
79-
assert cfg.bos_token is None, "No BOS in FIM training. Did you mean fim_pre?"
80-
8179
datasets, weights, cols = parse_data_args(cfg.datasets, cfg.weights, cfg.col_name)
8280

8381
# Base streaming dataset. Returns doc chunks in sequence.
@@ -94,7 +92,7 @@ def get_data_loader(cfg, rank, world_size):
9492
cfg.tokenizer_path, cols, cfg.doc_cutoff
9593
)
9694
else:
97-
filehandler = _handler_map[cfg.file_type](cols)
95+
filehandler = _handler_map[cfg.file_type, cols]
9896
# Base reader layer
9997
data = StreamingDocDataset(
10098
cfg.data_path,
@@ -124,34 +122,20 @@ def get_data_loader(cfg, rank, world_size):
124122
verbose=(rank == 0),
125123
)
126124
# Wrap above dataset in packing logic to form constant-length lines.
127-
# Increment seq len to counteract CLM's one token removal.
128125
data = BufferDataset(
129126
data,
130-
cfg.seq_length + 1,
127+
cfg.seq_length if causal_lm not in postprocess else cfg.seq_length + 1,
131128
bos_token=cfg.bol_token,
132129
eos_token=cfg.eol_token,
133130
pack_hard=True,
134131
)
135132
# Shuffle outputs in length 10k buffer. Consecutive lines appear 10k steps apart on average.
136133
data = PreloadBufferDataset(data, 10000)
137134

138-
# Apply FIM transformation if needed
139-
if fim_training:
140-
data = FIMDataset(
141-
data,
142-
cfg.eos_token,
143-
cfg.psm_rate,
144-
cfg.spm_rate,
145-
pre_token=cfg.fim_pre,
146-
mid_token=cfg.fim_mid,
147-
suf_token=cfg.fim_suf,
148-
)
149-
150-
# Transform to tensors
135+
# Apply desired postprocessing steps in sequence
151136
data = PreprocessDataset(data, torch.IntTensor)
152-
153-
# Apply CLM transformation
154-
data = PreprocessDataset(data, causal_lm)
137+
for p in postprocess:
138+
data = PreprocessDataset(data, p)
155139

156140
# Enable auto-saving
157141
data = CheckpointDataset(
@@ -181,4 +165,4 @@ def splitstrip(x):
181165
datas = splitstrip(datas)
182166
weights = [float(x) for x in splitstrip(weights)]
183167
cols = splitstrip(cols)
184-
return datas, weights, cols
168+
return datas, weights, cols

0 commit comments

Comments
 (0)