diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 96380b689..dd35bd3f1 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -65,7 +65,25 @@ make spell_fix We use `pytest` to test our code. You can run the tests by running the following command: ```bash -make tests +make test_all +``` + +If you prefer, you can run only the core tests with the command: + +```bash +make test_core +``` + +or the test of extensions with the command: + +```bash +make test_extensions +``` + +You can also run the tests with coverage by running the following command: + +```bash +make test-coverage ``` Make sure that all tests pass before submitting a pull request. diff --git a/pandasai/data_loader/semantic_layer_schema.py b/pandasai/data_loader/semantic_layer_schema.py index 3049b645d..501ec479f 100644 --- a/pandasai/data_loader/semantic_layer_schema.py +++ b/pandasai/data_loader/semantic_layer_schema.py @@ -1,6 +1,6 @@ import re from functools import partial -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List import yaml from pydantic import ( @@ -45,114 +45,107 @@ def __eq__(self, other): class Column(BaseModel): name: str = Field(..., description="Name of the column.") - type: Optional[str] = Field(None, description="Data type of the column.") - description: Optional[str] = Field(None, description="Description of the column") - expression: Optional[str] = Field( - None, description="Aggregation expression (avg, min, max, sum)" - ) - alias: Optional[str] = Field(None, description="Alias for the column") + type: str | None = Field(None, description="Data type of the column.") + description: str | None = Field(None, description="Description of the column") + expression: str | None = Field(None, description="Aggregation expression (avg, min, max, sum)") + alias: str | None = Field(None, description="Alias for the column") @field_validator("type") - @classmethod - def is_column_type_supported(cls, type: str) -> str: - if type and type not in VALID_COLUMN_TYPES: + def is_column_type_supported(cls, v: str) -> str: + if v and v not in VALID_COLUMN_TYPES: raise ValueError( - f"Unsupported column type: {type}. Supported types are: {VALID_COLUMN_TYPES}" + f"Unsupported column type: {v}. Supported types are: {VALID_COLUMN_TYPES}" ) - return type + return v @field_validator("expression") - @classmethod - def is_expression_valid(cls, expr: str) -> str: - try: - parse_one(expr) - return expr - except ParseError as e: - raise ValueError(f"Invalid SQL expression: {expr}. Error: {str(e)}") + def is_expression_valid(cls, v: str) -> str | None: + if v is not None: + try: + parse_one(v) + return v + except ParseError as e: + raise ValueError(f"Invalid SQL expression: {v}. Error: {str(e)}") class Relation(BaseModel): - name: Optional[str] = Field(None, description="Name of the relationship.") - description: Optional[str] = Field( - None, description="Description of the relationship." - ) - from_: str = Field( - ..., alias="from", description="Source column for the relationship." - ) + name: str | None = Field(None, description="Name of the relationship.") + description: str | None = Field(None, description="Description of the relationship.") + from_: str = Field(..., alias="from", description="Source column for the relationship.") to: str = Field(..., description="Target column for the relationship.") class TransformationParams(BaseModel): - column: Optional[str] = Field(None, description="Column to transform") - value: Optional[Union[str, int, float, bool]] = Field( + column: str | None = Field(None, description="Column to transform") + value: str | int | float | bool | None = Field( None, description="Value for fill_na and other transformations" ) - mapping: Optional[Dict[str, str]] = Field( + mapping: Dict[str, str] | None = Field( None, description="Mapping dictionary for map_values transformation" ) - format: Optional[str] = Field(None, description="Format string for date formatting") - decimals: Optional[int] = Field( + format: str | None = Field(None, description="Format string for date formatting") + decimals: int | None = Field( None, description="Number of decimal places for rounding" ) - factor: Optional[Union[int, float]] = Field(None, description="Scaling factor") - to_tz: Optional[str] = Field(None, description="Target timezone or format") - from_tz: Optional[str] = Field(None, description="From timezone or format") - errors: Optional[str] = Field( + factor: int | float | None = Field(None, description="Scaling factor") + to_tz: str | None = Field(None, description="Target timezone or format") + from_tz: str | None = Field(None, description="From timezone or format") + errors: str | None = Field( None, description="Error handling mode for numeric/datetime conversion" ) - old_value: Optional[Any] = Field( + old_value: Any | None = Field( None, description="Old value for replace transformation" ) - new_value: Optional[Any] = Field( + new_value: Any | None = Field( None, description="New value for replace transformation" ) - new_name: Optional[str] = Field( + new_name: str | None = Field( None, description="New name for column in rename transformation" ) - pattern: Optional[str] = Field( + pattern: str | None = Field( None, description="Pattern for extract transformation" ) - length: Optional[int] = Field( + length: int | None = Field( None, description="Length for truncate transformation" ) - add_ellipsis: Optional[bool] = Field( + add_ellipsis: bool | None = Field( True, description="Whether to add ellipsis in truncate" ) - width: Optional[int] = Field(None, description="Width for pad transformation") - side: Optional[str] = Field("left", description="Side for pad transformation") - pad_char: Optional[str] = Field(" ", description="Character for pad transformation") - lower: Optional[Union[int, float]] = Field(None, description="Lower bound for clip") - upper: Optional[Union[int, float]] = Field(None, description="Upper bound for clip") - bins: Optional[Union[int, List[Union[int, float]]]] = Field( + width: int | None = Field(None, description="Width for pad transformation") + side: str | None = Field("left", description="Side for pad transformation") + pad_char: str | None = Field(" ", description="Character for pad transformation") + lower: int | float | None = Field(None, description="Lower bound for clip") + upper: int | float | None = Field(None, description="Upper bound for clip") + bins: int | List[int | float] | None = Field( None, description="Bins for binning" ) - labels: Optional[List[str]] = Field(None, description="Labels for bins") - drop_first: Optional[bool] = Field( + labels: List[str] | None = Field(None, description="Labels for bins") + drop_first: bool | None = Field( True, description="Whether to drop first category in encoding" ) - drop_invalid: Optional[bool] = Field( + drop_invalid: bool | None = Field( False, description="Whether to drop invalid values" ) - start_date: Optional[str] = Field( + start_date: str | None = Field( None, description="Start date for date range validation" ) - end_date: Optional[str] = Field( + end_date: str | None = Field( None, description="End date for date range validation" ) - country_code: Optional[str] = Field( + country_code: str | None = Field( "+1", description="Country code for phone normalization" ) - columns: Optional[List[str]] = Field( + columns: List[str] | None = Field( None, description="List of columns for multi-column operations" ) - keep: Optional[str] = Field("first", description="Which duplicates to keep") - ref_table: Optional[Any] = Field( + keep: str | None = Field("first", description="Which duplicates to keep") + ref_table: Any | None = Field( None, description="Reference DataFrame for foreign key validation" ) - ref_column: Optional[str] = Field( + ref_column: str | None = Field( None, description="Reference column for foreign key validation" ) - drop_negative: Optional[bool] = Field( + drop_negative: bool | None = Field( False, description="Whether to drop negative values" ) @@ -172,7 +165,7 @@ def validate_required_params(cls, values: dict) -> dict: class Transformation(BaseModel): type: str = Field(..., description="Type of transformation to be applied.") - params: Optional[TransformationParams] = Field( + params: TransformationParams | None = Field( None, description="Parameters for the transformation." ) @@ -195,11 +188,11 @@ def set_transform_type(cls, values: dict) -> dict: class Source(BaseModel): type: str = Field(..., description="Type of the data source.") - path: Optional[str] = Field(None, description="Path of the local data source.") - connection: Optional[SQLConnectionConfig] = Field( + path: str | None = Field(None, description="Path of the local data source.") + connection: SQLConnectionConfig | None = Field( None, description="Connection object of the data source." ) - table: Optional[str] = Field(None, description="Table of the data source.") + table: str | None = Field(None, description="Table of the data source.") def is_compatible_source(self, source2: "Source"): """ @@ -267,33 +260,33 @@ def is_format_supported(cls, format: str) -> str: class SemanticLayerSchema(BaseModel): name: str = Field(..., description="Dataset name.") - source: Optional[Source] = Field(None, description="Data source for your dataset.") - view: Optional[bool] = Field(None, description="Whether table is a view") - description: Optional[str] = Field( + source: Source | None = Field(None, description="Data source for your dataset.") + view: bool | None = Field(None, description="Whether table is a view") + description: str | None = Field( None, description="Dataset’s contents and purpose description." ) - columns: Optional[List[Column]] = Field( + columns: List[Column] | None = Field( None, description="Structure and metadata of your dataset’s columns" ) - relations: Optional[List[Relation]] = Field( + relations: List[Relation] | None = Field( None, description="Relationships between columns and tables." ) - order_by: Optional[List[str]] = Field( + order_by: List[str] | None = Field( None, description="Ordering criteria for the dataset." ) - limit: Optional[int] = Field( + limit: int | None = Field( None, description="Maximum number of records to retrieve." ) - transformations: Optional[List[Transformation]] = Field( + transformations: List[Transformation] | None = Field( None, description="List of transformations to apply to the data." ) - destination: Optional[Destination] = Field( + destination: Destination | None = Field( None, description="Destination for saving the dataset." ) - update_frequency: Optional[str] = Field( + update_frequency: str | None = Field( None, description="Frequency of dataset updates." ) - group_by: Optional[List[str]] = Field( + group_by: List[str] | None = Field( None, description="List of columns to group by. Every non-aggregated column must be included in group_by.", ) diff --git a/pandasai/helpers/dataframe_serializer.py b/pandasai/helpers/dataframe_serializer.py index 2debb55c1..92cb08b14 100644 --- a/pandasai/helpers/dataframe_serializer.py +++ b/pandasai/helpers/dataframe_serializer.py @@ -28,6 +28,10 @@ def serialize(cls, df: "DataFrame", dialect: str = "postgres") -> str: if df.schema.description is not None: dataframe_info += f' description="{df.schema.description}"' + if df.schema.columns: + columns = [column.model_dump() for column in df.schema.columns] + dataframe_info += f' columns="{json.dumps(columns, ensure_ascii=False)}"' + dataframe_info += f' dimensions="{df.rows_count}x{df.columns_count}">' # Truncate long values diff --git a/poetry.lock b/poetry.lock index f4b30706e..2e129d437 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.0.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand. [[package]] name = "annotated-types" @@ -35,7 +35,7 @@ typing-extensions = {version = ">=4.1", markers = "python_version < \"3.11\""} [package.extras] doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] -test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "truststore (>=0.9.1)", "uvloop (>=0.21.0b1)"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "truststore (>=0.9.1) ; python_version >= \"3.10\"", "uvloop (>=0.21.0b1) ; platform_python_implementation == \"CPython\" and platform_system != \"Windows\""] trio = ["trio (>=0.26.1)"] [[package]] @@ -206,7 +206,7 @@ files = [ [package.extras] dev = ["Pygments", "build", "chardet", "pre-commit", "pytest", "pytest-cov", "pytest-dependency", "ruff", "tomli", "twine"] hard-encoding-detection = ["chardet"] -toml = ["tomli"] +toml = ["tomli ; python_version < \"3.11\""] types = ["chardet (>=5.1.0)", "mypy", "pytest", "pytest-cov", "pytest-dependency"] [[package]] @@ -377,7 +377,7 @@ files = [ ] [package.extras] -toml = ["tomli"] +toml = ["tomli ; python_full_version <= \"3.11.0a6\""] [[package]] name = "cycler" @@ -488,7 +488,7 @@ description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" groups = ["dev"] -markers = "python_version < \"3.11\"" +markers = "python_version <= \"3.10\"" files = [ {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, @@ -512,7 +512,7 @@ files = [ [package.extras] docs = ["furo (>=2024.8.6)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4.1)"] testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "diff-cover (>=9.2)", "pytest (>=8.3.3)", "pytest-asyncio (>=0.24)", "pytest-cov (>=5)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.26.4)"] -typing = ["typing-extensions (>=4.12.2)"] +typing = ["typing-extensions (>=4.12.2) ; python_version < \"3.11\""] [[package]] name = "fonttools" @@ -575,18 +575,18 @@ files = [ ] [package.extras] -all = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "fs (>=2.2.0,<3)", "lxml (>=4.0)", "lz4 (>=1.7.4.2)", "matplotlib", "munkres", "pycairo", "scipy", "skia-pathops (>=0.5.0)", "sympy", "uharfbuzz (>=0.23.0)", "unicodedata2 (>=15.1.0)", "xattr", "zopfli (>=0.1.4)"] +all = ["brotli (>=1.0.1) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\"", "fs (>=2.2.0,<3)", "lxml (>=4.0)", "lz4 (>=1.7.4.2)", "matplotlib", "munkres ; platform_python_implementation == \"PyPy\"", "pycairo", "scipy ; platform_python_implementation != \"PyPy\"", "skia-pathops (>=0.5.0)", "sympy", "uharfbuzz (>=0.23.0)", "unicodedata2 (>=15.1.0) ; python_version <= \"3.12\"", "xattr ; sys_platform == \"darwin\"", "zopfli (>=0.1.4)"] graphite = ["lz4 (>=1.7.4.2)"] -interpolatable = ["munkres", "pycairo", "scipy"] +interpolatable = ["munkres ; platform_python_implementation == \"PyPy\"", "pycairo", "scipy ; platform_python_implementation != \"PyPy\""] lxml = ["lxml (>=4.0)"] pathops = ["skia-pathops (>=0.5.0)"] plot = ["matplotlib"] repacker = ["uharfbuzz (>=0.23.0)"] symfont = ["sympy"] -type1 = ["xattr"] +type1 = ["xattr ; sys_platform == \"darwin\""] ufo = ["fs (>=2.2.0,<3)"] -unicode = ["unicodedata2 (>=15.1.0)"] -woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"] +unicode = ["unicodedata2 (>=15.1.0) ; python_version <= \"3.12\""] +woff = ["brotli (>=1.0.1) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\"", "zopfli (>=0.1.4)"] [[package]] name = "h11" @@ -641,7 +641,7 @@ httpcore = "==1.*" idna = "*" [package.extras] -brotli = ["brotli", "brotlicffi"] +brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""] cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] @@ -694,7 +694,7 @@ files = [ zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} [package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""] cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] enabler = ["pytest-enabler (>=2.2)"] @@ -1207,7 +1207,7 @@ files = [ numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">=1.21.0", markers = "python_version == \"3.10\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -1331,7 +1331,7 @@ docs = ["furo", "olefile", "sphinx (>=7.3)", "sphinx-copybutton", "sphinx-inline fpx = ["olefile"] mic = ["olefile"] tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] -typing = ["typing-extensions"] +typing = ["typing-extensions ; python_version < \"3.10\""] xmp = ["defusedxml"] [[package]] @@ -1454,7 +1454,7 @@ typing-extensions = ">=4.12.2" [package.extras] email = ["email-validator (>=2.0.0)"] -timezone = ["tzdata"] +timezone = ["tzdata ; python_version >= \"3.9\" and platform_system == \"Windows\""] [[package]] name = "pydantic-core" @@ -1918,7 +1918,7 @@ description = "A lil' TOML parser" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "python_version < \"3.11\"" +markers = "python_version <= \"3.10\"" files = [ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"}, @@ -2013,7 +2013,7 @@ files = [ ] [package.extras] -brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +brotli = ["brotli (>=1.0.9) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\""] h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] @@ -2037,7 +2037,7 @@ platformdirs = ">=3.9.1,<5" [package.extras] docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] -test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] +test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8) ; platform_python_implementation == \"PyPy\" or platform_python_implementation == \"CPython\" and sys_platform == \"win32\" and python_version >= \"3.13\"", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10) ; platform_python_implementation == \"CPython\""] [[package]] name = "zipp" @@ -2053,11 +2053,11 @@ files = [ ] [package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""] cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] enabler = ["pytest-enabler (>=2.2)"] -test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-ignore-flaky"] +test = ["big-O", "importlib-resources ; python_version < \"3.9\"", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-ignore-flaky"] type = ["pytest-mypy"] [metadata] diff --git a/tests/unit_tests/helpers/test_dataframe_serializer.py b/tests/unit_tests/helpers/test_dataframe_serializer.py index 46bd6ea14..2354be5dd 100644 --- a/tests/unit_tests/helpers/test_dataframe_serializer.py +++ b/tests/unit_tests/helpers/test_dataframe_serializer.py @@ -1,5 +1,3 @@ -import pandas as pd - from pandasai.helpers.dataframe_serializer import DataframeSerializer @@ -8,7 +6,7 @@ def test_serialize_with_name_and_description(self, sample_df): """Test serialization with name and description attributes.""" result = DataframeSerializer.serialize(sample_df) - expected = """ + expected = """
A,B 1,4 2,5 @@ -21,7 +19,7 @@ def test_serialize_with_name_and_description_with_dialect(self, sample_df): """Test serialization with name and description attributes.""" result = DataframeSerializer.serialize(sample_df, dialect="mysql") - expected = """
+ expected = """
A,B 1,4 2,5 @@ -44,7 +42,7 @@ def test_serialize_with_dataframe_long_strings(self, sample_df): truncated_text = long_text[: DataframeSerializer.MAX_COLUMN_TEXT_LENGTH] + "…" # Expected output - expected = f"""
+ expected = f"""
A,B {truncated_text},4 2,5