Skip to content

Fix/tool call accuracy (#2079) #2092

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 9 additions & 14 deletions ragas/src/ragas/metrics/_tool_call_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,7 @@ async def _get_arg_score(
def is_sequence_aligned(
self, pred_sequence: t.List[str], ref_sequence: t.List[str]
) -> bool:
ref_index = 0 # Index to track position in reference sequence
for pred in pred_sequence:
if ref_index < len(ref_sequence) and pred == ref_sequence[ref_index]:
ref_index += 1
if ref_index == len(ref_sequence):
return True
return False
return pred_sequence == ref_sequence

async def _multi_turn_ascore(
self, sample: MultiTurnSample, callbacks: Callbacks
Expand All @@ -87,13 +81,14 @@ async def _multi_turn_ascore(
if pred_tool_calls:
score = 0.0
reference_tool_calls = sample.reference_tool_calls
for ref_tool_call in reference_tool_calls:
for pred_tool_call in pred_tool_calls:
if ref_tool_call.name == pred_tool_call.name:
arg_score = await self._get_arg_score(
pred_tool_call.args, ref_tool_call.args, callbacks
)
score += arg_score
for ref_tool_call, pred_tool_call in zip(
reference_tool_calls, pred_tool_calls
):
if ref_tool_call.name == pred_tool_call.name:
arg_score = await self._get_arg_score(
pred_tool_call.args, ref_tool_call.args, callbacks
)
score += arg_score

score /= len(reference_tool_calls)
else:
Expand Down