Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions examples/evaluate_tool_calling_with_reflection_and_correction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import json

from unitxt import get_logger
from unitxt.api import create_dataset, evaluate
from unitxt.inference import (
CrossProviderInferenceEngine,
)

logger = get_logger()

sum_tool = {
"name": "add_numbers",
"description": "Add two numbers together and return the result.",
"parameters": {
"type": "object",
"properties": {
"num1": {"type": "number", "description": "The first number to add."},
"num2": {"type": "number", "description": "The second number to add."},
},
"required": ["num1", "num2"],
},
}

subtract_tool = {
"name": "subtract_numbers",
"description": "Subtracts one number from another from the other.",
"parameters": {
"type": "object",
"properties": {
"num1": {"type": "number", "description": "Number to subtract from"},
"num2": {"type": "number", "description": "Number to subtract"},
},
"required": ["num1", "num2"],
},
}

data = [
{
"query": "What is the sum of 1212382 and 834672?",
"tools": [sum_tool, subtract_tool],
"reference_calls": [
{
"name": "add_numbers",
"arguments": {"num1": 1212382, "num2": 834672},
},
{
"name": "add_numbers",
"arguments": {"num1": 834672, "num2": 1212382},
},
],
},
{
"query": "Subtract 12123 from 83467",
"tools": [sum_tool, subtract_tool],
"reference_calls": [
{
"name": "subtract_numbers",
"arguments": {"num1": 83467, "num2": 12123},
}
],
},
]

dataset = create_dataset(
task="tasks.tool_calling.supervised",
test_set=data,
split="test",
format="formats.chat_api",
metrics=[
"metrics.tool_calling.reflection",
],
max_test_instances=10,
)

model = CrossProviderInferenceEngine(
model="granite-3-3-8b-instruct", provider="watsonx", temperature=0
)


predictions = model(dataset)
# Insert errors ito the model predictions to check reflector
predictions[0] = predictions[0].replace("8", "7")
predictions[0] = predictions[0].replace("num1", "num3")
predictions[1] = predictions[1].replace("subtract_numbers", "multiply_numbers")

results = evaluate(predictions=predictions, data=dataset)
print("Instance Results:")
print(results.instance_scores.summary)

print("Global Results:")
print(results.global_scores.summary)


def find_elements(obj, element_name):
results = []
if isinstance(obj, dict):
for key, value in obj.items():
if key == element_name:
results.append(value)
else:
results.extend(find_elements(value, element_name))
elif isinstance(obj, list):
for item in obj:
results.extend(find_elements(item, element_name))
return results


# Collect the corrected tool calls when exists


corrected_predictions = []
for instance_scores, prediction in zip(results.instance_scores, predictions):
raw_responses = find_elements(instance_scores, "raw_response")
corrected_tool_calls = find_elements(raw_responses, "tool_call")
if len(corrected_tool_calls) > 0:
corrected_predictions.append(json.dumps(corrected_tool_calls[0]))
else:
corrected_predictions.append(prediction)

# Run again on the original dataset and metric
dataset = create_dataset(
task="tasks.tool_calling.supervised",
test_set=data,
split="test",
format="formats.chat_api",
metrics=["metrics.tool_calling.key_value.accuracy"],
max_test_instances=10,
)

corrected_results = evaluate(predictions=corrected_predictions, data=dataset)

print("Instance Results (After reflection):")
print(corrected_results.instance_scores.summary)

print("Global Results (After reflection):")
print(corrected_results.global_scores.summary)
10 changes: 6 additions & 4 deletions src/unitxt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,7 +1113,9 @@ async def map(
retries=3,
continue_on_static=True,
)
return result.model_dump()
result_dict = result.model_dump()
result_dict["overall_valid"] = float(result_dict["overall_valid"])
return result_dict

def map_stream(
self,
Expand Down Expand Up @@ -1173,7 +1175,7 @@ def reduce(self, intermidates: List[Dict[str, Any]]) -> Dict[str, float]:
for metric, metric_dict in metric_type_dict.get(
"metrics", {}
).items():
flat_instance_dict[f"semantic_{metric}"] = float(
flat_instance_dict[f"semantic_{metric}"] = 1 - float(
metric_dict["is_issue"]
)
flat_instance_dict[f"semantic_avg_score_{metric_type}"] = float(
Expand Down Expand Up @@ -1239,7 +1241,7 @@ def map(
static_result = ReflectionPipeline.static_only(tools_inventory, tool_call)

result_dict = static_result.model_dump()
result_dict["overall_valid"] = result_dict.pop("final_decision")
result_dict["overall_valid"] = float(result_dict.pop("final_decision"))
result_dict["metrics"]["json_schema_violation"] = result_dict["metrics"].pop(
"json_schema_validation"
)
Expand All @@ -1254,7 +1256,7 @@ def reduce(self, intermediates: List[Dict[str, float]]) -> Dict[str, float]:
flat_instance_dict = {}
for metric, metric_dict in instance.get("metrics", {}).items():
flat_instance_dict[metric] = float(metric_dict["valid"])
flat_instance_dict["overall_valid"] = instance["overall_valid"]
flat_instance_dict["overall_valid"] = float(instance["overall_valid"])
flat_instances.append(flat_instance_dict)

return self.reduction.reduce(flat_instances)
Expand Down
Loading