Skip to content

Commit

Permalink
[usability] dataset check approach update
Browse files Browse the repository at this point in the history
  • Loading branch information
wheresmyhair committed Jan 26, 2025
1 parent 2b0adba commit 93e9df4
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 44 deletions.
60 changes: 32 additions & 28 deletions src/lmflow/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from cmath import e
from pathlib import Path
from typing import Optional
from typing import Optional, List

from datasets import load_dataset
from datasets import Dataset as HFDataset
Expand All @@ -32,7 +32,7 @@
INSTANCE_FIELDS_MAP,
)
from lmflow.utils.versioning import is_multimodal_available
from lmflow.utils.data_utils import get_dataset_type_fast
from lmflow.utils.data_utils import get_dataset_type_fast, check_dataset_instances_key_fast

if is_multimodal_available():
from .multi_modal_dataset import CustomMultiModalDataset
Expand Down Expand Up @@ -92,32 +92,8 @@ def __init__(self, data_args: DatasetArguments=None, backend: str="huggingface",
]
logger.info(f"Data files: \n{data_files}")

# Iterate through all the files and ensure they have the same data type
for single_file in tqdm(data_files, desc='Checking dataset keys'):
# check keys: type, instances
json_data_type = get_dataset_type_fast(single_file)
if not json_data_type:
raise ValueError(
f'"{KEY_TYPE}" must be provided to initialize a dataset,'
f' e.g.\n'
f' {TEXT_ONLY_DATASET_DESCRIPTION}'
)
if self.type is None: # TODO: out of skip_dataset_check
self.type = json_data_type
elif self.type != json_data_type:
raise ValueError(
'All task files must have same data types. Previous'
f' files have type "{self.type}", but in file'
f' {single_file}, it has type "{self.type}".'
)
# json_data_instances = get_dataset_type_fast(single_file)
# if not json_data_instances:
# raise ValueError(
# f'"{KEY_INSTANCES}" must be provided to initialize a'
# f' dataset, e.g.\n'
# f' {TEXT_ONLY_DATASET_DESCRIPTION}'
# )

# check if the dataset is in the correct format and get the dataset type (text_only, text2text, etc.)
self._check_hf_json_format(data_files)
# Load the dataset using the HuggingFace dataset library
print('loading datasets')
extensions = "json"
Expand Down Expand Up @@ -161,6 +137,34 @@ def _check_instance_format(self):
f'data instance fields incorrect'
f' {list(correct_fields)} are required.'
)


def _check_hf_json_format(self, data_files: List[str]):
for single_file in tqdm(data_files, desc='Checking dataset keys'):
# get type and check if it is consistent
json_data_type = get_dataset_type_fast(single_file)
if not json_data_type:
raise ValueError(
f'"{KEY_TYPE}" must be provided to initialize a dataset,'
f' e.g.\n'
f' {TEXT_ONLY_DATASET_DESCRIPTION}'
)
if self.type is None:
self.type = json_data_type
elif self.type != json_data_type:
raise ValueError(
'All task files must have same data types. Previous'
f' files have type "{self.type}", but in file'
f' {single_file}, it has type "{self.type}".'
)
# check if instances key is provided
key_instances_exists_flag = check_dataset_instances_key_fast(single_file, KEY_INSTANCES)
if not key_instances_exists_flag:
raise ValueError(
f'"{KEY_INSTANCES}" must be provided to initialize a'
f' dataset, e.g.\n'
f' {TEXT_ONLY_DATASET_DESCRIPTION}'
)


def from_dict(self, dict_obj: dict, *args, **kwargs):
Expand Down
56 changes: 40 additions & 16 deletions src/lmflow/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def batchlize(examples: list, batch_size: int, random_shuffle: bool):
return dataloader


def read_last_n_lines_large_file(file_path, n=10):
def read_last_n_lines_large_file(file_path: str, n: int = 10) -> List[str]:
with open(file_path, 'rb') as f:
f.seek(0, os.SEEK_END)
buffer = bytearray()
Expand All @@ -107,22 +107,46 @@ def read_last_n_lines_large_file(file_path, n=10):
return buffer[::-1].decode('utf-8').splitlines()[-n:]


def get_dataset_type_fast(file_path, max_lines=100):
type_values = []
# first n lines
with open(file_path, 'r', encoding='utf-8') as f:
for i, line in enumerate(f):
if i >= max_lines:
def read_first_n_lines_large_file(file_path: str, n: int = 10) -> List[str]:
with open(file_path, 'rb') as f:
f.seek(0)
lines = []
for i in range(n):
line = f.readline()
if not line:
break
try:
data = json.loads(line.strip())
if isinstance(data, dict) and 'type' in data:
type_values.append(data['type'])
except json.JSONDecodeError:
continue
# last n lines
# TODO
return type_values
lines.append(line.decode('utf-8').strip())
return lines


def get_dataset_type_fast(file_path: str, max_lines: int = 100) -> Union[str, None]:
'''Get the type values from the first and last n lines of a large json dataset.
'''
lines = []
dataset_type = None
dataset_type_pattern = re.compile(r'[\"\']type[\"\']:\s*[\'\"]([^"]+)[\'\"]')
lines.extend(read_first_n_lines_large_file(file_path, max_lines))
lines.extend(read_last_n_lines_large_file(file_path, max_lines))
for line in lines:
try:
dataset_type = dataset_type_pattern.search(line).group(1)
break
except AttributeError:
continue
return dataset_type


def check_dataset_instances_key_fast(file_path: str, instances_key: str, max_lines: int = 100) -> bool:
'''Check if the dataset instances key matches the instance_key.
'''
lines = []
instance_key_pattern = re.compile(r'[\"\']' + instances_key + r'[\"\']')
lines.extend(read_first_n_lines_large_file(file_path, max_lines))
lines.extend(read_last_n_lines_large_file(file_path, max_lines))
for line in lines:
if instance_key_pattern.search(line):
return True
return False


def answer_extraction(response, answer_type=None): #use this funtion to extract answers from generated text
Expand Down

0 comments on commit 93e9df4

Please sign in to comment.