Skip to content

Commit 958f8fa

Browse files
committed
Fixes some bugs in requirement checkers.
1 parent e5f0be2 commit 958f8fa

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

mellea/stdlib/reqlib/tools.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from typing import Callable, Optional
2-
from mellea.stdlib.requirement import Context, ValidationResult, Requirement
2+
from mellea.stdlib.base import Context
3+
from mellea.stdlib.requirement import Requirement, ValidationResult
34

45

56
def _name2str(tool_name: str | Callable) -> str:
67
match tool_name:
7-
case Callable():
8+
case tool_name if callable(tool_name):
89
return tool_name.__name__
910
case str():
1011
return tool_name
@@ -72,11 +73,11 @@ def _validate(ctx: Context):
7273
return ValidationResult(result=False, reason=f"Valiudation did not pass for {tool_name}.{arg_name}. Arg value: {arg_value}. Argument validation result: {validate_result}")
7374
else:
7475
for tool in ctx.last_output().tool_calls.keys():
75-
if arg_name in output.tool_calls[tool_name].args:
76-
arg_value = output.tool_calls[tool_name].args[arg_name]
76+
if arg_name in output.tool_calls[tool].args:
77+
arg_value = output.tool_calls[tool].args[arg_name]
7778
validate_result = validation_fn(arg_value)
7879
if not validate_result:
7980
return ValidationResult(result=False, reason=f"Valiudation did not pass for {tool_name}.{arg_name}. Arg value: {arg_value}. Argument validation result: {validate_result}")
8081
return ValidationResult(result=True)
8182

82-
return Requirement(description=description, validation_fn=_validate, check_only=check_only)
83+
return Requirement(description=description, validation_fn=_validate, check_only=check_only)

0 commit comments

Comments
 (0)