diff --git a/pandasai/data_loader/loader.py b/pandasai/data_loader/loader.py index 2aac1517b..60295d2d4 100644 --- a/pandasai/data_loader/loader.py +++ b/pandasai/data_loader/loader.py @@ -46,7 +46,7 @@ def load(self, dataset_path: str) -> DataFrame: cache_format = self.schema["destination"]["format"] return self._read_csv_or_parquet(cache_file, cache_format) - df = self._load_from_source() + df = self._load_from_local_source() df = self._apply_transformations(df) # Convert to pandas DataFrame while preserving internal data @@ -169,18 +169,20 @@ def _read_csv_or_parquet(self, file_path: str, format: str) -> DataFrame: else: raise ValueError(f"Unsupported file format: {format}") - def _load_from_source(self) -> pd.DataFrame: + def _load_from_local_source(self) -> pd.DataFrame: source_type = self.schema["source"]["type"] - if source_type in ["csv", "parquet"]: - filepath = os.path.join( - self._get_abs_dataset_path(), - self.schema["source"]["path"], + + if source_type not in LOCAL_SOURCE_TYPES: + raise InvalidDataSourceType( + f"Unsupported local source type: {source_type}. Supported types are: {LOCAL_SOURCE_TYPES}." ) - return self._read_csv_or_parquet(filepath, source_type) - query_builder = QueryBuilder(self.schema) - query = query_builder.build_query() - return self.execute_query(query) + filepath = os.path.join( + str(self._get_abs_dataset_path()), + self.schema["source"]["path"], + ) + + return self._read_csv_or_parquet(filepath, source_type) def load_head(self) -> pd.DataFrame: query_builder = QueryBuilder(self.schema) diff --git a/tests/unit_tests/dataframe/test_loader.py b/tests/unit_tests/dataframe/test_loader.py index 2715c57b7..4f8d8c132 100644 --- a/tests/unit_tests/dataframe/test_loader.py +++ b/tests/unit_tests/dataframe/test_loader.py @@ -7,6 +7,7 @@ from pandasai.data_loader.loader import DatasetLoader from pandasai.dataframe.base import DataFrame +from pandasai.exceptions import InvalidDataSourceType class TestDatasetLoader: @@ -109,8 +110,8 @@ def test_load_from_cache(self, sample_schema): ) as mock_read_cache, patch( "builtins.open", mock_open(read_data=str(sample_schema)) ), patch( - "pandasai.data_loader.loader.DatasetLoader._load_from_source" - ) as mock_load_source, patch( + "pandasai.data_loader.loader.DatasetLoader._load_from_local_source" + ) as mock_load_local_source, patch( "pandasai.data_loader.loader.DatasetLoader.load_head" ) as mock_load_head: loader = DatasetLoader() @@ -126,7 +127,34 @@ def test_load_from_cache(self, sample_schema): assert isinstance(result, DataFrame) assert "email" in result.columns mock_read_cache.assert_called_once() - mock_load_source.assert_not_called() + mock_load_local_source.assert_not_called() + + def test_load_from_local_source_valid(self, sample_schema): + with patch("os.path.exists", return_value=True), patch( + "pandasai.data_loader.loader.DatasetLoader._read_csv_or_parquet" + ) as mock_read_csv_or_parquet: + loader = DatasetLoader() + loader.dataset_path = "test" + loader.schema = sample_schema + + mock_read_csv_or_parquet.return_value = DataFrame( + {"email": ["test@example.com"]} + ) + + result = loader._load_from_local_source() + + assert isinstance(result, DataFrame) + assert "email" in result.columns + + def test_load_from_local_source_invalid_source_type(self, sample_schema): + loader = DatasetLoader() + sample_schema["source"]["type"] = "mysql" + loader.schema = sample_schema + + with pytest.raises( + InvalidDataSourceType, match="Unsupported local source type" + ): + loader._load_from_local_source() def test_anonymize_method(self): loader = DatasetLoader()