Skip to content

Commit 1f0270f

Browse files
update (#2537)
Add support for agentic retrieval
1 parent 986009c commit 1f0270f

File tree

110 files changed

+1778
-158
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

110 files changed

+1778
-158
lines changed

app/backend/app.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
get_bearer_token_provider,
2424
)
2525
from azure.monitor.opentelemetry import configure_azure_monitor
26+
from azure.search.documents.agent.aio import KnowledgeAgentRetrievalClient
2627
from azure.search.documents.aio import SearchClient
2728
from azure.search.documents.indexes.aio import SearchIndexClient
2829
from azure.storage.blob.aio import ContainerClient
@@ -57,6 +58,8 @@
5758
from approaches.retrievethenreadvision import RetrieveThenReadVisionApproach
5859
from chat_history.cosmosdb import chat_history_cosmosdb_bp
5960
from config import (
61+
CONFIG_AGENT_CLIENT,
62+
CONFIG_AGENTIC_RETRIEVAL_ENABLED,
6063
CONFIG_ASK_APPROACH,
6164
CONFIG_ASK_VISION_APPROACH,
6265
CONFIG_AUTH_CLIENT,
@@ -308,6 +311,7 @@ def config():
308311
"showSpeechOutputAzure": current_app.config[CONFIG_SPEECH_OUTPUT_AZURE_ENABLED],
309312
"showChatHistoryBrowser": current_app.config[CONFIG_CHAT_HISTORY_BROWSER_ENABLED],
310313
"showChatHistoryCosmos": current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED],
314+
"showAgenticRetrievalOption": current_app.config[CONFIG_AGENTIC_RETRIEVAL_ENABLED],
311315
}
312316
)
313317

@@ -424,10 +428,14 @@ async def setup_clients():
424428
AZURE_USERSTORAGE_ACCOUNT = os.environ.get("AZURE_USERSTORAGE_ACCOUNT")
425429
AZURE_USERSTORAGE_CONTAINER = os.environ.get("AZURE_USERSTORAGE_CONTAINER")
426430
AZURE_SEARCH_SERVICE = os.environ["AZURE_SEARCH_SERVICE"]
431+
AZURE_SEARCH_ENDPOINT = f"https://{AZURE_SEARCH_SERVICE}.search.windows.net"
427432
AZURE_SEARCH_INDEX = os.environ["AZURE_SEARCH_INDEX"]
433+
AZURE_SEARCH_AGENT = os.getenv("AZURE_SEARCH_AGENT", "")
428434
# Shared by all OpenAI deployments
429435
OPENAI_HOST = os.getenv("OPENAI_HOST", "azure")
430436
OPENAI_CHATGPT_MODEL = os.environ["AZURE_OPENAI_CHATGPT_MODEL"]
437+
AZURE_OPENAI_SEARCHAGENT_MODEL = os.getenv("AZURE_OPENAI_SEARCHAGENT_MODEL")
438+
AZURE_OPENAI_SEARCHAGENT_DEPLOYMENT = os.getenv("AZURE_OPENAI_SEARCHAGENT_DEPLOYMENT")
431439
OPENAI_EMB_MODEL = os.getenv("AZURE_OPENAI_EMB_MODEL_NAME", "text-embedding-ada-002")
432440
OPENAI_EMB_DIMENSIONS = int(os.getenv("AZURE_OPENAI_EMB_DIMENSIONS") or 1536)
433441
OPENAI_REASONING_EFFORT = os.getenv("AZURE_OPENAI_REASONING_EFFORT")
@@ -479,6 +487,7 @@ async def setup_clients():
479487
USE_SPEECH_OUTPUT_AZURE = os.getenv("USE_SPEECH_OUTPUT_AZURE", "").lower() == "true"
480488
USE_CHAT_HISTORY_BROWSER = os.getenv("USE_CHAT_HISTORY_BROWSER", "").lower() == "true"
481489
USE_CHAT_HISTORY_COSMOS = os.getenv("USE_CHAT_HISTORY_COSMOS", "").lower() == "true"
490+
USE_AGENTIC_RETRIEVAL = os.getenv("USE_AGENTIC_RETRIEVAL", "").lower() == "true"
482491

483492
# WEBSITE_HOSTNAME is always set by App Service, RUNNING_IN_PRODUCTION is set in main.bicep
484493
RUNNING_ON_AZURE = os.getenv("WEBSITE_HOSTNAME") is not None or os.getenv("RUNNING_IN_PRODUCTION") is not None
@@ -513,10 +522,13 @@ async def setup_clients():
513522

514523
# Set up clients for AI Search and Storage
515524
search_client = SearchClient(
516-
endpoint=f"https://{AZURE_SEARCH_SERVICE}.search.windows.net",
525+
endpoint=AZURE_SEARCH_ENDPOINT,
517526
index_name=AZURE_SEARCH_INDEX,
518527
credential=azure_credential,
519528
)
529+
agent_client = KnowledgeAgentRetrievalClient(
530+
endpoint=AZURE_SEARCH_ENDPOINT, agent_name=AZURE_SEARCH_AGENT, credential=azure_credential
531+
)
520532

521533
blob_container_client = ContainerClient(
522534
f"https://{AZURE_STORAGE_ACCOUNT}.blob.core.windows.net", AZURE_STORAGE_CONTAINER, credential=azure_credential
@@ -527,7 +539,7 @@ async def setup_clients():
527539
if AZURE_USE_AUTHENTICATION:
528540
current_app.logger.info("AZURE_USE_AUTHENTICATION is true, setting up search index client")
529541
search_index_client = SearchIndexClient(
530-
endpoint=f"https://{AZURE_SEARCH_SERVICE}.search.windows.net",
542+
endpoint=AZURE_SEARCH_ENDPOINT,
531543
credential=azure_credential,
532544
)
533545
search_index = await search_index_client.get_index(AZURE_SEARCH_INDEX)
@@ -645,6 +657,7 @@ async def setup_clients():
645657

646658
current_app.config[CONFIG_OPENAI_CLIENT] = openai_client
647659
current_app.config[CONFIG_SEARCH_CLIENT] = search_client
660+
current_app.config[CONFIG_AGENT_CLIENT] = agent_client
648661
current_app.config[CONFIG_BLOB_CONTAINER_CLIENT] = blob_container_client
649662
current_app.config[CONFIG_AUTH_CLIENT] = auth_helper
650663

@@ -668,13 +681,18 @@ async def setup_clients():
668681
current_app.config[CONFIG_SPEECH_OUTPUT_AZURE_ENABLED] = USE_SPEECH_OUTPUT_AZURE
669682
current_app.config[CONFIG_CHAT_HISTORY_BROWSER_ENABLED] = USE_CHAT_HISTORY_BROWSER
670683
current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED] = USE_CHAT_HISTORY_COSMOS
684+
current_app.config[CONFIG_AGENTIC_RETRIEVAL_ENABLED] = USE_AGENTIC_RETRIEVAL
671685

672686
prompt_manager = PromptyManager()
673687

674688
# Set up the two default RAG approaches for /ask and /chat
675689
# RetrieveThenReadApproach is used by /ask for single-turn Q&A
676690
current_app.config[CONFIG_ASK_APPROACH] = RetrieveThenReadApproach(
677691
search_client=search_client,
692+
search_index_name=AZURE_SEARCH_INDEX,
693+
agent_model=AZURE_OPENAI_SEARCHAGENT_MODEL,
694+
agent_deployment=AZURE_OPENAI_SEARCHAGENT_DEPLOYMENT,
695+
agent_client=agent_client,
678696
openai_client=openai_client,
679697
auth_helper=auth_helper,
680698
chatgpt_model=OPENAI_CHATGPT_MODEL,
@@ -694,6 +712,10 @@ async def setup_clients():
694712
# ChatReadRetrieveReadApproach is used by /chat for multi-turn conversation
695713
current_app.config[CONFIG_CHAT_APPROACH] = ChatReadRetrieveReadApproach(
696714
search_client=search_client,
715+
search_index_name=AZURE_SEARCH_INDEX,
716+
agent_model=AZURE_OPENAI_SEARCHAGENT_MODEL,
717+
agent_deployment=AZURE_OPENAI_SEARCHAGENT_DEPLOYMENT,
718+
agent_client=agent_client,
697719
openai_client=openai_client,
698720
auth_helper=auth_helper,
699721
chatgpt_model=OPENAI_CHATGPT_MODEL,

app/backend/approaches/approach.py

Lines changed: 82 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,20 @@
22
from abc import ABC
33
from collections.abc import AsyncGenerator, Awaitable
44
from dataclasses import dataclass
5-
from typing import (
6-
Any,
7-
Callable,
8-
Optional,
9-
TypedDict,
10-
Union,
11-
cast,
12-
)
5+
from typing import Any, Callable, Optional, TypedDict, Union, cast
136
from urllib.parse import urljoin
147

158
import aiohttp
9+
from azure.search.documents.agent.aio import KnowledgeAgentRetrievalClient
10+
from azure.search.documents.agent.models import (
11+
KnowledgeAgentAzureSearchDocReference,
12+
KnowledgeAgentIndexParams,
13+
KnowledgeAgentMessage,
14+
KnowledgeAgentMessageTextContent,
15+
KnowledgeAgentRetrievalRequest,
16+
KnowledgeAgentRetrievalResponse,
17+
KnowledgeAgentSearchActivityRecord,
18+
)
1619
from azure.search.documents.aio import SearchClient
1720
from azure.search.documents.models import (
1821
QueryCaptionResult,
@@ -36,16 +39,17 @@
3639

3740
@dataclass
3841
class Document:
39-
id: Optional[str]
40-
content: Optional[str]
41-
category: Optional[str]
42-
sourcepage: Optional[str]
43-
sourcefile: Optional[str]
44-
oids: Optional[list[str]]
45-
groups: Optional[list[str]]
46-
captions: list[QueryCaptionResult]
42+
id: Optional[str] = None
43+
content: Optional[str] = None
44+
category: Optional[str] = None
45+
sourcepage: Optional[str] = None
46+
sourcefile: Optional[str] = None
47+
oids: Optional[list[str]] = None
48+
groups: Optional[list[str]] = None
49+
captions: Optional[list[QueryCaptionResult]] = None
4750
score: Optional[float] = None
4851
reranker_score: Optional[float] = None
52+
search_agent_query: Optional[str] = None
4953

5054
def serialize_for_results(self) -> dict[str, Any]:
5155
result_dict = {
@@ -70,6 +74,7 @@ def serialize_for_results(self) -> dict[str, Any]:
7074
),
7175
"score": self.score,
7276
"reranker_score": self.reranker_score,
77+
"search_agent_query": self.search_agent_query,
7378
}
7479
return result_dict
7580

@@ -247,6 +252,67 @@ async def search(
247252

248253
return qualified_documents
249254

255+
async def run_agentic_retrieval(
256+
self,
257+
messages: list[ChatCompletionMessageParam],
258+
agent_client: KnowledgeAgentRetrievalClient,
259+
search_index_name: str,
260+
top: Optional[int] = None,
261+
filter_add_on: Optional[str] = None,
262+
minimum_reranker_score: Optional[float] = None,
263+
max_docs_for_reranker: Optional[int] = None,
264+
) -> tuple[KnowledgeAgentRetrievalResponse, list[Document]]:
265+
# STEP 1: Invoke agentic retrieval
266+
response = await agent_client.retrieve(
267+
retrieval_request=KnowledgeAgentRetrievalRequest(
268+
messages=[
269+
KnowledgeAgentMessage(
270+
role=str(msg["role"]), content=[KnowledgeAgentMessageTextContent(text=str(msg["content"]))]
271+
)
272+
for msg in messages
273+
if msg["role"] != "system"
274+
],
275+
target_index_params=[
276+
KnowledgeAgentIndexParams(
277+
index_name=search_index_name,
278+
reranker_threshold=minimum_reranker_score,
279+
max_docs_for_reranker=max_docs_for_reranker,
280+
filter_add_on=filter_add_on,
281+
include_reference_source_data=True,
282+
)
283+
],
284+
)
285+
)
286+
287+
# STEP 2: Generate a contextual and content specific answer using the search results and chat history
288+
activities = response.activity
289+
activity_mapping = (
290+
{
291+
activity.id: activity.query.search if activity.query else ""
292+
for activity in activities
293+
if isinstance(activity, KnowledgeAgentSearchActivityRecord)
294+
}
295+
if activities
296+
else {}
297+
)
298+
299+
results = []
300+
if response and response.references:
301+
for reference in response.references:
302+
if isinstance(reference, KnowledgeAgentAzureSearchDocReference) and reference.source_data:
303+
results.append(
304+
Document(
305+
id=reference.doc_key,
306+
content=reference.source_data["content"],
307+
sourcepage=reference.source_data["sourcepage"],
308+
search_agent_query=activity_mapping[reference.activity_source],
309+
)
310+
)
311+
if top and len(results) == top:
312+
break
313+
314+
return response, results
315+
250316
def get_sources_content(
251317
self, results: list[Document], use_semantic_captions: bool, use_image_citation: bool
252318
) -> list[str]:

0 commit comments

Comments
 (0)