Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
9a15870
migrate reward structure from 7b to 3b
sunildkumar Apr 25, 2025
307bddc
defined regex, but running into errors trying to use it.
sunildkumar Apr 25, 2025
d60bf1e
guided decoding generally works, but it is very slow. Trying to use a…
sunildkumar Apr 25, 2025
17d1388
migrate tool calling off json and off structured outputs to simpler f…
sunildkumar Apr 26, 2025
f8c9bdf
fixed errors in prompt schema collector. zoom now scales image tokens…
sunildkumar Apr 26, 2025
ddcddc7
prep for a 7B run
sunildkumar Apr 27, 2025
28506ae
log time to api
sunildkumar Apr 27, 2025
78a59f1
no image returns from od if no objects in image
sunildkumar Apr 27, 2025
4519a06
fix error in od
sunildkumar Apr 27, 2025
8eda8a0
reduce batch size due to oom
sunildkumar Apr 27, 2025
b5c9212
working on migrating to triton
sunildkumar Apr 28, 2025
2591c06
prep for another fc run. no od as couldn't get the API to behave. Ref…
sunildkumar Apr 28, 2025
21a09d2
bump lr as it worked in experiments this morning
sunildkumar Apr 28, 2025
fd6331f
rename run
sunildkumar Apr 28, 2025
6786454
new prompt, new bootstrap prompt
sunildkumar Apr 28, 2025
41794ee
reduce LR, training diverged. set grad norm
sunildkumar Apr 28, 2025
ac869ab
zoom docstring clarity
sunildkumar Apr 28, 2025
7ffc740
improve examples
sunildkumar Apr 28, 2025
31d54c8
make reward more strict and prevent no tool call. more IFT-y bootstrap
sunildkumar Apr 28, 2025
6b6e0a8
progress on tool server, having interface issues
sunildkumar Apr 28, 2025
c7f7109
eval script
sunildkumar Apr 28, 2025
9c245de
fix bug in script
sunildkumar Apr 28, 2025
6b34551
changing branches
sunildkumar Apr 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -185,4 +185,8 @@ dataset/*
attention_visualization*
src/r1_vlm/environments/real_iad_env/completion_results/completion_results.jsonl
src/r1_vlm/environments/real_iad_env/sft/dataset/*
src/r1_vlm/environments/real_iad_env/sft/output/*
src/r1_vlm/environments/real_iad_env/sft/output/*
*.engine
*.pbtxt
*.onnx
checkpoint*
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ dependencies = [
"sentence-transformers>=4.1.0",
"tiktoken>=0.9.0",
"openai>=1.65.4",
"opencv-python>=4.11.0.86",
]

[tool.hatch.metadata]
Expand Down
2 changes: 1 addition & 1 deletion src/r1_vlm/datasets/aok_vqa/aok_vqa_mc_tool_use_7B_r1.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def generate_r1_messages(example):
"content": [
{
"type": "text",
"text": "\n<think> I'll collect as much visual evidence as possible from the image, and then consider all possible answers. Then, I'll select the most likely answer based on the evidence and my knowledge of the world. First, I'll consider the tool available to me and determine how to best call it to collect the evidence needed to answer the question.",
"text": "\n<think> I'll collect as much visual evidence as possible from the image, and then consider all possible answers. Then, I'll select the most likely answer based on the evidence and my knowledge of the world. First, I'll consider the tools available to me and determine which one is most likely to help me collect the evidence needed to answer the question and how to best call it.",
}
],
},
Expand Down
4 changes: 2 additions & 2 deletions src/r1_vlm/datasets/aok_vqa/aok_vqa_mc_tool_use_r1.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def generate_r1_messages(example):

{choices_str}

You must use the tools to inspect the input image and gather visual evidence.
You must inspect the input image and gather visual evidence.
"""

r1_messages = [
Expand All @@ -63,7 +63,7 @@ def generate_r1_messages(example):
"content": [
{
"type": "text",
"text": "\n<think> I'll collect as much visual evidence as possible from the image, and then consider all possible answers. Then, I'll select the most likely answer based on the evidence and my knowledge of the world. First, I'll consider the tools available to me and determine which one is most likely to help me collect the evidence needed to answer the question and how to best call it.",
"text": "\n<think> I'll collect as much visual evidence as possible from the image. First, I'll consider what region of the image to zoom in on to get the most information. Then, I'll review and consider the four possible answers. Then, I'll select the most likely answer based on the evidence and my knowledge of the world.",
}
],
},
Expand Down
13 changes: 13 additions & 0 deletions src/r1_vlm/environments/regex_outline.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@

1. Message 1 - either the model ends in an answer or a tool
a. text </think> <tool> tool regex </tool>
b. text </think> <answer> text </answer>

2. Message 2 - if the model ends message 1 in a tool, then message 2 is the answer
<think> text </think> <answer> text </answer>


{think section without open think | think section with open think} {tool section | answer section}

tool section:
{ zoom tool | detect objects tool }
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def train():
save_steps=50,
save_total_limit=3,
num_train_epochs=10,
per_device_train_batch_size=2,
num_generations=12,
per_device_train_batch_size=1,
num_generations=6,
gradient_accumulation_steps=4,
gradient_checkpointing=gradient_checkpointing,
bf16=True,
Expand All @@ -143,6 +143,8 @@ def train():
epsilon_high=0.28,
# reward weights with schedules for some of the reward functions
reward_weights=reward_weights,
# clip gradients to avoid exploding gradients
max_grad_norm=1.0,
)

trainer = QwenGRPOTrainer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,27 @@
)
from r1_vlm.datasets.utils import preprocess_r1_dataset
from r1_vlm.environments.multistep_vision_env import MultistepVisionEnv
from r1_vlm.environments.tool_vision_env import ToolVisionEnv
from r1_vlm.environments.tool_vision_env import ToolArgParser, ToolVisionEnv
from r1_vlm.tools.object_detection import detect_objects, parse_detect_objects_args
from r1_vlm.tools.tool_prompts import SINGLE_TOOL_PROMPT_TEMPLATE
from r1_vlm.tools.zoom import zoom
from r1_vlm.tools.zoom import parse_zoom_args, zoom


class AOKVQAToolEnv(ToolVisionEnv):
def __init__(
self,
processing_class: AutoProcessor,
dataset_name: str = "Groundlight/real-iad-toy-brick-tool-use-r1",
tools: list[Callable] = [zoom],
tools_with_parsers: list[tuple[Callable, ToolArgParser]] = [
(detect_objects, parse_detect_objects_args),
(zoom, parse_zoom_args),
],
max_steps: int = 3,
tool_prompt_template: str = SINGLE_TOOL_PROMPT_TEMPLATE,
):
super().__init__(
processing_class=processing_class,
tools=tools,
tools_with_parsers=tools_with_parsers,
max_steps=max_steps,
tool_prompt_template=tool_prompt_template,
)
Expand Down Expand Up @@ -200,16 +204,6 @@ def check_format(trajectory):
if valid_end: # Should end with tool/answer if structure is valid
format_score += 0.2

# Debug print (optional, can be removed)
print(
f"text_content: {text_content},\n"
f"has_think: {has_think}, has_tool: {has_tool}, has_answer: {has_answer}, "
f"is_valid_structure: {is_valid_structure},\n"
f"has_correct_spacing: {has_correct_spacing}, "
f"starts_with_think: {starts_with_think}, valid_end: {valid_end},\n"
f"format_score: {format_score}"
)

format_scores.append(format_score)
if not format_scores:
return 0.0
Expand Down
141 changes: 141 additions & 0 deletions src/r1_vlm/environments/tool_use_aokvqa_env/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import json
import os
import re

from imgcat import imgcat
from tqdm import tqdm
from transformers import AutoProcessor
from vllm import LLM, SamplingParams

from r1_vlm.environments.tool_use_aokvqa_env.tool_use_aokvqa_env import AOKVQAToolEnv


def extract_answer(generation: str):
"""Extracts the text between the first <answer> and </answer> tags."""
match = re.search(r"<answer>(.*?)</answer>", generation, re.DOTALL)
if match:
return match.group(1).strip()
else:
return None


def main():
checkpoint = (
"/millcreek/home/sunil/r1_vlm_bumbershoot2/r1_vlm/checkpoint-850-better-zoom"
)
processor = AutoProcessor.from_pretrained(checkpoint, padding_side="left")
vf_env = AOKVQAToolEnv(processing_class=processor)
train_dataset, val_dataset, test_dataset = vf_env.get_dataset()

if not os.path.exists("generations.json"):
vlm = LLM(
model=checkpoint,
gpu_memory_utilization=1.0,
dtype="bfloat16",
tensor_parallel_size=2,
enable_prefix_caching=True,
limit_mm_per_prompt={"image": 2, "video": 0},
)

sampling_params = SamplingParams(
temperature=0.1,
max_tokens=2048,
)

batch_size = 6
batches = []

for example in val_dataset:
if len(batches) == 0:
batches.append([example])
elif len(batches[-1]) < batch_size:
batches[-1].append(example)
else:
batches.append([example])

generations = []
for batch in tqdm(batches, desc="Generating completions"):
conversations, texts, processed_batch, vllm_inputs = vf_env.prepare_data(
inputs=batch, processing_class=processor
)

completion_ids = vf_env.generate(
conversations=conversations,
vlm_inputs=vllm_inputs,
vlm=vlm,
sampling_params=sampling_params,
)

generated_texts = processor.batch_decode(
completion_ids["ids"],
skip_special_tokens=False,
clean_up_tokenization_spaces=False,
)

print(generated_texts)

for example, generation in zip(batch, generated_texts):
data = {
"question_id": example["question_id"],
"question": example["question"],
"options": example["choices"],
"rationales": example["rationales"],
"gt_answer": example["multiple_choice_answer"],
"generation": generation,
"model_answer": extract_answer(generation),
}
generations.append(data)

# Save the generations list as a JSON array to a file
with open("generations.json", "w") as f:
json.dump(generations, f, indent=2) # Use indent for readability (optional)

else:
with open("generations.json", "r") as f:
generations = json.load(f)

generations_dict = {}
for generation in generations:
if generation["question_id"] in generations_dict:
raise ValueError(f"Duplicate question_id: {generation['question_id']}")
generations_dict[generation["question_id"]] = generation

total = 0
correct = 0
in_option_set = 0
for example in val_dataset:
question_id = example["question_id"]

if question_id not in generations_dict:
raise ValueError(f"Question_id not found in generations: {question_id}")

model_answer = generations_dict[question_id]["model_answer"]
gt_answer = example["multiple_choice_answer"]

options_set = example["choices"]

if model_answer in options_set:
in_option_set += 1

total += 1
if model_answer == gt_answer:
correct += 1

else:
print("--------------------------------")
print("Incorrect answer:")
print(f"Question: {example['question']}")
print(f"Model answer: {model_answer}")
print(f"GT answer: {gt_answer}")
print(f"Options: {options_set}")
print(f"Generation: {generations_dict[question_id]['generation']}")
print(f"Reasoning: {example['rationales']}")
imgcat(example["image"])
print("--------------------------------")

print(f"Accuracy: {correct / total}")
print(f"In option set: {in_option_set / total}")


if __name__ == "__main__":
main()
48 changes: 48 additions & 0 deletions src/r1_vlm/environments/tool_use_aokvqa_env/regex_for_aok.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Regex pattern for the complete <answer>...</answer> block.
# Allows any content within the tags. (?s) makes '.' match newlines.
ANSWER_BLOCK_REGEX = r"(?s)<answer>.*?</answer>"

# --- Tool Call Regexes ---

# Permissive pattern for JSON values, used to avoid validating value contents.
# Matches any characters non-greedily up to the next structural element (comma, brace).
_PERMISSIVE_JSON_VALUE = r".*?"

# Core JSON structure for the zoom tool call.
# Enforces keys and structure (colons, commas, braces) but not value formats.
_ZOOM_TOOL_JSON_CORE = (
r"\{\s*"
r"\"name\"\s*:\s*\"zoom\"\s*,"
r"\s*\"args\"\s*:\s*\{\s*"
r"\"image_name\"\s*:\s*\"input_image\"\s*,"
r"\s*\"bbox\"\s*:\s*" + _PERMISSIVE_JSON_VALUE + r"\s*,"
r"\s*\"magnification\"\s*:\s*" + _PERMISSIVE_JSON_VALUE + r"\s*\}\s*"
r"\}\s*"
)
# Full regex for the <tool> call using the zoom tool's core JSON structure.
ZOOM_TOOL_CALL_REGEX = rf"<tool>\s*{_ZOOM_TOOL_JSON_CORE}\s*</tool>"

# Core JSON structure for the detect_objects tool call.
# Enforces keys and structure but not the format of the 'classes' value.
_DETECT_OBJECTS_JSON_CORE = (
r"\{\s*"
r"\"name\"\s*:\s*\"detect_objects\"\s*,"
r"\s*\"args\"\s*:\s*\{\s*"
r"\"image_name\"\s*:\s*\"input_image\"\s*,"
r"\s*\"classes\"\s*:\s*" + _PERMISSIVE_JSON_VALUE + r"\s*\}\s*"
r"\}\s*"
)
# Full regex for the <tool> call using the detect_objects tool's core JSON structure.
DETECT_OBJECTS_TOOL_CALL_REGEX = rf"<tool>\s*{_DETECT_OBJECTS_JSON_CORE}\s*</tool>"

# Combined regex matching either a valid zoom OR detect_objects tool call.
EITHER_TOOL_CALL_REGEX = f"(?:{ZOOM_TOOL_CALL_REGEX}|{DETECT_OBJECTS_TOOL_CALL_REGEX})"

# --- Final Combined Regex for Model Output ---

# Matches optional arbitrary leading text (e.g., thoughts) followed by
# either a valid tool call OR an answer block, then optional whitespace.
# This regex is intended for use with vLLM's regex-guided generation.
# (?s) allows '.' to match newlines in the leading arbitrary text.
# The non-greedy .*? ensures it matches up to the first valid tool/answer block.
FINAL_OUTPUT_REGEX = rf"(?s).*?(?:{EITHER_TOOL_CALL_REGEX}|{ANSWER_BLOCK_REGEX})\s*"
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,12 @@ def train():
model, peft_config, processor, model_config, gradient_checkpointing = (
load_model_and_processor(gradient_checkpointing=True, use_peft=False)
)
print("loaded model")

vf_env = AOKVQAToolEnv(processing_class=processor, max_steps=3)

print("loaded env")

train_dataset, val_dataset, test_dataset = vf_env.get_dataset()

rubric = vf_env.get_rubric()
Expand All @@ -103,9 +106,10 @@ def train():
training_args = GRPOConfig(
model_init_kwargs=model_config,
# save path on the runpod instance
output_dir="vlm-r1-tool-use-aokvqa-env-reduced-beta-single-tool",
output_dir="vlm-r1-zoom-only-reward-refactor-oversampling",
# increase learning rate for PEFT - 1e-4
learning_rate=1e-4 if peft_config is not None else 1e-6,
max_grad_norm=1.0,
adam_beta2=0.98,
lr_scheduler_type="cosine",
warmup_steps=10,
Expand Down Expand Up @@ -148,6 +152,7 @@ def train():
train_dataset=train_dataset,
env=vf_env,
peft_config=peft_config,
guided_regex=None,
)

trainer.train()
Expand Down
Loading