Skip to content

Commit

Permalink
refactor(loader): renames _load_from_source to _load_from_local_sourc…
Browse files Browse the repository at this point in the history
…e removes unreachable code and adds tests for _load_from_local_source (#1514)

* refactor(loader): renames _load_from_source to _load_from_local_source, removes unreachable code and adds tests for _load_from_local_source

* chore: fixing ruff formatting

* Update pandasai/data_loader/loader.py

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

---------

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
  • Loading branch information
scaliseraoul and ellipsis-dev[bot] authored Jan 13, 2025
1 parent 9fe0d32 commit 5030471
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 13 deletions.
22 changes: 12 additions & 10 deletions pandasai/data_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def load(self, dataset_path: str) -> DataFrame:
cache_format = self.schema["destination"]["format"]
return self._read_csv_or_parquet(cache_file, cache_format)

df = self._load_from_source()
df = self._load_from_local_source()
df = self._apply_transformations(df)

# Convert to pandas DataFrame while preserving internal data
Expand Down Expand Up @@ -169,18 +169,20 @@ def _read_csv_or_parquet(self, file_path: str, format: str) -> DataFrame:
else:
raise ValueError(f"Unsupported file format: {format}")

def _load_from_source(self) -> pd.DataFrame:
def _load_from_local_source(self) -> pd.DataFrame:
source_type = self.schema["source"]["type"]
if source_type in ["csv", "parquet"]:
filepath = os.path.join(
self._get_abs_dataset_path(),
self.schema["source"]["path"],

if source_type not in LOCAL_SOURCE_TYPES:
raise InvalidDataSourceType(
f"Unsupported local source type: {source_type}. Supported types are: {LOCAL_SOURCE_TYPES}."
)
return self._read_csv_or_parquet(filepath, source_type)

query_builder = QueryBuilder(self.schema)
query = query_builder.build_query()
return self.execute_query(query)
filepath = os.path.join(
str(self._get_abs_dataset_path()),
self.schema["source"]["path"],
)

return self._read_csv_or_parquet(filepath, source_type)

def load_head(self) -> pd.DataFrame:
query_builder = QueryBuilder(self.schema)
Expand Down
34 changes: 31 additions & 3 deletions tests/unit_tests/dataframe/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from pandasai.data_loader.loader import DatasetLoader
from pandasai.dataframe.base import DataFrame
from pandasai.exceptions import InvalidDataSourceType


class TestDatasetLoader:
Expand Down Expand Up @@ -109,8 +110,8 @@ def test_load_from_cache(self, sample_schema):
) as mock_read_cache, patch(
"builtins.open", mock_open(read_data=str(sample_schema))
), patch(
"pandasai.data_loader.loader.DatasetLoader._load_from_source"
) as mock_load_source, patch(
"pandasai.data_loader.loader.DatasetLoader._load_from_local_source"
) as mock_load_local_source, patch(
"pandasai.data_loader.loader.DatasetLoader.load_head"
) as mock_load_head:
loader = DatasetLoader()
Expand All @@ -126,7 +127,34 @@ def test_load_from_cache(self, sample_schema):
assert isinstance(result, DataFrame)
assert "email" in result.columns
mock_read_cache.assert_called_once()
mock_load_source.assert_not_called()
mock_load_local_source.assert_not_called()

def test_load_from_local_source_valid(self, sample_schema):
with patch("os.path.exists", return_value=True), patch(
"pandasai.data_loader.loader.DatasetLoader._read_csv_or_parquet"
) as mock_read_csv_or_parquet:
loader = DatasetLoader()
loader.dataset_path = "test"
loader.schema = sample_schema

mock_read_csv_or_parquet.return_value = DataFrame(
{"email": ["[email protected]"]}
)

result = loader._load_from_local_source()

assert isinstance(result, DataFrame)
assert "email" in result.columns

def test_load_from_local_source_invalid_source_type(self, sample_schema):
loader = DatasetLoader()
sample_schema["source"]["type"] = "mysql"
loader.schema = sample_schema

with pytest.raises(
InvalidDataSourceType, match="Unsupported local source type"
):
loader._load_from_local_source()

def test_anonymize_method(self):
loader = DatasetLoader()
Expand Down

0 comments on commit 5030471

Please sign in to comment.