-
Notifications
You must be signed in to change notification settings - Fork 291
/
Copy pathagent.py
32 lines (24 loc) · 890 Bytes
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
"""Agent functionality."""
import pandas as pd
from config import set_environment
from langchain.agents import AgentExecutor
from langchain_core.prompts import PromptTemplate
from langchain_experimental.agents.agent_toolkits import create_pandas_dataframe_agent
from langchain_openai import OpenAI
from data_science.prompts import PROMPT
set_environment()
def create_agent(csv_file: str) -> AgentExecutor:
"""
Create data agent.
Args:
csv_file: The path to the CSV file.
Returns:
An agent executor.
"""
llm = OpenAI()
df = pd.read_csv(csv_file)
return create_pandas_dataframe_agent(llm, df, verbose=True)
def query_agent(agent: AgentExecutor, query: str) -> str:
"""Query an agent and return the response."""
prompt = PromptTemplate(template=PROMPT, input_variables=["query"])
return agent.run(prompt.format(query=query))