Skip to content

Commit

Permalink
feat(Dataframe): add classmethod get_default_schema to generate defau…
Browse files Browse the repository at this point in the history
…lt schema, refactor pai.create to accept name, description, columns, removes cache from datasetloader
  • Loading branch information
scaliseraoul committed Jan 16, 2025
1 parent f9644cf commit 02f5cf8
Show file tree
Hide file tree
Showing 7 changed files with 239 additions and 237 deletions.
38 changes: 25 additions & 13 deletions pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,34 +13,39 @@

from pandasai.config import APIKeyManager, ConfigManager
from pandasai.constants import DEFAULT_API_URL
from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema
from pandasai.exceptions import DatasetNotFound, PandaAIApiKeyError
from pandasai.helpers.path import find_project_root
from pandasai.helpers.session import get_pandaai_session

from .agent import Agent
from .core.cache import Cache
from .data_loader.loader import DatasetLoader
from .data_loader.semantic_layer_schema import Column
from .dataframe import DataFrame, VirtualDataFrame
from .helpers.sql_sanitizer import sanitize_sql_table_name
from .smart_dataframe import SmartDataframe
from .smart_datalake import SmartDatalake


def create(path: str, df: pd.DataFrame, schema: SemanticLayerSchema):
def create(
path: str,
df: pd.DataFrame,
name: str = None,
description: str = None,
columns: List[dict] = None,
):
"""
Create a new dataset with the given DataFrame and schema.
Args:
path (str): Path in format 'organization/dataset'
df (pd.DataFrame): DataFrame to save
schema (SemanticLayerSchema): Schema containing dataset metadata
Returns:
DataFrame: A new PandaAI DataFrame instance with loaded data
Raises:
ValueError: If path format is invalid or dataset already exists
path (str): Path in the format 'organization/dataset'. This specifies
the location where the dataset should be created.
df (pd.DataFrame): The DataFrame containing the data to save.
name (str, optional): The name of the dataset. Defaults to None.
If not provided, a name will be automatically generated.
description (str, optional): A textual description of the dataset.
Defaults to None.
columns (List[dict], optional): A list of dictionaries defining the column schema.
Each dictionary should have keys like 'name', 'type', and optionally
'description' to describe individual columns. Defaults to None.
"""

# Validate path format
Expand Down Expand Up @@ -81,6 +86,13 @@ def create(path: str, df: pd.DataFrame, schema: SemanticLayerSchema):

# Save schema to yaml
schema_path = os.path.join(dataset_directory, "schema.yaml")

schema = df.schema
schema.name = name or schema.name
schema.description = description or schema.description
if columns:
schema.columns = list(map(lambda column: Column(**column), columns))

with open(schema_path, "w") as yml_file:
yml_file.write(schema.to_yaml())

Expand Down
43 changes: 3 additions & 40 deletions pandasai/data_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,11 @@ def load(self, dataset_path: str) -> DataFrame:

source_type = self.schema.source.type
if source_type in LOCAL_SOURCE_TYPES:
cache_file = self._get_cache_file_path()

if self._is_cache_valid(cache_file):
cache_format = self.schema.destination.format
return self._read_csv_or_parquet(cache_file, cache_format)

df = self._load_from_local_source()
df = self._apply_transformations(df)

# Convert to pandas DataFrame while preserving internal data
df = pd.DataFrame(df._data)
self._cache_data(df, cache_file)

return DataFrame(
df._data,
Expand Down Expand Up @@ -80,29 +73,6 @@ def _load_schema(self):
raw_schema = yaml.safe_load(file)
self.schema = SemanticLayerSchema(**raw_schema)

def _get_cache_file_path(self) -> str:
if self.schema.destination.path:
return os.path.join(
str(self._get_abs_dataset_path()), self.schema.destination.path
)

file_extension = (
"parquet" if self.schema.destination.format == "parquet" else "csv"
)
return os.path.join(str(self._get_abs_dataset_path()), f"data.{file_extension}")

def _is_cache_valid(self, cache_file: str) -> bool:
if not os.path.exists(cache_file):
return False

file_mtime = datetime.fromtimestamp(os.path.getmtime(cache_file))
update_frequency = self.schema.update_frequency

if update_frequency and update_frequency == "weekly":
return file_mtime > datetime.now() - timedelta(weeks=1)

return False

def _get_loader_function(self, source_type: str):
"""
Get the loader function for a specified data source type.
Expand Down Expand Up @@ -180,9 +150,9 @@ def get_row_count(self) -> int:
return result.iloc[0, 0]

def execute_query(self, query: str) -> pd.DataFrame:
source = self.schema.get("source", {})
source_type = source.get("type")
connection_info = source.get("connection", {})
source = self.schema.source
source_type = source.type
connection_info = source.connection

if not source_type:
raise ValueError("Source type is missing in the schema.")
Expand Down Expand Up @@ -220,13 +190,6 @@ def _anonymize(value: Any) -> Any:
except ValueError:
return value

def _cache_data(self, df: pd.DataFrame, cache_file: str):
cache_format = self.schema.destination.format
if cache_format == "parquet":
df.to_parquet(cache_file, index=False)
elif cache_format == "csv":
df.to_csv(cache_file, index=False)

def copy(self) -> "DatasetLoader":
"""
Create a new independent copy of the current DatasetLoader instance.
Expand Down
68 changes: 35 additions & 33 deletions pandasai/dataframe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import hashlib
import os
from io import BytesIO
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Union
from typing import TYPE_CHECKING, ClassVar, Optional, Union
from zipfile import ZipFile

import pandas as pd
Expand All @@ -14,7 +14,6 @@
from pandasai.core.response import BaseResponse
from pandasai.data_loader.semantic_layer_schema import (
Column,
Destination,
SemanticLayerSchema,
Source,
)
Expand Down Expand Up @@ -67,11 +66,11 @@ def __init__(
if not self.name:
self.name = f"table_{self._column_hash}"

self.description: Optional[str] = kwargs.pop("description", None)
self.path: Optional[str] = kwargs.pop("path", None)
schema: Optional[SemanticLayerSchema] = kwargs.pop("schema", None)
self.schema = schema or DataFrame.get_default_schema(self)

self.schema = schema
self.description: Optional[str] = kwargs.pop("description", None)
self.path: Optional[str] = kwargs.pop("path", None)
self.config = pai.config.get()
self._agent: Optional[Agent] = None

Expand Down Expand Up @@ -147,34 +146,6 @@ def serialize_dataframe(self) -> str:
def get_head(self):
return self.head()

@staticmethod
def _create_yml_template(
name, description, columns_dict: List[dict]
) -> Dict[str, Any]:
"""
Generate a .yml file with a simplified metadata template from a pandas DataFrame.
Args:
name: dataset name
description: dataset description
columns_dict: dictionary with info about columns of the dataframe
"""

if columns_dict:
columns_dict = list(map(lambda column: Column(**column), columns_dict))

schema = SemanticLayerSchema(
name=name,
description=description,
columns=columns_dict,
source=Source(type="parquet", path="data.parquet"),
destination=Destination(
type="local", format="parquet", path="data.parquet"
),
)

return schema.to_dict()

def push(self):
if self.path is None:
raise ValueError(
Expand Down Expand Up @@ -264,3 +235,34 @@ def execute_sql_query(self, query: str) -> pd.DataFrame:
db = duckdb.connect(":memory:")
db.register(self.name, self)
return db.query(query).df()

@staticmethod
def get_column_type(column_dtype) -> Optional[str]:
"""
Map pandas dtype to a valid column type.
"""
if pd.api.types.is_string_dtype(column_dtype):
return "string"
elif pd.api.types.is_integer_dtype(column_dtype):
return "integer"
elif pd.api.types.is_float_dtype(column_dtype):
return "float"
elif pd.api.types.is_datetime64_any_dtype(column_dtype):
return "datetime"
elif pd.api.types.is_bool_dtype(column_dtype):
return "boolean"
else:
return None

@classmethod
def get_default_schema(cls, dataframe: DataFrame) -> SemanticLayerSchema:
columns_list = [
Column(name=str(name), type=DataFrame.get_column_type(dtype))
for name, dtype in dataframe.dtypes.items()
]

return SemanticLayerSchema(
name=dataframe.name,
source=Source(type="parquet", path="data.parquet"),
columns=columns_list,
)
Loading

0 comments on commit 02f5cf8

Please sign in to comment.