Skip to content

Commit

Permalink
feat : Airtable connector Support (#635)
Browse files Browse the repository at this point in the history
* feat : Airtable connector Support

* Removed unneccesary print statements

* Fixed tests and implementation

* where filter and pagination added

* Removed Pagination

* Refactored where filter functionality

* Docs and example change

* Fixed cached methods and code improvements

* Added Pagination to connector

* Removed preprocces method

* add offset not in res break condition

* Added more test coverge for where clause and column hashing

---------

Co-authored-by: ArslanSaleem <[email protected]>
  • Loading branch information
Tanmaypatil123 and ArslanSaleem authored Oct 18, 2023
1 parent 16e5241 commit cec0a61
Show file tree
Hide file tree
Showing 8 changed files with 482 additions and 6 deletions.
29 changes: 29 additions & 0 deletions docs/connectors.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,32 @@ yahoo_connector = YahooFinanceConnector("MSFT")
df = SmartDataframe(yahoo_connector)
df.chat("What is the closing price for yesterday?")
```

## Airtable Connector

The Airtable connector allows you to connect to Airtable Projects Tables, by simply passing the `base_id` , `api_key` and `table_name` of the table you want to analyze.

To use the Airtable connector, you only need to import it into your Python code and pass it to a `SmartDataframe` or `SmartDatalake` object:

```python
from pandasai.connectors import AirtableConnector
from pandasai import SmartDataframe


airtable_connectors = AirtableConnector(
config={
"api_key": "AIRTABLE_API_TOKEN",
"table":"AIRTABLE_TABLE_NAME",
"base_id":"AIRTABLE_BASE_ID",
"where" : [
# this is optional and filters the data to
# reduce the size of the dataframe
["Status" ,"=","In progress"]
]
}
)

df = SmartDataframe(airtable_connectors)

df.chat("How many rows are there in data ?")
```
23 changes: 23 additions & 0 deletions examples/from_airtable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from pandasai.connectors import AirtableConnector
from pandasai.llm import OpenAI
from pandasai import SmartDataframe


airtable_connectors = AirtableConnector(
config={
"api_key": "AIRTABLE_API_TOKEN",
"table": "AIRTABLE_TABLE_NAME",
"base_id": "AIRTABLE_BASE_ID",
"where": [
# this is optional and filters the data to
# reduce the size of the dataframe
["Status", "=", "In progress"]
],
}
)

llm = OpenAI("OPENAI_API_KEY")
df = SmartDataframe(airtable_connectors, config={"llm": llm})

response = df.chat("How many rows are there in data ?")
print(response)
2 changes: 2 additions & 0 deletions pandasai/connectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .snowflake import SnowFlakeConnector
from .databricks import DatabricksConnector
from .yahoo_finance import YahooFinanceConnector
from .airtable import AirtableConnector

__all__ = [
"BaseConnector",
Expand All @@ -18,4 +19,5 @@
"YahooFinanceConnector",
"SnowFlakeConnector",
"DatabricksConnector",
"AirtableConnector",
]
279 changes: 279 additions & 0 deletions pandasai/connectors/airtable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
"""
Airtable connectors are used to connect airtable records.
"""

from .base import AirtableConnectorConfig, BaseConnector, BaseConnectorConfig
from typing import Union, Optional
import requests
import pandas as pd
import os
from ..helpers.path import find_project_root
import time
import hashlib
from ..exceptions import InvalidRequestError
from functools import cache, cached_property


class AirtableConnector(BaseConnector):
"""
Airtable connector to retrieving record data.
"""

_rows_count: int = None
_columns_count: int = None
_instance = None

def __init__(
self,
config: Optional[Union[AirtableConnectorConfig, dict]] = None,
cache_interval: int = 600,
):
if isinstance(config, dict):
if "api_key" in config and "base_id" in config and "table" in config:
config = AirtableConnectorConfig(**config)
else:
raise KeyError(
"Please specify all api_key,table,base_id properly in config ."
)

elif not config:
airtable_env_vars = {
"api_key": "AIRTABLE_API_TOKEN",
"base_id": "AIRTABLE_BASE_ID",
"table": "AIRTABLE_TABLE_NAME",
}
config = AirtableConnectorConfig(
**self._populate_config_from_env(config, airtable_env_vars)
)

self._root_url: str = "https://api.airtable.com/v0/"
self._cache_interval = cache_interval

super().__init__(config)

def _init_connection(self, config: BaseConnectorConfig):
"""
make connection to database
"""
config = config.dict()
url = f"{self._root_url}{config['base_id']}/{config['table']}"
response = requests.head(
url=url, headers={"Authorization": f"Bearer {config['api_key']}"}
)
if response.status_code == 200:
self.logger.log(
"""
Connected to Airtable.
"""
)
else:
raise InvalidRequestError(
f"""Failed to connect to Airtable.
Status code: {response.status_code},
message: {response.text}"""
)

def _get_cache_path(self, include_additional_filters: bool = False):
"""
Return the path of the cache file.
Returns :
str : The path of the cache file.
"""
cache_dir = os.path.join(os.getcwd(), "")
try:
cache_dir = os.path.join((find_project_root()), "cache")
except ValueError:
cache_dir = os.path.join(os.getcwd(), "cache")
return os.path.join(cache_dir, f"{self._config.table}_data.parquet")

def _cached(self, include_additional_filters: bool = False):
"""
Returns the cached Airtable data if it exists and
is not older than the cache interval.
Returns :
DataFrame | None : The cached data if
it exists and is not older than the cache
interval, None otherwise.
"""
cache_path = self._get_cache_path(include_additional_filters)
if not os.path.exists(cache_path):
return None

# If the file is older than 1 day , delete it.
if os.path.getmtime(cache_path) < time.time() - self._cache_interval:
if self.logger:
self.logger.log(f"Deleting expired cached data from {cache_path}")
os.remove(cache_path)
return None

if self.logger:
self.logger.log(f"Loading cached data from {cache_path}")

return cache_path

def _save_cache(self, df):
"""
Save the given DataFrame to the cache.
Args:
df (DataFrame): The DataFrame to save to the cache.
"""
filename = self._get_cache_path(
include_additional_filters=self._additional_filters is not None
and len(self._additional_filters) > 0
)
df.to_parquet(filename)

@property
def fallback_name(self):
"""
Returns the fallback table name of the connector.
Returns :
str : The fallback table name of the connector.
"""
return self._config.table

def execute(self):
"""
Execute the connector and return the result.
Returns:
DataFrameType: The result of the connector.
"""
cached = self._cached() or self._cached(include_additional_filters=True)
if cached:
return pd.read_parquet(cached)

if isinstance(self._instance, pd.DataFrame):
return self._instance
else:
self._instance = self._fetch_data()

return self._instance

def _build_formula(self):
"""
Build Airtable query formula for filtering.
"""

condition_strings = []
if self._config.where is not None:
for i in self._config.where:
filter_query = f"{i[0]}{i[1]}'{i[2]}'"
condition_strings.append(filter_query)
filter_formula = f'AND({",".join(condition_strings)})'
return filter_formula

def _request_api(self, params):
url = f"{self._root_url}{self._config.base_id}/{self._config.table}"
response = requests.get(
url=url,
headers={"Authorization": f"Bearer {self._config.api_key}"},
params=params,
)
return response

def _fetch_data(self):
"""
Fetches data from the Airtable server via API and converts it to a DataFrame.
"""

params = {
"pageSize": 100,
"offset": "0"
}

if self._config.where is not None:
params["filterByFormula"] = self._build_formula()

data = []
while True:
response = self._request_api(params=params)

if response.status_code != 200:
raise InvalidRequestError(
f"Failed to connect to Airtable. "
f"Status code: {response.status_code}, "
f"message: {response.text}"
)

res = response.json()
records = res.get("records", [])
data.extend({"id": record["id"], **record["fields"]} for record in records)

if len(records) < 100 or "offset" not in res:
break

if "offset" in res:
params["offset"] = res["offset"]

return pd.DataFrame(data)

@cache
def head(self):
"""
Return the head of the table that
the connector is connected to.
Returns :
DatFrameType: The head of the data source
that the conector is connected to .
"""
data = self._request_api(params={"maxRecords": 5})
return pd.DataFrame(
[
{"id": record["id"], **record["fields"]}
for record in data.json()["records"]
]
)

@cached_property
def rows_count(self):
"""
Return the number of rows in the data source that the connector is
connected to.
Returns:
int: The number of rows in the data source that the connector is
connected to.
"""
if self._rows_count is not None:
return self._rows_count
data = self.execute()
self._rows_count = len(data)
return self._rows_count

@cached_property
def columns_count(self):
"""
Return the number of columns in the data source that the connector is
connected to.
Returns:
int: The number of columns in the data source that the connector is
connected to.
"""
if self._columns_count is not None:
return self._columns_count
data = self.head()
self._columns_count = len(data.columns)
return self._columns_count

@property
def column_hash(self):
"""
Return the hash code that is unique to the columns of the data source
that the connector is connected to.
Returns:
int: The hash code that is unique to the columns of the data source
that the connector is connected to.
"""
if not isinstance(self._instance, pd.DataFrame):
self._instance = self.execute()
columns_str = "|".join(self._instance.columns)
columns_str += "WHERE" + self._build_formula()
return hashlib.sha256(columns_str.encode("utf-8")).hexdigest()
10 changes: 10 additions & 0 deletions pandasai/connectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ class BaseConnectorConfig(BaseModel):
where: list[list[str]] = None


class AirtableConnectorConfig(BaseConnectorConfig):
"""
Connecter configuration for Airtable data.
"""

api_key: str
base_id: str
database: str = "airtable_data"


class SQLBaseConnectorConfig(BaseConnectorConfig):
"""
Base Connector configuration.
Expand Down
10 changes: 10 additions & 0 deletions pandasai/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@
"""


class InvalidRequestError(Exception):

"""
Raised when the request is not succesfull.
Args :
Exception (Exception): InvalidRequestError
"""


class APIKeyNotFoundError(Exception):

"""
Expand Down
Loading

0 comments on commit cec0a61

Please sign in to comment.