Skip to content

Commit d9bb084

Browse files
nitsanlukejlamypoiriertscholak
authored
Combine GPTHuggingfaceDatasetConfig input sources into source_schema (#255)
Co-authored-by: Joel Lamy-Poirier <[email protected]> Co-authored-by: Torsten Scholak <[email protected]>
1 parent 24ac566 commit d9bb084

File tree

2 files changed

+41
-18
lines changed

2 files changed

+41
-18
lines changed

fast_llm/data/preparator/gpt_memmap/config.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,23 @@
2525
MEMMAP_INDEX_HEADER = b"MMIDIDX\x00\x00"
2626

2727

28+
@config_class(registry=True)
29+
class SourceSchemaConfig(Config):
30+
pass
31+
32+
33+
@config_class(dynamic_type={SourceSchemaConfig: "text_column"})
34+
class TextColumnConfig(SourceSchemaConfig):
35+
input_column: str = Field(
36+
default="text",
37+
desc="Field of the dataset to use.",
38+
hint=FieldHint.optional,
39+
)
40+
loss_masking_spans_column: None | str = Field(
41+
default=None, desc="Field containing character spans to mask for loss computation", hint=FieldHint.optional
42+
)
43+
44+
2845
@config_class()
2946
class GPTHuggingfaceDatasetConfig(Config):
3047
path: str = Field(
@@ -52,14 +69,10 @@ class GPTHuggingfaceDatasetConfig(Config):
5269
desc="Split of the dataset to use.",
5370
hint=FieldHint.optional,
5471
)
55-
field: str = Field(
56-
default="text",
57-
desc="Field of the dataset to use.",
72+
source_schema: SourceSchemaConfig = Field(
73+
desc="Configuration for the data source.",
5874
hint=FieldHint.optional,
5975
)
60-
loss_masking_spans: None | str = Field(
61-
default=None, desc="Field containing character spans to mask for loss computation", hint=FieldHint.optional
62-
)
6376
chosen_text: None | str = Field(
6477
default=None, desc="Field containing chosen text for preference optimization", hint=FieldHint.optional
6578
)

fast_llm/data/preparator/gpt_memmap/prepare.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset
2525
from fast_llm.data.dataset.gpt.sampled import GPTSample
2626
from fast_llm.data.preparator.config import DatasetPreparator
27-
from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig
27+
from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, TextColumnConfig
2828
from fast_llm.data.tokenizer import Tokenizer
2929
from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type
3030
from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum
@@ -37,11 +37,12 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](D
3737

3838
_tokenizer: Tokenizer
3939
_data_type: DataType
40+
_text_column: str
41+
_loss_masking_spans_column: str | None
4042

4143
def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]:
4244
input_ids = [
43-
np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy)
44-
for text in batch[self._config.dataset.field]
45+
np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) for text in batch[self._text_column]
4546
]
4647
num_tokens = [len(x) for x in input_ids]
4748
return {
@@ -60,9 +61,7 @@ def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict
6061
)
6162
for input_ids, token_spans in [
6263
self._tokenizer.tokenize_with_spans(text, char_spans)
63-
for text, char_spans in zip(
64-
batch[self._config.dataset.field], batch[self._config.dataset.loss_masking_spans]
65-
)
64+
for text, char_spans in zip(batch[self._text_column], batch[self._loss_masking_spans_column])
6665
]
6766
]
6867
),
@@ -144,7 +143,7 @@ def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetCon
144143
shard_output_path = self._config.output_path / prefix
145144

146145
def _document_generator():
147-
if "token_spans" in shard_dataset.column_names and self._config.dataset.loss_masking_spans is not None:
146+
if "token_spans" in shard_dataset.column_names and self._loss_masking_spans_column is not None:
148147
for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"):
149148
yield GPTSample(
150149
np.array(item["input_ids"], dtype=self._data_type.numpy),
@@ -288,8 +287,19 @@ def run(self) -> None:
288287
num_shards=self._config.distributed.world_size,
289288
index=self._config.distributed.rank,
290289
)
291-
if self._config.dataset.field not in dataset.column_names:
292-
raise ValueError(f"Dataset does not have field '{self._config.dataset.field}'.")
290+
291+
# Set data column and loss masking spans column based on source schema
292+
if isinstance(self._config.dataset.source_schema, TextColumnConfig):
293+
self._text_column = self._config.dataset.source_schema.input_column
294+
self._loss_masking_spans_column = self._config.dataset.source_schema.loss_masking_spans_column
295+
else:
296+
raise ValueError(
297+
f"Dataset source_schema set incorrectly. source_schema: '{self._config.dataset.source_schema}'."
298+
)
299+
300+
if self._text_column not in dataset.column_names:
301+
raise ValueError(f"Dataset does not have field '{self._text_column}'.")
302+
293303
if self._config.dataset.loss_masking_spans is not None and (
294304
self._config.dataset.chosen_text is not None or self._config.dataset.rejected_text is not None
295305
):
@@ -298,9 +308,9 @@ def run(self) -> None:
298308
raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.")
299309

300310
# route tokenize function
301-
if self._config.dataset.loss_masking_spans is not None:
302-
if self._config.dataset.loss_masking_spans not in dataset.column_names:
303-
raise ValueError(f"Dataset does not have spans field '{self._config.dataset.loss_masking_spans}'.")
311+
if self._loss_masking_spans_column is not None:
312+
if self._loss_masking_spans_column not in dataset.column_names:
313+
raise ValueError(f"Dataset does not have spans field '{self._loss_masking_spans_column}'.")
304314
tokenize_fn = self._tokenize_batch_with_spans
305315
elif self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None:
306316
if self._config.dataset.chosen_text not in dataset.column_names:

0 commit comments

Comments
 (0)