24
24
from fast_llm .data .dataset .gpt .memmap import GPTMemmapDataset
25
25
from fast_llm .data .dataset .gpt .sampled import GPTSample
26
26
from fast_llm .data .preparator .config import DatasetPreparator
27
- from fast_llm .data .preparator .gpt_memmap .config import GPTMemmapDatasetPreparatorConfig
27
+ from fast_llm .data .preparator .gpt_memmap .config import GPTMemmapDatasetPreparatorConfig , TextColumnConfig
28
28
from fast_llm .data .tokenizer import Tokenizer
29
29
from fast_llm .engine .config_utils .data_type import DataType , get_unsigned_integer_type
30
30
from fast_llm .utils import Assert , normalize_probabilities , padded_cumsum
@@ -37,11 +37,12 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](D
37
37
38
38
_tokenizer : Tokenizer
39
39
_data_type : DataType
40
+ _text_column : str
41
+ _loss_masking_spans_column : str | None
40
42
41
43
def _tokenize_batch (self , batch : dict [str , list [typing .Any ]]) -> dict [str , list [typing .Any ]]:
42
44
input_ids = [
43
- np .array (self ._tokenizer .tokenize (text ), dtype = self ._data_type .numpy )
44
- for text in batch [self ._config .dataset .field ]
45
+ np .array (self ._tokenizer .tokenize (text ), dtype = self ._data_type .numpy ) for text in batch [self ._text_column ]
45
46
]
46
47
num_tokens = [len (x ) for x in input_ids ]
47
48
return {
@@ -60,9 +61,7 @@ def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict
60
61
)
61
62
for input_ids , token_spans in [
62
63
self ._tokenizer .tokenize_with_spans (text , char_spans )
63
- for text , char_spans in zip (
64
- batch [self ._config .dataset .field ], batch [self ._config .dataset .loss_masking_spans ]
65
- )
64
+ for text , char_spans in zip (batch [self ._text_column ], batch [self ._loss_masking_spans_column ])
66
65
]
67
66
]
68
67
),
@@ -144,7 +143,7 @@ def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetCon
144
143
shard_output_path = self ._config .output_path / prefix
145
144
146
145
def _document_generator ():
147
- if "token_spans" in shard_dataset .column_names and self ._config . dataset . loss_masking_spans is not None :
146
+ if "token_spans" in shard_dataset .column_names and self ._loss_masking_spans_column is not None :
148
147
for item in tqdm .tqdm (shard_dataset , desc = f"Saving shard { shard_idx } " , unit = "docs" ):
149
148
yield GPTSample (
150
149
np .array (item ["input_ids" ], dtype = self ._data_type .numpy ),
@@ -288,8 +287,19 @@ def run(self) -> None:
288
287
num_shards = self ._config .distributed .world_size ,
289
288
index = self ._config .distributed .rank ,
290
289
)
291
- if self ._config .dataset .field not in dataset .column_names :
292
- raise ValueError (f"Dataset does not have field '{ self ._config .dataset .field } '." )
290
+
291
+ # Set data column and loss masking spans column based on source schema
292
+ if isinstance (self ._config .dataset .source_schema , TextColumnConfig ):
293
+ self ._text_column = self ._config .dataset .source_schema .input_column
294
+ self ._loss_masking_spans_column = self ._config .dataset .source_schema .loss_masking_spans_column
295
+ else :
296
+ raise ValueError (
297
+ f"Dataset source_schema set incorrectly. source_schema: '{ self ._config .dataset .source_schema } '."
298
+ )
299
+
300
+ if self ._text_column not in dataset .column_names :
301
+ raise ValueError (f"Dataset does not have field '{ self ._text_column } '." )
302
+
293
303
if self ._config .dataset .loss_masking_spans is not None and (
294
304
self ._config .dataset .chosen_text is not None or self ._config .dataset .rejected_text is not None
295
305
):
@@ -298,9 +308,9 @@ def run(self) -> None:
298
308
raise ValueError (f"Both chosen and rejected loss masking spans must be specified if one is specified." )
299
309
300
310
# route tokenize function
301
- if self ._config . dataset . loss_masking_spans is not None :
302
- if self ._config . dataset . loss_masking_spans not in dataset .column_names :
303
- raise ValueError (f"Dataset does not have spans field '{ self ._config . dataset . loss_masking_spans } '." )
311
+ if self ._loss_masking_spans_column is not None :
312
+ if self ._loss_masking_spans_column not in dataset .column_names :
313
+ raise ValueError (f"Dataset does not have spans field '{ self ._loss_masking_spans_column } '." )
304
314
tokenize_fn = self ._tokenize_batch_with_spans
305
315
elif self ._config .dataset .chosen_text is not None and self ._config .dataset .rejected_text is not None :
306
316
if self ._config .dataset .chosen_text not in dataset .column_names :
0 commit comments