diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index fa8e434d4..94e6673d7 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -180,10 +180,11 @@ def load_dataset(self): logger.debug(f"Loading dataset {self.data_args.dataset}") return get_raw_dataset( - self.data_args, - None, + self.data_args.dataset, + name=self.data_args.dataset_config_name, split=self.split, streaming=self.data_args.streaming, + trust_remove_code=self.data_args.trust_remote_code_data, **self.data_args.raw_kwargs, ) diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index 23c70e561..3b4a93cfd 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -67,9 +67,7 @@ def format_calibration_data( def get_raw_dataset( - data_args, - cache_dir: Optional[str] = None, - streaming: Optional[bool] = False, + path: str, **kwargs, ) -> Dataset: """ @@ -80,15 +78,10 @@ def get_raw_dataset( :return: the requested dataset """ - raw_datasets = load_dataset( - data_args.dataset, - data_args.dataset_config_name, - cache_dir=cache_dir, - streaming=streaming, - trust_remote_code=data_args.trust_remote_code_data, + return load_dataset( + path, **kwargs, ) - return raw_datasets def make_dataset_splits( diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py index 812b26a56..37f7b8821 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py @@ -12,7 +12,12 @@ def test_combined_datasets(): data_args = DataTrainingArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1" ) - raw_wikitext2 = get_raw_dataset(data_args) + raw_wikitext2 = get_raw_dataset( + data_args.dataset, + name=data_args.dataset_config_name, + splits=data_args.splits, + streaming=data_args.streaming, + ) datasets = {"all": raw_wikitext2} split_datasets = make_dataset_splits( @@ -37,8 +42,13 @@ def test_separate_datasets(): dataset="wikitext", dataset_config_name="wikitext-2-raw-v1" ) datasets = {} - for split_name, split_str in splits.items(): - raw_wikitext2 = get_raw_dataset(data_args, split=split_str) + for split_name, _ in splits.items(): + raw_wikitext2 = get_raw_dataset( + data_args.dataset, + name=data_args.dataset_config_name, + splits=data_args.splits, + streaming=data_args.streaming, + ) datasets[split_name] = raw_wikitext2 split_datasets = make_dataset_splits(