Skip to content

feat: Support citation for agentic template #642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
ec06a12
feat: add citation processing and prompts for query engine
leehuwuj May 22, 2025
ff9acac
add example and improve tool description
leehuwuj May 23, 2025
a46d655
better cond
leehuwuj May 23, 2025
ad0c2f3
Improve citation instructions in prompts and enhance logging in query…
leehuwuj May 26, 2025
0960b14
Merge remote-tracking branch 'origin/main' into lee/citation-agentic
leehuwuj May 26, 2025
491dd8e
Enhance citation instructions in prompts, improve error handling in s…
leehuwuj May 26, 2025
22ca4ba
Update @llamaindex/chat-ui to version 0.4.6 in package.json and pnpm-…
leehuwuj May 27, 2025
eedffb0
Enhance system prompt and tool description for improved clarity on kn…
leehuwuj May 27, 2025
13a1454
introduce preconfigured agent for citation answering
leehuwuj May 27, 2025
8799669
Refactor workflow creation to utilize query tool with citation suppor…
leehuwuj May 27, 2025
fd82563
Implement AgentCallTool for event handling in chat router; enhance to…
leehuwuj May 27, 2025
7a463d0
better display llamacloud file name
leehuwuj May 27, 2025
f6f5f23
Merge remote-tracking branch 'origin/main' into lee/citation-agentic
leehuwuj May 27, 2025
1c83fa2
Remove citation agent implementation and update dependencies in pypro…
leehuwuj May 27, 2025
7a33a58
update createllama and add changesets
leehuwuj May 27, 2025
3a2be5c
Refactor citation handling in query tools
leehuwuj May 28, 2025
641a2be
Refactor SourceNodesFromToolCall initialization to use optional tool_…
leehuwuj May 28, 2025
46dec12
Refactor query engine and citation handling; enable citation in workf…
leehuwuj May 28, 2025
b95dcc7
refactor llamacloud file
leehuwuj May 28, 2025
b73680f
Refactor SourceNodesFromToolCall to remove deprecated tool_name param…
leehuwuj May 28, 2025
31576d9
fix mypy
leehuwuj May 28, 2025
356bbb3
add test for local python package
leehuwuj May 28, 2025
e6e2f78
remove tool name constraint
leehuwuj May 28, 2025
f9f3437
add missing working-directory
leehuwuj May 28, 2025
0ad7581
Update e2e workflow to build server package and set SERVER_PACKAGE_PA…
leehuwuj May 28, 2025
c3ad902
Update dependency handling for llama-index-server template in CI; rem…
leehuwuj May 28, 2025
c32b7f3
config hatch to fix script
leehuwuj May 28, 2025
f7c4ed3
fix mkdir windows
leehuwuj May 28, 2025
8764d93
fix wrong build command
leehuwuj May 28, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/server/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
"@babel/traverse": "^7.27.0",
"@babel/types": "^7.27.0",
"@hookform/resolvers": "^5.0.1",
"@llamaindex/chat-ui": "0.4.5",
"@llamaindex/chat-ui": "0.4.6",
"@radix-ui/react-accordion": "^1.2.3",
"@radix-ui/react-alert-dialog": "^1.1.7",
"@radix-ui/react-aspect-ratio": "^1.1.3",
Expand Down
10 changes: 5 additions & 5 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

60 changes: 60 additions & 0 deletions python/llama-index-server/examples/llamacloud/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import os
from typing import Optional

from fastapi import FastAPI
from llama_index.core.agent.workflow import AgentWorkflow
from llama_index.core.settings import Settings
from llama_index.llms.openai import OpenAI
from llama_index.server import LlamaIndexServer, UIConfig
from llama_index.server.api.models import ChatRequest
from llama_index.server.services.llamacloud import get_index
from llama_index.server.tools.index.query import get_query_engine_tool

# Please set the following environment variables to use LlamaCloud
if os.getenv("LLAMA_CLOUD_API_KEY") is None:
raise ValueError("LLAMA_CLOUD_API_KEY is not set")
if os.getenv("LLAMA_CLOUD_PROJECT_NAME") is None:
raise ValueError("LLAMA_CLOUD_PROJECT_NAME is not set")
if os.getenv("LLAMA_CLOUD_INDEX_NAME") is None:
raise ValueError("LLAMA_CLOUD_INDEX_NAME is not set")

Settings.llm = OpenAI(model="gpt-4.1")


def create_workflow(chat_request: Optional[ChatRequest] = None) -> AgentWorkflow:
index = get_index(chat_request=chat_request)
if index is None:
raise RuntimeError("Index not found!")
# Create a query tool with citations enabled
query_tool = get_query_engine_tool(index=index, enable_citation=True)

# Append the citation system prompt to the system prompt
system_prompt = """
You are a helpful assistant that has access to a knowledge base.
"""
system_prompt += query_tool.citation_system_prompt
return AgentWorkflow.from_tools_or_functions(
tools_or_functions=[query_tool],
system_prompt=system_prompt,
)


def create_app() -> FastAPI:
app = LlamaIndexServer(
workflow_factory=create_workflow,
env="dev",
suggest_next_questions=False,
ui_config=UIConfig(
llamacloud_index_selector=True, # to select different indexes in the UI
),
)
return app


app = create_app()


if __name__ == "__main__":
import uvicorn

uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from llama_index.server.api.callbacks.agent_call_tool import AgentCallTool
from llama_index.server.api.callbacks.base import EventCallback
from llama_index.server.api.callbacks.llamacloud import LlamaCloudFileDownload
from llama_index.server.api.callbacks.source_nodes import SourceNodesFromToolCall
Expand All @@ -10,4 +11,5 @@
"SourceNodesFromToolCall",
"SuggestNextQuestions",
"LlamaCloudFileDownload",
"AgentCallTool",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import logging
from typing import Any

from llama_index.core.agent.workflow.workflow_events import ToolCall, ToolCallResult
from llama_index.server.api.callbacks.base import EventCallback
from llama_index.server.api.models import AgentRunEvent

logger = logging.getLogger("uvicorn")


class AgentCallTool(EventCallback):
"""
Adapter for convert tool call events to agent run events.
"""

async def run(self, event: Any) -> Any:
if isinstance(event, ToolCall) and not isinstance(event, ToolCallResult):
return AgentRunEvent(
name="Agent",
msg=f"Calling tool: {event.tool_name} with: {event.tool_kwargs}",
)
return event

@classmethod
def from_default(cls, *args: Any, **kwargs: Any) -> "AgentCallTool":
return cls()
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Optional

from llama_index.core.agent.workflow.workflow_events import ToolCallResult
from llama_index.server.api.callbacks.base import EventCallback
Expand All @@ -17,9 +17,16 @@ class SourceNodesFromToolCall(EventCallback):
def __init__(self, query_tool_name: str = "query_index"):
self.query_tool_name = query_tool_name

def transform_tool_call_result(self, event: ToolCallResult) -> SourceNodesEvent:
source_nodes = event.tool_output.raw_output.source_nodes
return SourceNodesEvent(nodes=source_nodes)
def transform_tool_call_result(
self, event: ToolCallResult
) -> Optional[SourceNodesEvent]:
# Check whether result is error
tool_output = event.tool_output
if tool_output.is_error:
return None
else:
source_nodes = tool_output.raw_output.source_nodes
return SourceNodesEvent(nodes=source_nodes)

async def run(self, event: Any) -> Any:
if isinstance(event, ToolCallResult):
Expand Down
16 changes: 12 additions & 4 deletions python/llama-index-server/llama_index/server/api/models.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import logging
import os
import re
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union

from pydantic import BaseModel, field_validator

from llama_index.core.schema import NodeWithScore
from llama_index.core.types import ChatMessage, MessageRole
from llama_index.core.workflow import Event
from llama_index.server.settings import server_settings
from pydantic import BaseModel, field_validator

logger = logging.getLogger("uvicorn")

Expand Down Expand Up @@ -91,6 +91,15 @@ def from_source_node(cls, source_node: NodeWithScore) -> "SourceNodes":
url=url,
)

@classmethod
def get_local_llamacloud_file_name(
cls, llamacloud_file_name: str, pipeline_id: str
) -> str:
file_ext = os.path.splitext(llamacloud_file_name)[1]
file_name = llamacloud_file_name.replace(file_ext, "")
sanitized_file_name = re.sub(r"[^A-Za-z0-9_\-]", "_", file_name)
return f"{sanitized_file_name}_{pipeline_id}{file_ext}"

@classmethod
def get_url_from_metadata(
cls,
Expand All @@ -103,11 +112,10 @@ def get_url_from_metadata(
file_name = metadata.get("file_name")

if file_name and url_prefix:
# file_name exists and file server is configured
pipeline_id = metadata.get("pipeline_id")
if pipeline_id:
# file is from LlamaCloud
file_name = f"{pipeline_id}${file_name}"
file_name = cls.get_local_llamacloud_file_name(file_name, pipeline_id)
return f"{url_prefix}/output/llamacloud/{file_name}"
is_private = metadata.get("private", "false") == "true"
if is_private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

from fastapi import APIRouter, BackgroundTasks, HTTPException
from fastapi.responses import StreamingResponse

from llama_index.core.agent.workflow.workflow_events import (
AgentInput,
AgentSetup,
AgentStream,
)
from llama_index.core.workflow import StopEvent, Workflow
from llama_index.server.api.callbacks import (
AgentCallTool,
EventCallback,
LlamaCloudFileDownload,
SourceNodesFromToolCall,
Expand Down Expand Up @@ -54,6 +54,7 @@ async def chat(
)

callbacks: list[EventCallback] = [
AgentCallTool(),
SourceNodesFromToolCall(),
LlamaCloudFileDownload(background_tasks),
]
Expand Down
40 changes: 38 additions & 2 deletions python/llama-index-server/llama_index/server/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,45 @@
---------------------
Given the conversation history, please give me 3 questions that user might ask next!
Your answer should be wrapped in three sticks without any index numbers and follows the following format:
\`\`\`
```
<question 1>
<question 2>
<question 3>
\`\`\`
```
"""

# Used as a prompt for synthesizer
# Override this prompt by setting the `CITATION_PROMPT` environment variable
CITATION_PROMPT = """
Context information is below.
------------------
{context_str}
------------------
There would be citation_id that is associated with each text chunk (at the beginning) or previous response (wrapped in `[citation:]` block).
Use the citation_id for citation construction.

Answer the following query with citations:
------------------
{query_str}
------------------

# Citation format

[citation:id]

Where:
- [citation:] is a matching pattern which is required for all citations.
- `id` is the `citation_id` provided in the context or previous response.

Example:
```
Here is a response that uses context information [citation:90ca859f-4f32-40ca-8cd0-edfad4fb298b]
and other ideas that don't use context information [citation:17b2cc9a-27ae-4b6d-bede-5ca60fc00ff4] .\n
The citation block will be displayed automatically with useful information for the user in the UI [citation:1c606612-e75f-490e-8374-44e79f818d19] .
```

## Requirements:
1. Always include citations for every fact from the context information in your response.
2. Make sure that the citation_id is correct with the context, don't mix up the citation_id with other information.
Now, you answer the query with citations:
"""
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def __hash__(self) -> int:

class LlamaCloudFileService:
LOCAL_STORE_PATH = "output/llamacloud"
DOWNLOAD_FILE_NAME_TPL = "{pipeline_id}${filename}"

@classmethod
def get_all_projects_with_pipelines(cls) -> List[Dict[str, Any]]:
Expand Down Expand Up @@ -155,13 +154,10 @@ def _get_files_to_download(cls, nodes: List[NodeWithScore]) -> Set[LlamaCloudFil
# Remove duplicates and return
return set(llama_cloud_files)

@classmethod
def _get_file_name(cls, name: str, pipeline_id: str) -> str:
return cls.DOWNLOAD_FILE_NAME_TPL.format(pipeline_id=pipeline_id, filename=name)

@classmethod
def _get_file_path(cls, name: str, pipeline_id: str) -> str:
return os.path.join(cls.LOCAL_STORE_PATH, cls._get_file_name(name, pipeline_id))
file_name = SourceNodes.get_local_llamacloud_file_name(name, pipeline_id)
return os.path.join(cls.LOCAL_STORE_PATH, file_name)

@classmethod
def _download_file(cls, url: str, local_file_path: str) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .query import get_query_engine_tool
from .node_citation_processor import NodeCitationProcessor

__all__ = ["get_query_engine_tool"]
__all__ = ["get_query_engine_tool", "NodeCitationProcessor"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Optional

from llama_index.core.agent.workflow import FunctionAgent, ReActAgent
from llama_index.core.llms import LLM
from llama_index.core.settings import Settings
from llama_index.server.tools.index import get_query_engine_tool


def create_citation_agent(
index,
llm: Optional[LLM] = None,
name: Optional[str] = None,
description: Optional[str] = None,
system_prompt: Optional[str] = None,
) -> FunctionAgent | ReActAgent:
"""
Create a citation agent that can answer question with citations using information from provided index.
Example:
```python
citation_agent = create_citation_agent(index=index)
my_workflow = AgentWorkflow(agents=[citation_agent], root_agent=citation_agent.name)
my_workflow.run(user_msg="Why is sky blue?")
```
"""
llm = llm or Settings.llm
agent_cls = FunctionAgent if llm.metadata.is_function_calling_model else ReActAgent
name = name or "citation_agent"
description = (
description
or "An agent that can answer questions with citations using information from provided index. Do not change the citations when restructuring the answer."
)
system_prompt = (
system_prompt
or """
You are a helpful assistant that have access to a knowledge base.
You can use the query_index tool to get the information you need.
Answer the user question with citations for the parts that uses the information from the knowledge base.
"""
)
return agent_cls(
name=name,
description=description,
tools=[get_query_engine_tool(index=index, enable_citation=True)],
llm=llm or Settings.llm,
system_prompt=system_prompt,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import List, Optional

from llama_index.core import QueryBundle
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import NodeWithScore


class NodeCitationProcessor(BaseNodePostprocessor):
"""
Add a new field `citation_id` to the metadata of the node by copying the id from the node.
Useful for citation construction.
"""

def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
for node_score in nodes:
node_score.node.metadata["citation_id"] = node_score.node.node_id
return nodes
Loading
Loading