Skip to content

Commit 39ed399

Browse files
Feat/train tool agent (#17)
* train tool agent * Feat/add search scrape tool (#14) * migrate search, scrape tool into verl and add code to spin up retrieval server * add debug code to train tool agent * Update async_sglang_server.py --------- Co-authored-by: nguyenhoangthuan99 <35255081+nguyenhoangthuan99@users.noreply.github.com> * add data code (#15) * add data code * push reward code * add system prompt and user prompt * fix prompt * fix prompt * update to train 30b model * fix training bug * add reward remove repeate tool call and bulk tool calls --------- Co-authored-by: bachvudinh <bachvudinh02@gmail.com> Co-authored-by: bachvudinh <89349141+bachvudinh@users.noreply.github.com>
1 parent 63f2ffb commit 39ed399

File tree

5 files changed

+78
-9
lines changed

5 files changed

+78
-9
lines changed

examples/vllm_multiturn/config/search_multiturn_grpo.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ actor_rollout_ref:
1919
name: sglang
2020
multi_turn:
2121
enable: True
22-
max_assistant_turns: 100
22+
max_assistant_turns: 50
2323
format: hermes
24+
max_parallel_calls: 10
25+
max_tool_response_length: 16384
26+
2427

verl/trainer/config/rollout/rollout.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,10 @@ multi_turn:
160160
max_user_turns: null
161161

162162
# max parallel call for tools in single turn
163-
max_parallel_calls: 1
163+
max_parallel_calls: 10
164164

165165
# max length of tool response
166-
max_tool_response_length: 256
166+
max_tool_response_length: 16384
167167

168168
# truncate side of tool response: left, middle, right
169169
tool_response_truncate_side: middle

verl/utils/dataset/rl_dataset.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,19 @@
4646
- Consider what information is needed to provide a complete answer
4747
4848
2. Mandatory Logical Analysis (Say It Out Loud):
49-
- Before engaging tools, you must articulate your complete thought process in natural language. You must act as a "professional tool caller," demonstrating extreme logic.
49+
- Before engaging tools, you must articulate your complete thought process in natural language. You must act as a "professional tool caller," demonstrating extreme logic and NEVER repeat your thought process across tool calls.
5050
- Analyze the Information Gap: Explicitly state what data is missing.
5151
- Derive the Strategy: Explain why a specific tool is the logical next step.
5252
- Justify Parameters: Explain why you chose those specific search keywords or that specific URL.
5353
54-
3. When you need to search for information, call the "web_search" tool using this exact XML format:
54+
3 Never give an unclear answer immediately right after searching; try scraping the URL to get more information.
55+
56+
4. When you need to search for information, call the "web_search" tool using this exact XML format:
5557
<tool_call>
5658
{{"name": "web_search", "arguments": {{"query": "your search query here"}}}}
5759
</tool_call>
5860
59-
4. If search results show promising URLs/documents but you need more detailed information, use the "scrape" tool:
61+
5. If search results show promising URLs/documents but you need more detailed information, use the "scrape" tool:
6062
<tool_call>
6163
{{"name": "scrape", "arguments": {{"url": "doc_1 or specific URL from search results"}}}}
6264
</tool_call>

verl/utils/reward_score/jan_v2_reward/format_reward.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
1-
import re
1+
import re, json
22
from scipy.stats import skewnorm
3+
import Levenshtein as levenshtein
34

5+
def count_tool_call_turns(text):
6+
# Pattern explanation:
7+
# assistant : matches the literal string
8+
# \s+ : matches one or more whitespace characters (space, tab, newline, etc.)
9+
# <tool_call> : matches the literal string
10+
pattern = r'assistant\s+<tool_call>'
11+
12+
# re.findall returns a list of all non-overlapping matches
13+
matches = re.findall(pattern, text)
14+
15+
return len(matches)
416

517
def calculate_skewed_penalty(x, center=70, skewness=5, scale=200):
618
"""
@@ -39,6 +51,18 @@ def parse_assistant_thoughts(text):
3951
return [match.strip() for match in matches if match.strip()]
4052

4153

54+
def parse_all_tool_calls(solution_str):
55+
pattern = r'<tool_call>(.*?)</tool_call>'
56+
matches = re.findall(pattern, solution_str, flags=re.DOTALL)
57+
matches = [match.strip() for match in matches if match.strip()]
58+
tool_calls = []
59+
for match in matches:
60+
try:
61+
tool_calls.append(json.loads(match))
62+
except:
63+
return []
64+
return tool_calls
65+
4266
def compute_format_reward(solution_str):
4367
if "<answer>" not in solution_str and "</answer>" not in solution_str:
4468
return 0
@@ -48,7 +72,32 @@ def compute_format_reward(solution_str):
4872
if solution_str.count("<tool_call>") == 0:
4973
return 0.
5074

75+
tool_calls = parse_all_tool_calls(solution_str)
76+
if len(tool_calls) == 0:
77+
return 0.
78+
79+
processed_tool_calls = []
80+
for tool_call in tool_calls:
81+
if "name" not in tool_call or "arguments" not in tool_call:
82+
return 0.
83+
if type(tool_call["arguments"]) != dict:
84+
return 0.
85+
tool = str(tool_call).lower().replace(" ", "")
86+
if tool in processed_tool_calls:
87+
return 0.
88+
processed_tool_calls.append(tool)
89+
90+
91+
5192
assistant_thoughts = parse_assistant_thoughts(solution_str)
93+
for i, str1 in enumerate(assistant_thoughts):
94+
for j, str2 in enumerate(assistant_thoughts):
95+
if i != j:
96+
ratio = levenshtein.ratio(str2, str1)
97+
if ratio > 0.8:
98+
return 0.0
99+
100+
52101
length = [len(x.split(" ")) for x in assistant_thoughts]
53102
if len(length):
54103
avg_length = float(sum(length))/len(length)
@@ -57,3 +106,15 @@ def compute_format_reward(solution_str):
57106
else:
58107
return 0.
59108

109+
def compute_bulk_tool_call_reward(solution_str):
110+
num_turns = count_tool_call_turns(solution_str)
111+
num_tools = solution_str.count("<tool_call>")
112+
if num_turns == 0 :
113+
return 0.0
114+
else:
115+
ratio = float(num_tools)/num_turns
116+
if ratio > 1:
117+
return ratio
118+
else:
119+
return 0.0
120+

verl/utils/reward_score/jan_v2_reward/qa_reward.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import re
33
import string
44
from .llm_judge_utils import evaluate_answer
5-
from .format_reward import compute_format_reward
5+
from .format_reward import compute_format_reward, compute_bulk_tool_call_reward
6+
67

78
def normalize_answer(s):
89
def remove_articles(text):
@@ -114,10 +115,12 @@ def compute_score(solution_str, ground_truth, question, format_score=0.0, score=
114115
if evaluate_result['grade_description'] == "CORRECT":
115116

116117
format_score = compute_format_reward(solution_str)
118+
bulk_tool_score = compute_bulk_tool_call_reward(solution_str)
117119
print(f"Solution string with format score {format_score}: {solution_str}")
118120
if format_score == 0.0:
119121
return 0.0
120-
return 1.0 + format_score*0.2
122+
return 1.0 + format_score*0.2 + bulk_tool_score* 0.2
123+
121124
else:
122125

123126
print(f"Solution string: {solution_str}")

0 commit comments

Comments
 (0)