55 AutoHandler ,
66 BufferDataset ,
77 CheckpointDataset ,
8+ DocSliceDataset ,
9+ FIMDataset ,
810 ParquetHandler ,
911 PreloadBufferDataset ,
1012 PreprocessDataset ,
1113 SamplingDataset ,
1214 ScalableShardDataset ,
1315 StreamingDocDataset ,
1416)
17+ from math import ceil
1518
1619
1720_handler_map = {
@@ -57,9 +60,9 @@ def __iter__(self):
5760 return torch .utils .data .DataLoader (data , batch_size = cfg .batch_size )
5861
5962
60- def get_data_loader (cfg , rank , world_size , postprocess = [ causal_lm ] ):
63+ def get_data_loader (cfg , rank , world_size , dp_degree ):
6164 """
62- Pytorch dataloader for stateful, distributed, and rescalable causal language model (CLM) training.
65+ Pytorch dataloader for stateful, distributed, and rescalable language model training.
6366 Assumes underlying data is sequences of integer values.
6467 ...
6568 Args
@@ -70,12 +73,21 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
7073 Rank of current distributed worker. Used for handling dataset sharding logic.
7174 world_size : int
7275 Number of distributed workers. Used for handling dataset sharding logic.
73- postprocess : List[Callable]
74- Any task-specific postprocessing to apply before handing over data. Steps will apply in
75- the order provided by the user. For CLM training, use postprocess=[causal_lm].
7676 """
7777
78- datasets , weights = parse_data_args (cfg .datasets , cfg .weights )
78+ do_cp = False
79+ if dp_degree != world_size :
80+ do_cp = True
81+ cp_worldsize = world_size // dp_degree
82+ cp_rank = rank % cp_worldsize
83+ world_size = dp_degree
84+ rank = rank // cp_worldsize
85+
86+ fim_training = cfg .psm_rate + cfg .spm_rate > 0
87+ if fim_training :
88+ assert cfg .bos_token is None , "No BOS in FIM training. Did you mean fim_pre?"
89+
90+ datasets , weights , cols = parse_data_args (cfg .datasets , cfg .weights , cfg .col_name )
7991
8092 # Base streaming dataset. Returns doc chunks in sequence.
8193 # Implements dataset sampling and rescalability.
@@ -87,9 +99,9 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
8799 cfg .file_type in _handler_map
88100 ), f"File type { cfg .file_type } is not recognized ({ list (_handler_map .keys ())} )"
89101 if cfg .file_type == "hf_parquet" or cfg .file_type == "auto" :
90- filehandler = _handler_map [cfg .file_type ](cfg .tokenizer_path , cfg .col_name )
102+ filehandler = _handler_map [cfg .file_type ](cfg .tokenizer_path , cols , cfg .doc_cutoff )
91103 else :
92- filehandler = _handler_map [cfg .file_type ]
104+ filehandler = _handler_map [cfg .file_type ]( cols )
93105 # Base reader layer
94106 data = StreamingDocDataset (
95107 cfg .data_path ,
@@ -99,8 +111,10 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
99111 cfg .eos_token ,
100112 bos_token = cfg .bos_token ,
101113 strip_tokens = set (droplist ),
102- min_length = 3 ,
114+ min_length = cfg . target_doclen ,
103115 seed = cfg .seed ,
116+ filter_exp = cfg .filter_exp ,
117+ max_consecutive_chunks = ceil (cfg .doc_breakpoint / 1024 ),
104118 )
105119 # Add rescaling/resharding
106120 data = ScalableShardDataset (
@@ -120,18 +134,40 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
120134 # Wrap above dataset in packing logic to form constant-length lines.
121135 data = BufferDataset (
122136 data ,
123- cfg .seq_length if causal_lm not in postprocess else cfg . seq_length + 1 ,
137+ cfg .seq_length + 1 ,
124138 bos_token = cfg .bol_token ,
125139 eos_token = cfg .eol_token ,
126140 pack_hard = True ,
127141 )
128142 # Shuffle outputs in length 10k buffer. Consecutive lines appear 10k steps apart on average.
129- data = PreloadBufferDataset (data , 10000 )
130-
131- # Apply desired postprocessing steps in sequence
143+ data = PreloadBufferDataset (data , 1000 )
144+ # Slice and rearrange docs to force long-context retrieval
145+ if cfg .slice_rate > 0 :
146+ data = DocSliceDataset (
147+ data ,
148+ cfg .eos_token ,
149+ slice_rate = cfg .slice_rate ,
150+ )
151+ # Apply FIM transformation if needed
152+ if fim_training :
153+ data = FIMDataset (
154+ data ,
155+ cfg .eos_token ,
156+ cfg .psm_rate ,
157+ cfg .spm_rate ,
158+ pre_token = cfg .fim_pre ,
159+ mid_token = cfg .fim_mid ,
160+ suf_token = cfg .fim_suf ,
161+ )
162+ # Transform to tensors
132163 data = PreprocessDataset (data , torch .IntTensor )
133- for p in postprocess :
134- data = PreprocessDataset (data , p )
164+ # Apply CLM transformation
165+ data = PreprocessDataset (data , causal_lm )
166+ # Apply CP chunking if using CP
167+ if do_cp :
168+ def chunk (x ):
169+ return x [(cp_rank * x .size (0 ))// cp_worldsize : ((cp_rank + 1 )* x .size (0 ))// cp_worldsize ]
170+ data = PreprocessDataset (data , lambda x : (chunk (x [0 ]), chunk (x [1 ])))
135171
136172 # Enable auto-saving
137173 data = CheckpointDataset (
@@ -146,7 +182,7 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
146182 )
147183
148184
149- def parse_data_args (datas , weights ):
185+ def parse_data_args (datas , weights , cols ):
150186 # Convert csv inputs into corresponding lists of values
151187 def splitstrip (x ):
152188 if isinstance (x , str ):
@@ -160,4 +196,5 @@ def splitstrip(x):
160196
161197 datas = splitstrip (datas )
162198 weights = [float (x ) for x in splitstrip (weights )]
163- return datas , weights
199+ cols = splitstrip (cols )
200+ return datas , weights , cols
0 commit comments