diff --git a/examples/evaluate_tool_calling_with_reflection_and_correction.py b/examples/evaluate_tool_calling_with_reflection_and_correction.py new file mode 100644 index 0000000000..856b489bc6 --- /dev/null +++ b/examples/evaluate_tool_calling_with_reflection_and_correction.py @@ -0,0 +1,136 @@ +import json + +from unitxt import get_logger +from unitxt.api import create_dataset, evaluate +from unitxt.inference import ( + CrossProviderInferenceEngine, +) + +logger = get_logger() + +sum_tool = { + "name": "add_numbers", + "description": "Add two numbers together and return the result.", + "parameters": { + "type": "object", + "properties": { + "num1": {"type": "number", "description": "The first number to add."}, + "num2": {"type": "number", "description": "The second number to add."}, + }, + "required": ["num1", "num2"], + }, +} + +subtract_tool = { + "name": "subtract_numbers", + "description": "Subtracts one number from another from the other.", + "parameters": { + "type": "object", + "properties": { + "num1": {"type": "number", "description": "Number to subtract from"}, + "num2": {"type": "number", "description": "Number to subtract"}, + }, + "required": ["num1", "num2"], + }, +} + +data = [ + { + "query": "What is the sum of 1212382 and 834672?", + "tools": [sum_tool, subtract_tool], + "reference_calls": [ + { + "name": "add_numbers", + "arguments": {"num1": 1212382, "num2": 834672}, + }, + { + "name": "add_numbers", + "arguments": {"num1": 834672, "num2": 1212382}, + }, + ], + }, + { + "query": "Subtract 12123 from 83467", + "tools": [sum_tool, subtract_tool], + "reference_calls": [ + { + "name": "subtract_numbers", + "arguments": {"num1": 83467, "num2": 12123}, + } + ], + }, +] + +dataset = create_dataset( + task="tasks.tool_calling.supervised", + test_set=data, + split="test", + format="formats.chat_api", + metrics=[ + "metrics.tool_calling.reflection", + ], + max_test_instances=10, +) + +model = CrossProviderInferenceEngine( + model="granite-3-3-8b-instruct", provider="watsonx", temperature=0 +) + + +predictions = model(dataset) +# Insert errors ito the model predictions to check reflector +predictions[0] = predictions[0].replace("8", "7") +predictions[0] = predictions[0].replace("num1", "num3") +predictions[1] = predictions[1].replace("subtract_numbers", "multiply_numbers") + +results = evaluate(predictions=predictions, data=dataset) +print("Instance Results:") +print(results.instance_scores.summary) + +print("Global Results:") +print(results.global_scores.summary) + + +def find_elements(obj, element_name): + results = [] + if isinstance(obj, dict): + for key, value in obj.items(): + if key == element_name: + results.append(value) + else: + results.extend(find_elements(value, element_name)) + elif isinstance(obj, list): + for item in obj: + results.extend(find_elements(item, element_name)) + return results + + +# Collect the corrected tool calls when exists + + +corrected_predictions = [] +for instance_scores, prediction in zip(results.instance_scores, predictions): + raw_responses = find_elements(instance_scores, "raw_response") + corrected_tool_calls = find_elements(raw_responses, "tool_call") + if len(corrected_tool_calls) > 0: + corrected_predictions.append(json.dumps(corrected_tool_calls[0])) + else: + corrected_predictions.append(prediction) + +# Run again on the original dataset and metric +dataset = create_dataset( + task="tasks.tool_calling.supervised", + test_set=data, + split="test", + format="formats.chat_api", + metrics=["metrics.tool_calling.key_value.accuracy"], + max_test_instances=10, +) + +corrected_results = evaluate(predictions=corrected_predictions, data=dataset) + +print("Instance Results (After reflection):") +print(corrected_results.instance_scores.summary) + +print("Global Results (After reflection):") +print(corrected_results.global_scores.summary) diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index d6f5de9609..33e4b033d1 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -1113,7 +1113,9 @@ async def map( retries=3, continue_on_static=True, ) - return result.model_dump() + result_dict = result.model_dump() + result_dict["overall_valid"] = float(result_dict["overall_valid"]) + return result_dict def map_stream( self, @@ -1173,7 +1175,7 @@ def reduce(self, intermidates: List[Dict[str, Any]]) -> Dict[str, float]: for metric, metric_dict in metric_type_dict.get( "metrics", {} ).items(): - flat_instance_dict[f"semantic_{metric}"] = float( + flat_instance_dict[f"semantic_{metric}"] = 1 - float( metric_dict["is_issue"] ) flat_instance_dict[f"semantic_avg_score_{metric_type}"] = float( @@ -1239,7 +1241,7 @@ def map( static_result = ReflectionPipeline.static_only(tools_inventory, tool_call) result_dict = static_result.model_dump() - result_dict["overall_valid"] = result_dict.pop("final_decision") + result_dict["overall_valid"] = float(result_dict.pop("final_decision")) result_dict["metrics"]["json_schema_violation"] = result_dict["metrics"].pop( "json_schema_validation" ) @@ -1254,7 +1256,7 @@ def reduce(self, intermediates: List[Dict[str, float]]) -> Dict[str, float]: flat_instance_dict = {} for metric, metric_dict in instance.get("metrics", {}).items(): flat_instance_dict[metric] = float(metric_dict["valid"]) - flat_instance_dict["overall_valid"] = instance["overall_valid"] + flat_instance_dict["overall_valid"] = float(instance["overall_valid"]) flat_instances.append(flat_instance_dict) return self.reduction.reduce(flat_instances) diff --git a/tests/library/test_metrics.py b/tests/library/test_metrics.py index 2bf2a94eac..160b9d2543 100644 --- a/tests/library/test_metrics.py +++ b/tests/library/test_metrics.py @@ -1892,9 +1892,9 @@ def test_reflection_tool_calling_metric_reduce(self): self.assertAlmostEqual(reduced["overall_valid"], expected_overall_valid) # Static aggregated metrics, now flattened - # non_existent_function: True, False, True -> 2/3 + # non_existent_function validity: True, False, True -> 2/3 nef_expected = (1 + 0 + 1) / 3 - # missing_required_parameter: True, False, False -> 1/3 + # missing_required_parameter validity: True, False, False -> 1/3 mrp_expected = (1 + 0 + 0) / 3 self.assertIn("static_non_existent_function", reduced) self.assertIn("static_missing_required_parameter", reduced) @@ -1914,15 +1914,15 @@ def test_reflection_tool_calling_metric_reduce(self): reduced["semantic_avg_score_function_selection"], semantic_avg_expected ) - # Semantic issue rates for individual checks, flattened, mean of is_issue - # general_hallucination_check: False, True, False -> 1/3 - ghc_expected = (0 + 1 + 0) / 3 - # general_value_format_alignment: False, True, True -> 2/3 - gvfa_expected = (0 + 1 + 1) / 3 - # function_selection_appropriateness: False, True, False -> 1/3 - fsa_expected = (0 + 1 + 0) / 3 - # agentic_constraints_satisfaction: False, True, True -> 2/3 - acs_expected = (0 + 1 + 1) / 3 + # Semantic issue rates for individual checks, flattened, mean of 1 - is_issue + # general_hallucination_check: False, True, False -> 2/3 + ghc_expected = (1 + 0 + 1) / 3 + # general_value_format_alignment: False, True, True -> 1/3 + gvfa_expected = (1 + 0 + 0) / 3 + # function_selection_appropriateness: False, True, False -> 2/3 + fsa_expected = (1 + 0 + 1) / 3 + # agentic_constraints_satisfaction: False, True, True -> 1/3 + acs_expected = (1 + 0 + 0) / 3 self.assertIn("semantic_general_hallucination_check", reduced) self.assertIn("semantic_general_value_format_alignment", reduced)