Skip to content

Commit

Permalink
switch from autogen to plain openai
Browse files Browse the repository at this point in the history
  • Loading branch information
etwk committed Aug 3, 2024
1 parent 6ef1589 commit f37dd9e
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 47 deletions.
6 changes: 4 additions & 2 deletions .env
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
HOSTING_CHECK_BASE_URL=http://127.0.0.1:8000
LLM_LOCAL_BASE_URL=http://xinference:9997/v1
LLM_MODEL_NAME=google/gemma-2-27b-it
OLLAMA_BASE_URL=http://ollama:11434
HOSTING_CHECK_BASE_URL=http://127.0.0.1:8000
SEARCH_BASE_URL=https://s.jina.ai
OPENAI_API_KEY=sk-proj-aaaaaaaaaaaaaaaaa
OPENAI_BASE_URL=http://localhost:8000/v1
RAG_MODEL_DEPLOY=local
SEARCH_BASE_URL=https://s.jina.ai
2 changes: 1 addition & 1 deletion requirements.local.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ llama-index
llama-index-embeddings-huggingface
llama-index-embeddings-ollama
llama-index-postprocessor-jinaai-rerank
pyautogen
openai
uvicorn
2 changes: 1 addition & 1 deletion requirements.remote.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ fastapi
llama-index
llama-index-embeddings-ollama
llama-index-postprocessor-jinaai-rerank
pyautogen
openai
uvicorn
66 changes: 23 additions & 43 deletions src/llm.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,29 @@
import os
from autogen import AssistantAgent
from openai import OpenAI
import logging

import utils

logger = logging.getLogger(__name__)

"""
About models:
- Gemma 2 does not support system rule
config_list:
- {"price": [prompt_price_per_1k, completion_token_price_per_1k]}
Todo:
- With xinference + Gemma 2 + AutoGen, why 'system message' does not work well
"""

LLM_MODEL_NAME = os.environ.get("LLM_MODEL_NAME") or "google/gemma-2-27b-it"
config_list_local = [
# set prices, otherwise there will be warnings
{"model": LLM_MODEL_NAME, "base_url": os.environ.get("OLLAMA_BASE_URL") + "/v1", "tags": ["gemma", "local"], "price": [0, 0]},
]

llm_config = {"config_list": config_list_local}
llm_client = OpenAI(
base_url=os.environ.get("OPENAI_BASE_URL"),
api_key="token",
)

def get_llm_reply(prompt):
completion = llm_client.chat.completions.create(
model=LLM_MODEL_NAME,
messages=[
{"role": "user", "content": prompt}
]
)
return completion.choices[0].message.content

"""
Get list of statements from input.
Expand All @@ -33,21 +34,14 @@ def get_statements(input):
Extract key facts from the given content.
Provide a list of the facts in array format as response only.'''

statement_extract_agent = AssistantAgent(
name="statement_extract_agent",
system_message='',
llm_config=llm_config,
human_input_mode="NEVER",
)

content = f'''{system_message}
prompt = f'''{system_message}
```
Content:
{input}
```'''

reply = statement_extract_agent.generate_reply(messages=[{"content": content, "role": "user"}])
logger.debug(f"get_statements LLM reply: {reply}")
reply = get_llm_reply(prompt)
logging.debug(f"get_statements LLM reply: {reply}")
return utils.llm2list(reply)

"""
Expand All @@ -59,19 +53,12 @@ def get_search_keywords(statement):
Generate search keyword used for fact check on the given statement.
Include only the keyword in your response.'''

search_keywords_agent = AssistantAgent(
name="search_keywords_agent",
system_message='',
llm_config=llm_config,
human_input_mode="NEVER",
)

content = f'''{system_message}
prompt = f'''{system_message}
```
Statement:
{statement}
```'''
reply = search_keywords_agent.generate_reply(messages=[{"content": content, "role": "user"}])
reply = get_llm_reply(prompt)
return reply.strip()

def get_verdict(statement, contexts):
Expand All @@ -85,14 +72,7 @@ def get_verdict(statement, contexts):
Be thorough in your explanations, avoiding any duplication of information.
Provide the response as JSON with the structure:{verdict, reason}'''

fact_check_agent = AssistantAgent(
name="fact_check_agent",
system_message='',
llm_config=llm_config,
human_input_mode="NEVER",
)

content = f'''{system_message}
prompt = f'''{system_message}
```
Statement:
{statement}
Expand All @@ -103,13 +83,13 @@ def get_verdict(statement, contexts):
_text = node.get('text')
if not _text:
continue
content = f"""{content}
prompt = f"""{prompt}
```
Context {ind + 1}:
{_text}
```"""

reply = fact_check_agent.generate_reply(messages=[{"content": content, "role": "user"}])
reply = get_llm_reply(prompt)
verdict = utils.llm2json(reply)
if verdict:
verdict['statement'] = statement
Expand Down

0 comments on commit f37dd9e

Please sign in to comment.