Skip to content

Commit 800b2ba

Browse files
yoavkatzKoren Lazar
andauthored
Correct reflection based tool calling metrics so valid results will be 1. (#1940)
* Correct reflection based tool calling metrics so valid results will be 1. Also added example of using reflection based tool calling to correct tool calls and reevaluate. Signed-off-by: Yoav Katz <[email protected]> * Fixed the test of the tool calling evaluation metrics accordingly. --------- Signed-off-by: Yoav Katz <[email protected]> Co-authored-by: Koren Lazar <[email protected]>
1 parent c90534f commit 800b2ba

File tree

3 files changed

+153
-15
lines changed

3 files changed

+153
-15
lines changed
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import json
2+
3+
from unitxt import get_logger
4+
from unitxt.api import create_dataset, evaluate
5+
from unitxt.inference import (
6+
CrossProviderInferenceEngine,
7+
)
8+
9+
logger = get_logger()
10+
11+
sum_tool = {
12+
"name": "add_numbers",
13+
"description": "Add two numbers together and return the result.",
14+
"parameters": {
15+
"type": "object",
16+
"properties": {
17+
"num1": {"type": "number", "description": "The first number to add."},
18+
"num2": {"type": "number", "description": "The second number to add."},
19+
},
20+
"required": ["num1", "num2"],
21+
},
22+
}
23+
24+
subtract_tool = {
25+
"name": "subtract_numbers",
26+
"description": "Subtracts one number from another from the other.",
27+
"parameters": {
28+
"type": "object",
29+
"properties": {
30+
"num1": {"type": "number", "description": "Number to subtract from"},
31+
"num2": {"type": "number", "description": "Number to subtract"},
32+
},
33+
"required": ["num1", "num2"],
34+
},
35+
}
36+
37+
data = [
38+
{
39+
"query": "What is the sum of 1212382 and 834672?",
40+
"tools": [sum_tool, subtract_tool],
41+
"reference_calls": [
42+
{
43+
"name": "add_numbers",
44+
"arguments": {"num1": 1212382, "num2": 834672},
45+
},
46+
{
47+
"name": "add_numbers",
48+
"arguments": {"num1": 834672, "num2": 1212382},
49+
},
50+
],
51+
},
52+
{
53+
"query": "Subtract 12123 from 83467",
54+
"tools": [sum_tool, subtract_tool],
55+
"reference_calls": [
56+
{
57+
"name": "subtract_numbers",
58+
"arguments": {"num1": 83467, "num2": 12123},
59+
}
60+
],
61+
},
62+
]
63+
64+
dataset = create_dataset(
65+
task="tasks.tool_calling.supervised",
66+
test_set=data,
67+
split="test",
68+
format="formats.chat_api",
69+
metrics=[
70+
"metrics.tool_calling.reflection",
71+
],
72+
max_test_instances=10,
73+
)
74+
75+
model = CrossProviderInferenceEngine(
76+
model="granite-3-3-8b-instruct", provider="watsonx", temperature=0
77+
)
78+
79+
80+
predictions = model(dataset)
81+
# Insert errors ito the model predictions to check reflector
82+
predictions[0] = predictions[0].replace("8", "7")
83+
predictions[0] = predictions[0].replace("num1", "num3")
84+
predictions[1] = predictions[1].replace("subtract_numbers", "multiply_numbers")
85+
86+
results = evaluate(predictions=predictions, data=dataset)
87+
print("Instance Results:")
88+
print(results.instance_scores.summary)
89+
90+
print("Global Results:")
91+
print(results.global_scores.summary)
92+
93+
94+
def find_elements(obj, element_name):
95+
results = []
96+
if isinstance(obj, dict):
97+
for key, value in obj.items():
98+
if key == element_name:
99+
results.append(value)
100+
else:
101+
results.extend(find_elements(value, element_name))
102+
elif isinstance(obj, list):
103+
for item in obj:
104+
results.extend(find_elements(item, element_name))
105+
return results
106+
107+
108+
# Collect the corrected tool calls when exists
109+
110+
111+
corrected_predictions = []
112+
for instance_scores, prediction in zip(results.instance_scores, predictions):
113+
raw_responses = find_elements(instance_scores, "raw_response")
114+
corrected_tool_calls = find_elements(raw_responses, "tool_call")
115+
if len(corrected_tool_calls) > 0:
116+
corrected_predictions.append(json.dumps(corrected_tool_calls[0]))
117+
else:
118+
corrected_predictions.append(prediction)
119+
120+
# Run again on the original dataset and metric
121+
dataset = create_dataset(
122+
task="tasks.tool_calling.supervised",
123+
test_set=data,
124+
split="test",
125+
format="formats.chat_api",
126+
metrics=["metrics.tool_calling.key_value.accuracy"],
127+
max_test_instances=10,
128+
)
129+
130+
corrected_results = evaluate(predictions=corrected_predictions, data=dataset)
131+
132+
print("Instance Results (After reflection):")
133+
print(corrected_results.instance_scores.summary)
134+
135+
print("Global Results (After reflection):")
136+
print(corrected_results.global_scores.summary)

src/unitxt/metrics.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,7 +1113,9 @@ async def map(
11131113
retries=3,
11141114
continue_on_static=True,
11151115
)
1116-
return result.model_dump()
1116+
result_dict = result.model_dump()
1117+
result_dict["overall_valid"] = float(result_dict["overall_valid"])
1118+
return result_dict
11171119

11181120
def map_stream(
11191121
self,
@@ -1173,7 +1175,7 @@ def reduce(self, intermidates: List[Dict[str, Any]]) -> Dict[str, float]:
11731175
for metric, metric_dict in metric_type_dict.get(
11741176
"metrics", {}
11751177
).items():
1176-
flat_instance_dict[f"semantic_{metric}"] = float(
1178+
flat_instance_dict[f"semantic_{metric}"] = 1 - float(
11771179
metric_dict["is_issue"]
11781180
)
11791181
flat_instance_dict[f"semantic_avg_score_{metric_type}"] = float(
@@ -1239,7 +1241,7 @@ def map(
12391241
static_result = ReflectionPipeline.static_only(tools_inventory, tool_call)
12401242

12411243
result_dict = static_result.model_dump()
1242-
result_dict["overall_valid"] = result_dict.pop("final_decision")
1244+
result_dict["overall_valid"] = float(result_dict.pop("final_decision"))
12431245
result_dict["metrics"]["json_schema_violation"] = result_dict["metrics"].pop(
12441246
"json_schema_validation"
12451247
)
@@ -1254,7 +1256,7 @@ def reduce(self, intermediates: List[Dict[str, float]]) -> Dict[str, float]:
12541256
flat_instance_dict = {}
12551257
for metric, metric_dict in instance.get("metrics", {}).items():
12561258
flat_instance_dict[metric] = float(metric_dict["valid"])
1257-
flat_instance_dict["overall_valid"] = instance["overall_valid"]
1259+
flat_instance_dict["overall_valid"] = float(instance["overall_valid"])
12581260
flat_instances.append(flat_instance_dict)
12591261

12601262
return self.reduction.reduce(flat_instances)

tests/library/test_metrics.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1892,9 +1892,9 @@ def test_reflection_tool_calling_metric_reduce(self):
18921892
self.assertAlmostEqual(reduced["overall_valid"], expected_overall_valid)
18931893

18941894
# Static aggregated metrics, now flattened
1895-
# non_existent_function: True, False, True -> 2/3
1895+
# non_existent_function validity: True, False, True -> 2/3
18961896
nef_expected = (1 + 0 + 1) / 3
1897-
# missing_required_parameter: True, False, False -> 1/3
1897+
# missing_required_parameter validity: True, False, False -> 1/3
18981898
mrp_expected = (1 + 0 + 0) / 3
18991899
self.assertIn("static_non_existent_function", reduced)
19001900
self.assertIn("static_missing_required_parameter", reduced)
@@ -1914,15 +1914,15 @@ def test_reflection_tool_calling_metric_reduce(self):
19141914
reduced["semantic_avg_score_function_selection"], semantic_avg_expected
19151915
)
19161916

1917-
# Semantic issue rates for individual checks, flattened, mean of is_issue
1918-
# general_hallucination_check: False, True, False -> 1/3
1919-
ghc_expected = (0 + 1 + 0) / 3
1920-
# general_value_format_alignment: False, True, True -> 2/3
1921-
gvfa_expected = (0 + 1 + 1) / 3
1922-
# function_selection_appropriateness: False, True, False -> 1/3
1923-
fsa_expected = (0 + 1 + 0) / 3
1924-
# agentic_constraints_satisfaction: False, True, True -> 2/3
1925-
acs_expected = (0 + 1 + 1) / 3
1917+
# Semantic issue rates for individual checks, flattened, mean of 1 - is_issue
1918+
# general_hallucination_check: False, True, False -> 2/3
1919+
ghc_expected = (1 + 0 + 1) / 3
1920+
# general_value_format_alignment: False, True, True -> 1/3
1921+
gvfa_expected = (1 + 0 + 0) / 3
1922+
# function_selection_appropriateness: False, True, False -> 2/3
1923+
fsa_expected = (1 + 0 + 1) / 3
1924+
# agentic_constraints_satisfaction: False, True, True -> 1/3
1925+
acs_expected = (1 + 0 + 0) / 3
19261926

19271927
self.assertIn("semantic_general_hallucination_check", reduced)
19281928
self.assertIn("semantic_general_value_format_alignment", reduced)

0 commit comments

Comments
 (0)