|
| 1 | +import json |
| 2 | + |
| 3 | +from unitxt import get_logger |
| 4 | +from unitxt.api import create_dataset, evaluate |
| 5 | +from unitxt.inference import ( |
| 6 | + CrossProviderInferenceEngine, |
| 7 | +) |
| 8 | + |
| 9 | +logger = get_logger() |
| 10 | + |
| 11 | +sum_tool = { |
| 12 | + "name": "add_numbers", |
| 13 | + "description": "Add two numbers together and return the result.", |
| 14 | + "parameters": { |
| 15 | + "type": "object", |
| 16 | + "properties": { |
| 17 | + "num1": {"type": "number", "description": "The first number to add."}, |
| 18 | + "num2": {"type": "number", "description": "The second number to add."}, |
| 19 | + }, |
| 20 | + "required": ["num1", "num2"], |
| 21 | + }, |
| 22 | +} |
| 23 | + |
| 24 | +subtract_tool = { |
| 25 | + "name": "subtract_numbers", |
| 26 | + "description": "Subtracts one number from another from the other.", |
| 27 | + "parameters": { |
| 28 | + "type": "object", |
| 29 | + "properties": { |
| 30 | + "num1": {"type": "number", "description": "Number to subtract from"}, |
| 31 | + "num2": {"type": "number", "description": "Number to subtract"}, |
| 32 | + }, |
| 33 | + "required": ["num1", "num2"], |
| 34 | + }, |
| 35 | +} |
| 36 | + |
| 37 | +data = [ |
| 38 | + { |
| 39 | + "query": "What is the sum of 1212382 and 834672?", |
| 40 | + "tools": [sum_tool, subtract_tool], |
| 41 | + "reference_calls": [ |
| 42 | + { |
| 43 | + "name": "add_numbers", |
| 44 | + "arguments": {"num1": 1212382, "num2": 834672}, |
| 45 | + }, |
| 46 | + { |
| 47 | + "name": "add_numbers", |
| 48 | + "arguments": {"num1": 834672, "num2": 1212382}, |
| 49 | + }, |
| 50 | + ], |
| 51 | + }, |
| 52 | + { |
| 53 | + "query": "Subtract 12123 from 83467", |
| 54 | + "tools": [sum_tool, subtract_tool], |
| 55 | + "reference_calls": [ |
| 56 | + { |
| 57 | + "name": "subtract_numbers", |
| 58 | + "arguments": {"num1": 83467, "num2": 12123}, |
| 59 | + } |
| 60 | + ], |
| 61 | + }, |
| 62 | +] |
| 63 | + |
| 64 | +dataset = create_dataset( |
| 65 | + task="tasks.tool_calling.supervised", |
| 66 | + test_set=data, |
| 67 | + split="test", |
| 68 | + format="formats.chat_api", |
| 69 | + metrics=[ |
| 70 | + "metrics.tool_calling.reflection", |
| 71 | + ], |
| 72 | + max_test_instances=10, |
| 73 | +) |
| 74 | + |
| 75 | +model = CrossProviderInferenceEngine( |
| 76 | + model="granite-3-3-8b-instruct", provider="watsonx", temperature=0 |
| 77 | +) |
| 78 | + |
| 79 | + |
| 80 | +predictions = model(dataset) |
| 81 | +# Insert errors ito the model predictions to check reflector |
| 82 | +predictions[0] = predictions[0].replace("8", "7") |
| 83 | +predictions[0] = predictions[0].replace("num1", "num3") |
| 84 | +predictions[1] = predictions[1].replace("subtract_numbers", "multiply_numbers") |
| 85 | + |
| 86 | +results = evaluate(predictions=predictions, data=dataset) |
| 87 | +print("Instance Results:") |
| 88 | +print(results.instance_scores.summary) |
| 89 | + |
| 90 | +print("Global Results:") |
| 91 | +print(results.global_scores.summary) |
| 92 | + |
| 93 | + |
| 94 | +def find_elements(obj, element_name): |
| 95 | + results = [] |
| 96 | + if isinstance(obj, dict): |
| 97 | + for key, value in obj.items(): |
| 98 | + if key == element_name: |
| 99 | + results.append(value) |
| 100 | + else: |
| 101 | + results.extend(find_elements(value, element_name)) |
| 102 | + elif isinstance(obj, list): |
| 103 | + for item in obj: |
| 104 | + results.extend(find_elements(item, element_name)) |
| 105 | + return results |
| 106 | + |
| 107 | + |
| 108 | +# Collect the corrected tool calls when exists |
| 109 | + |
| 110 | + |
| 111 | +corrected_predictions = [] |
| 112 | +for instance_scores, prediction in zip(results.instance_scores, predictions): |
| 113 | + raw_responses = find_elements(instance_scores, "raw_response") |
| 114 | + corrected_tool_calls = find_elements(raw_responses, "tool_call") |
| 115 | + if len(corrected_tool_calls) > 0: |
| 116 | + corrected_predictions.append(json.dumps(corrected_tool_calls[0])) |
| 117 | + else: |
| 118 | + corrected_predictions.append(prediction) |
| 119 | + |
| 120 | +# Run again on the original dataset and metric |
| 121 | +dataset = create_dataset( |
| 122 | + task="tasks.tool_calling.supervised", |
| 123 | + test_set=data, |
| 124 | + split="test", |
| 125 | + format="formats.chat_api", |
| 126 | + metrics=["metrics.tool_calling.key_value.accuracy"], |
| 127 | + max_test_instances=10, |
| 128 | +) |
| 129 | + |
| 130 | +corrected_results = evaluate(predictions=corrected_predictions, data=dataset) |
| 131 | + |
| 132 | +print("Instance Results (After reflection):") |
| 133 | +print(corrected_results.instance_scores.summary) |
| 134 | + |
| 135 | +print("Global Results (After reflection):") |
| 136 | +print(corrected_results.global_scores.summary) |
0 commit comments