Skip to content

Commit e7fe3d6

Browse files
committed
⚡ [agents] Refactor tool handling and add dependency management support
- Add utilize `safe_fileio` for secure file I/O within tool handling, and introduced a dependency retrieval method (`deps`) to manage runtime requirements. Improves code modularity for better aligns with agent processing flows for code generation and execution.
1 parent 8127254 commit e7fe3d6

File tree

3 files changed

+148
-54
lines changed

3 files changed

+148
-54
lines changed

exp/agents.py

Lines changed: 96 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,16 @@
44

55
from jinja2 import Template
66
from langchain_core.language_models import LanguageModelInput
7-
from langchain_core.messages import ToolMessage, HumanMessage, SystemMessage, BaseMessage, AIMessage
7+
from langchain_core.messages import ToolMessage, HumanMessage, SystemMessage, BaseMessage, AIMessage, ToolCall
88
from langchain_core.runnables import Runnable
99
from langgraph.checkpoint.memory import MemorySaver
1010
from langgraph.func import task, entrypoint
1111
from langgraph.prebuilt import ToolNode
1212
from rich.console import Console
1313

14-
from exp.utils import get_vllm_with_tools
14+
from exp.utils import get_vllm_with_tools, safe_fileio
1515
from mle.function import (
16-
read_file, create_file, write_file, list_files,
16+
read_file, create_file, list_files,
1717
create_directory, preview_csv_data, preview_zip_structure, unzip_data
1818
)
1919
from mle.utils import clean_json_string
@@ -149,39 +149,52 @@
149149
textwrap.dedent(
150150
"""
151151
You are a **Machine Learning Engineer** tasked with implementing a solution based on the provided requirements by the advisor.
152-
You will be given the whole project plan, and each task will be provided to you one by one.
153152
Requirements: {{ advisor_report | tojson(indent=2) }}
154-
Implementation Plan: {{ plan | tojson(indent=2) }}
155153
Working Directory: {{ working_dir }}
156154
Environment: {{ env | tojson(indent=2) }}
157155
158-
Your task is to generate the complete, working Python code that implements the solution. Call the write_file, mkdir, and read_file function tools to inspect and generate the necessary files.
159-
IMPORTANT:
160-
1. Generate a single file `solution.py` that contains all the code for the solution, including imports, constants, functions, classes, main guard (`if __name__ == "__main__":`), argument parsing (if needed), execution logic, and docstrings.
161-
Focus on:
156+
Your task is to generate the complete, working Python code. Focus points to consider:
162157
1. Clean, readable code
163158
2. Proper data handling
164159
3. Model implementation
165160
4. Training and evaluation logic
166161
5. Kaggle submission format
162+
""".strip()
163+
)
164+
)
165+
166+
CODE_PROMPT = Template(
167+
textwrap.dedent(
168+
"""
169+
Implement Python code to solve the following task:
170+
## Task: {{ task }}
171+
{{ description }}
167172
168-
After finalizing the code, you will call the `create_file` function to save the code to `solution.py`
169-
After the tool calling result is given back, you should also provide the dependencies required to run the code and the command to run the code in a JSON format:
173+
Make sure to follow the requirements and provide the code in a single Python file. Call any necessary tools to inspect the data.
174+
Once ready, call the ` create_file ` tool to save the code.
175+
176+
The code should include:
177+
1. A single file `solution.py` that contains all the code for the solution, including imports, constants, functions, classes, main guard (`if __name__ == "__main__":`), argument parsing (if needed), execution logic, and docstrings.
178+
2. Overwrite existing code; do not supply diffs or partial patches.
179+
""".strip()
180+
)
181+
)
182+
183+
CODER_DEPS_PROMPT = Template(
184+
textwrap.dedent(
185+
"""
186+
Look at the latest created code in the chat history and analyze the dependencies required to run the code.
187+
Providing the dependencies required and the command to run the code in a JSON format.
188+
Example (for JSON schema illustration only):
170189
{
171190
"dependency": ["pkg1", "pkg2", "..."],
172-
"command": "python solution.py"
191+
"command": "python solution.py",
192+
"entryfile": "solution.py"
173193
}
174194
""".strip()
175195
)
176196
)
177197

178-
CODE_PROMPT = Template(
179-
"""
180-
## Task: {{ task }}
181-
{{ description }}
182-
"""
183-
)
184-
185198

186199
class AdviseAgent:
187200
console: Console = None
@@ -342,7 +355,6 @@ class CodeAgent:
342355

343356
class State(TypedDict):
344357
advisor_report: dict
345-
plan: dict
346358
task: str
347359
description: str
348360
env: dict
@@ -358,16 +370,15 @@ def __new__(cls, model_name, working_dir='.', console=None):
358370
working_dir: the working directory.
359371
console: the console to use.
360372
"""
361-
tools = [
362-
read_file,
363-
create_file,
364-
write_file,
365-
list_files,
366-
create_directory,
367-
preview_csv_data,
368-
preview_zip_structure,
369-
unzip_data,
370-
]
373+
tools = [
374+
safe_fileio(working_dir)(read_file),
375+
safe_fileio(working_dir, path_params=["path"])(create_file),
376+
safe_fileio(working_dir)(list_files),
377+
safe_fileio(working_dir, path_params=["path"])(create_directory),
378+
safe_fileio(working_dir, path_params=["path"])(preview_csv_data),
379+
safe_fileio(working_dir, path_params=["path"])(preview_zip_structure),
380+
safe_fileio(working_dir, path_params=["extract_path"])(unzip_data),
381+
]
371382

372383
cls.model = get_vllm_with_tools(model_name, tools)
373384
cls.working_dir = working_dir
@@ -377,7 +388,22 @@ def __new__(cls, model_name, working_dir='.', console=None):
377388

378389
@staticmethod
379390
@task
380-
def code(task: str, description: str, first_call=True) -> AIMessage:
391+
def setup(advisor_report: dict, env: dict):
392+
# Set up the chat history with the system prompt if not already set
393+
if len(CodeAgent.chat_history) == 0:
394+
CodeAgent.chat_history.append(
395+
SystemMessage(
396+
content=CODER_SYSTEM_PROMPT.render(
397+
working_dir=CodeAgent.working_dir,
398+
advisor_report=advisor_report,
399+
env=env,
400+
)
401+
)
402+
)
403+
404+
@staticmethod
405+
@task
406+
def code(task: str, description: str, first_call=True) -> AIMessage | dict:
381407
"""
382408
Handle the query from the model query response.
383409
Args:
@@ -396,10 +422,29 @@ def code(task: str, description: str, first_call=True) -> AIMessage:
396422
)
397423
)
398424
)
399-
message = CodeAgent.model.invoke(CodeAgent.chat_history)
425+
message: AIMessage = CodeAgent.model.invoke(CodeAgent.chat_history)
400426

401427
CodeAgent.chat_history.append(message)
402-
return message
428+
return message
429+
430+
@staticmethod
431+
@task
432+
def deps() -> dict:
433+
"""
434+
Get the dependencies required to run the code and the command to run the code.
435+
Returns:
436+
A dictionary containing the dependencies and the command to run the code.
437+
"""
438+
CodeAgent.chat_history.append(
439+
HumanMessage(content=CODER_DEPS_PROMPT.render())
440+
)
441+
message = CodeAgent.model.invoke(CodeAgent.chat_history)
442+
443+
CodeAgent.chat_history.append(message)
444+
try:
445+
return json.loads(message.content)
446+
except json.JSONDecodeError as e:
447+
return clean_json_string(message.content)
403448

404449
@staticmethod
405450
@entrypoint(checkpointer=checkpointer)
@@ -411,18 +456,7 @@ def graph(state: State) -> dict:
411456
Returns:
412457
The code for the task.
413458
"""
414-
# Set up the chat history with the system prompt if not already set
415-
if len(CodeAgent.chat_history) == 0:
416-
CodeAgent.chat_history.append(
417-
SystemMessage(
418-
content=CODER_SYSTEM_PROMPT.render(
419-
working_dir=CodeAgent.working_dir,
420-
advisor_report=state['advisor_report'],
421-
plan=state['plan'],
422-
env=state['env'],
423-
)
424-
)
425-
)
459+
CodeAgent.setup(state['advisor_report'], state['env'])
426460

427461
try_times = 5
428462
while try_times > 0:
@@ -432,20 +466,28 @@ def graph(state: State) -> dict:
432466
first_call=(try_times == 5)
433467
).result()
434468
try_times -= 1
435-
436469
if isinstance(message, AIMessage):
437-
# If the message is an AIMessage, check if it need tool calls
438470
if message.tool_calls:
439-
message = CodeAgent.tool_node.invoke({
440-
"messages": CodeAgent.chat_history[-1:],
441-
})
471+
CodeAgent.console.print(f"Calling tools {[tool['name'] for tool in message.tool_calls]}")
472+
message = CodeAgent.tool_node.invoke(
473+
{
474+
"messages": [message],
475+
}
476+
)
442477
CodeAgent.chat_history.extend(message['messages'])
478+
479+
# If the tool `create_file` succeeded, break the loop
480+
if any(
481+
isinstance(msg, ToolMessage) and msg.name == "create_file" and
482+
msg.content and "error" not in msg.content.lower()
483+
for msg in message['messages']
484+
):
485+
break
443486
else:
444-
# If no tool calls, return the message
445-
CodeAgent.chat_history.append(message)
446487
break
447488
else:
448489
break
449490

450-
CodeAgent.console.print(CodeAgent.chat_history)
451-
return message
491+
# Check the dependencies and command to run the code
492+
deps = CodeAgent.deps().result()
493+
return deps

exp/kaggle_solver.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343

4444
from mlebench.registry import registry
4545

46+
from mle.function import read_file
4647
from mle.utils import print_in_box
4748

4849

exp/utils.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,19 @@
44
55
Date: Jul 12, 2025
66
"""
7+
import inspect
78
import logging
89
import mimetypes
910
import os
1011
import random
12+
import tempfile
1113
import venv
1214
import zipfile
1315
from collections import defaultdict
16+
from functools import wraps
1417
from io import StringIO
1518
from pathlib import Path
19+
from typing import Callable, Any
1620

1721
import pandas as pd
1822
from langchain.chat_models import init_chat_model
@@ -223,3 +227,50 @@ def create_virtualenv(cwd='.', path='.venv'):
223227
builder = venv.EnvBuilder(with_pip=True)
224228
builder.create(venv_path)
225229
return venv_path.absolute() / 'bin' / 'python' if os.name != 'nt' else venv_path / 'Scripts' / 'python.exe'
230+
231+
232+
233+
def safe_fileio(working_dir: str, path_params: str | list[str] | None = None) -> Callable:
234+
working_dir = Path(working_dir).resolve()
235+
temp_dir = Path(tempfile.gettempdir()).resolve()
236+
237+
def decorator(func: Callable):
238+
@wraps(func)
239+
def wrapper(*args, **kwargs) -> Any:
240+
sig = inspect.signature(func)
241+
bound = sig.bind(*args, **kwargs)
242+
bound.apply_defaults()
243+
244+
# Validate and rewrite specified path parameters
245+
if isinstance(path_params, str):
246+
path_params_ = [path_params]
247+
else:
248+
path_params_ = path_params or []
249+
for param in path_params_:
250+
if param in bound.arguments:
251+
original = bound.arguments[param]
252+
if original is None:
253+
continue
254+
if not isinstance(original, (str, Path)):
255+
raise TypeError(f"Expected str or Path for '{param}', got {type(original)}")
256+
if not Path(original).is_absolute():
257+
abs_path = (working_dir / original).resolve()
258+
else:
259+
abs_path = Path(original).resolve()
260+
if not (
261+
str(abs_path).startswith(str(working_dir)) or
262+
str(abs_path).startswith(str(temp_dir))
263+
):
264+
raise PermissionError(f"Access denied: {abs_path} is outside allowed directories")
265+
bound.arguments[param] = abs_path
266+
267+
# Change CWD temporarily
268+
prev_cwd = os.getcwd()
269+
os.chdir(working_dir)
270+
try:
271+
return func(*bound.args, **bound.kwargs)
272+
finally:
273+
os.chdir(prev_cwd)
274+
275+
return wrapper
276+
return decorator

0 commit comments

Comments
 (0)