From 8c3b23d931201cdd5370b690629e9b2232f54a71 Mon Sep 17 00:00:00 2001 From: Gabriele Venturi Date: Fri, 18 Oct 2024 00:54:14 +0200 Subject: [PATCH] refactor: make amazon bedrock llm extension --- extensions/llms/bedrock/README.md | 11 ++ .../llms/bedrock/pandasai_bedrock/__init__.py | 4 + .../llms/bedrock/pandasai_bedrock/base.py | 94 +++++++++++ .../llms/bedrock/pandasai_bedrock/claude.py | 8 +- extensions/llms/bedrock/poetry.lock | 146 ++++++++++++++++++ extensions/llms/bedrock/pyproject.toml | 16 ++ pandasai/llm/__init__.py | 6 - tests/unit_tests/llms/test_bedrock_claude.py | 2 +- 8 files changed, 276 insertions(+), 11 deletions(-) create mode 100644 extensions/llms/bedrock/README.md create mode 100644 extensions/llms/bedrock/pandasai_bedrock/__init__.py create mode 100644 extensions/llms/bedrock/pandasai_bedrock/base.py rename pandasai/llm/bedrock_claude.py => extensions/llms/bedrock/pandasai_bedrock/claude.py (96%) create mode 100644 extensions/llms/bedrock/poetry.lock create mode 100644 extensions/llms/bedrock/pyproject.toml diff --git a/extensions/llms/bedrock/README.md b/extensions/llms/bedrock/README.md new file mode 100644 index 000000000..05129138e --- /dev/null +++ b/extensions/llms/bedrock/README.md @@ -0,0 +1,11 @@ +# Google AI Extension for PandasAI + +This extension integrates Google AI with PandasAI, providing Google AI LLMs support. + +## Installation + +You can install this extension using poetry: + +```bash +poetry add pandasai-google +``` diff --git a/extensions/llms/bedrock/pandasai_bedrock/__init__.py b/extensions/llms/bedrock/pandasai_bedrock/__init__.py new file mode 100644 index 000000000..39837eb86 --- /dev/null +++ b/extensions/llms/bedrock/pandasai_bedrock/__init__.py @@ -0,0 +1,4 @@ +from .claude import BedrockClaude + + +__all__ = ["BedrockClaude"] diff --git a/extensions/llms/bedrock/pandasai_bedrock/base.py b/extensions/llms/bedrock/pandasai_bedrock/base.py new file mode 100644 index 000000000..28c1485b7 --- /dev/null +++ b/extensions/llms/bedrock/pandasai_bedrock/base.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import TYPE_CHECKING, Optional + +from pandasai.helpers.memory import Memory + +from pandasai.exceptions import ( + MethodNotImplementedError, +) +from pandasai.prompts.base import BasePrompt +from pandasai.llm.base import LLM + +if TYPE_CHECKING: + from pandasai.pipelines.pipeline_context import PipelineContext + + +class BaseGoogle(LLM): + """Base class to implement a new Google LLM + + LLM base class is extended to be used with + """ + + temperature: Optional[float] = 0 + top_p: Optional[float] = 0.8 + top_k: Optional[int] = 40 + max_output_tokens: Optional[int] = 1000 + + def _valid_params(self): + return ["temperature", "top_p", "top_k", "max_output_tokens"] + + def _set_params(self, **kwargs): + """ + Dynamically set Parameters for the object. + + Args: + **kwargs: + Possible keyword arguments: "temperature", "top_p", "top_k", + "max_output_tokens". + + Returns: + None. + + """ + + valid_params = self._valid_params() + for key, value in kwargs.items(): + if key in valid_params: + setattr(self, key, value) + + def _validate(self): + """Validates the parameters for Google""" + + if self.temperature is not None and not 0 <= self.temperature <= 1: + raise ValueError("temperature must be in the range [0.0, 1.0]") + + if self.top_p is not None and not 0 <= self.top_p <= 1: + raise ValueError("top_p must be in the range [0.0, 1.0]") + + if self.top_k is not None and not 0 <= self.top_k <= 100: + raise ValueError("top_k must be in the range [0.0, 100.0]") + + if self.max_output_tokens is not None and self.max_output_tokens <= 0: + raise ValueError("max_output_tokens must be greater than zero") + + @abstractmethod + def _generate_text(self, prompt: str, memory: Optional[Memory] = None) -> str: + """ + Generates text for prompt, specific to implementation. + + Args: + prompt (str): A string representation of the prompt. + + Returns: + str: LLM response. + + """ + raise MethodNotImplementedError("method has not been implemented") + + def call(self, instruction: BasePrompt, context: PipelineContext = None) -> str: + """ + Call the Google LLM. + + Args: + instruction (BasePrompt): Instruction to pass. + context (PipelineContext): Pass PipelineContext. + + Returns: + str: LLM response. + + """ + self.last_prompt = instruction.to_string() + memory = context.memory if context else None + return self._generate_text(self.last_prompt, memory) diff --git a/pandasai/llm/bedrock_claude.py b/extensions/llms/bedrock/pandasai_bedrock/claude.py similarity index 96% rename from pandasai/llm/bedrock_claude.py rename to extensions/llms/bedrock/pandasai_bedrock/claude.py index 1dc4dd477..ce8fe0ca8 100644 --- a/pandasai/llm/bedrock_claude.py +++ b/extensions/llms/bedrock/pandasai_bedrock/claude.py @@ -3,10 +3,10 @@ import json from typing import TYPE_CHECKING, Any, Dict, Optional -from ..exceptions import APIKeyNotFoundError, UnsupportedModelError -from ..helpers import load_dotenv -from ..prompts.base import BasePrompt -from .base import LLM +from pandasai.exceptions import APIKeyNotFoundError, UnsupportedModelError +from pandasai.helpers import load_dotenv +from pandasai.prompts.base import BasePrompt +from pandasai.llm.base import LLM if TYPE_CHECKING: from pandasai.pipelines.pipeline_context import PipelineContext diff --git a/extensions/llms/bedrock/poetry.lock b/extensions/llms/bedrock/poetry.lock new file mode 100644 index 000000000..db0fae6f9 --- /dev/null +++ b/extensions/llms/bedrock/poetry.lock @@ -0,0 +1,146 @@ +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. + +[[package]] +name = "boto3" +version = "1.35.43" +description = "The AWS SDK for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "boto3-1.35.43-py3-none-any.whl", hash = "sha256:e6a50a0599f75b21de0de1a551a0564793d25b304fa623e4052e527b268de734"}, + {file = "boto3-1.35.43.tar.gz", hash = "sha256:0197f460632804577aa78b2f6daf7b823bffa9d4d67a5cebb179efff0fe9631b"}, +] + +[package.dependencies] +botocore = ">=1.35.43,<1.36.0" +jmespath = ">=0.7.1,<2.0.0" +s3transfer = ">=0.10.0,<0.11.0" + +[package.extras] +crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] + +[[package]] +name = "botocore" +version = "1.35.43" +description = "Low-level, data-driven core of boto 3." +optional = false +python-versions = ">=3.8" +files = [ + {file = "botocore-1.35.43-py3-none-any.whl", hash = "sha256:7cfdee9117617da97daaf259dd8484bcdc259c59eb7d1ce7db9ecf8506b7d36c"}, + {file = "botocore-1.35.43.tar.gz", hash = "sha256:04539b85ade060601a3023cacb538fc17aad8c059a5a2e18fe4bc5d0d91fbd72"}, +] + +[package.dependencies] +jmespath = ">=0.7.1,<2.0.0" +python-dateutil = ">=2.1,<3.0.0" +urllib3 = [ + {version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""}, + {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""}, +] + +[package.extras] +crt = ["awscrt (==0.22.0)"] + +[[package]] +name = "jmespath" +version = "1.0.1" +description = "JSON Matching Expressions" +optional = false +python-versions = ">=3.7" +files = [ + {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"}, + {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, +] + +[[package]] +name = "pandasai" +version = "3.0.0" +description = "Chat with your database (SQL, CSV, pandas, mongodb, noSQL, etc). PandasAI makes data analysis conversational using LLMs (GPT 3.5 / 4, Anthropic, VertexAI) and RAG." +optional = false +python-versions = "*" +files = [] +develop = true + +[package.source] +type = "directory" +url = "../../.." + +[[package]] +name = "python-dateutil" +version = "2.9.0.post0" +description = "Extensions to the standard Python datetime module" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +files = [ + {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, + {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, +] + +[package.dependencies] +six = ">=1.5" + +[[package]] +name = "s3transfer" +version = "0.10.3" +description = "An Amazon S3 Transfer Manager" +optional = false +python-versions = ">=3.8" +files = [ + {file = "s3transfer-0.10.3-py3-none-any.whl", hash = "sha256:263ed587a5803c6c708d3ce44dc4dfedaab4c1a32e8329bab818933d79ddcf5d"}, + {file = "s3transfer-0.10.3.tar.gz", hash = "sha256:4f50ed74ab84d474ce614475e0b8d5047ff080810aac5d01ea25231cfc944b0c"}, +] + +[package.dependencies] +botocore = ">=1.33.2,<2.0a.0" + +[package.extras] +crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"] + +[[package]] +name = "six" +version = "1.16.0" +description = "Python 2 and 3 compatibility utilities" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, + {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, +] + +[[package]] +name = "urllib3" +version = "1.26.20" +description = "HTTP library with thread-safe connection pooling, file post, and more." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +files = [ + {file = "urllib3-1.26.20-py2.py3-none-any.whl", hash = "sha256:0ed14ccfbf1c30a9072c7ca157e4319b70d65f623e91e7b32fadb2853431016e"}, + {file = "urllib3-1.26.20.tar.gz", hash = "sha256:40c2dc0c681e47eb8f90e7e27bf6ff7df2e677421fd46756da1161c39ca70d32"}, +] + +[package.extras] +brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] +secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] +socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] + +[[package]] +name = "urllib3" +version = "2.2.3" +description = "HTTP library with thread-safe connection pooling, file post, and more." +optional = false +python-versions = ">=3.8" +files = [ + {file = "urllib3-2.2.3-py3-none-any.whl", hash = "sha256:ca899ca043dcb1bafa3e262d73aa25c465bfb49e0bd9dd5d59f1d0acba2f8fac"}, + {file = "urllib3-2.2.3.tar.gz", hash = "sha256:e7d814a81dad81e6caf2ec9fdedb284ecc9c73076b62654547cc64ccdcae26e9"}, +] + +[package.extras] +brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +h2 = ["h2 (>=4,<5)"] +socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] +zstd = ["zstandard (>=0.18.0)"] + +[metadata] +lock-version = "2.0" +python-versions = ">=3.9,<4.0" +content-hash = "dadd975897bbe36c0fb4f0945aa419e2ecc1b49dec8003a1cda1e3ec419df659" diff --git a/extensions/llms/bedrock/pyproject.toml b/extensions/llms/bedrock/pyproject.toml new file mode 100644 index 000000000..4b0c110ff --- /dev/null +++ b/extensions/llms/bedrock/pyproject.toml @@ -0,0 +1,16 @@ +[tool.poetry] +name = "pandasai-bedrock" +version = "0.1.0" +description = "Amazon bedrock integration for PandasAI" +authors = ["Gabriele Venturi"] +license = "MIT" +readme = "README.md" + +[tool.poetry.dependencies] +python = ">=3.9,<4.0" +pandasai = "^3.0.0" +boto3 = "^1.34.59" + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" diff --git a/pandasai/llm/__init__.py b/pandasai/llm/__init__.py index f507c6d45..e7d8beb23 100644 --- a/pandasai/llm/__init__.py +++ b/pandasai/llm/__init__.py @@ -1,8 +1,5 @@ from .bamboo_llm import BambooLLM from .base import LLM -from .bedrock_claude import BedrockClaude -from .google_gemini import GoogleGemini -from .google_vertexai import GoogleVertexAI from .huggingface_text_gen import HuggingFaceTextGen from .ibm_watsonx import IBMwatsonx from .langchain import LangchainLLM @@ -10,10 +7,7 @@ __all__ = [ "LLM", "BambooLLM", - "GoogleVertexAI", - "GoogleGemini", "HuggingFaceTextGen", "LangchainLLM", - "BedrockClaude", "IBMwatsonx", ] diff --git a/tests/unit_tests/llms/test_bedrock_claude.py b/tests/unit_tests/llms/test_bedrock_claude.py index 62f589bea..38627b2d1 100644 --- a/tests/unit_tests/llms/test_bedrock_claude.py +++ b/tests/unit_tests/llms/test_bedrock_claude.py @@ -6,7 +6,7 @@ import pytest from pandasai.exceptions import APIKeyNotFoundError, UnsupportedModelError -from pandasai.llm import BedrockClaude +from extensions.llms.bedrock.pandasai_bedrock.claude import BedrockClaude from pandasai.prompts import BasePrompt