Skip to content

Commit

Permalink
Merge pull request #930 from OptimalScale/yizhenjia-data-debug
Browse files Browse the repository at this point in the history
[usability] Change dataset check method
  • Loading branch information
research4pan authored Feb 1, 2025
2 parents 9a9957d + bf59cec commit 0b3bbf6
Show file tree
Hide file tree
Showing 8 changed files with 167 additions and 95 deletions.
9 changes: 9 additions & 0 deletions src/lmflow/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,10 @@ class DatasetArguments:
conversation_template: str
a string representing the template for conversation datasets.
dataset_cache_dir: str
a string representing the path to the dataset cache directory. Useful when the default cache dir
(`~/.cache/huggingface/datasets`) has limited space.
The class also includes some additional parameters that can be used to configure the dataset further, such as `overwrite_cache`,
`validation_split_percentage`, `preprocessing_num_workers`, `disable_group_texts`, `demo_example_in_prompt`, `explanation_in_prompt`,
`keep_linebreaks`, and `prompt_structure`.
Expand Down Expand Up @@ -608,6 +612,11 @@ class DatasetArguments:
default=None,
metadata={"help": "The template for conversation datasets."}
)
dataset_cache_dir: Optional[str] = field(
default=None,
metadata={"help": ("The path to the dataset cache directory. Useful when the "
"default cache dir (`~/.cache/huggingface/datasets`) has limited space.")}
)

def __post_init__(self):
if self.streaming:
Expand Down
96 changes: 47 additions & 49 deletions src/lmflow/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@

from cmath import e
from pathlib import Path
from typing import Optional
from typing import Optional, List

from datasets import load_dataset
from datasets import Dataset as HFDataset
from tqdm import tqdm

from lmflow.args import DatasetArguments
from lmflow.utils.constants import (
Expand All @@ -31,6 +32,7 @@
INSTANCE_FIELDS_MAP,
)
from lmflow.utils.versioning import is_multimodal_available
from lmflow.utils.data_utils import get_dataset_type_fast, check_dataset_instances_key_fast

if is_multimodal_available():
from .multi_modal_dataset import CustomMultiModalDataset
Expand Down Expand Up @@ -88,41 +90,22 @@ def __init__(self, data_args: DatasetArguments=None, backend: str="huggingface",
x.absolute().as_posix()
for x in Path(self.dataset_path).glob("*.json")
]

# Iterate through all the files and ensure they have the same data type
for single_file in data_files:
with open(single_file) as fin:
json_data = json.load(fin)
if KEY_TYPE not in json_data.keys():
raise ValueError(
f'"{KEY_TYPE}" field must be specified for data, e.g.'
'{\n'
f' "{KEY_TYPE}: "text_only",\n'
f' "{KEY_INSTANCES}": [\n'
' { "text": "Sentence 1: This is a sentence." }\n'
' { "text": "Sentence 2: This is another sentence." }\n'
f' ]\n'
'}'
)
if self.type is None:
self.type = json_data[KEY_TYPE]
elif self.type != json_data[KEY_TYPE]:
raise ValueError(
'All task files must have same data types. Previous'
f' files have type "{self.type}", but in file'
f' {single_file}, it has type "{self.type}".'
)

logger.info(f"Data files: \n{data_files}")

# check if the dataset is in the correct format and get the dataset type (text_only, text2text, etc.)
self._check_hf_json_format(data_files)
# Load the dataset using the HuggingFace dataset library
logger.info('Loading datasets')
extensions = "json"
raw_dataset = load_dataset(
extensions,
data_files=data_files,
field=KEY_INSTANCES,
split="train",
cache_dir=data_args.dataset_cache_dir,
)
self.backend_dataset = raw_dataset
self._check_data_format()
self._check_instance_format()
elif backend == "json":
# TODO (@Jiachun)
pass
Expand All @@ -137,36 +120,51 @@ def __init__(self, data_args: DatasetArguments=None, backend: str="huggingface",
else:
raise NotImplementedError(f'Unsupported dataset backend "{backend}"')


def __len__(self):
return len(self.backend_dataset)


def _check_data_format(self):
"""Checks if data type and data structure matches
Raise messages with hints if not matched.
def _check_instance_format(self):
"""
data_dict = self.to_dict()
if KEY_TYPE not in data_dict:
raise ValueError(
f'"{KEY_TYPE}" must be provided to initialize a dataset,'
f' e.g.\n'
f' {TEXT_ONLY_DATASET_DESCRIPTION}'
)
if KEY_INSTANCES not in data_dict:
raise ValueError(
f'"{KEY_INSTANCES}" must be provided to initialize a'
f' dataset, e.g.\n'
f' {TEXT_ONLY_DATASET_DESCRIPTION}'
)

data_type = data_dict[KEY_TYPE]
fields = self.get_backend_dataset().features
correct_fields = INSTANCE_FIELDS_MAP[data_type]
Checks if data (instances) have required fields.
Raises messages with hints if not matched.
"""
fields = self.backend_dataset.features
correct_fields = INSTANCE_FIELDS_MAP[self.type]
if not set(correct_fields).issubset(set(fields)):
raise ValueError(
f'data instance fields incorrect'
f' {list(correct_fields)} are required.'
)


def _check_hf_json_format(self, data_files: List[str]):
for single_file in tqdm(data_files, desc='Checking dataset keys'):
# get type and check if it is consistent
json_data_type = get_dataset_type_fast(single_file)
if not json_data_type:
raise ValueError(
f'"{KEY_TYPE}" must be provided to initialize a dataset,'
f' e.g.\n'
f' {TEXT_ONLY_DATASET_DESCRIPTION}'
)
if self.type is None:
self.type = json_data_type
elif self.type != json_data_type:
raise ValueError(
'All task files must have same data types. Previous'
f' files have type "{self.type}", but in file'
f' {single_file}, it has type "{self.type}".'
)
# check if instances key is provided
key_instances_exists_flag = check_dataset_instances_key_fast(single_file, KEY_INSTANCES)
if not key_instances_exists_flag:
raise ValueError(
f'"{KEY_INSTANCES}" must be provided to initialize a'
f' dataset, e.g.\n'
f' {TEXT_ONLY_DATASET_DESCRIPTION}'
)


def from_dict(self, dict_obj: dict, *args, **kwargs):
Expand Down Expand Up @@ -252,7 +250,7 @@ def from_dict(self, dict_obj: dict, *args, **kwargs):
f" follows:\n"
f" {DATASET_DESCRIPTION_MAP[self.type]}"
)
self._check_data_format()
self._check_instance_format()

return self
elif self.backend == "dict":
Expand Down
7 changes: 6 additions & 1 deletion src/lmflow/models/hf_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def __init__(

def tokenize(
self,
dataset,
dataset: Dataset,
add_special_tokens=True,
*args,
**kwargs
Expand Down Expand Up @@ -236,6 +236,11 @@ def tokenize(
"new_fingerprint": fingerprint,
}

if data_args.block_size < self.tokenizer.model_max_length:
logger.warning(
f"block_size {data_args.block_size} < model_max_length {self.tokenizer.model_max_length}, "
"use block_size for maximum tokenized sequence length."
)
tokenized_datasets = raw_datasets.map(
tokenize_fn,
batched=True,
Expand Down
11 changes: 9 additions & 2 deletions src/lmflow/pipeline/finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import numpy as np

import lmflow.optim.optimizers as optim
from lmflow.args import OptimizerNames
from lmflow.args import OptimizerNames, DatasetArguments, ModelArguments, FinetunerArguments
from lmflow.datasets.dataset import Dataset
from lmflow.pipeline.base_tuner import BaseTuner
from lmflow.pipeline.utils.peft_trainer import PeftTrainer, PeftSavingCallback
Expand Down Expand Up @@ -64,7 +64,14 @@ class Finetuner(BaseTuner):
Keyword arguments.
"""
def __init__(self, model_args, data_args, finetuner_args, *args, **kwargs):
def __init__(
self,
model_args: ModelArguments,
data_args: DatasetArguments,
finetuner_args: FinetunerArguments,
*args,
**kwargs
):

self.model_args = model_args
self.data_args = data_args
Expand Down
10 changes: 0 additions & 10 deletions src/lmflow/tokenization/hf_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,10 @@ def blocking(
padding_side: str,
truncation_side: str='right',
) -> Dict:
block_size_warning_num = 0
num_example = len(token_dict[list(token_dict.keys())[0]])
for i in range(num_example):
max_length = min(block_size, model_max_length)
pad_length = max_length - len(token_dict["input_ids"][i])
if block_size < model_max_length:
block_size_warning_num += 1
if pad_length < 0:
# Truncates too long samples
for key in ["input_ids", "attention_mask", "labels"]:
Expand Down Expand Up @@ -72,13 +69,6 @@ def blocking(
raise ValueError(
f"padding_side should be either 'right' or 'left', got {padding_side}"
)
if block_size_warning_num > 0:
logger.warning(
f"There are {block_size_warning_num} of {num_example} samples where"
f"block_size {block_size} < model_max_length"
f" {model_max_length}, use block_size"
" for maximum tokenized sequence length"
)

return token_dict

Expand Down
30 changes: 0 additions & 30 deletions src/lmflow/tokenization/hf_text_regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,11 @@ def blocking_paired(
padding_side: str,
truncation_side: str='right',
) -> Dict:
block_size_warning_num = 0
num_example = len(token_dict[list(token_dict.keys())[0]])
for i in range(num_example):
for column_name in column_names:
max_length = min(block_size, model_max_length)
pad_length = max_length - len(token_dict[f"input_ids_{column_name}"][i])
if block_size < model_max_length:
block_size_warning_num += 1
if pad_length < 0:
# Truncates too long samples
for key in [f"input_ids_{column_name}", f"attention_mask_{column_name}"]:
Expand Down Expand Up @@ -68,13 +65,6 @@ def blocking_paired(
raise ValueError(
f"padding_side should be either 'right' or 'left', got {padding_side}"
)
if block_size_warning_num > 0:
logger.warning(
f"There are {block_size_warning_num} of {num_example} samples where"
f" block_size {block_size} < model_max_length"
f" {model_max_length}, use block_size"
" for maximum tokenized sequence length"
)

return token_dict

Expand All @@ -87,13 +77,10 @@ def blocking(
padding_side: str,
truncation_side: str='right',
) -> Dict:
block_size_warning_num = 0
num_example = len(token_dict[list(token_dict.keys())[0]])
for i in range(num_example):
max_length = min(block_size, model_max_length)
pad_length = max_length - len(token_dict["input_ids"][i])
if block_size < model_max_length:
block_size_warning_num += 1
if pad_length < 0:
# Truncates too long samples
for key in ["input_ids", "attention_mask", "labels"]:
Expand Down Expand Up @@ -132,13 +119,6 @@ def blocking(
raise ValueError(
f"padding_side should be either 'right' or 'left', got {padding_side}"
)
if block_size_warning_num > 0:
logger.warning(
f"There are {block_size_warning_num} of {num_example} samples where"
f" block_size {block_size} < model_max_length"
f" {model_max_length}, use block_size"
" for maximum tokenized sequence length"
)

return token_dict

Expand All @@ -151,15 +131,12 @@ def blocking_text_to_textlist(
padding_side: str,
truncation_side: str='right',
) -> Dict:
block_size_warning_num = 0
num_example = len(token_dict[list(token_dict.keys())[0]])
max_length = min(block_size, model_max_length)

for example_idx in range(num_example):
for content_idx in range(len(token_dict["input_ids"][example_idx])):
pad_length = max_length - len(token_dict["input_ids"][example_idx][content_idx])
if block_size < model_max_length:
block_size_warning_num += 1
if pad_length < 0:
# Truncates too long samples
if truncation_side == 'right':
Expand All @@ -185,13 +162,6 @@ def blocking_text_to_textlist(
raise ValueError(
f"padding_side should be either 'right' or 'left', got {padding_side}"
)
if block_size_warning_num > 0:
logger.warning(
f"There are {block_size_warning_num} of {num_example} samples where"
f" block_size {block_size} < model_max_length"
f" {model_max_length}, use block_size"
" for maximum tokenized sequence length"
)

return token_dict

Expand Down
Loading

0 comments on commit 0b3bbf6

Please sign in to comment.