Skip to content

Add streaming #61

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
-r src/backend/requirements.txt
ruff
mypy
pre-commit
pip-tools
pip-compile-cross-platform
pytest
pytest-cov
pytest-asyncio
pytest-snapshot
mypy
locust
44 changes: 42 additions & 2 deletions src/backend/fastapi_app/api_models.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,42 @@
from enum import Enum
from typing import Any

from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel


class AIChatRoles(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"


class Message(BaseModel):
content: str
role: str = "user"
role: AIChatRoles = AIChatRoles.USER


class RetrievalMode(str, Enum):
TEXT = "text"
VECTORS = "vectors"
HYBRID = "hybrid"


class ChatRequestOverrides(BaseModel):
top: int = 3
temperature: float = 0.3
retrieval_mode: RetrievalMode = RetrievalMode.HYBRID
use_advanced_flow: bool = True
prompt_template: str | None = None


class ChatRequestContext(BaseModel):
overrides: ChatRequestOverrides


class ChatRequest(BaseModel):
messages: list[ChatCompletionMessageParam]
context: dict = {}
context: ChatRequestContext


class ThoughtStep(BaseModel):
Expand All @@ -32,6 +57,12 @@ class RetrievalResponse(BaseModel):
session_state: Any | None = None


class RetrievalResponseDelta(BaseModel):
delta: Message | None = None
context: RAGContext | None = None
session_state: Any | None = None


class ItemPublic(BaseModel):
id: int
type: str
Expand All @@ -43,3 +74,12 @@ class ItemPublic(BaseModel):

class ItemWithDistance(ItemPublic):
distance: float


class ChatParams(ChatRequestOverrides):
prompt_template: str
response_token_limit: int = 1024
enable_text_search: bool
enable_vector_search: bool
original_user_query: str
past_messages: list[ChatCompletionMessageParam]
175 changes: 118 additions & 57 deletions src/backend/fastapi_app/rag_advanced.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
import pathlib
from collections.abc import AsyncGenerator
from typing import (
Any,
)
from typing import Any

from openai import AsyncAzureOpenAI, AsyncOpenAI
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
from openai import AsyncAzureOpenAI, AsyncOpenAI, AsyncStream
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam
from openai_messages_token_helper import build_messages, get_token_limit

from .api_models import Message, RAGContext, RetrievalResponse, ThoughtStep
from .postgres_searcher import PostgresSearcher
from .query_rewriter import build_search_function, extract_search_arguments
from fastapi_app.api_models import (
AIChatRoles,
Message,
RAGContext,
RetrievalResponse,
RetrievalResponseDelta,
ThoughtStep,
)
from fastapi_app.postgres_models import Item
from fastapi_app.postgres_searcher import PostgresSearcher
from fastapi_app.query_rewriter import build_search_function, extract_search_arguments
from fastapi_app.rag_base import ChatParams, RAGChatBase


class AdvancedRAGChat:
class AdvancedRAGChat(RAGChatBase):
def __init__(
self,
*,
Expand All @@ -27,24 +33,11 @@ def __init__(
self.chat_model = chat_model
self.chat_deployment = chat_deployment
self.chat_token_limit = get_token_limit(chat_model, default_to_minimum=True)
current_dir = pathlib.Path(__file__).parent
self.query_prompt_template = open(current_dir / "prompts/query.txt").read()
self.answer_prompt_template = open(current_dir / "prompts/answer.txt").read()

async def run(
self, messages: list[ChatCompletionMessageParam], overrides: dict[str, Any] = {}
) -> RetrievalResponse | AsyncGenerator[dict[str, Any], None]:
text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None]
vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
top = overrides.get("top", 3)

original_user_query = messages[-1]["content"]
if not isinstance(original_user_query, str):
raise ValueError("The most recent message content must be a string.")
past_messages = messages[:-1]

# Generate an optimized keyword search query based on the chat history and the last question
query_response_token_limit = 500

async def generate_search_query(
self, original_user_query: str, past_messages: list[ChatCompletionMessageParam], query_response_token_limit: int
) -> tuple[list[ChatCompletionMessageParam], Any | str | None, list]:
"""Generate an optimized keyword search query based on the chat history and the last question"""
query_messages: list[ChatCompletionMessageParam] = build_messages(
model=self.chat_model,
system_prompt=self.query_prompt_template,
Expand All @@ -67,68 +60,128 @@ async def run(

query_text, filters = extract_search_arguments(original_user_query, chat_completion)

return query_messages, query_text, filters

async def prepare_context(
self, chat_params: ChatParams
) -> tuple[list[ChatCompletionMessageParam], list[Item], list[ThoughtStep]]:
query_messages, query_text, filters = await self.generate_search_query(
original_user_query=chat_params.original_user_query,
past_messages=chat_params.past_messages,
query_response_token_limit=500,
)

# Retrieve relevant items from the database with the GPT optimized query
results = await self.searcher.search_and_embed(
query_text,
top=top,
enable_vector_search=vector_search,
enable_text_search=text_search,
top=chat_params.top,
enable_vector_search=chat_params.enable_vector_search,
enable_text_search=chat_params.enable_text_search,
filters=filters,
)

sources_content = [f"[{(item.id)}]:{item.to_str_for_rag()}\n\n" for item in results]
content = "\n".join(sources_content)

# Generate a contextual and content specific answer using the search results and chat history
response_token_limit = 1024
contextual_messages: list[ChatCompletionMessageParam] = build_messages(
model=self.chat_model,
system_prompt=overrides.get("prompt_template") or self.answer_prompt_template,
new_user_content=original_user_query + "\n\nSources:\n" + content,
past_messages=past_messages,
max_tokens=self.chat_token_limit - response_token_limit,
system_prompt=chat_params.prompt_template,
new_user_content=chat_params.original_user_query + "\n\nSources:\n" + content,
past_messages=chat_params.past_messages,
max_tokens=self.chat_token_limit - chat_params.response_token_limit,
fallback_to_default=True,
)

thoughts = [
ThoughtStep(
title="Prompt to generate search arguments",
description=[str(message) for message in query_messages],
props=(
{"model": self.chat_model, "deployment": self.chat_deployment}
if self.chat_deployment
else {"model": self.chat_model}
),
),
ThoughtStep(
title="Search using generated search arguments",
description=query_text,
props={
"top": chat_params.top,
"vector_search": chat_params.enable_vector_search,
"text_search": chat_params.enable_text_search,
"filters": filters,
},
),
ThoughtStep(
title="Search results",
description=[result.to_dict() for result in results],
),
]
return contextual_messages, results, thoughts

async def answer(
self,
chat_params: ChatParams,
contextual_messages: list[ChatCompletionMessageParam],
results: list[Item],
earlier_thoughts: list[ThoughtStep],
) -> RetrievalResponse:
chat_completion_response: ChatCompletion = await self.openai_chat_client.chat.completions.create(
# Azure OpenAI takes the deployment name as the model name
model=self.chat_deployment if self.chat_deployment else self.chat_model,
messages=contextual_messages,
temperature=overrides.get("temperature", 0.3),
max_tokens=response_token_limit,
temperature=chat_params.temperature,
max_tokens=chat_params.response_token_limit,
n=1,
stream=False,
)
first_choice_message = chat_completion_response.choices[0].message

return RetrievalResponse(
message=Message(content=str(first_choice_message.content), role=first_choice_message.role),
message=Message(
content=str(chat_completion_response.choices[0].message.content), role=AIChatRoles.ASSISTANT
),
context=RAGContext(
data_points={item.id: item.to_dict() for item in results},
thoughts=[
thoughts=earlier_thoughts
+ [
ThoughtStep(
title="Prompt to generate search arguments",
description=[str(message) for message in query_messages],
title="Prompt to generate answer",
description=[str(message) for message in contextual_messages],
props=(
{"model": self.chat_model, "deployment": self.chat_deployment}
if self.chat_deployment
else {"model": self.chat_model}
),
),
ThoughtStep(
title="Search using generated search arguments",
description=query_text,
props={
"top": top,
"vector_search": vector_search,
"text_search": text_search,
"filters": filters,
},
),
ThoughtStep(
title="Search results",
description=[result.to_dict() for result in results],
),
],
),
)

async def answer_stream(
self,
chat_params: ChatParams,
contextual_messages: list[ChatCompletionMessageParam],
results: list[Item],
earlier_thoughts: list[ThoughtStep],
) -> AsyncGenerator[RetrievalResponseDelta, None]:
chat_completion_async_stream: AsyncStream[
ChatCompletionChunk
] = await self.openai_chat_client.chat.completions.create(
# Azure OpenAI takes the deployment name as the model name
model=self.chat_deployment if self.chat_deployment else self.chat_model,
messages=contextual_messages,
temperature=chat_params.temperature,
max_tokens=chat_params.response_token_limit,
n=1,
stream=True,
)

yield RetrievalResponseDelta(
context=RAGContext(
data_points={item.id: item.to_dict() for item in results},
thoughts=earlier_thoughts
+ [
ThoughtStep(
title="Prompt to generate answer",
description=[str(message) for message in contextual_messages],
Expand All @@ -141,3 +194,11 @@ async def run(
],
),
)

async for response_chunk in chat_completion_async_stream:
# first response has empty choices and last response has empty content
if response_chunk.choices and response_chunk.choices[0].delta.content:
yield RetrievalResponseDelta(
delta=Message(content=str(response_chunk.choices[0].delta.content), role=AIChatRoles.ASSISTANT)
)
return
73 changes: 73 additions & 0 deletions src/backend/fastapi_app/rag_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import pathlib
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator

from openai.types.chat import ChatCompletionMessageParam

from fastapi_app.api_models import (
ChatParams,
ChatRequestOverrides,
RetrievalResponse,
RetrievalResponseDelta,
ThoughtStep,
)
from fastapi_app.postgres_models import Item


class RAGChatBase(ABC):
current_dir = pathlib.Path(__file__).parent
query_prompt_template = open(current_dir / "prompts/query.txt").read()
answer_prompt_template = open(current_dir / "prompts/answer.txt").read()

def get_params(self, messages: list[ChatCompletionMessageParam], overrides: ChatRequestOverrides) -> ChatParams:
response_token_limit = 1024
prompt_template = overrides.prompt_template or self.answer_prompt_template

enable_text_search = overrides.retrieval_mode in ["text", "hybrid", None]
enable_vector_search = overrides.retrieval_mode in ["vectors", "hybrid", None]

original_user_query = messages[-1]["content"]
if not isinstance(original_user_query, str):
raise ValueError("The most recent message content must be a string.")
past_messages = messages[:-1]

return ChatParams(
top=overrides.top,
temperature=overrides.temperature,
retrieval_mode=overrides.retrieval_mode,
use_advanced_flow=overrides.use_advanced_flow,
response_token_limit=response_token_limit,
prompt_template=prompt_template,
enable_text_search=enable_text_search,
enable_vector_search=enable_vector_search,
original_user_query=original_user_query,
past_messages=past_messages,
)

@abstractmethod
async def prepare_context(
self, chat_params: ChatParams
) -> tuple[list[ChatCompletionMessageParam], list[Item], list[ThoughtStep]]:
raise NotImplementedError

@abstractmethod
async def answer(
self,
chat_params: ChatParams,
contextual_messages: list[ChatCompletionMessageParam],
results: list[Item],
earlier_thoughts: list[ThoughtStep],
) -> RetrievalResponse:
raise NotImplementedError

@abstractmethod
async def answer_stream(
self,
chat_params: ChatParams,
contextual_messages: list[ChatCompletionMessageParam],
results: list[Item],
earlier_thoughts: list[ThoughtStep],
) -> AsyncGenerator[RetrievalResponseDelta, None]:
raise NotImplementedError
if False:
yield 0
Loading
Loading