Skip to content

Commit c967bb5

Browse files
authored
Pull in datafixes
See PR #144
1 parent 1feca7c commit c967bb5

File tree

3 files changed

+66
-39
lines changed

3 files changed

+66
-39
lines changed

fms_fsdp/config/training.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ class train_config:
1515
file_type: str = "arrow"
1616
col_name: str = "tokens"
1717
tokenizer_path: str = "/fsx/tokenizer"
18-
datasets: str = "lang=en/dataset=commoncrawl,lang=en/dataset=webhose,lang=en/dataset=github_clean,lang=de/dataset=wikipedia,lang=es/dataset=wikipedia,lang=fr/dataset=wikipedia,lang=ja/dataset=wikipedia,lang=pt/dataset=wikipedia,lang=en/dataset=wikimedia,lang=en/dataset=uspto,lang=en/dataset=pubmedcentral,lang=en/dataset=arxiv,lang=en/dataset=stackexchange"
18+
datasets: str = (
19+
"lang=en/dataset=commoncrawl,lang=en/dataset=webhose,lang=en/dataset=github_clean,lang=de/dataset=wikipedia,lang=es/dataset=wikipedia,lang=fr/dataset=wikipedia,lang=ja/dataset=wikipedia,lang=pt/dataset=wikipedia,lang=en/dataset=wikimedia,lang=en/dataset=uspto,lang=en/dataset=pubmedcentral,lang=en/dataset=arxiv,lang=en/dataset=stackexchange"
20+
)
1921
weights: str = "7725,500,550,28,17,22,25,8,100,500,175,250,100"
2022
seq_length: int = 4096
2123
vocab_size: int = 32000
@@ -26,6 +28,7 @@ class train_config:
2628
strip_tokens: str = ""
2729
logical_shards: int = 1024
2830
num_workers: int = 1
31+
doc_cutoff: int = 1_000_000
2932

3033
# fsdp policies
3134
sharding_strategy: str = "hsdp"
@@ -74,7 +77,6 @@ class train_config:
7477
stage2_seq_length: int = 256
7578

7679
# FIM training
77-
fim_training: bool = False
7880
psm_rate: float = 0.0
7981
spm_rate: float = 0.0
8082
fim_pre: int = 1

fms_fsdp/utils/dataloader_utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ def get_data_loader(cfg, rank, world_size):
7272
world_size : int
7373
Number of distributed workers. Used for handling dataset sharding logic.
7474
"""
75-
if cfg.fim_training:
75+
76+
fim_training = cfg.psm_rate + cfg.spm_rate > 0
77+
if fim_training:
7678
assert cfg.bos_token is None, "No BOS in FIM training. Did you mean fim_pre?"
7779

7880
datasets, weights, cols = parse_data_args(cfg.datasets, cfg.weights, cfg.col_name)
@@ -87,8 +89,10 @@ def get_data_loader(cfg, rank, world_size):
8789
cfg.file_type in _handler_map
8890
), f"File type {cfg.file_type} is not recognized ({list(_handler_map.keys())})"
8991
if cfg.file_type == "hf_parquet" or cfg.file_type == "auto":
90-
filehandler = _handler_map[cfg.file_type](cfg.tokenizer_path, cols)
91-
elif cfg.file_type == "arrow":
92+
filehandler = _handler_map[cfg.file_type](
93+
cfg.tokenizer_path, cols, cfg.doc_cutoff
94+
)
95+
else:
9296
filehandler = _handler_map[cfg.file_type](cols)
9397

9498
# Base reader layer
@@ -131,7 +135,7 @@ def get_data_loader(cfg, rank, world_size):
131135
data = PreloadBufferDataset(data, 10000)
132136

133137
# Apply FIM transformation if needed
134-
if cfg.fim_training:
138+
if fim_training:
135139
data = FIMDataset(
136140
data,
137141
cfg.eos_token,

fms_fsdp/utils/dataset_utils.py

Lines changed: 54 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -343,9 +343,9 @@ class ArrowHandler(_ShardFileHandler):
343343
Non-standard data format, though.
344344
"""
345345

346-
def __init__(self, col_names: List[str] = ["tokens"]):
346+
def __init__(self, col_names: List[str] = ["text", "contents", "tokens"]):
347347
self.col_names = col_names
348-
348+
349349
def is_legal(self, filepath: str):
350350
return "arrow" in os.path.splitext(filepath)[1]
351351

@@ -356,14 +356,18 @@ def length(self, path: str):
356356
return self.open(path).num_record_batches
357357

358358
def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set):
359+
assert (
360+
index < reader.num_record_batches
361+
), f"Illegal index {index} in set of {reader.num_record_batches} documents"
359362
frame = reader.get_batch(index)
360-
361363
doc = None
362364
for name in self.col_names:
363365
if name in frame.column_names:
364366
doc = frame[name]
365367
break
366-
assert doc is not None, f"None of column names {self.col_names} found in file headers {frame.column_names}"
368+
assert (
369+
doc is not None
370+
), f"None of column names {self.col_names} found in file headers {frame.column_names}"
367371
if len(doc) > 0 and doc[0].as_py() in drop_tokens:
368372
doc = doc.slice(1, len(doc) - 1)
369373
# Recheck len for edge case where doc=[eos]
@@ -378,14 +382,20 @@ def slice(self, doc: pa.UInt32Array, index: int, n_pull: int) -> List:
378382
class ParquetHandler(_ShardFileHandler):
379383
"""
380384
Reader for indexable parquet shard files, common in HF datasets.
381-
Here we assume reasonably small shard files (<5Gb) and documents (<100k tokens),
385+
Here we assume reasonably small shard files (<5Gb) and truncate docs to max_doclen characters,
382386
as we rely on parquet/pandas for efficient file reading, and tokenize entire documents
383387
before getting/slicing. However, this is a standard and widely-used data format.
384388
"""
385389

386-
def __init__(self, tokenizer_path: str, col_names: List[str] = ["text"]):
390+
def __init__(
391+
self,
392+
tokenizer_path: str,
393+
col_names: List[str] = ["text", "contents", "tokens"],
394+
max_doclen: int = 1_000_000,
395+
):
387396
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
388397
self.col_names = col_names
398+
self.max_doclen = max_doclen
389399

390400
def is_legal(self, filepath: str):
391401
return "parquet" in os.path.splitext(filepath)[1]
@@ -397,18 +407,19 @@ def open(self, path: str):
397407
if name in names:
398408
match = name
399409
break
400-
assert match is not None, f"None of column names {self.col_names} found in file headers {names}"
410+
assert (
411+
match is not None
412+
), f"None of column names {self.col_names} found in file headers {names}"
401413
return pq.read_pandas(path, columns=[match], partitioning=None)[match]
402414

403415
def length(self, path: str):
404416
return pq.read_metadata(path).num_rows
405417

406418
def get(self, reader, index: int, drop_tokens: Set):
407-
408-
document_str = str(reader[index])
409-
410-
doc = self.tokenizer(str(reader[index]))["input_ids"]
411-
419+
assert (
420+
index < reader.length()
421+
), f"Illegal index {index} in set of {reader.length()} documents"
422+
doc = self.tokenizer(str(reader[index])[: self.max_doclen])["input_ids"]
412423
if len(doc) > 0 and doc[0] in drop_tokens:
413424
doc = doc[1:]
414425
# Recheck len for edge case where doc=[eos]
@@ -421,8 +432,13 @@ def slice(self, doc: List, index: int, n_pull: int) -> List:
421432

422433

423434
class AutoHandler(_ShardFileHandler):
424-
def __init__(self, tokenizer_path: str, col_names: List[str] = ["text", "contents", "tokens"]):
425-
self.PHandler = ParquetHandler(tokenizer_path, col_names)
435+
def __init__(
436+
self,
437+
tokenizer_path: str,
438+
col_names: List[str] = ["text", "contents", "tokens"],
439+
max_doclen: int = 1_000_000,
440+
):
441+
self.PHandler = ParquetHandler(tokenizer_path, col_names, max_doclen)
426442
self.AHandler = ArrowHandler(col_names)
427443
self.current = _ShardFileHandler()
428444

@@ -979,10 +995,10 @@ class StreamingDocDataset(_StatefulDataset):
979995
Documents below this length are skipped
980996
max_chunksize : int
981997
Maximum sequence length to return. Break long docs into chunks of this size or shorter.
998+
max_consecutive_chunks : int
999+
Number of doc chunks to emit before manually inserting EOS and resuming later.
9821000
verbose : bool
9831001
Track setup progress?
984-
shuffle : bool
985-
Shuffle shard file and document orders? (Disable for simple testing)
9861002
"""
9871003

9881004
def __init__(
@@ -1095,12 +1111,11 @@ def setup(self):
10951111
for row in reader:
10961112
fullpath = row["dataset/filename"]
10971113
prefix = fullpath.find(dataset + "/")
1098-
if prefix > 0:
1114+
if prefix >= 0:
10991115
key = fullpath[prefix + len(dataset) + 1 :]
11001116
sizes[key] = int(row["size"])
11011117
shard_sizes = [sizes[shard] for shard in shards]
11021118
else:
1103-
# Count file does not exist, touch every owned file for length
11041119
shard_sizes = [
11051120
os.path.getsize(os.path.join(datapath, shard)) for shard in shards
11061121
]
@@ -1125,7 +1140,7 @@ def setup(self):
11251140
reader = csv.DictReader(csvfile)
11261141
for row in reader:
11271142
fullpath = row["dataset/filename"]
1128-
prefix = fullpath.find(dataset + "/")
1143+
prefix = fullpath.find(dataset)
11291144
if prefix >= 0:
11301145
key = fullpath[prefix + len(dataset) + 1 :]
11311146
doc_counts[key] = int(row["documents"])
@@ -1141,10 +1156,13 @@ def setup(self):
11411156
doccount = 0
11421157
for shard in shardset:
11431158
ndocs = doc_counts[shard]
1144-
doc_start = int(ndocs * shardset[shard][0])
1145-
doc_end = max(doc_start, int(ndocs * shardset[shard][1]) - 1) # inclusive upper bound
1146-
self.docset.append([shard, doc_start, doc_end])
1147-
doccount += doc_end - doc_start + 1
1159+
if ndocs > 0:
1160+
doc_start = int(ndocs * shardset[shard][0])
1161+
doc_end = max(
1162+
doc_start, int(ndocs * shardset[shard][1]) - 1
1163+
) # inclusive upper bound
1164+
self.docset.append([shard, doc_start, doc_end])
1165+
doccount += doc_end - doc_start + 1
11481166
self._len = doccount
11491167

11501168
if self.verbose:
@@ -1253,10 +1271,8 @@ def __iter__(self):
12531271
doclcg = self._random_map_docid(docrange)
12541272
docid = doclcg + mindoc
12551273
doc = self.filehandler.get(reader, docid, self.drop)
1256-
if len(doc) == 0:
1257-
continue
12581274
doclen = len(doc) + 1 if self.bos is None else len(doc) + 2
1259-
if doclen >= self.min_length:
1275+
if len(doc) > 0 and doclen >= self.min_length:
12601276
n_chunks = math.ceil(doclen / self.chunksize)
12611277
for j in range(n_chunks):
12621278
if i == 0 and j < residual_chunks:
@@ -1283,18 +1299,18 @@ def __iter__(self):
12831299
newpath = os.path.join(self.datapath, shardid)
12841300
path, reader = self._get_reader(path, newpath, reader)
12851301
doc = self.filehandler.get(reader, docid, self.drop)
1286-
if len(doc) == 0:
1287-
continue
12881302
doclen = len(doc) + 1 if self.bos is None else len(doc) + 2
1289-
if doclen >= self.min_length:
1303+
if len(doc) > 0 and doclen >= self.min_length:
12901304
n_chunks = math.ceil(doclen / self.chunksize)
12911305
for j in range(residual_chunks):
12921306
self.chunk_index = j
12931307
self.has_yielded = True
12941308
yield self._construct_chunk(j, doc, n_chunks)
12951309

12961310
# Check that epoch was non-empty
1297-
assert self.has_yielded, f"Empty logical shard detected: {self.dataset, self.docset}"
1311+
assert (
1312+
self.has_yielded
1313+
), f"Empty logical shard detected: {self.dataset, self.docset}"
12981314

12991315
def load_state_dict(self, state_dicts, sharded_input=False):
13001316
self.setup()
@@ -1367,12 +1383,12 @@ def setup(self):
13671383
if not self.is_setup:
13681384
_StatefulDataset.setup(self)
13691385
n_logical_shards = self.total_shards
1386+
assert (
1387+
n_logical_shards % self.worldsize == 0
1388+
), f"Total workers {self.worldsize} must divide n_logical_shards {n_logical_shards} evenly"
13701389
logicals = list(range(n_logical_shards))
13711390
self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize)
13721391
self.n_logicals = n_logical_shards // self.worldsize
1373-
assert (
1374-
len(self.logicals_owned) == self.n_logicals
1375-
), "(world size * num workers) does not divide logical shards evenly"
13761392

13771393
# Build logical shards
13781394
for i in range(self.n_logicals):
@@ -1403,6 +1419,7 @@ def __iter__(self):
14031419
# (i.e. [1,1,0,0,0,0] into [1,1,0] [0,0,0] )
14041420
if sum(self.n_docs_remaining) == 0:
14051421
self.n_docs_remaining = [d._len for d in self.data]
1422+
self.generator.manual_seed(self.rank)
14061423
while True:
14071424
# Sample logical shard (or load from ckp)
14081425
if self.current_reader is not None:
@@ -1499,6 +1516,10 @@ def __init__(
14991516
]
15001517
)
15011518
assert len(self.datasets) > 0, "You must specify at least one dataset"
1519+
for d in datasets:
1520+
assert os.path.exists(
1521+
os.path.join(datapath, d)
1522+
), f"Invalid subdataset path: {os.path.join(datapath, d)}"
15021523

15031524
if weights is not None:
15041525
assert len(weights) == len(

0 commit comments

Comments
 (0)