Skip to content

Commit

Permalink
refactor: make amazon bedrock llm extension
Browse files Browse the repository at this point in the history
  • Loading branch information
gventuri committed Oct 17, 2024
1 parent 4e0a6a8 commit 8c3b23d
Show file tree
Hide file tree
Showing 8 changed files with 276 additions and 11 deletions.
11 changes: 11 additions & 0 deletions extensions/llms/bedrock/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Google AI Extension for PandasAI

This extension integrates Google AI with PandasAI, providing Google AI LLMs support.

## Installation

You can install this extension using poetry:

```bash
poetry add pandasai-google
```
4 changes: 4 additions & 0 deletions extensions/llms/bedrock/pandasai_bedrock/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .claude import BedrockClaude


__all__ = ["BedrockClaude"]
94 changes: 94 additions & 0 deletions extensions/llms/bedrock/pandasai_bedrock/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from __future__ import annotations

from abc import abstractmethod
from typing import TYPE_CHECKING, Optional

from pandasai.helpers.memory import Memory

from pandasai.exceptions import (
MethodNotImplementedError,
)
from pandasai.prompts.base import BasePrompt
from pandasai.llm.base import LLM

if TYPE_CHECKING:
from pandasai.pipelines.pipeline_context import PipelineContext


class BaseGoogle(LLM):
"""Base class to implement a new Google LLM
LLM base class is extended to be used with
"""

temperature: Optional[float] = 0
top_p: Optional[float] = 0.8
top_k: Optional[int] = 40
max_output_tokens: Optional[int] = 1000

def _valid_params(self):
return ["temperature", "top_p", "top_k", "max_output_tokens"]

def _set_params(self, **kwargs):
"""
Dynamically set Parameters for the object.
Args:
**kwargs:
Possible keyword arguments: "temperature", "top_p", "top_k",
"max_output_tokens".
Returns:
None.
"""

valid_params = self._valid_params()
for key, value in kwargs.items():
if key in valid_params:
setattr(self, key, value)

def _validate(self):
"""Validates the parameters for Google"""

if self.temperature is not None and not 0 <= self.temperature <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]")

if self.top_p is not None and not 0 <= self.top_p <= 1:
raise ValueError("top_p must be in the range [0.0, 1.0]")

if self.top_k is not None and not 0 <= self.top_k <= 100:
raise ValueError("top_k must be in the range [0.0, 100.0]")

if self.max_output_tokens is not None and self.max_output_tokens <= 0:
raise ValueError("max_output_tokens must be greater than zero")

@abstractmethod
def _generate_text(self, prompt: str, memory: Optional[Memory] = None) -> str:
"""
Generates text for prompt, specific to implementation.
Args:
prompt (str): A string representation of the prompt.
Returns:
str: LLM response.
"""
raise MethodNotImplementedError("method has not been implemented")

def call(self, instruction: BasePrompt, context: PipelineContext = None) -> str:
"""
Call the Google LLM.
Args:
instruction (BasePrompt): Instruction to pass.
context (PipelineContext): Pass PipelineContext.
Returns:
str: LLM response.
"""
self.last_prompt = instruction.to_string()
memory = context.memory if context else None
return self._generate_text(self.last_prompt, memory)
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import json
from typing import TYPE_CHECKING, Any, Dict, Optional

from ..exceptions import APIKeyNotFoundError, UnsupportedModelError
from ..helpers import load_dotenv
from ..prompts.base import BasePrompt
from .base import LLM
from pandasai.exceptions import APIKeyNotFoundError, UnsupportedModelError
from pandasai.helpers import load_dotenv
from pandasai.prompts.base import BasePrompt
from pandasai.llm.base import LLM

if TYPE_CHECKING:
from pandasai.pipelines.pipeline_context import PipelineContext
Expand Down
146 changes: 146 additions & 0 deletions extensions/llms/bedrock/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions extensions/llms/bedrock/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[tool.poetry]
name = "pandasai-bedrock"
version = "0.1.0"
description = "Amazon bedrock integration for PandasAI"
authors = ["Gabriele Venturi"]
license = "MIT"
readme = "README.md"

[tool.poetry.dependencies]
python = ">=3.9,<4.0"
pandasai = "^3.0.0"
boto3 = "^1.34.59"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
6 changes: 0 additions & 6 deletions pandasai/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
from .bamboo_llm import BambooLLM
from .base import LLM
from .bedrock_claude import BedrockClaude
from .google_gemini import GoogleGemini
from .google_vertexai import GoogleVertexAI
from .huggingface_text_gen import HuggingFaceTextGen
from .ibm_watsonx import IBMwatsonx
from .langchain import LangchainLLM

__all__ = [
"LLM",
"BambooLLM",
"GoogleVertexAI",
"GoogleGemini",
"HuggingFaceTextGen",
"LangchainLLM",
"BedrockClaude",
"IBMwatsonx",
]
2 changes: 1 addition & 1 deletion tests/unit_tests/llms/test_bedrock_claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

from pandasai.exceptions import APIKeyNotFoundError, UnsupportedModelError
from pandasai.llm import BedrockClaude
from extensions.llms.bedrock.pandasai_bedrock.claude import BedrockClaude
from pandasai.prompts import BasePrompt


Expand Down

0 comments on commit 8c3b23d

Please sign in to comment.