diff --git a/ragas/src/ragas/metrics/_tool_call_accuracy.py b/ragas/src/ragas/metrics/_tool_call_accuracy.py index 8f29c81be..e7ae8db67 100644 --- a/ragas/src/ragas/metrics/_tool_call_accuracy.py +++ b/ragas/src/ragas/metrics/_tool_call_accuracy.py @@ -55,13 +55,7 @@ async def _get_arg_score( def is_sequence_aligned( self, pred_sequence: t.List[str], ref_sequence: t.List[str] ) -> bool: - ref_index = 0 # Index to track position in reference sequence - for pred in pred_sequence: - if ref_index < len(ref_sequence) and pred == ref_sequence[ref_index]: - ref_index += 1 - if ref_index == len(ref_sequence): - return True - return False + return pred_sequence == ref_sequence async def _multi_turn_ascore( self, sample: MultiTurnSample, callbacks: Callbacks @@ -87,13 +81,14 @@ async def _multi_turn_ascore( if pred_tool_calls: score = 0.0 reference_tool_calls = sample.reference_tool_calls - for ref_tool_call in reference_tool_calls: - for pred_tool_call in pred_tool_calls: - if ref_tool_call.name == pred_tool_call.name: - arg_score = await self._get_arg_score( - pred_tool_call.args, ref_tool_call.args, callbacks - ) - score += arg_score + for ref_tool_call, pred_tool_call in zip( + reference_tool_calls, pred_tool_calls + ): + if ref_tool_call.name == pred_tool_call.name: + arg_score = await self._get_arg_score( + pred_tool_call.args, ref_tool_call.args, callbacks + ) + score += arg_score score /= len(reference_tool_calls) else: