Skip to content

Add comment in workflow #104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion .github/workflows/evaluate.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ jobs:
run: |
echo "Comment contains #evaluate hashtag"

- name: Comment on pull request
uses: actions/github-script@v7
with:
script: |
github.rest.issues.createComment({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
body: "Starting evaluation! Check the Actions tab for progress, or wait for a comment with the results."
})

- uses: actions/checkout@v4

- name: Install pgvector
Expand Down Expand Up @@ -133,7 +144,7 @@ jobs:

- name: Evaluate local RAG flow
run: |
python evals/evaluate.py --targeturl=http://127.0.0.1:8000/chat --numquestions=2 --resultsdir=evals/results/pr${{ github.event.issue.number }}
python evals/evaluate.py --targeturl=http://127.0.0.1:8000/chat --resultsdir=evals/results/pr${{ github.event.issue.number }}

- name: Upload server logs as build artifact
uses: actions/upload-artifact@v4
Expand Down
12 changes: 6 additions & 6 deletions docs/evaluation.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,19 @@ pip install -r requirements-dev.txt

## Generate ground truth data

Modify the prompt in `evals/generate.txt` to match your database table and RAG scenario.

Generate ground truth data by running the following command:

```bash
python evals/generate.py
python evals/generate_ground_truth_data.py
```

Review the generated data after running that script, removing any question/answer pairs that don't seem like realistic user input.

## Evaluate the RAG answer quality

Review the configuration in `evals/eval_config.json` to ensure that everything is correctly setup. You may want to adjust the metrics used. [TODO: link to evaluator docs]
Review the configuration in `evals/eval_config.json` to ensure that everything is correctly setup. You may want to adjust the metrics used. See [the ai-rag-chat-evaluator README](https://github.com/Azure-Samples/ai-rag-chat-evaluator) for more information on the available metrics.

By default, the evaluation script will evaluate every question in the ground truth data.
Run the evaluation script by running the following command:
Expand All @@ -68,8 +70,6 @@ Compare answers across runs by running the following command:
python -m evaltools diff evals/results/baseline/
```

## Run the evaluation in GitHub actions

## Run the evaluation on a PR

# TODO: Add GPT-4 deployment with high capacity for evaluation
# TODO: Add CI workflow that can be triggered to run the evaluate on the local app
To run the evaluation on the changes in a PR, you can add a `/evaluate` comment to the PR. This will trigger the evaluation workflow to run the evaluation on the PR changes, and will post the results to the PR.
145 changes: 101 additions & 44 deletions evals/generate_ground_truth.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import json
import logging
import os
from collections.abc import Generator
from pathlib import Path

from azure.identity import AzureDeveloperCliCredential
from dotenv import load_dotenv
from evaltools.gen.generate import generate_test_qa_data
from azure.identity import AzureDeveloperCliCredential, get_bearer_token_provider
from dotenv_azd import load_azd_env
from openai import AzureOpenAI, OpenAI
from openai.types.chat import ChatCompletionToolParam
from sqlalchemy import create_engine, select
from sqlalchemy.orm import Session

Expand All @@ -14,7 +16,37 @@
logger = logging.getLogger("ragapp")


def source_retriever() -> Generator[dict, None, None]:
def qa_pairs_tool(num_questions: int = 1) -> ChatCompletionToolParam:
return {
"type": "function",
"function": {
"name": "qa_pairs",
"description": "Send in question and answer pairs for a customer-facing chat app",
"parameters": {
"type": "object",
"properties": {
"qa_list": {
"type": "array",
"description": f"List of {num_questions} question and answer pairs",
"items": {
"type": "object",
"properties": {
"question": {"type": "string", "description": "The question text"},
"answer": {"type": "string", "description": "The answer text"},
},
"required": ["question", "answer"],
},
"minItems": num_questions,
"maxItems": num_questions,
}
},
"required": ["qa_list"],
},
},
}


def source_retriever() -> Generator[str, None, None]:
# Connect to the database
DBHOST = os.environ["POSTGRES_HOST"]
DBUSER = os.environ["POSTGRES_USERNAME"]
Expand All @@ -27,16 +59,14 @@ def source_retriever() -> Generator[dict, None, None]:
item_types = session.scalars(select(Item.type).distinct())
for item_type in item_types:
records = list(session.scalars(select(Item).filter(Item.type == item_type).order_by(Item.id)))
# logger.info(f"Processing database records for type: {item_type}")
# yield {
# "citations": " ".join([f"[{record.id}] - {record.name}" for record in records]),
# "content": "\n\n".join([record.to_str_for_rag() for record in records]),
# }
logger.info(f"Processing database records for type: {item_type}")
yield "\n\n".join([f"## Product ID: [{record.id}]\n" + record.to_str_for_rag() for record in records])
# Fetch each item individually
records = list(session.scalars(select(Item).order_by(Item.id)))
for record in records:
logger.info(f"Processing database record: {record.name}")
yield {"id": record.id, "content": record.to_str_for_rag()}
# records = list(session.scalars(select(Item).order_by(Item.id)))
# for record in records:
# logger.info(f"Processing database record: {record.name}")
# yield f"## Product ID: [{record.id}]\n" + record.to_str_for_rag()
# await self.openai_chat_client.chat.completions.create(


def source_to_text(source) -> str:
Expand All @@ -47,49 +77,76 @@ def answer_formatter(answer, source) -> str:
return f"{answer} [{source['id']}]"


def get_openai_config_dict() -> dict:
"""Return a dictionary with OpenAI configuration based on environment variables."""
def get_openai_client() -> tuple[AzureOpenAI | OpenAI, str]:
"""Return an OpenAI client based on the environment variables"""
openai_client: AzureOpenAI | OpenAI
OPENAI_CHAT_HOST = os.getenv("OPENAI_CHAT_HOST")
if OPENAI_CHAT_HOST == "azure":
if api_key := os.getenv("AZURE_OPENAI_KEY"):
logger.info("Using Azure OpenAI Service with API Key from AZURE_OPENAI_KEY")
api_key = os.environ["AZURE_OPENAI_KEY"]
openai_client = AzureOpenAI(
api_version=os.environ["AZURE_OPENAI_VERSION"],
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
api_key=api_key,
)
else:
logger.info("Using Azure OpenAI Service with Azure Developer CLI Credential")
azure_credential = AzureDeveloperCliCredential(process_timeout=60)
api_key = azure_credential.get_token("https://cognitiveservices.azure.com/.default").token
openai_config = {
"api_type": "azure",
"api_base": os.environ["AZURE_OPENAI_ENDPOINT"],
"api_key": api_key,
"api_version": os.environ["AZURE_OPENAI_VERSION"],
"deployment": os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"],
"model": os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"],
}
azure_credential = AzureDeveloperCliCredential(process_timeout=60, tenant_id=os.environ["AZURE_TENANT_ID"])
token_provider = get_bearer_token_provider(azure_credential, "https://cognitiveservices.azure.com/.default")
openai_client = AzureOpenAI(
api_version=os.environ["AZURE_OPENAI_VERSION"],
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
azure_ad_token_provider=token_provider,
)
model = os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"]
elif OPENAI_CHAT_HOST == "ollama":
raise NotImplementedError("Ollama OpenAI Service is not supported. Switch to Azure or OpenAI.com")
else:
logger.info("Using OpenAI Service with API Key from OPENAICOM_KEY")
openai_config = {
"api_type": "openai",
"api_key": os.environ["OPENAICOM_KEY"],
"model": os.environ["OPENAICOM_CHAT_MODEL"],
"deployment": "none-needed-for-openaicom",
}
return openai_config
openai_client = OpenAI(api_key=os.environ["OPENAICOM_KEY"])
model = os.environ["OPENAICOM_CHAT_MODEL"]
return openai_client, model


def generate_ground_truth_data(num_questions_total: int, num_questions_per_source: int = 5):
logger.info("Generating %d questions total", num_questions_total)
openai_client, model = get_openai_client()
current_dir = Path(__file__).parent
generate_prompt = open(current_dir / "generate_prompt.txt").read()
output_file = Path(__file__).parent / "ground_truth.jsonl"

qa: list[dict] = []
for source in source_retriever():
if len(qa) > num_questions_total:
logger.info("Generated enough questions already, stopping")
break
result = openai_client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": generate_prompt},
{"role": "user", "content": json.dumps(source)},
],
tools=[qa_pairs_tool(num_questions=2)],
)
if not result.choices[0].message.tool_calls:
logger.warning("No tool calls found in response, skipping")
continue
qa_pairs = json.loads(result.choices[0].message.tool_calls[0].function.arguments)["qa_list"]
qa_pairs = [{"question": qa_pair["question"], "truth": qa_pair["answer"]} for qa_pair in qa_pairs]
qa.extend(qa_pairs)

logger.info("Writing %d questions to %s", num_questions_total, output_file)
directory = Path(output_file).parent
if not directory.exists():
directory.mkdir(parents=True)
with open(output_file, "w", encoding="utf-8") as f:
for item in qa[0:num_questions_total]:
f.write(json.dumps(item) + "\n")


if __name__ == "__main__":
logging.basicConfig(level=logging.WARNING)
logger.setLevel(logging.INFO)
load_dotenv(".env", override=True)

generate_test_qa_data(
openai_config=get_openai_config_dict(),
num_questions_total=202,
num_questions_per_source=2,
output_file=Path(__file__).parent / "ground_truth.jsonl",
source_retriever=source_retriever,
source_to_text=source_to_text,
answer_formatter=answer_formatter,
)
load_azd_env()

generate_ground_truth_data(num_questions_total=10)
9 changes: 9 additions & 0 deletions evals/generate_prompt.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Your job is to generate example questions that a customer might ask about the products.
You should come up with a question and an answer based on the provided data.
The answer should include the product ID in square brackets.
For example,
'What climbing gear do you have?'
with answer:
'We have a variety of climbing gear, including ropes, harnesses, and carabiners. [1][2]'.
Remember that customers probably don't know the names of specific brands,
so your questions should be more general questions from someone who is shopping for these types of products.
Loading
Loading