Skip to content

Commit

Permalink
get_raw_dataset fix
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Feb 3, 2025
1 parent 317c10c commit a7dc55a
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 15 deletions.
5 changes: 3 additions & 2 deletions src/llmcompressor/transformers/finetune/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,11 @@ def load_dataset(self):

logger.debug(f"Loading dataset {self.data_args.dataset}")
return get_raw_dataset(
self.data_args,
None,
self.data_args.dataset,
name=self.data_args.dataset_config_name,
split=self.split,
streaming=self.data_args.streaming,
trust_remove_code=self.data_args.trust_remote_code_data,
**self.data_args.raw_kwargs,
)

Expand Down
13 changes: 3 additions & 10 deletions src/llmcompressor/transformers/finetune/data/data_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ def format_calibration_data(


def get_raw_dataset(
data_args,
cache_dir: Optional[str] = None,
streaming: Optional[bool] = False,
path: str,
**kwargs,
) -> Dataset:
"""
Expand All @@ -80,15 +78,10 @@ def get_raw_dataset(
:return: the requested dataset
"""
raw_datasets = load_dataset(
data_args.dataset,
data_args.dataset_config_name,
cache_dir=cache_dir,
streaming=streaming,
trust_remote_code=data_args.trust_remote_code_data,
return load_dataset(
path,
**kwargs,
)
return raw_datasets


def make_dataset_splits(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@ def test_combined_datasets():
data_args = DataTrainingArguments(
dataset="wikitext", dataset_config_name="wikitext-2-raw-v1"
)
raw_wikitext2 = get_raw_dataset(data_args)
raw_wikitext2 = get_raw_dataset(
data_args.dataset,
name=data_args.dataset_config_name,
splits=data_args.splits,
streaming=data_args.streaming,
)
datasets = {"all": raw_wikitext2}

split_datasets = make_dataset_splits(
Expand All @@ -37,8 +42,13 @@ def test_separate_datasets():
dataset="wikitext", dataset_config_name="wikitext-2-raw-v1"
)
datasets = {}
for split_name, split_str in splits.items():
raw_wikitext2 = get_raw_dataset(data_args, split=split_str)
for split_name, _ in splits.items():
raw_wikitext2 = get_raw_dataset(
data_args.dataset,
name=data_args.dataset_config_name,
splits=data_args.splits,
streaming=data_args.streaming,
)
datasets[split_name] = raw_wikitext2

split_datasets = make_dataset_splits(
Expand Down

0 comments on commit a7dc55a

Please sign in to comment.