Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(SemanticLayerSchema): Refactoring using SemanticLayerSchema all over the code instead of the dictionary #1519

Closed
wants to merge 8 commits into from
84 changes: 35 additions & 49 deletions pandasai/data_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import importlib
import os
from datetime import datetime, timedelta
from typing import Any
from typing import Any, Optional

import pandas as pd
import yaml
from sympy.parsing.sympy_parser import transformations

from pandasai.dataframe.base import DataFrame
from pandasai.dataframe.virtual_dataframe import VirtualDataFrame
Expand All @@ -19,11 +20,12 @@
SUPPORTED_SOURCE_CONNECTORS,
)
from .query_builder import QueryBuilder
from .semantic_layer_schema import SemanticLayerSchema


class DatasetLoader:
def __init__(self):
self.schema = None
self.schema: Optional[SemanticLayerSchema] = None
self.dataset_path = None

def load(self, dataset_path: str) -> DataFrame:
Expand All @@ -37,13 +39,13 @@ def load(self, dataset_path: str) -> DataFrame:
"""
self.dataset_path = dataset_path
self._load_schema()
self._validate_source_type()

if self.schema["source"]["type"] in LOCAL_SOURCE_TYPES:
source_type = self.schema.source.type
if source_type in LOCAL_SOURCE_TYPES:
cache_file = self._get_cache_file_path()

if self._is_cache_valid(cache_file):
cache_format = self.schema["destination"]["format"]
cache_format = self.schema.destination.format
return self._read_csv_or_parquet(cache_file, cache_format)

df = self._load_from_local_source()
Expand All @@ -53,27 +55,20 @@ def load(self, dataset_path: str) -> DataFrame:
df = pd.DataFrame(df._data)
self._cache_data(df, cache_file)

table_name = self.schema["source"].get("table", None) or self.schema["name"]
table_description = self.schema.get("description", None)

return DataFrame(
df._data,
schema=self.schema,
name=table_name,
description=table_description,
name=self.schema.name,
description=self.schema.description,
path=dataset_path,
)
elif self.schema["source"]["type"] in REMOTE_SOURCE_TYPES:
else:
data_loader = self.copy()
return VirtualDataFrame(
schema=self.schema,
data_loader=data_loader,
path=dataset_path,
)
else:
raise ValueError(
f"Unsupported source type: {self.schema['source']['type']}"
)

def _get_abs_dataset_path(self):
return os.path.join(find_project_root(), "datasets", self.dataset_path)
Expand All @@ -84,33 +79,26 @@ def _load_schema(self):
raise FileNotFoundError(f"Schema file not found: {schema_path}")

with open(schema_path, "r") as file:
self.schema = yaml.safe_load(file)

def _validate_source_type(self):
source_type = self.schema["source"]["type"]
if source_type not in SUPPORTED_SOURCE_CONNECTORS and source_type not in [
"csv",
"parquet",
]:
raise ValueError(f"Unsupported database type: {source_type}")
raw_schema = yaml.safe_load(file)
self.schema = SemanticLayerSchema(**raw_schema)

def _get_cache_file_path(self) -> str:
if "path" in self.schema["destination"]:
if self.schema.destination.path:
return os.path.join(
self._get_abs_dataset_path(), self.schema["destination"]["path"]
str(self._get_abs_dataset_path()), self.schema.destination.path
)

file_extension = (
"parquet" if self.schema["destination"]["format"] == "parquet" else "csv"
"parquet" if self.schema.destination.format == "parquet" else "csv"
)
return os.path.join(self._get_abs_dataset_path(), f"data.{file_extension}")
return os.path.join(str(self._get_abs_dataset_path()), f"data.{file_extension}")

def _is_cache_valid(self, cache_file: str) -> bool:
if not os.path.exists(cache_file):
return False

file_mtime = datetime.fromtimestamp(os.path.getmtime(cache_file))
update_frequency = self.schema.get("update_frequency", None)
update_frequency = self.schema.update_frequency

if update_frequency and update_frequency == "weekly":
return file_mtime > datetime.now() - timedelta(weeks=1)
Expand Down Expand Up @@ -148,29 +136,27 @@ def _get_loader_function(self, source_type: str):
) from e

def _read_csv_or_parquet(self, file_path: str, format: str) -> DataFrame:
table_name = self.schema["source"].get("table") or self.schema.get("name", None)
table_description = self.schema.get("description", None)
if format == "parquet":
return DataFrame(
pd.read_parquet(file_path),
schema=self.schema,
path=self.dataset_path,
name=table_name,
description=table_description,
name=self.schema.name,
description=self.schema.description,
)
elif format == "csv":
return DataFrame(
pd.read_csv(file_path),
schema=self.schema,
path=self.dataset_path,
name=table_name,
description=table_description,
name=self.schema.name,
description=self.schema.description,
)
else:
raise ValueError(f"Unsupported file format: {format}")

def _load_from_local_source(self) -> pd.DataFrame:
source_type = self.schema["source"]["type"]
source_type = self.schema.source.type

if source_type not in LOCAL_SOURCE_TYPES:
raise InvalidDataSourceType(
Expand All @@ -179,7 +165,7 @@ def _load_from_local_source(self) -> pd.DataFrame:

filepath = os.path.join(
str(self._get_abs_dataset_path()),
self.schema["source"]["path"],
self.schema.source.path,
)

return self._read_csv_or_parquet(filepath, source_type)
Expand Down Expand Up @@ -213,15 +199,17 @@ def execute_query(self, query: str) -> pd.DataFrame:
) from e

def _apply_transformations(self, df: pd.DataFrame) -> pd.DataFrame:
for transform in self.schema.get("transformations", []):
if transform["type"] == "anonymize":
df[transform["params"]["column"]] = df[
transform["params"]["column"]
].apply(self._anonymize)
elif transform["type"] == "convert_timezone":
df[transform["params"]["column"]] = pd.to_datetime(
df[transform["params"]["column"]]
).dt.tz_convert(transform["params"]["to"])
for transformation in self.schema.transformations or []:
transformation_type = transformation.type
transformation_column = transformation.params["column"]
if transformation_type == "anonymize":
df[transformation_column] = df[transformation_column].apply(
self._anonymize
)
elif transformation_type == "convert_timezone":
df[transformation_column] = pd.to_datetime(
df[transformation_column]
).dt.tz_convert(transformation.params["to"])
return df

@staticmethod
Expand All @@ -235,13 +223,11 @@ def _anonymize(value: Any) -> Any:
return value

def _cache_data(self, df: pd.DataFrame, cache_file: str):
cache_format = self.schema["destination"]["format"]
cache_format = self.schema.destination.format
if cache_format == "parquet":
df.to_parquet(cache_file, index=False)
elif cache_format == "csv":
df.to_csv(cache_file, index=False)
else:
raise ValueError(f"Unsupported cache format: {cache_format}")

def copy(self) -> "DatasetLoader":
"""
Expand Down
35 changes: 15 additions & 20 deletions pandasai/data_loader/query_builder.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import Any, Dict, List, Union

from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema


class QueryBuilder:
def __init__(self, schema: Dict[str, Any]):
def __init__(self, schema: SemanticLayerSchema):
self.schema = schema

def build_query(self) -> str:
columns = self._get_columns()
table_name = self.schema["source"]["table"]
table_name = self._get_table_name()
query = f"SELECT {columns} FROM {table_name}"

query += self._add_order_by()
Expand All @@ -16,42 +18,35 @@ def build_query(self) -> str:
return query

def _get_columns(self) -> str:
if "columns" in self.schema:
return ", ".join([col["name"] for col in self.schema["columns"]])
if self.schema.columns:
return ", ".join([col.name for col in self.schema.columns])
else:
return "*"

def _get_table_name(self):
table_name = self.schema["source"].get("table", None) or self.schema["name"]

if not table_name:
raise ValueError("Table name not found in schema!")

table_name = self.schema.source.table
table_name = table_name.lower()

return table_name

def _add_order_by(self) -> str:
if "order_by" not in self.schema:
if not self.schema.order_by:
return ""

order_by = self.schema["order_by"]
order_by = self.schema.order_by
order_by_clause = self._format_order_by(order_by)
return f" ORDER BY {order_by_clause}"

def _format_order_by(self, order_by: Union[List[str], str]) -> str:
return ", ".join(order_by) if isinstance(order_by, list) else order_by
@staticmethod
def _format_order_by(order_by: List[str]) -> str:
return ", ".join(order_by)

def _add_limit(self, n=None) -> str:
limit = n if n else (self.schema["limit"] if "limit" in self.schema else "")
return f" LIMIT {self.schema['limit']}" if limit else ""
limit = n or self.schema.limit
return f" LIMIT {limit}" if limit else ""

def get_head_query(self, n=5):
source = self.schema.get("source", {})
source_type = source.get("type")

source_type = self.schema.source.type
table_name = self._get_table_name()

columns = self._get_columns()

order_by = "RANDOM()" if source_type in {"sqlite", "postgres"} else "RAND()"
Expand Down
24 changes: 18 additions & 6 deletions pandasai/data_loader/semantic_layer_schema.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import json
from typing import Dict, List, Optional, Union

import yaml
from pydantic import (
BaseModel,
Field,
ValidationError,
field_validator,
model_validator,
root_validator,
)

from pandasai.constants import (
Expand All @@ -21,7 +20,7 @@
class Column(BaseModel):
name: str = Field(..., description="Name of the column.")
type: str = Field(..., description="Data type of the column.")
description: str = Field(..., description="Description of the column")
description: Optional[str] = Field(None, description="Description of the column")

@field_validator("type")
@classmethod
Expand Down Expand Up @@ -51,15 +50,14 @@ class Source(BaseModel):
None, description="Connection object of the data source."
)
path: Optional[str] = Field(None, description="Path of the local data source.")
query: Optional[str] = Field(
None, description="Query to retrieve data from the data source"
)
table: Optional[str] = Field(None, description="Table of the data source.")

@model_validator(mode="before")
@classmethod
def validate_type_and_fields(cls, values):
_type = values.get("type")
path = values.get("path")
table = values.get("table")
connection = values.get("connection")

if _type in LOCAL_SOURCE_TYPES:
Expand All @@ -72,6 +70,10 @@ def validate_type_and_fields(cls, values):
raise ValueError(
f"For remote source type '{_type}', 'connection' must be defined."
)
if not table:
raise ValueError(
f"For remote source type '{_type}', 'table' must be defined."
)
else:
raise ValueError(f"Unsupported source type: {_type}")

Expand All @@ -83,6 +85,13 @@ class Destination(BaseModel):
format: str = Field(..., description="Format of the output file.")
path: str = Field(..., description="Path to save the output file.")

@field_validator("format")
@classmethod
def is_format_supported(cls, format: str) -> str:
if format not in LOCAL_SOURCE_TYPES:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The is_format_supported validator checks if the format is in LOCAL_SOURCE_TYPES, which might be misleading. Consider using a more appropriate constant or list for format validation.

raise ValueError(f"Unsupported destination format: {format}")
return format


class SemanticLayerSchema(BaseModel):
name: str = Field(..., description="Dataset name.")
Expand All @@ -109,6 +118,9 @@ class SemanticLayerSchema(BaseModel):
None, description="Frequency of dataset updates."
)

def to_yaml(self) -> str:
return yaml.dump(self.model_dump(), sort_keys=False)


def is_schema_source_same(
schema1: SemanticLayerSchema, schema2: SemanticLayerSchema
Expand Down
Loading
Loading