@@ -343,9 +343,9 @@ class ArrowHandler(_ShardFileHandler):
343343 Non-standard data format, though.
344344 """
345345
346- def __init__ (self , col_names : List [str ] = ["tokens" ]):
346+ def __init__ (self , col_names : List [str ] = ["text" , "contents" , " tokens" ]):
347347 self .col_names = col_names
348-
348+
349349 def is_legal (self , filepath : str ):
350350 return "arrow" in os .path .splitext (filepath )[1 ]
351351
@@ -356,14 +356,18 @@ def length(self, path: str):
356356 return self .open (path ).num_record_batches
357357
358358 def get (self , reader : pa .RecordBatchFileReader , index : int , drop_tokens : Set ):
359+ assert (
360+ index < reader .num_record_batches
361+ ), f"Illegal index { index } in set of { reader .num_record_batches } documents"
359362 frame = reader .get_batch (index )
360-
361363 doc = None
362364 for name in self .col_names :
363365 if name in frame .column_names :
364366 doc = frame [name ]
365367 break
366- assert doc is not None , f"None of column names { self .col_names } found in file headers { frame .column_names } "
368+ assert (
369+ doc is not None
370+ ), f"None of column names { self .col_names } found in file headers { frame .column_names } "
367371 if len (doc ) > 0 and doc [0 ].as_py () in drop_tokens :
368372 doc = doc .slice (1 , len (doc ) - 1 )
369373 # Recheck len for edge case where doc=[eos]
@@ -378,14 +382,20 @@ def slice(self, doc: pa.UInt32Array, index: int, n_pull: int) -> List:
378382class ParquetHandler (_ShardFileHandler ):
379383 """
380384 Reader for indexable parquet shard files, common in HF datasets.
381- Here we assume reasonably small shard files (<5Gb) and documents (<100k tokens) ,
385+ Here we assume reasonably small shard files (<5Gb) and truncate docs to max_doclen characters ,
382386 as we rely on parquet/pandas for efficient file reading, and tokenize entire documents
383387 before getting/slicing. However, this is a standard and widely-used data format.
384388 """
385389
386- def __init__ (self , tokenizer_path : str , col_names : List [str ] = ["text" ]):
390+ def __init__ (
391+ self ,
392+ tokenizer_path : str ,
393+ col_names : List [str ] = ["text" , "contents" , "tokens" ],
394+ max_doclen : int = 1_000_000 ,
395+ ):
387396 self .tokenizer = AutoTokenizer .from_pretrained (tokenizer_path )
388397 self .col_names = col_names
398+ self .max_doclen = max_doclen
389399
390400 def is_legal (self , filepath : str ):
391401 return "parquet" in os .path .splitext (filepath )[1 ]
@@ -397,18 +407,19 @@ def open(self, path: str):
397407 if name in names :
398408 match = name
399409 break
400- assert match is not None , f"None of column names { self .col_names } found in file headers { names } "
410+ assert (
411+ match is not None
412+ ), f"None of column names { self .col_names } found in file headers { names } "
401413 return pq .read_pandas (path , columns = [match ], partitioning = None )[match ]
402414
403415 def length (self , path : str ):
404416 return pq .read_metadata (path ).num_rows
405417
406418 def get (self , reader , index : int , drop_tokens : Set ):
407-
408- document_str = str (reader [index ])
409-
410- doc = self .tokenizer (str (reader [index ]))["input_ids" ]
411-
419+ assert (
420+ index < reader .length ()
421+ ), f"Illegal index { index } in set of { reader .length ()} documents"
422+ doc = self .tokenizer (str (reader [index ])[: self .max_doclen ])["input_ids" ]
412423 if len (doc ) > 0 and doc [0 ] in drop_tokens :
413424 doc = doc [1 :]
414425 # Recheck len for edge case where doc=[eos]
@@ -421,8 +432,13 @@ def slice(self, doc: List, index: int, n_pull: int) -> List:
421432
422433
423434class AutoHandler (_ShardFileHandler ):
424- def __init__ (self , tokenizer_path : str , col_names : List [str ] = ["text" , "contents" , "tokens" ]):
425- self .PHandler = ParquetHandler (tokenizer_path , col_names )
435+ def __init__ (
436+ self ,
437+ tokenizer_path : str ,
438+ col_names : List [str ] = ["text" , "contents" , "tokens" ],
439+ max_doclen : int = 1_000_000 ,
440+ ):
441+ self .PHandler = ParquetHandler (tokenizer_path , col_names , max_doclen )
426442 self .AHandler = ArrowHandler (col_names )
427443 self .current = _ShardFileHandler ()
428444
@@ -979,10 +995,10 @@ class StreamingDocDataset(_StatefulDataset):
979995 Documents below this length are skipped
980996 max_chunksize : int
981997 Maximum sequence length to return. Break long docs into chunks of this size or shorter.
998+ max_consecutive_chunks : int
999+ Number of doc chunks to emit before manually inserting EOS and resuming later.
9821000 verbose : bool
9831001 Track setup progress?
984- shuffle : bool
985- Shuffle shard file and document orders? (Disable for simple testing)
9861002 """
9871003
9881004 def __init__ (
@@ -1095,12 +1111,11 @@ def setup(self):
10951111 for row in reader :
10961112 fullpath = row ["dataset/filename" ]
10971113 prefix = fullpath .find (dataset + "/" )
1098- if prefix > 0 :
1114+ if prefix >= 0 :
10991115 key = fullpath [prefix + len (dataset ) + 1 :]
11001116 sizes [key ] = int (row ["size" ])
11011117 shard_sizes = [sizes [shard ] for shard in shards ]
11021118 else :
1103- # Count file does not exist, touch every owned file for length
11041119 shard_sizes = [
11051120 os .path .getsize (os .path .join (datapath , shard )) for shard in shards
11061121 ]
@@ -1125,7 +1140,7 @@ def setup(self):
11251140 reader = csv .DictReader (csvfile )
11261141 for row in reader :
11271142 fullpath = row ["dataset/filename" ]
1128- prefix = fullpath .find (dataset + "/" )
1143+ prefix = fullpath .find (dataset )
11291144 if prefix >= 0 :
11301145 key = fullpath [prefix + len (dataset ) + 1 :]
11311146 doc_counts [key ] = int (row ["documents" ])
@@ -1141,10 +1156,13 @@ def setup(self):
11411156 doccount = 0
11421157 for shard in shardset :
11431158 ndocs = doc_counts [shard ]
1144- doc_start = int (ndocs * shardset [shard ][0 ])
1145- doc_end = max (doc_start , int (ndocs * shardset [shard ][1 ]) - 1 ) # inclusive upper bound
1146- self .docset .append ([shard , doc_start , doc_end ])
1147- doccount += doc_end - doc_start + 1
1159+ if ndocs > 0 :
1160+ doc_start = int (ndocs * shardset [shard ][0 ])
1161+ doc_end = max (
1162+ doc_start , int (ndocs * shardset [shard ][1 ]) - 1
1163+ ) # inclusive upper bound
1164+ self .docset .append ([shard , doc_start , doc_end ])
1165+ doccount += doc_end - doc_start + 1
11481166 self ._len = doccount
11491167
11501168 if self .verbose :
@@ -1253,10 +1271,8 @@ def __iter__(self):
12531271 doclcg = self ._random_map_docid (docrange )
12541272 docid = doclcg + mindoc
12551273 doc = self .filehandler .get (reader , docid , self .drop )
1256- if len (doc ) == 0 :
1257- continue
12581274 doclen = len (doc ) + 1 if self .bos is None else len (doc ) + 2
1259- if doclen >= self .min_length :
1275+ if len ( doc ) > 0 and doclen >= self .min_length :
12601276 n_chunks = math .ceil (doclen / self .chunksize )
12611277 for j in range (n_chunks ):
12621278 if i == 0 and j < residual_chunks :
@@ -1283,18 +1299,18 @@ def __iter__(self):
12831299 newpath = os .path .join (self .datapath , shardid )
12841300 path , reader = self ._get_reader (path , newpath , reader )
12851301 doc = self .filehandler .get (reader , docid , self .drop )
1286- if len (doc ) == 0 :
1287- continue
12881302 doclen = len (doc ) + 1 if self .bos is None else len (doc ) + 2
1289- if doclen >= self .min_length :
1303+ if len ( doc ) > 0 and doclen >= self .min_length :
12901304 n_chunks = math .ceil (doclen / self .chunksize )
12911305 for j in range (residual_chunks ):
12921306 self .chunk_index = j
12931307 self .has_yielded = True
12941308 yield self ._construct_chunk (j , doc , n_chunks )
12951309
12961310 # Check that epoch was non-empty
1297- assert self .has_yielded , f"Empty logical shard detected: { self .dataset , self .docset } "
1311+ assert (
1312+ self .has_yielded
1313+ ), f"Empty logical shard detected: { self .dataset , self .docset } "
12981314
12991315 def load_state_dict (self , state_dicts , sharded_input = False ):
13001316 self .setup ()
@@ -1367,12 +1383,12 @@ def setup(self):
13671383 if not self .is_setup :
13681384 _StatefulDataset .setup (self )
13691385 n_logical_shards = self .total_shards
1386+ assert (
1387+ n_logical_shards % self .worldsize == 0
1388+ ), f"Total workers { self .worldsize } must divide n_logical_shards { n_logical_shards } evenly"
13701389 logicals = list (range (n_logical_shards ))
13711390 self .logicals_owned = _shard_partition (logicals , self .rank , self .worldsize )
13721391 self .n_logicals = n_logical_shards // self .worldsize
1373- assert (
1374- len (self .logicals_owned ) == self .n_logicals
1375- ), "(world size * num workers) does not divide logical shards evenly"
13761392
13771393 # Build logical shards
13781394 for i in range (self .n_logicals ):
@@ -1403,6 +1419,7 @@ def __iter__(self):
14031419 # (i.e. [1,1,0,0,0,0] into [1,1,0] [0,0,0] )
14041420 if sum (self .n_docs_remaining ) == 0 :
14051421 self .n_docs_remaining = [d ._len for d in self .data ]
1422+ self .generator .manual_seed (self .rank )
14061423 while True :
14071424 # Sample logical shard (or load from ckp)
14081425 if self .current_reader is not None :
@@ -1499,6 +1516,10 @@ def __init__(
14991516 ]
15001517 )
15011518 assert len (self .datasets ) > 0 , "You must specify at least one dataset"
1519+ for d in datasets :
1520+ assert os .path .exists (
1521+ os .path .join (datapath , d )
1522+ ), f"Invalid subdataset path: { os .path .join (datapath , d )} "
15021523
15031524 if weights is not None :
15041525 assert len (weights ) == len (
0 commit comments