From 32054fb0041fa122cce863228748f367c202adb5 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Tue, 29 Apr 2025 22:33:22 +0000 Subject: [PATCH 01/24] revised object detection tool to use triton + yolo --- pyproject.toml | 2 + src/r1_vlm/tools/object_detection.py | 221 +++++++---------------- uv.lock | 255 +++++++++++++++++++++++++++ 3 files changed, 325 insertions(+), 153 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 960d2c1c..239cfe0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,8 @@ dependencies = [ "tiktoken>=0.9.0", "openai>=1.65.4", "opencv-python>=4.11.0.86", + "tritonclient[all]>=2.51.0", + "ultralytics>=8.3.120", ] [tool.hatch.metadata] diff --git a/src/r1_vlm/tools/object_detection.py b/src/r1_vlm/tools/object_detection.py index af521e8e..83fc030d 100644 --- a/src/r1_vlm/tools/object_detection.py +++ b/src/r1_vlm/tools/object_detection.py @@ -1,17 +1,14 @@ -import base64 -import io +import contextlib import json import os import time # Import the time module # Add imports for numpy and cv2 -import cv2 -import numpy as np -import pytest -import requests from dotenv import load_dotenv from imgcat import imgcat from PIL import Image +from tritonclient.http import InferenceServerClient +from ultralytics import YOLO from r1_vlm.environments.tool_vision_env import RawToolArgs, TypedToolArgs @@ -20,16 +17,60 @@ API_IP = str(os.getenv("API_IP")) API_PORT = int(os.getenv("API_PORT")) +_object_detection_tool = None -def detect_objects( - image_name: str, classes: list[str], **kwargs -) -> tuple[list[dict], Image.Image]: + +class ObjectDetectionTool: + def __init__(self): + # url = f"{API_IP}:{API_PORT}/yolo" + url = "localhost:8000/yolo" + self.triton_client = InferenceServerClient(url=url, verbose=False, ssl=False) + + # Wait until model is ready + for _ in range(10): + with contextlib.suppress(Exception): + assert self.triton_client.is_model_ready("yolo") + break + time.sleep(1) + + self.model = YOLO("http://localhost:8000/yolo", task="detect") + + def detect_objects(self, image: Image.Image) -> list[dict]: + result = self.model(image, conf=0.3)[0] + boxes = [[int(round(x)) for x in box] for box in result.boxes.xyxy.tolist()] + labels = [result.names[int(cls)] for cls in result.boxes.cls.tolist()] + + detections = [ + {"bbox_2d": box, "label": label} for box, label in zip(boxes, labels) + ] + + if len(detections) == 0: + dets_string = "No objects detected." + annotated_image = None + else: + dets_string = "" + for index, det in enumerate(detections): + dets_string += f"{index + 1}. {det}" + + if index < len(detections) - 1: + dets_string += "\n" + + annotated_image = result.plot(conf=False, labels=True) + + return {"text_data": dets_string, "image_data": annotated_image} + + +def set_object_detection_tool(tool: ObjectDetectionTool): + global _object_detection_tool + _object_detection_tool = tool + + +def detect_objects(image_name: str, **kwargs) -> tuple[list[dict], Image.Image]: """ - Calls an open vocabulary object detection model on the image. Useful for localizing objects in an image or determining if an object is present. + Calls an object detection model on the image. Useful for localizing objects in an image or determining if an object is present. Args: image_name: str, the name of the image to detect objects in. Can only be called on the "input_image" image. - classes: list[str], the classes to detect. As the model is open vocabulary, your classes can be any object you want to detect in the image. Each class should contain an noun for best results. Returns: 1. A list of dictionaries, each containing the following keys: @@ -41,12 +82,6 @@ def detect_objects( name: detect_objects image_name: input_image - classes: ["car", "person", "train", "bus"] - - - name: detect_objects - image_name: input_image - classes: ["elephant", "white jeep", "tree", "water"] """ @@ -64,94 +99,15 @@ def detect_objects( f"Error: Image {image_name} is not the input_image. This tool can only be called on the input_image." ) - # construct the API request - # I decided to fix the confidence threshold at 0.10, as the model tends to set this value very high, which leads to a lot of false negatives - url = f"http://{API_IP}:{API_PORT}/detect?confidence={0.10}" - - # Convert PIL Image to bytes - img_byte_arr = io.BytesIO() - image.save(img_byte_arr, format="JPEG") - img_byte_arr = img_byte_arr.getvalue() - - files = {"image": img_byte_arr} - data = {} - for c in classes: - data.setdefault("classes", []).append(c) - - # send the request - start_time = time.time() # Record start time - response = requests.post(url, files=files, data=data) - end_time = time.time() # Record end time - print(f"API call took {end_time - start_time:.2f} seconds") # Print duration - - if response.status_code == 200: - result = response.json() - else: - raise Exception( - f"Error: API request failed with status code {response.status_code}" - ) - - detections = result["results"]["detections"] - - dets = [] - for detection in detections: - dets.append( - { - "bbox_2d": detection["bbox_2d"], - "label": detection["label"], - } - ) - - if len(dets) == 0: - dets_string = "No objects detected." - annotated_image = None - else: - dets_string = "" - for index, det in enumerate(dets): - dets_string += f"{index + 1}. {det}" - - if index < len(dets) - 1: - dets_string += "\n" - - # convert the annotated image(base64 encoded) to a PIL Image only if detections exist - annotated_image_data = base64.b64decode(result["annotated_image"]) - annotated_image_pil = Image.open(io.BytesIO(annotated_image_data)) - - # Convert PIL Image to NumPy array (OpenCV format) - # PIL images with mode 'RGB' are loaded as NumPy arrays with shape (H, W, 3) in RGB order. - # PIL images with mode 'RGBA' are loaded as NumPy arrays with shape (H, W, 4) in RGBA order. - annotated_image_np = np.array(annotated_image_pil) - - # Convert BGR(A) to RGB(A) using OpenCV if it's a color image - # Assuming the source API sent BGR/BGRA data, which np.array converted retaining channel order relative to PIL's interpretation. - # If PIL interpreted as RGB, the np array is RGB. If RGBA, the np array is RGBA. - # Since the *source* was BGR/BGRA, we convert the numpy array from BGR/BGRA to RGB/RGBA. - if annotated_image_np.ndim == 3 and annotated_image_np.shape[2] == 3: # RGB/BGR - annotated_image_np_rgb = cv2.cvtColor(annotated_image_np, cv2.COLOR_BGR2RGB) - elif ( - annotated_image_np.ndim == 3 and annotated_image_np.shape[2] == 4 - ): # RGBA/BGRA - annotated_image_np_rgb = cv2.cvtColor( - annotated_image_np, cv2.COLOR_BGRA2RGBA - ) - else: - # Grayscale or other formats, no conversion needed - annotated_image_np_rgb = annotated_image_np - - # Convert NumPy array back to PIL Image - annotated_image = Image.fromarray(annotated_image_np_rgb) - - # Return None for image_data if no detections were found - return {"text_data": dets_string, "image_data": annotated_image} + return _object_detection_tool.detect_objects(image) def parse_detect_objects_args(raw_args: RawToolArgs) -> TypedToolArgs: """ Parses raw string arguments for the detect_objects tool, focusing on type conversion. - Expects keys: 'name', 'image_name', 'classes'. - Converts 'classes' from a JSON string representing a list of strings. - Detailed validation of values (e.g., 'image_name' validity, 'classes' content) + Expects keys: 'name', 'image_name' + Detailed validation of values (e.g., 'image_name' validity) is deferred to the detect_objects function itself. Args: @@ -159,13 +115,12 @@ def parse_detect_objects_args(raw_args: RawToolArgs) -> TypedToolArgs: Returns: A dictionary containing the arguments with basic type conversions applied, - ready for the detect_objects function. Keys: 'image_name', 'classes'. + ready for the detect_objects function. Keys: 'image_name'. Raises: ValueError: If required keys are missing or basic type conversion fails - (e.g., 'classes' is not valid JSON). """ - required_keys = {"name", "image_name", "classes"} + required_keys = {"name", "image_name"} actual_keys = set(raw_args.keys()) # 1. Check for Missing Keys @@ -188,17 +143,6 @@ def parse_detect_objects_args(raw_args: RawToolArgs) -> TypedToolArgs: # Keep image_name as string typed_args["image_name"] = raw_args["image_name"] - # Convert classes string using json.loads - classes_list = json.loads(raw_args["classes"]) - - # Basic type check - ensure it's a list, defer content check (list of strings) to tool - if not isinstance(classes_list, list): - raise ValueError( - f"Error: Invalid format for 'classes': Expected a JSON list, got type {type(classes_list).__name__}" - ) - - typed_args["classes"] = classes_list - except json.JSONDecodeError: raise ValueError( f"Error: Invalid JSON format for 'classes': '{raw_args['classes']}'" @@ -213,44 +157,15 @@ def parse_detect_objects_args(raw_args: RawToolArgs) -> TypedToolArgs: return typed_args -@pytest.fixture -def sample_image_fixture(): - """Provides a simple dummy image for testing.""" - current_dir = os.path.dirname(os.path.abspath(__file__)) - img = Image.open(os.path.join(current_dir, "cars.jpeg")) - return {"input_image": img} - - -def test_basic_detection_integration(sample_image_fixture): - """Tests basic object detection call against the running API.""" - # Call the function under test - this will make a real HTTP request - # Using classes unlikely to be in a plain red image might be safer - # depending on the actual model behavior. Let's use "object". - try: - result = detect_objects( - image_name="input_image", - # there should be cars, but no dogs - classes=["car", "dog"], - images=sample_image_fixture, - ) +if __name__ == "__main__": + tool = ObjectDetectionTool() + set_object_detection_tool(tool) + image = Image.open( + "/millcreek/home/sunil/r1_vlm_bumbershoot0/r1_vlm/src/r1_vlm/tools/cars.jpeg" + ) + detections = detect_objects(image_name="input_image", images={"input_image": image}) - assert isinstance(result, dict) - assert "text_data" in result - assert "image_data" in result - assert isinstance(result["text_data"], str) - assert isinstance(result["image_data"], Image.Image) - - # visualize the annotated image - annotated_image = result["image_data"] - imgcat(annotated_image) - - # visualize the text data - print(result["text_data"]) - - except requests.exceptions.ConnectionError as e: - pytest.fail( - f"API connection failed. Is the server running at http://{API_IP}:{API_PORT}? Error: {e}" - ) - except Exception as e: - # Catch other potential errors during the API call or processing - pytest.fail(f"An unexpected error occurred: {e}") + image_data = detections["image_data"] + imgcat(image_data) + text_data = detections["text_data"] + print(text_data) diff --git a/uv.lock b/uv.lock index e925b579..8adcfdb0 100644 --- a/uv.lock +++ b/uv.lock @@ -790,6 +790,33 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/cf/1f7649b8b9a3543e042d3f348e398a061923ac05b507f3f4d95f11938aa9/cryptography-44.0.2-cp39-abi3-win_amd64.whl", hash = "sha256:5f6f90b72d8ccadb9c6e311c775c8305381db88374c65fa1a68250aa8a9cb3a6", size = 3210957 }, ] +[[package]] +name = "cuda-bindings" +version = "12.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pywin32", marker = "sys_platform == 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/3c/16/621f2ff6e4c6a0c1d57f5a0a373d1fb9d10eb9a7f05052cc64eba2e7dab2/cuda_bindings-12.8.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0865c9b75ee8f0535044c3f0f06ca34a37131192b573ab59e20a9e058da1ead4", size = 10904424 }, + { url = "https://files.pythonhosted.org/packages/59/11/aee1afd60a5d6af67994dd88697912be22366a6e548e52e6cd2defdbe678/cuda_bindings-12.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e0e6a889c87238e6cd55e9b25ce4fd1d90fe2d4169982860fed5f0bc3230795e", size = 11235285 }, + { url = "https://files.pythonhosted.org/packages/c1/c7/eedad18aeb461e9a3c1f8e2ea856caa50202a572b024912cb561f847a054/cuda_bindings-12.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:d0123d841cb3053d227e18b08ea7680d0b5ca64fab4664a2b80b7c83c8edf1ee", size = 11224401 }, + { url = "https://files.pythonhosted.org/packages/4e/82/dc34a092d9111524eea70671d41d72dd3a5452ef70c424680bee1daf9c45/cuda_bindings-12.8.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e264ea93027c7448b9efa134729c12217ca9096051114ee7a9425d49b5a14222", size = 10722116 }, + { url = "https://files.pythonhosted.org/packages/78/f2/b5c3f07f743e74c1f5c42bb2fc6e735f3adac8b526f60ef731d861663dd9/cuda_bindings-12.8.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:099f27e79e754346fa51517168787cda395fb437b31fbf20771c002f30adc0c9", size = 11039795 }, + { url = "https://files.pythonhosted.org/packages/d5/89/d1f3c70651cdeb7c276c0503aea34c1d0c22f8bc66de73887f5ce40c600a/cuda_bindings-12.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:630290148879b47f5e34629ee15061414caaf2f73ea284175a73b30427ad94fd", size = 11190771 }, +] + +[[package]] +name = "cuda-python" +version = "12.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cuda-bindings" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/2c/02bb311b996ffb91d05f8c1fb79131bf50855f7410dd33d09f800fe78c58/cuda_python-12.8.0-py3-none-any.whl", hash = "sha256:3fca3a03c247d6aa1c414989dfe0dd21e9500307b8573f72216ed57d99344c5a", size = 11930 }, +] + [[package]] name = "cupy-cuda12x" version = "13.4.0" @@ -1232,6 +1259,49 @@ http = [ { name = "aiohttp" }, ] +[[package]] +name = "gevent" +version = "25.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi", marker = "platform_python_implementation == 'CPython' and sys_platform == 'win32'" }, + { name = "greenlet", marker = "platform_python_implementation == 'CPython'" }, + { name = "zope-event" }, + { name = "zope-interface" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/00/e5/a2d9c2d5bfb575973bca7733b23e7f8649f1079c18140a8680a551f3963e/gevent-25.4.2.tar.gz", hash = "sha256:7ffba461458ed28a85a01285ea0e0dc14f883204d17ce5ed82fa839a9d620028", size = 6342241 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/67/3c9a560d3b64510dc053714375b3d9f2c3d98192dc85b78a6e6f8b9a284b/gevent-25.4.2-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:5940174c7d1ffc7bb4b0ea9f2908f4f361eb03ada9e145d3590b8df1e61c379b", size = 2969979 }, + { url = "https://files.pythonhosted.org/packages/39/ee/594a40e09d9d56b76a04265ea37b825ec8e7b98cd41e8012eda413f233e6/gevent-25.4.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7ae7ad4ff9c4492d4b633702e35153509b07dc6ffd20f1577076d7647c9caba", size = 1805780 }, + { url = "https://files.pythonhosted.org/packages/d6/87/0707bfae4cc3728eb8d5fc29018b5ac3e0e1f8efca237d267d1d3abc7153/gevent-25.4.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d68fdf9bff0068367126983d7d85765124c292b4bc3d4d19ed8138335d8426a7", size = 1885718 }, + { url = "https://files.pythonhosted.org/packages/09/c6/4f35473d46ca8cfbffeee5e6f89ac29370280b3f34682ed8f0fea907f987/gevent-25.4.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ff92408011d78e4ffe297331ff30cded39a3e22845ba237516c646f6a485a241", size = 1845102 }, + { url = "https://files.pythonhosted.org/packages/7a/9b/d2269957be2867802d10bcb28e17eba64783067057d55e91e57207294c05/gevent-25.4.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a7c70ab6d33dfeb43bfe982c636609d8f90506dacaaa1f409a3c43c66d578fb1", size = 2084973 }, + { url = "https://files.pythonhosted.org/packages/6b/59/9a069d16d8b6b7ef82b0d241de9041b1341c9f132fbd096b80d6d1bc2345/gevent-25.4.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:8e740bc08ba4c34951f4bb6351dbe04209416e12d620691fb57e115b218a7818", size = 1822891 }, + { url = "https://files.pythonhosted.org/packages/96/0d/815808f04cef2410a93521814e51de7554874012fc49c5ca7197f86ac340/gevent-25.4.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c535d96ded6e26b37fadda9242a49fea6308754da5945173940614b7520c07b4", size = 2115665 }, + { url = "https://files.pythonhosted.org/packages/42/b4/15e5f9c06d50843c0e7c87d580acc2ac4e47fef0195c2d3f73c3bd54e3f0/gevent-25.4.2-cp312-cp312-win_amd64.whl", hash = "sha256:c62bf14557d2cb54f5e3c1ba0a3b3f4b69bf0441081c32d63b205763b495b251", size = 1679652 }, + { url = "https://files.pythonhosted.org/packages/7d/1d/195936c1e0c5b1dc89a8b534c05d080d24d760f6913632cbb13d9430c907/gevent-25.4.2-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:f735f57bc19d0f8bbc784093cfb7953a9ad66612b05c3ff876ec7951a96d7edd", size = 2996686 }, + { url = "https://files.pythonhosted.org/packages/52/2a/a82de55db10ca17e210a61548a421d65d144045a62958d172537d4ea6f26/gevent-25.4.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63aecf1e43b8d01086ea574ed05f7272ed40c48dd41fa3d061e3c5ca900abcdd", size = 1809379 }, + { url = "https://files.pythonhosted.org/packages/77/73/3508d539c96e435d883aa07c67ad5859505af33346795c8c575501d3ebda/gevent-25.4.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f12e570777027f807dc7dc3ea1945ea040befaf1c9485deb6f24d7110009fc12", size = 1887353 }, + { url = "https://files.pythonhosted.org/packages/4d/40/911e4eca7958bea73d3889433e780b59413f3d7bbd4d24cadc0a2f276528/gevent-25.4.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:44acca4196d4a174c2b4817642564526898f42f72992dc1818b834b2bbf17582", size = 1848809 }, + { url = "https://files.pythonhosted.org/packages/59/eb/ccf5a2d7cb8ed2814b69fbe9cf46a8875f275fa0e5984889b1cbb0a67492/gevent-25.4.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75d2fdd24f3948c085d341281648014760f5cb23de9b29f710083e6911b2e605", size = 2084966 }, + { url = "https://files.pythonhosted.org/packages/7d/19/a1aadd6f3da55f18bb10877ccda7245be0c3b5e6acdc3c882fe54f412e01/gevent-25.4.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:0cc1d6093f482547ac522ab1a985429d8c12494518eeca354c956f0ff6de7a94", size = 1824458 }, + { url = "https://files.pythonhosted.org/packages/0f/70/ee8b5a4df0a6f587c44a102ad46356d626d652e35f46eeec05c5ba1575de/gevent-25.4.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:fe4a3e3fa3a16ed9b12b6ff0922208ef83287e066e696b82b96d33723d8207f2", size = 2116628 }, + { url = "https://files.pythonhosted.org/packages/13/c6/50ee863dd09dd31f61892b847b684fde730473487bcae3240acd9e3e412c/gevent-25.4.2-cp313-cp313-win_amd64.whl", hash = "sha256:8b90913360b1af058b279160679d804d4917a8661f128b2f7625f8665c39450f", size = 1678856 }, + { url = "https://files.pythonhosted.org/packages/54/d8/e29cc7f90ae7aa9e8f5298ca5a157bab34bfbc65d070385b28f4d72af1ac/gevent-25.4.2-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:b0a656eccd9cb115d01c9bbe55bfe84cf20c8422c495503f41aef747b193c33d", size = 3007128 }, +] + +[[package]] +name = "geventhttpclient" +version = "2.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "brotli" }, + { name = "certifi" }, + { name = "gevent" }, + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/68/98/1ee9fbab4ae97d5f0f05035059a56a61a9c966331e6c837f974b402fdf63/geventhttpclient-2.0.2.tar.gz", hash = "sha256:8135a85200b170def7293d01dd1557931fcd1bec1ac78c52ad7cedd22368b9ba", size = 73821 } + [[package]] name = "gguf" version = "0.10.0" @@ -1326,6 +1396,40 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/95/cb/002424d4f5af1425f9cfe7dcee3ed795ed1367bf0f185a6c4bf81385e1d6/gradio_client-1.7.2-py3-none-any.whl", hash = "sha256:50d61b4db3e87639430a121a7cde4303055486ed72a5035edae94b4fbe6a0e6b", size = 322052 }, ] +[[package]] +name = "greenlet" +version = "3.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3f/74/907bb43af91782e0366b0960af62a8ce1f9398e4291cac7beaeffbee0c04/greenlet-3.2.1.tar.gz", hash = "sha256:9f4dd4b4946b14bb3bf038f81e1d2e535b7d94f1b2a59fdba1293cd9c1a0a4d7", size = 184475 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f0/d1/e4777b188a04726f6cf69047830d37365b9191017f54caf2f7af336a6f18/greenlet-3.2.1-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:0ba2811509a30e5f943be048895a983a8daf0b9aa0ac0ead526dfb5d987d80ea", size = 270381 }, + { url = "https://files.pythonhosted.org/packages/59/e7/b5b738f5679247ddfcf2179c38945519668dced60c3164c20d55c1a7bb4a/greenlet-3.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4245246e72352b150a1588d43ddc8ab5e306bef924c26571aafafa5d1aaae4e8", size = 637195 }, + { url = "https://files.pythonhosted.org/packages/6c/9f/57968c88a5f6bc371364baf983a2e5549cca8f503bfef591b6dd81332cbc/greenlet-3.2.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7abc0545d8e880779f0c7ce665a1afc3f72f0ca0d5815e2b006cafc4c1cc5840", size = 651381 }, + { url = "https://files.pythonhosted.org/packages/40/81/1533c9a458e9f2ebccb3ae22f1463b2093b0eb448a88aac36182f1c2cd3d/greenlet-3.2.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6dcc6d604a6575c6225ac0da39df9335cc0c6ac50725063fa90f104f3dbdb2c9", size = 646110 }, + { url = "https://files.pythonhosted.org/packages/06/66/25f7e4b1468ebe4a520757f2e41c2a36a2f49a12e963431b82e9f98df2a0/greenlet-3.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2273586879affca2d1f414709bb1f61f0770adcabf9eda8ef48fd90b36f15d12", size = 648070 }, + { url = "https://files.pythonhosted.org/packages/d7/4c/49d366565c4c4d29e6f666287b9e2f471a66c3a3d8d5066692e347f09e27/greenlet-3.2.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ff38c869ed30fff07f1452d9a204ece1ec6d3c0870e0ba6e478ce7c1515acf22", size = 603816 }, + { url = "https://files.pythonhosted.org/packages/04/15/1612bb61506f44b6b8b6bebb6488702b1fe1432547e95dda57874303a1f5/greenlet-3.2.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e934591a7a4084fa10ee5ef50eb9d2ac8c4075d5c9cf91128116b5dca49d43b1", size = 1119572 }, + { url = "https://files.pythonhosted.org/packages/cc/2f/002b99dacd1610e825876f5cbbe7f86740aa2a6b76816e5eca41c8457e85/greenlet-3.2.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:063bcf7f8ee28eb91e7f7a8148c65a43b73fbdc0064ab693e024b5a940070145", size = 1147442 }, + { url = "https://files.pythonhosted.org/packages/c0/ba/82a2c3b9868644ee6011da742156247070f30e952f4d33f33857458450f2/greenlet-3.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:7132e024ebeeeabbe661cf8878aac5d2e643975c4feae833142592ec2f03263d", size = 296207 }, + { url = "https://files.pythonhosted.org/packages/77/2a/581b3808afec55b2db838742527c40b4ce68b9b64feedff0fd0123f4b19a/greenlet-3.2.1-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:e1967882f0c42eaf42282a87579685c8673c51153b845fde1ee81be720ae27ac", size = 269119 }, + { url = "https://files.pythonhosted.org/packages/b0/f3/1c4e27fbdc84e13f05afc2baf605e704668ffa26e73a43eca93e1120813e/greenlet-3.2.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e77ae69032a95640a5fe8c857ec7bee569a0997e809570f4c92048691ce4b437", size = 637314 }, + { url = "https://files.pythonhosted.org/packages/fc/1a/9fc43cb0044f425f7252da9847893b6de4e3b20c0a748bce7ab3f063d5bc/greenlet-3.2.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3227c6ec1149d4520bc99edac3b9bc8358d0034825f3ca7572165cb502d8f29a", size = 651421 }, + { url = "https://files.pythonhosted.org/packages/8a/65/d47c03cdc62c6680206b7420c4a98363ee997e87a5e9da1e83bd7eeb57a8/greenlet-3.2.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0ddda0197c5b46eedb5628d33dad034c455ae77708c7bf192686e760e26d6a0c", size = 645789 }, + { url = "https://files.pythonhosted.org/packages/2f/40/0faf8bee1b106c241780f377b9951dd4564ef0972de1942ef74687aa6bba/greenlet-3.2.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de62b542e5dcf0b6116c310dec17b82bb06ef2ceb696156ff7bf74a7a498d982", size = 648262 }, + { url = "https://files.pythonhosted.org/packages/e0/a8/73305f713183c2cb08f3ddd32eaa20a6854ba9c37061d682192db9b021c3/greenlet-3.2.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c07a0c01010df42f1f058b3973decc69c4d82e036a951c3deaf89ab114054c07", size = 606770 }, + { url = "https://files.pythonhosted.org/packages/c3/05/7d726e1fb7f8a6ac55ff212a54238a36c57db83446523c763e20cd30b837/greenlet-3.2.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:2530bfb0abcd451ea81068e6d0a1aac6dabf3f4c23c8bd8e2a8f579c2dd60d95", size = 1117960 }, + { url = "https://files.pythonhosted.org/packages/bf/9f/2b6cb1bd9f1537e7b08c08705c4a1d7bd4f64489c67d102225c4fd262bda/greenlet-3.2.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:1c472adfca310f849903295c351d297559462067f618944ce2650a1878b84123", size = 1145500 }, + { url = "https://files.pythonhosted.org/packages/e4/f6/339c6e707062319546598eb9827d3ca8942a3eccc610d4a54c1da7b62527/greenlet-3.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:24a496479bc8bd01c39aa6516a43c717b4cee7196573c47b1f8e1011f7c12495", size = 295994 }, + { url = "https://files.pythonhosted.org/packages/f1/72/2a251d74a596af7bb1717e891ad4275a3fd5ac06152319d7ad8c77f876af/greenlet-3.2.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:175d583f7d5ee57845591fc30d852b75b144eb44b05f38b67966ed6df05c8526", size = 629889 }, + { url = "https://files.pythonhosted.org/packages/29/2e/d7ed8bf97641bf704b6a43907c0e082cdf44d5bc026eb8e1b79283e7a719/greenlet-3.2.1-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3ecc9d33ca9428e4536ea53e79d781792cee114d2fa2695b173092bdbd8cd6d5", size = 635261 }, + { url = "https://files.pythonhosted.org/packages/1e/75/802aa27848a6fcb5e566f69c64534f572e310f0f12d41e9201a81e741551/greenlet-3.2.1-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3f56382ac4df3860ebed8ed838f268f03ddf4e459b954415534130062b16bc32", size = 632523 }, + { url = "https://files.pythonhosted.org/packages/56/09/f7c1c3bab9b4c589ad356503dd71be00935e9c4db4db516ed88fc80f1187/greenlet-3.2.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc45a7189c91c0f89aaf9d69da428ce8301b0fd66c914a499199cfb0c28420fc", size = 628816 }, + { url = "https://files.pythonhosted.org/packages/79/e0/1bb90d30b5450eac2dffeaac6b692857c4bd642c21883b79faa8fa056cf2/greenlet-3.2.1-cp313-cp313t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:51a2f49da08cff79ee42eb22f1658a2aed60c72792f0a0a95f5f0ca6d101b1fb", size = 593687 }, + { url = "https://files.pythonhosted.org/packages/c5/b5/adbe03c8b4c178add20cc716021183ae6b0326d56ba8793d7828c94286f6/greenlet-3.2.1-cp313-cp313t-musllinux_1_1_aarch64.whl", hash = "sha256:0c68bbc639359493420282d2f34fa114e992a8724481d700da0b10d10a7611b8", size = 1105754 }, + { url = "https://files.pythonhosted.org/packages/39/93/84582d7ef38dec009543ccadec6ab41079a6cbc2b8c0566bcd07bf1aaf6c/greenlet-3.2.1-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:e775176b5c203a1fa4be19f91da00fd3bff536868b77b237da3f4daa5971ae5d", size = 1125160 }, + { url = "https://files.pythonhosted.org/packages/01/e6/f9d759788518a6248684e3afeb3691f3ab0276d769b6217a1533362298c8/greenlet-3.2.1-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:d6668caf15f181c1b82fb6406f3911696975cc4c37d782e19cb7ba499e556189", size = 269897 }, +] + [[package]] name = "groovy" version = "0.1.2" @@ -3426,6 +3530,38 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/45/58/38b5afbc1a800eeea951b9285d3912613f2603bdf897a4ab0f4bd7f405fc/python_multipart-0.0.20-py3-none-any.whl", hash = "sha256:8a62d3a8335e06589fe01f2a3e178cdcc632f3fbe0d492ad9ee0ec35aab1f104", size = 24546 }, ] +[[package]] +name = "python-rapidjson" +version = "1.20" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/47/2a/2510836a65a1fc40c923393611896c3c8ad1e2f583ed0c32cf0bb48cc378/python_rapidjson-1.20.tar.gz", hash = "sha256:115f08c86d2df7543c02605e77c84727cdabc4b08310d2f097e953efeaaa73eb", size = 238158 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b1/d1/40616f40499f8f61e83135aa078a0ba7d392e7ea63c016c7cc544ecb7344/python_rapidjson-1.20-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6056fcc8caeb9b04775bf655568bba362c7670ab792c1b438671bb056db954cd", size = 230104 }, + { url = "https://files.pythonhosted.org/packages/ea/2f/d28f4da4df83cfeb60fb7b84396a9c3678a0ac615012dc234d5b962fbaaf/python_rapidjson-1.20-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:225bd4cbabfe7910261cbcebb8b811d4ff98e90cdd17c233b916c6aa71a9553f", size = 211105 }, + { url = "https://files.pythonhosted.org/packages/b3/60/ebc521afbdb626bb571a815378831f685213cb6b98ffe08176fe3191c5a3/python_rapidjson-1.20-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:026077b663acf93a3f2b1adb87282e611a30214b8ae8001b7e4863a3b978e646", size = 1650309 }, + { url = "https://files.pythonhosted.org/packages/19/da/4c375b90c54091e93a600fca06a9f3b8456b0e09050e862e998fc22b6385/python_rapidjson-1.20-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:884e1dd4c0770ed424737941af4d5dc9014995f9c33595f151af13f83ce282c3", size = 1700043 }, + { url = "https://files.pythonhosted.org/packages/bc/6e/2718413e7bc300523c5d4eaa25418059d8b17effa9aef2f2ae370493b861/python_rapidjson-1.20-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f55531c8197cb7a21a5ef0ffa46f2b8fc8c5fe7c6fd08bdbd2063ae65d2ff65", size = 1700523 }, + { url = "https://files.pythonhosted.org/packages/32/fe/d96e996f9c5140d3ce93d440f871a1b336f1c14fae27b64d4872fc58d45d/python_rapidjson-1.20-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c60121d155562dc694c05ed7df4e39e42ee1d3adff2a060c64a004498e6451f7", size = 1598383 }, + { url = "https://files.pythonhosted.org/packages/46/32/ef3a381641b803e1b67c9b9c360d161b650620605768652e704fb35ad2b9/python_rapidjson-1.20-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3a6620eed0b04196f37fab7048c1d672d03391bb29d7f09ee8fee8dea33f11f4", size = 2454134 }, + { url = "https://files.pythonhosted.org/packages/2f/50/771826d3f217b7c597f14df0dfa943d9e6f2f14749d974de4402f56ce39a/python_rapidjson-1.20-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:ddb63eff401ce7cf20cdd5e21942fc23fbe0e1dc1d96d7ae838645fb1f74fb47", size = 2585576 }, + { url = "https://files.pythonhosted.org/packages/64/95/f3e7ed53c9ab27a99c876c42b7d1994312e6fd2c2d8131ce849bd4275be8/python_rapidjson-1.20-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:05e28c3dbb4a0d74ec13af9668ef2b9f302edf83cf7ce1d8316a95364720eec0", size = 2599382 }, + { url = "https://files.pythonhosted.org/packages/bc/4c/34778932d0145fdc7087274cd4c0fa421a96acbc96bf9860cbdf3e389dcd/python_rapidjson-1.20-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b733978ecd84fc5df9a778ce821dc1f3113f7bfc2493cac0bb17efb4ae0bb8fa", size = 2537066 }, + { url = "https://files.pythonhosted.org/packages/50/16/dfef47ec507d5a5d00281b8db8526d5c36b715afeeae0ceeef4030f1640f/python_rapidjson-1.20-cp312-cp312-win32.whl", hash = "sha256:d87041448cec00e2db5d858625a76dc1b59eef6691a039acff6d92ad8581cfc1", size = 128358 }, + { url = "https://files.pythonhosted.org/packages/bc/97/42a550a79ab90ab37fcd8b519cd71bba4b96b85679218100d63b437770c0/python_rapidjson-1.20-cp312-cp312-win_amd64.whl", hash = "sha256:5d3be149ce5475f9605f01240487541057792abad94d3fd0cd56af363cf5a4dc", size = 149067 }, + { url = "https://files.pythonhosted.org/packages/18/04/47d9d10c3fa6e57af9462792088187605a07d88ad6f6f2e193fb01eff0fc/python_rapidjson-1.20-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:daee815b4c20ca6e4dbc6bde373dd3f65b53813d775f1c94b765b33b402513a7", size = 229315 }, + { url = "https://files.pythonhosted.org/packages/9a/3a/0c4e0af51d7356d9efdef1bf1785d9d9f9e0789a7d2844cc3e9b35ef383f/python_rapidjson-1.20-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:083df379c769b30f9bc40041c91fd9d8f7bb8ca2b3c7170258842aced2098e05", size = 211111 }, + { url = "https://files.pythonhosted.org/packages/83/e1/e253de9a774d021f9a6947f845628fae8237f441c63198e8a72e5906d31f/python_rapidjson-1.20-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b9399ad75a2e3377f9e6208caabe73eb9354cd01b732407475ccadcd42c577df", size = 1650131 }, + { url = "https://files.pythonhosted.org/packages/3e/93/8f723c7f7be055086d6bec2ba9e5ef13e749c3fb3ad5a3dc1d740acee889/python_rapidjson-1.20-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:599ab208ccf6172d6cfac1abe048c837e62612f91f97d198e32773c45346a0b4", size = 1699873 }, + { url = "https://files.pythonhosted.org/packages/7d/2e/eb7255601b81a5b70f2bff05caab136e191b66825c16db3e7db1bdaa8314/python_rapidjson-1.20-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf3c0e2a5b97b0d07311f15f0dce4434e43dec865c3794ad1b10d968460fd665", size = 1700484 }, + { url = "https://files.pythonhosted.org/packages/90/54/23d8b595dd4fdbdaa6c5f723a4df7a7be78aa702aa0b6dac6c964e6e6d30/python_rapidjson-1.20-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e8064b8edb57ddd9e3ffa539cf2ec2f03515751fb0698b40ba5cb66a2123af19", size = 1598344 }, + { url = "https://files.pythonhosted.org/packages/3d/3a/3628e199a826e7bc598633ce895516981602ab1d8fce76359005f90ca488/python_rapidjson-1.20-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:bc79d7f00f7538e027960ca6bcd1e03ed99fcf660d4d882d1c22f641155d0db0", size = 2454206 }, + { url = "https://files.pythonhosted.org/packages/ed/19/eef8629f73b1af21fa778d140e68e72076fe5746357426d6716a0c411dd2/python_rapidjson-1.20-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:87aa0b01b8c20984844f1440b8ff6bdb32de911a1750fed344b9daed33b4b52b", size = 2585553 }, + { url = "https://files.pythonhosted.org/packages/d8/9d/217e56c74a65cfaf4441b26b6206b924b41fb339f98776a74e60dd287b46/python_rapidjson-1.20-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:4099cb9eae8a0ce19c09e02729eb6d69d5180424f13a2641a6c407d053e47a82", size = 2599513 }, + { url = "https://files.pythonhosted.org/packages/54/f6/4d40189f14e4fa5526a91aad9944864c8a4eebc0257e0314a331f3c64170/python_rapidjson-1.20-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4c680cd2b4de760ff6875de71fe6a87bd610aa116593d62e4f81a563be86ae18", size = 2537192 }, + { url = "https://files.pythonhosted.org/packages/ee/30/f3f40abfd8d7f0586b88ccfcd747f2e227fe589c16fbb485b1e238d8e641/python_rapidjson-1.20-cp313-cp313-win32.whl", hash = "sha256:9e431a7afc77aa874fed537c9f6bf5fcecaef124ebeae2a2379d3b9e9adce74b", size = 128362 }, + { url = "https://files.pythonhosted.org/packages/94/df/7126352e55cb72a5ca99630bd44ffb11bbf61ee35f4e1f34d203a77597c5/python_rapidjson-1.20-cp313-cp313-win_amd64.whl", hash = "sha256:7444bc7e6a04c03d6ed748b5dab0798fa2b3f2b303be8c38d3af405b2cac6d63", size = 149072 }, +] + [[package]] name = "pytz" version = "2025.1" @@ -3580,6 +3716,8 @@ dependencies = [ { name = "torch" }, { name = "torchvision" }, { name = "transformers" }, + { name = "tritonclient", extra = ["all"] }, + { name = "ultralytics" }, { name = "unsloth" }, { name = "verifiers" }, { name = "vllm" }, @@ -3619,6 +3757,8 @@ requires-dist = [ { name = "torch", specifier = "==2.5.1" }, { name = "torchvision", specifier = "==0.20.1" }, { name = "transformers", specifier = "==4.49.0" }, + { name = "tritonclient", extras = ["all"], specifier = ">=2.51.0" }, + { name = "ultralytics", specifier = ">=8.3.120" }, { name = "unsloth", specifier = ">=2025.3.19" }, { name = "verifiers", editable = "../verifiers" }, { name = "vllm", specifier = "==0.7.3" }, @@ -4032,6 +4172,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0a/c8/b3f566db71461cabd4b2d5b39bcc24a7e1c119535c8361f81426be39bb47/scipy-1.15.2-cp313-cp313t-win_amd64.whl", hash = "sha256:fe8a9eb875d430d81755472c5ba75e84acc980e4a8f6204d402849234d3017db", size = 40477705 }, ] +[[package]] +name = "seaborn" +version = "0.13.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "matplotlib" }, + { name = "numpy" }, + { name = "pandas" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/86/59/a451d7420a77ab0b98f7affa3a1d78a313d2f7281a57afb1a34bae8ab412/seaborn-0.13.2.tar.gz", hash = "sha256:93e60a40988f4d65e9f4885df477e2fdaff6b73a9ded434c1ab356dd57eefff7", size = 1457696 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/83/11/00d3c3dfc25ad54e731d91449895a79e4bf2384dc3ac01809010ba88f6d5/seaborn-0.13.2-py3-none-any.whl", hash = "sha256:636f8336facf092165e27924f223d3c62ca560b1f2bb5dff7ab7fad265361987", size = 294914 }, +] + [[package]] name = "semantic-version" version = "2.10.0" @@ -4589,6 +4743,33 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f5/75/aac76f24dd17eb2245973ec1dd995759ce85ed91e5bb045fabb3c83ab1d6/triton_windows-3.2.0.post17-cp313-cp313-win_amd64.whl", hash = "sha256:539dd7ba8b7cc238930c1f4cb6e7819c22d1b8798fde361b78115b0fdb98a147", size = 40039344 }, ] +[[package]] +name = "tritonclient" +version = "2.51.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "python-rapidjson" }, + { name = "urllib3" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/23/a6/301bd2f431346adac05ff3c062bbcec0a93b567f1d3ef0d3ccf353a5bcd6/tritonclient-2.51.0-py3-none-any.whl", hash = "sha256:eef99681b0a18ee72808d887d2324a38a81fa1250924e595db46256b83f13668", size = 98012 }, + { url = "https://files.pythonhosted.org/packages/87/0b/57eae443655212c73ae3586b280e1b1c81ba1668afc94109b1efac8c23c4/tritonclient-2.51.0-py3-none-manylinux1_x86_64.whl", hash = "sha256:c485bb0123bdf310f90bc8b03d3489b28df2ffed55b30c7eee0b795b48113d52", size = 13956700 }, + { url = "https://files.pythonhosted.org/packages/70/bd/eb64fe810b8728f5f7936fe4d156062847d850c55923289dad8e281ee3d6/tritonclient-2.51.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:ee6f5a508409f6c95069f4d77d34e97bef84fb4a1aedb5d82ad0ad311ad128d5", size = 13325829 }, +] + +[package.optional-dependencies] +all = [ + { name = "aiohttp" }, + { name = "cuda-python" }, + { name = "geventhttpclient" }, + { name = "grpcio" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "protobuf" }, + { name = "python-rapidjson" }, +] + [[package]] name = "trl" version = "0.15.0.dev0" @@ -4740,6 +4921,45 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d7/72/6cb6728e2738c05bbe9bd522d6fc79f86b9a28402f38663e85a28fddd4a0/ujson-5.10.0-cp313-cp313-win_amd64.whl", hash = "sha256:4573fd1695932d4f619928fd09d5d03d917274381649ade4328091ceca175539", size = 42212 }, ] +[[package]] +name = "ultralytics" +version = "8.3.120" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "matplotlib" }, + { name = "numpy" }, + { name = "opencv-python" }, + { name = "pandas" }, + { name = "pillow" }, + { name = "psutil" }, + { name = "py-cpuinfo" }, + { name = "pyyaml" }, + { name = "requests" }, + { name = "scipy" }, + { name = "seaborn" }, + { name = "torch" }, + { name = "torchvision" }, + { name = "tqdm" }, + { name = "ultralytics-thop" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1d/c8/921621be09aed3c498d0db807261a9737d04efe84f8cb729de3874dfe2d8/ultralytics-8.3.120.tar.gz", hash = "sha256:5b709c2a66fc1580dfbf8d6be56727b941d0d3d5906582f9613e72b90b486e53", size = 863199 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f5/bc/3f390c44ef15deb1af6235b349b6953c6409f45f02d49b1e22b6f940871c/ultralytics-8.3.120-py3-none-any.whl", hash = "sha256:7ac3bf90850eb7b943c3f1c8451eca271f8277c51d9af9cb34933c7a23cab9ad", size = 1004601 }, +] + +[[package]] +name = "ultralytics-thop" +version = "2.0.14" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "torch" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c2/d8/e43a8bfcb03ff036119d098a7ea27be9f0adb715543ed6bd83b16cda83dc/ultralytics_thop-2.0.14.tar.gz", hash = "sha256:38ebfdbd3cd8dafdc3d26ec3a7d4f604fbeed5e69a74e61a48060b39736c945c", size = 28793 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a6/10/251f036b4c5d77249f9a119cc89dafe8745dc1ad1f1a5f06b6a3988ca454/ultralytics_thop-2.0.14-py3-none-any.whl", hash = "sha256:720b421e2459179fee21ec8f730d242a20774cd4b0a00a58d02351a39ec3881c", size = 26517 }, +] + [[package]] name = "unsloth" version = "2025.3.19" @@ -5283,3 +5503,38 @@ sdist = { url = "https://files.pythonhosted.org/packages/3f/50/bad581df71744867e wheels = [ { url = "https://files.pythonhosted.org/packages/b7/1a/7e4798e9339adc931158c9d69ecc34f5e6791489d469f5e50ec15e35f458/zipp-3.21.0-py3-none-any.whl", hash = "sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931", size = 9630 }, ] + +[[package]] +name = "zope-event" +version = "5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "setuptools" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/46/c2/427f1867bb96555d1d34342f1dd97f8c420966ab564d58d18469a1db8736/zope.event-5.0.tar.gz", hash = "sha256:bac440d8d9891b4068e2b5a2c5e2c9765a9df762944bda6955f96bb9b91e67cd", size = 17350 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fe/42/f8dbc2b9ad59e927940325a22d6d3931d630c3644dae7e2369ef5d9ba230/zope.event-5.0-py3-none-any.whl", hash = "sha256:2832e95014f4db26c47a13fdaef84cef2f4df37e66b59d8f1f4a8f319a632c26", size = 6824 }, +] + +[[package]] +name = "zope-interface" +version = "7.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "setuptools" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/30/93/9210e7606be57a2dfc6277ac97dcc864fd8d39f142ca194fdc186d596fda/zope.interface-7.2.tar.gz", hash = "sha256:8b49f1a3d1ee4cdaf5b32d2e738362c7f5e40ac8b46dd7d1a65e82a4872728fe", size = 252960 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/68/0b/c7516bc3bad144c2496f355e35bd699443b82e9437aa02d9867653203b4a/zope.interface-7.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:086ee2f51eaef1e4a52bd7d3111a0404081dadae87f84c0ad4ce2649d4f708b7", size = 208959 }, + { url = "https://files.pythonhosted.org/packages/a2/e9/1463036df1f78ff8c45a02642a7bf6931ae4a38a4acd6a8e07c128e387a7/zope.interface-7.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:21328fcc9d5b80768bf051faa35ab98fb979080c18e6f84ab3f27ce703bce465", size = 209357 }, + { url = "https://files.pythonhosted.org/packages/07/a8/106ca4c2add440728e382f1b16c7d886563602487bdd90004788d45eb310/zope.interface-7.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f6dd02ec01f4468da0f234da9d9c8545c5412fef80bc590cc51d8dd084138a89", size = 264235 }, + { url = "https://files.pythonhosted.org/packages/fc/ca/57286866285f4b8a4634c12ca1957c24bdac06eae28fd4a3a578e30cf906/zope.interface-7.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8e7da17f53e25d1a3bde5da4601e026adc9e8071f9f6f936d0fe3fe84ace6d54", size = 259253 }, + { url = "https://files.pythonhosted.org/packages/96/08/2103587ebc989b455cf05e858e7fbdfeedfc3373358320e9c513428290b1/zope.interface-7.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cab15ff4832580aa440dc9790b8a6128abd0b88b7ee4dd56abacbc52f212209d", size = 264702 }, + { url = "https://files.pythonhosted.org/packages/5f/c7/3c67562e03b3752ba4ab6b23355f15a58ac2d023a6ef763caaca430f91f2/zope.interface-7.2-cp312-cp312-win_amd64.whl", hash = "sha256:29caad142a2355ce7cfea48725aa8bcf0067e2b5cc63fcf5cd9f97ad12d6afb5", size = 212466 }, + { url = "https://files.pythonhosted.org/packages/c6/3b/e309d731712c1a1866d61b5356a069dd44e5b01e394b6cb49848fa2efbff/zope.interface-7.2-cp313-cp313-macosx_10_9_x86_64.whl", hash = "sha256:3e0350b51e88658d5ad126c6a57502b19d5f559f6cb0a628e3dc90442b53dd98", size = 208961 }, + { url = "https://files.pythonhosted.org/packages/49/65/78e7cebca6be07c8fc4032bfbb123e500d60efdf7b86727bb8a071992108/zope.interface-7.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:15398c000c094b8855d7d74f4fdc9e73aa02d4d0d5c775acdef98cdb1119768d", size = 209356 }, + { url = "https://files.pythonhosted.org/packages/11/b1/627384b745310d082d29e3695db5f5a9188186676912c14b61a78bbc6afe/zope.interface-7.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:802176a9f99bd8cc276dcd3b8512808716492f6f557c11196d42e26c01a69a4c", size = 264196 }, + { url = "https://files.pythonhosted.org/packages/b8/f6/54548df6dc73e30ac6c8a7ff1da73ac9007ba38f866397091d5a82237bd3/zope.interface-7.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eb23f58a446a7f09db85eda09521a498e109f137b85fb278edb2e34841055398", size = 259237 }, + { url = "https://files.pythonhosted.org/packages/b6/66/ac05b741c2129fdf668b85631d2268421c5cd1a9ff99be1674371139d665/zope.interface-7.2-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a71a5b541078d0ebe373a81a3b7e71432c61d12e660f1d67896ca62d9628045b", size = 264696 }, + { url = "https://files.pythonhosted.org/packages/0a/2f/1bccc6f4cc882662162a1158cda1a7f616add2ffe322b28c99cb031b4ffc/zope.interface-7.2-cp313-cp313-win_amd64.whl", hash = "sha256:4893395d5dd2ba655c38ceb13014fd65667740f09fa5bb01caa1e6284e48c0cd", size = 212472 }, +] From 610aeb984e004c56c9bcb71bb4adcf3e34fe12ba Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Tue, 29 Apr 2025 16:48:15 -0700 Subject: [PATCH 02/24] =?UTF-8?q?Re-introduce=20object=20detection=20tool,?= =?UTF-8?q?=20but=20this=20time=20it=20is=20a=20YOLO=20instead=20of=20open?= =?UTF-8?q?=20vocabulary=20(reducing=20degrees=20of=20freedom=20of=20the?= =?UTF-8?q?=20tool)=20Explicitly=20tell=20the=20model=20that=20it=20can=20?= =?UTF-8?q?call=20a=20tool=20but=20it=20does=20not=20have=20to.=20Explicit?= =?UTF-8?q?ly=20tell=20the=20model=20it=20needs=20to=20consider=20all=204?= =?UTF-8?q?=20options=20in=20the=20user=20prompt.=20Failures=20often=20loo?= =?UTF-8?q?k=20like=20torpedos,=20so=20maybe=20this=20helps=20prevent=20th?= =?UTF-8?q?at=3F=20Doing=20this=20in=20the=20bootstrap=20prompt=20didn?= =?UTF-8?q?=E2=80=99t=20help,=20but=20I=20think=20the=20IFT=20model=20?= =?UTF-8?q?=E2=80=9Clistens=E2=80=9D=20to=20the=20user=20more=20strongly.?= =?UTF-8?q?=20Reward=20schedule=20for=20tool=20use=20reward.=20The=20model?= =?UTF-8?q?=20gets=20200=20gradient=20updates=20with=20a=20tool=20use=20re?= =?UTF-8?q?ward.=20The=20reward=20decays=20linearly=20between=20steps=200?= =?UTF-8?q?=20and=20200.=20Then=20it=20stays=20at=200.=20Point=20tool=20at?= =?UTF-8?q?=20b0=20save=20more=20frequently?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../aok_vqa/aok_vqa_mc_tool_use_r1.py | 10 ++++---- .../tool_use_aokvqa_env/tool_use_aok_train.py | 8 +++---- .../tool_use_aokvqa_env.py | 24 ++++++++++++++----- src/r1_vlm/tools/object_detection.py | 5 ++-- src/r1_vlm/tools/tool_prompts.py | 4 +++- 5 files changed, 32 insertions(+), 19 deletions(-) diff --git a/src/r1_vlm/datasets/aok_vqa/aok_vqa_mc_tool_use_r1.py b/src/r1_vlm/datasets/aok_vqa/aok_vqa_mc_tool_use_r1.py index 12923ce9..56e8c599 100644 --- a/src/r1_vlm/datasets/aok_vqa/aok_vqa_mc_tool_use_r1.py +++ b/src/r1_vlm/datasets/aok_vqa/aok_vqa_mc_tool_use_r1.py @@ -28,7 +28,7 @@ def generate_r1_messages(example): system_prompt = "REPLACED WITH TOOLS SYSTEM PROMPT" - choices_str = "These are the possible answers, you must choose one: " + choices_str = "Possible answers: " for i, choice in enumerate(choices): if i == len(choices) - 1: choices_str += f"or {choice}." @@ -36,11 +36,11 @@ def generate_r1_messages(example): choices_str += f"{choice}, " instruction = f""" - {question} + Question: {question} - {choices_str} + {choices_str} You must choose one to answer the question and place in tags. - You must inspect the input image and gather visual evidence. The image size is {image_size}. + You must inspect the input image to gather visual evidence. After you've collected evidence, combine that with your knowledge of the world to answer the question. You must consider all 4 possible answers when thinking through your reasoning. The image size is {image_size}. """ r1_messages = [ @@ -66,7 +66,7 @@ def generate_r1_messages(example): "content": [ { "type": "text", - "text": "\n I'll collect as much visual evidence as possible from the image. First, I'll consider what region of the image to zoom in on to get the most information. Then, I'll review and consider the four possible answers. Then, I'll select the most likely answer based on the evidence and my knowledge of the world.", + "text": "\n I'll collect as much visual evidence as possible from the image. Then, I'll consider the four possible answers. Finally, I'll select the most likely answer based on the evidence and my knowledge of the world.", } ], }, diff --git a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py index dee5a949..6f19c408 100644 --- a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py +++ b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py @@ -4,11 +4,11 @@ from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl from peft import LoraConfig, TaskType from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration - -from r1_vlm.environments.tool_use_aokvqa_env.tool_use_aokvqa_env import AOKVQAToolEnv from trl import GRPOConfig, ModelConfig from trl.trainer.qwen_grpo_trainer import QwenGRPOTrainer +from r1_vlm.environments.tool_use_aokvqa_env.tool_use_aokvqa_env import AOKVQAToolEnv + os.environ["WANDB_ENTITY"] = "groundlightai" os.environ["WANDB_PROJECT"] = "tool-use-aokvqa-env" @@ -115,8 +115,8 @@ def train(): warmup_steps=10, logging_steps=1, save_steps=50, - save_total_limit=5, - num_train_epochs=10, + save_total_limit=10, + num_train_epochs=1, per_device_train_batch_size=2, num_generations=6, gradient_accumulation_steps=4, diff --git a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py index 67f46c95..d73b03fb 100644 --- a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py +++ b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py @@ -2,17 +2,28 @@ from datasets import Dataset from transformers import AutoProcessor +from trl.trainer.grpo_trainer import RewardFunc +from verifiers.parsers import XMLParser from r1_vlm.datasets.aok_vqa.aok_vqa_mc_tool_use_r1 import ( create_r1_aok_vqa_tool_use_dataset, ) from r1_vlm.datasets.utils import preprocess_r1_dataset from r1_vlm.environments.multistep_vision_env import MultistepVisionEnv +from r1_vlm.environments.reward_schedules import create_linear_decay_schedule from r1_vlm.environments.tool_vision_env import ToolArgParser, ToolVisionEnv -from r1_vlm.tools.tool_prompts import SINGLE_TOOL_PROMPT_TEMPLATE +from r1_vlm.tools.object_detection import ( + ObjectDetectionTool, + detect_objects, + parse_detect_objects_args, + set_object_detection_tool, +) +from r1_vlm.tools.tool_prompts import SINGLE_OPTIONAL_TOOL_PROMPT_TEMPLATE from r1_vlm.tools.zoom import parse_zoom_args, zoom -from trl.trainer.grpo_trainer import RewardFunc -from verifiers.parsers import XMLParser + +# This is a global variable that is used to store the object detection tool. It is accessed by the detect_objects function. +od_tool = ObjectDetectionTool() +set_object_detection_tool(od_tool) class AOKVQAToolEnv(ToolVisionEnv): @@ -21,11 +32,11 @@ def __init__( processing_class: AutoProcessor, dataset_name: str = "Groundlight/real-iad-toy-brick-tool-use-r1", tools_with_parsers: list[tuple[Callable, ToolArgParser]] = [ - # (detect_objects, parse_detect_objects_args), + (detect_objects, parse_detect_objects_args), (zoom, parse_zoom_args), ], max_steps: int = 3, - tool_prompt_template: str = SINGLE_TOOL_PROMPT_TEMPLATE, + tool_prompt_template: str = SINGLE_OPTIONAL_TOOL_PROMPT_TEMPLATE, ): super().__init__( processing_class=processing_class, @@ -103,7 +114,8 @@ def get_reward_weights(self) -> list[float]: schedule = 0.1 reward_weights.append(schedule) elif reward_function.__name__ == "tool_execution_reward_func": - schedule = 0.1 + # linearly decay from 0.1 to 0.0 over 200 global steps (200 gradient updates) + schedule = create_linear_decay_schedule(0.1, 0.0, 200) reward_weights.append(schedule) elif reward_function.__name__ == "correct_answer_reward_func": diff --git a/src/r1_vlm/tools/object_detection.py b/src/r1_vlm/tools/object_detection.py index 83fc030d..ddb17872 100644 --- a/src/r1_vlm/tools/object_detection.py +++ b/src/r1_vlm/tools/object_detection.py @@ -22,8 +22,7 @@ class ObjectDetectionTool: def __init__(self): - # url = f"{API_IP}:{API_PORT}/yolo" - url = "localhost:8000/yolo" + url = f"{API_IP}:{API_PORT}/yolo" self.triton_client = InferenceServerClient(url=url, verbose=False, ssl=False) # Wait until model is ready @@ -33,7 +32,7 @@ def __init__(self): break time.sleep(1) - self.model = YOLO("http://localhost:8000/yolo", task="detect") + self.model = YOLO(f"http://{url}", task="detect") def detect_objects(self, image: Image.Image) -> list[dict]: result = self.model(image, conf=0.3)[0] diff --git a/src/r1_vlm/tools/tool_prompts.py b/src/r1_vlm/tools/tool_prompts.py index ea300bd1..c8178258 100644 --- a/src/r1_vlm/tools/tool_prompts.py +++ b/src/r1_vlm/tools/tool_prompts.py @@ -37,7 +37,7 @@ {tool_descriptions} For each step: -1. Start by thinking through your reasoning inside tags. Then either return your answer inside tags, or use a tool inside tags. You are not required to use a tool if you can answer the question without one. +1. Start by thinking through your reasoning inside tags. Then either return your answer inside tags, or use a tool inside tags. 2. If needed, use a tool by writing its arguments inside tags. Use one line for each argument in the format 'key: value'. The first line must be 'name: '. 3. You will see the tool's output inside tags. 4. Continue until you can give the final answer inside tags. @@ -45,4 +45,6 @@ Tools expect specific arguments. Follow the examples carefully for the required keys and expected value formats. Do not make up tools or arguments that aren't listed. If the tool includes the argument "image_name", you must provide it the name of an image from this conversation. + +As a reminder, you are not required to use a tool if you can answer the user's question without one. """ From af03c8c48b4d273a1f150aaa93b8fa2e5e1c7026 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Tue, 29 Apr 2025 17:48:14 -0700 Subject: [PATCH 03/24] ready to train --- .../environments/tool_use_aokvqa_env/tool_use_aok_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py index 6f19c408..4fc9cfe5 100644 --- a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py +++ b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py @@ -106,7 +106,7 @@ def train(): training_args = GRPOConfig( model_init_kwargs=model_config, # save path on the runpod instance - output_dir="vlm-r1-new-zoom-tool-reward-independent-oversampling", + output_dir="vlm-r1-new-od-tool-reward-schedule-for-tools-apr-29", # increase learning rate for PEFT - 1e-4 learning_rate=1e-4 if peft_config is not None else 1e-6, max_grad_norm=1.0, From 76542f5eae88409191097a0f7cccd20d8c09519e Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Tue, 29 Apr 2025 18:11:01 -0700 Subject: [PATCH 04/24] start server was only local --- tool_server/pyproject.toml | 1 + tool_server/start_server.py | 2 +- tool_server/uv.lock | 11 +++++++++++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/tool_server/pyproject.toml b/tool_server/pyproject.toml index b4241c49..371551e9 100644 --- a/tool_server/pyproject.toml +++ b/tool_server/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "ipdb>=0.13.13", "numpy>=1.26.4", "pillow>=11.2.1", + "python-dotenv>=1.1.0", ] [tool.uv.sources] diff --git a/tool_server/start_server.py b/tool_server/start_server.py index 1a7ebc04..6134e0de 100644 --- a/tool_server/start_server.py +++ b/tool_server/start_server.py @@ -20,7 +20,7 @@ container_id = ( subprocess.check_output( # Use the absolute path here - f"docker run -d --gpus 0 -v {absolute_triton_repo_path}:/models -p 8000:8000 {tag} tritonserver --model-repository=/models", + f"docker run -d --gpus 0 -v {absolute_triton_repo_path}:/models -p 0.0.0.0:8000:8000 {tag} tritonserver --model-repository=/models", shell=True, ) .decode("utf-8") diff --git a/tool_server/uv.lock b/tool_server/uv.lock index a6125dcf..2929c3fa 100644 --- a/tool_server/uv.lock +++ b/tool_server/uv.lock @@ -1118,6 +1118,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892 }, ] +[[package]] +name = "python-dotenv" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/88/2c/7bb1416c5620485aa793f2de31d3df393d3686aa8a8506d11e10e13c5baf/python_dotenv-1.1.0.tar.gz", hash = "sha256:41f90bc6f5f177fb41f53e87666db362025010eb28f60a01c9143bfa33a2b2d5", size = 39920 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/18/98a99ad95133c6a6e2005fe89faedf294a748bd5dc803008059409ac9b1e/python_dotenv-1.1.0-py3-none-any.whl", hash = "sha256:d7c01d9e2293916c18baf562d95698754b0dbbb5e74d457c45d4f6561fb9d55d", size = 20256 }, +] + [[package]] name = "python-rapidjson" version = "1.20" @@ -1381,6 +1390,7 @@ dependencies = [ { name = "onnxruntime-gpu" }, { name = "onnxslim" }, { name = "pillow" }, + { name = "python-dotenv" }, { name = "tensorrt" }, { name = "tritonclient", extra = ["all"] }, { name = "ultralytics" }, @@ -1398,6 +1408,7 @@ requires-dist = [ { name = "onnxruntime-gpu", specifier = ">=1.21.1" }, { name = "onnxslim", specifier = ">=0.1.50" }, { name = "pillow", specifier = ">=11.2.1" }, + { name = "python-dotenv", specifier = ">=1.1.0" }, { name = "tensorrt", specifier = ">=10.9.0.34" }, { name = "tritonclient", extras = ["all"], specifier = ">=2.56.0" }, { name = "ultralytics", specifier = ">=8.3.112" }, From 2d6b792c9321d67debaed2160778a4d80ed9ffcf Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Tue, 29 Apr 2025 18:23:56 -0700 Subject: [PATCH 05/24] working now --- src/r1_vlm/tools/object_detection.py | 13 +++--- tool_server/infer.py | 67 +++++++++------------------- 2 files changed, 28 insertions(+), 52 deletions(-) diff --git a/src/r1_vlm/tools/object_detection.py b/src/r1_vlm/tools/object_detection.py index ddb17872..cfaad25a 100644 --- a/src/r1_vlm/tools/object_detection.py +++ b/src/r1_vlm/tools/object_detection.py @@ -162,9 +162,12 @@ def parse_detect_objects_args(raw_args: RawToolArgs) -> TypedToolArgs: image = Image.open( "/millcreek/home/sunil/r1_vlm_bumbershoot0/r1_vlm/src/r1_vlm/tools/cars.jpeg" ) - detections = detect_objects(image_name="input_image", images={"input_image": image}) + for i in range(10): + detections = detect_objects( + image_name="input_image", images={"input_image": image} + ) - image_data = detections["image_data"] - imgcat(image_data) - text_data = detections["text_data"] - print(text_data) + image_data = detections["image_data"] + imgcat(image_data) + text_data = detections["text_data"] + print(text_data) diff --git a/tool_server/infer.py b/tool_server/infer.py index 45aff7a2..aa3c7264 100644 --- a/tool_server/infer.py +++ b/tool_server/infer.py @@ -1,74 +1,47 @@ import contextlib +import os import time import numpy as np +from dotenv import load_dotenv from imgcat import imgcat from PIL import Image from tritonclient.http import InferenceServerClient from ultralytics import YOLO +load_dotenv() + +API_IP = str(os.getenv("API_IP")) +API_PORT = int(os.getenv("API_PORT")) +url = f"{API_IP}:{API_PORT}/yolo" +print(url) + # Wait for the Triton server to start -triton_client = InferenceServerClient(url="localhost:8000", verbose=False, ssl=False) +triton_client = InferenceServerClient(url=url, verbose=False, ssl=False) # Wait until model is ready for _ in range(10): with contextlib.suppress(Exception): + print("checking if model is ready") assert triton_client.is_model_ready("yolo") + print("model is ready") break time.sleep(1) - +print("loading model") # Load the Triton Server model -model = YOLO("http://localhost:8000/yolo", task="detect") +model = YOLO(f"http://{url}", task="detect") # load the image via PIL img = Image.open( "/millcreek/home/sunil/r1_vlm_bumbershoot0/r1_vlm/tool_server/cars.jpeg" ) -# create 10 noisy copies and their crops -test_images = [] -crop_ratios = [(2, 1), (1, 1), (1, 2)] - -for i in range(10): - # Create noisy image - arr = np.array(img) - noise = np.random.normal(0, 5, arr.shape) - noisy_arr = np.clip(arr + noise, 0, 255).astype(np.uint8) - noisy_img = Image.fromarray(noisy_arr) - - # Create crops for this noisy image - img_w, img_h = noisy_img.size - for w_ratio, h_ratio in crop_ratios: - if img_w / img_h > w_ratio / h_ratio: - crop_h = img_h - crop_w = int(crop_h * w_ratio / h_ratio) - else: - crop_w = img_w - crop_h = int(crop_w * h_ratio / w_ratio) - - x0 = np.random.randint(0, img_w - crop_w + 1) - y0 = np.random.randint(0, img_h - crop_h + 1) - cropped = noisy_img.crop((x0, y0, x0 + crop_w, y0 + crop_h)) - test_images.append( - {"image": cropped, "ratio": f"{w_ratio}:{h_ratio}", "noise_id": i} - ) -speeds = [] - -# run inference on each variant -for test_case in test_images: - start = time.time() - results = model(test_case["image"]) # Pass the cropped image - end = time.time() - speeds.append(end - start) - print( - f"Noise #{test_case['noise_id']}, Aspect ratio {test_case['ratio']} – time taken: {end - start} seconds" - ) - # Convert the cropped image to numpy for visualization - vis_img = np.array(test_case["image"]) - # Plot directly on the cropped image - plotted = results[0].plot(img=vis_img) - imgcat(Image.fromarray(plotted)) +results = model(img) # Pass the cropped image -print(speeds) +# Convert the cropped image to numpy for visualization +vis_img = np.array(img) +# Plot directly on the cropped image +plotted = results[0].plot(img=vis_img) +imgcat(Image.fromarray(plotted)) From 41dceb640387f835840531ead0c5b64a7a7a5403 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Tue, 29 Apr 2025 18:56:02 -0700 Subject: [PATCH 06/24] fix threading issues and bgr to rgb --- src/r1_vlm/tools/object_detection.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/r1_vlm/tools/object_detection.py b/src/r1_vlm/tools/object_detection.py index cfaad25a..4722cbb6 100644 --- a/src/r1_vlm/tools/object_detection.py +++ b/src/r1_vlm/tools/object_detection.py @@ -22,8 +22,10 @@ class ObjectDetectionTool: def __init__(self): - url = f"{API_IP}:{API_PORT}/yolo" - self.triton_client = InferenceServerClient(url=url, verbose=False, ssl=False) + self.url = f"{API_IP}:{API_PORT}/yolo" + self.triton_client = InferenceServerClient( + url=self.url, verbose=False, ssl=False + ) # Wait until model is ready for _ in range(10): @@ -32,10 +34,12 @@ def __init__(self): break time.sleep(1) - self.model = YOLO(f"http://{url}", task="detect") + self.model = YOLO(f"http://{self.url}", task="detect") def detect_objects(self, image: Image.Image) -> list[dict]: - result = self.model(image, conf=0.3)[0] + local_model = YOLO(f"http://{self.url}", task="detect") + + result = local_model(image, conf=0.3)[0] boxes = [[int(round(x)) for x in box] for box in result.boxes.xyxy.tolist()] labels = [result.names[int(cls)] for cls in result.boxes.cls.tolist()] @@ -55,7 +59,8 @@ def detect_objects(self, image: Image.Image) -> list[dict]: dets_string += "\n" annotated_image = result.plot(conf=False, labels=True) - + # convert to rgb and then to PIL image + annotated_image = Image.fromarray(annotated_image[..., ::-1]) return {"text_data": dets_string, "image_data": annotated_image} From 2bedbacae87084fd39c6fbb17609d7d079dbd4d7 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Tue, 29 Apr 2025 19:41:04 -0700 Subject: [PATCH 07/24] up the weight to 1.0. --- .../environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py index d73b03fb..5e0ee177 100644 --- a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py +++ b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py @@ -114,8 +114,8 @@ def get_reward_weights(self) -> list[float]: schedule = 0.1 reward_weights.append(schedule) elif reward_function.__name__ == "tool_execution_reward_func": - # linearly decay from 0.1 to 0.0 over 200 global steps (200 gradient updates) - schedule = create_linear_decay_schedule(0.1, 0.0, 200) + # linearly decay from 1.0 to 0.0 over 200 global steps (200 gradient updates) + schedule = create_linear_decay_schedule(1.0, 0.0, 200) reward_weights.append(schedule) elif reward_function.__name__ == "correct_answer_reward_func": From 12fec83381035a794689d3d313a5cd772e4a6234 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Tue, 29 Apr 2025 23:51:11 -0700 Subject: [PATCH 08/24] setup to restart the run --- .../tool_use_aokvqa_env/tool_use_aok_train.py | 8 ++++-- .../tool_use_aokvqa_env.py | 4 +-- src/r1_vlm/tools/object_detection.py | 25 ++++++++++++++++--- 3 files changed, 29 insertions(+), 8 deletions(-) diff --git a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py index 4fc9cfe5..b218a210 100644 --- a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py +++ b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py @@ -88,8 +88,12 @@ def find_target_linear_names( def train(): + checkpoint = "/millcreek/home/sunil/r1_vlm/vlm-r1-new-od-tool-reward-schedule-for-tools-apr-29/checkpoint-200" + model, peft_config, processor, model_config, gradient_checkpointing = ( - load_model_and_processor(gradient_checkpointing=True, use_peft=False) + load_model_and_processor( + model_name_or_path=checkpoint, gradient_checkpointing=True, use_peft=False + ) ) print("loaded model") @@ -106,7 +110,7 @@ def train(): training_args = GRPOConfig( model_init_kwargs=model_config, # save path on the runpod instance - output_dir="vlm-r1-new-od-tool-reward-schedule-for-tools-apr-29", + output_dir="vlm-r1-new-od-tool-reward-schedule-for-tools-apr-29-restart", # increase learning rate for PEFT - 1e-4 learning_rate=1e-4 if peft_config is not None else 1e-6, max_grad_norm=1.0, diff --git a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py index 5e0ee177..9a1a37a4 100644 --- a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py +++ b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py @@ -10,7 +10,6 @@ ) from r1_vlm.datasets.utils import preprocess_r1_dataset from r1_vlm.environments.multistep_vision_env import MultistepVisionEnv -from r1_vlm.environments.reward_schedules import create_linear_decay_schedule from r1_vlm.environments.tool_vision_env import ToolArgParser, ToolVisionEnv from r1_vlm.tools.object_detection import ( ObjectDetectionTool, @@ -115,7 +114,8 @@ def get_reward_weights(self) -> list[float]: reward_weights.append(schedule) elif reward_function.__name__ == "tool_execution_reward_func": # linearly decay from 1.0 to 0.0 over 200 global steps (200 gradient updates) - schedule = create_linear_decay_schedule(1.0, 0.0, 200) + # schedule = create_linear_decay_schedule(1.0, 0.0, 200) + schedule = 0.0 # restarting the run, past step 200 reward_weights.append(schedule) elif reward_function.__name__ == "correct_answer_reward_func": diff --git a/src/r1_vlm/tools/object_detection.py b/src/r1_vlm/tools/object_detection.py index 4722cbb6..5b8f4b01 100644 --- a/src/r1_vlm/tools/object_detection.py +++ b/src/r1_vlm/tools/object_detection.py @@ -1,6 +1,7 @@ import contextlib import json import os +import threading import time # Import the time module # Add imports for numpy and cv2 @@ -9,6 +10,7 @@ from PIL import Image from tritonclient.http import InferenceServerClient from ultralytics import YOLO +from ultralytics.utils import ThreadingLocked from r1_vlm.environments.tool_vision_env import RawToolArgs, TypedToolArgs @@ -34,12 +36,21 @@ def __init__(self): break time.sleep(1) - self.model = YOLO(f"http://{self.url}", task="detect") + # Thread-local storage for model instances + self._thread_local = threading.local() - def detect_objects(self, image: Image.Image) -> list[dict]: - local_model = YOLO(f"http://{self.url}", task="detect") + def _get_model(self): + """Get or create thread-local model instance""" + if not hasattr(self._thread_local, "model"): + self._thread_local.model = YOLO(f"http://{self.url}", task="detect") + return self._thread_local.model + + @ThreadingLocked() + def detect_objects(self, image: Image.Image) -> dict: + """Thread-safe object detection using thread-local model instances""" + model = self._get_model() + result = model(image, conf=0.3)[0] - result = local_model(image, conf=0.3)[0] boxes = [[int(round(x)) for x in box] for box in result.boxes.xyxy.tolist()] labels = [result.names[int(cls)] for cls in result.boxes.cls.tolist()] @@ -63,6 +74,12 @@ def detect_objects(self, image: Image.Image) -> list[dict]: annotated_image = Image.fromarray(annotated_image[..., ::-1]) return {"text_data": dets_string, "image_data": annotated_image} + def __del__(self): + """Cleanup method to ensure resources are properly released""" + if hasattr(self, "_thread_local"): + if hasattr(self._thread_local, "model"): + del self._thread_local.model + def set_object_detection_tool(tool: ObjectDetectionTool): global _object_detection_tool From 0f8791f4477b2e5e873aa800dfb05186b18a8973 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Wed, 30 Apr 2025 10:39:57 -0700 Subject: [PATCH 09/24] the code technically works now, but it isn't pretty --- .../tool_use_aokvqa_env/tool_use_aok_train.py | 16 +- src/r1_vlm/tools/object_detection.py | 179 +++++++++++++++--- 2 files changed, 167 insertions(+), 28 deletions(-) diff --git a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py index b218a210..c60d3af6 100644 --- a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py +++ b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py @@ -1,3 +1,4 @@ +import multiprocessing import os import torch @@ -88,7 +89,7 @@ def find_target_linear_names( def train(): - checkpoint = "/millcreek/home/sunil/r1_vlm/vlm-r1-new-od-tool-reward-schedule-for-tools-apr-29/checkpoint-200" + checkpoint = "/millcreek/home/sunil/r1_vlm/vlm-r1-new-od-tool-reward-schedule-for-tools-apr-29-restart/checkpoint-50" model, peft_config, processor, model_config, gradient_checkpointing = ( load_model_and_processor( @@ -110,7 +111,7 @@ def train(): training_args = GRPOConfig( model_init_kwargs=model_config, # save path on the runpod instance - output_dir="vlm-r1-new-od-tool-reward-schedule-for-tools-apr-29-restart", + output_dir="vlm-r1-new-od-tool-reward-schedule-for-tools-apr-29-restart-2", # increase learning rate for PEFT - 1e-4 learning_rate=1e-4 if peft_config is not None else 1e-6, max_grad_norm=1.0, @@ -163,6 +164,17 @@ def train(): if __name__ == "__main__": + # --- Set Multiprocessing Start Method --- + # Must be done early, before processes are created, and ideally only once. + # Using force=True might be necessary if it's potentially set elsewhere. + try: + multiprocessing.set_start_method("spawn", force=True) + print("Multiprocessing start method set to 'spawn'.") + except RuntimeError as e: + # Handles cases where the start method might have already been set. + print(f"Multiprocessing start method already set or error setting it: {e}") + # --- End Set Start Method --- + train() # CUDA_VISIBLE_DEVICES=0,1,2,3 uv run accelerate launch --config_file src/r1_vlm/deepspeed_configs/multi_gpu_3only.yaml src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py diff --git a/src/r1_vlm/tools/object_detection.py b/src/r1_vlm/tools/object_detection.py index 5b8f4b01..cccbe55c 100644 --- a/src/r1_vlm/tools/object_detection.py +++ b/src/r1_vlm/tools/object_detection.py @@ -1,8 +1,9 @@ import contextlib import json import os -import threading +import pickle import time # Import the time module +from multiprocessing import Pipe, Process # Add imports for numpy and cv2 from dotenv import load_dotenv @@ -10,7 +11,6 @@ from PIL import Image from tritonclient.http import InferenceServerClient from ultralytics import YOLO -from ultralytics.utils import ThreadingLocked from r1_vlm.environments.tool_vision_env import RawToolArgs, TypedToolArgs @@ -22,9 +22,72 @@ _object_detection_tool = None +# --- Worker Process Function --- +def _yolo_worker(conn, url, image_bytes): + """Runs YOLO detection in a separate process.""" + try: + t_start = time.time() + # Deserialize image + image = pickle.loads(image_bytes) + t_deserialized = time.time() + + # Create transient model instance INSIDE the subprocess + # Note: Add requests patch here if needed for the subprocess context + # import requests + # _worker_session = requests.Session() + # requests.Session = lambda *args, **kwargs: _worker_session + # requests.session = lambda *args, **kwargs: _worker_session + model = YOLO(f"http://{url}", task="detect") + t_model_created = time.time() + result = model(image, conf=0.3)[0] + t_inference_done = time.time() + del model # Explicitly delete + + # Extract necessary data + boxes = [[int(round(x)) for x in box] for box in result.boxes.xyxy.tolist()] + labels = [result.names[int(cls)] for cls in result.boxes.cls.tolist()] + plot_img_array = result.plot(conf=False, labels=True) # Get plotted numpy array + t_results_extracted = time.time() + + # Serialize results (only send data, not complex objects) + output = { + "boxes": boxes, + "labels": labels, + "plot_img_array": plot_img_array, + } + serialized_output = pickle.dumps(output) + t_results_serialized = time.time() + + # Print worker timings + print( + f" [Worker {os.getpid()}] Timings (s): " + f"Deserialize: {t_deserialized - t_start:.3f}, " + f"ModelCreate: {t_model_created - t_deserialized:.3f}, " + f"Inference: {t_inference_done - t_model_created:.3f}, " + f"Extract: {t_results_extracted - t_inference_done:.3f}, " + f"Serialize: {t_results_serialized - t_results_extracted:.3f}, " + f"Total: {t_results_serialized - t_start:.3f}" + ) + + conn.send(serialized_output) + except Exception as e: + # Send back exception info if something goes wrong + print(f"YOLO Worker Error: {e}") # Log error in worker + import traceback + + traceback.print_exc() + conn.send(pickle.dumps(e)) + finally: + conn.close() + + +# --- End Worker Process Function --- + + class ObjectDetectionTool: def __init__(self): self.url = f"{API_IP}:{API_PORT}/yolo" + # Keep triton client for readiness check, but not for inference here self.triton_client = InferenceServerClient( url=self.url, verbose=False, ssl=False ) @@ -36,23 +99,77 @@ def __init__(self): break time.sleep(1) - # Thread-local storage for model instances - self._thread_local = threading.local() - - def _get_model(self): - """Get or create thread-local model instance""" - if not hasattr(self._thread_local, "model"): - self._thread_local.model = YOLO(f"http://{self.url}", task="detect") - return self._thread_local.model - - @ThreadingLocked() def detect_objects(self, image: Image.Image) -> dict: - """Thread-safe object detection using thread-local model instances""" - model = self._get_model() - result = model(image, conf=0.3)[0] - - boxes = [[int(round(x)) for x in box] for box in result.boxes.xyxy.tolist()] - labels = [result.names[int(cls)] for cls in result.boxes.cls.tolist()] + """Performs object detection using a separate worker process.""" + t_parent_start = time.time() + parent_conn, child_conn = Pipe() + + # Serialize image data + try: + image_bytes = pickle.dumps(image) + except Exception as e: + raise RuntimeError(f"Failed to pickle input image: {e}") from e + + proc = None # Initialize proc to None + t_process_start = 0.0 + t_process_end = 0.0 + try: + # Create and start the worker process + proc = Process( + target=_yolo_worker, args=(child_conn, self.url, image_bytes) + ) + t_process_start = time.time() + proc.start() + child_conn.close() # Close child end in parent immediately after start + + # Wait for result from worker using poll (with timeout) + if parent_conn.poll(timeout=60.0): # Wait up to 60 seconds + result_bytes = parent_conn.recv() + t_process_end = time.time() # Record time when result received + else: + t_process_end = time.time() # Record time even on timeout + raise TimeoutError("YOLO worker process timed out waiting for result.") + + except EOFError: # Handle case where child exits unexpectedly before sending + result_bytes = pickle.dumps( + RuntimeError("YOLO worker process exited before sending results.") + ) + except Exception as e: # Catch other potential errors during process management + result_bytes = pickle.dumps( + RuntimeError(f"Error managing YOLO worker process: {e}") + ) + finally: + # Ensure process cleanup + if proc is not None: + proc.join(timeout=5.0) # Short wait for graceful exit + if proc.is_alive(): + print("Warning: YOLO worker process unresponsive, terminating.") + proc.terminate() + proc.join(timeout=1.0) # Wait after terminate + # Ensure parent connection is closed + if "parent_conn" in locals() and not parent_conn.closed: + parent_conn.close() + # Ensure child connection is closed (belt and suspenders) + if "child_conn" in locals() and not child_conn.closed: + child_conn.close() + + # Deserialize result + try: + result_data = pickle.loads(result_bytes) + except Exception as e: + # If deserialization fails, result_bytes might contain partial/error data + raise RuntimeError( + f"Failed to deserialize result from YOLO worker. Raw data: {result_bytes!r}. Error: {e}" + ) from e + + # Re-raise exception if worker sent one + if isinstance(result_data, Exception): + raise RuntimeError("YOLO worker process failed") from result_data + + # Process results back into the expected format + boxes = result_data["boxes"] + labels = result_data["labels"] + plot_img_array = result_data["plot_img_array"] detections = [ {"bbox_2d": box, "label": label} for box, label in zip(boxes, labels) @@ -65,20 +182,30 @@ def detect_objects(self, image: Image.Image) -> dict: dets_string = "" for index, det in enumerate(detections): dets_string += f"{index + 1}. {det}" - if index < len(detections) - 1: dets_string += "\n" + # Convert plotted numpy array back to PIL Image + annotated_image = Image.fromarray( + plot_img_array[..., ::-1] + ) # BGR->RGB for PIL + + t_parent_end = time.time() + # Print parent timings + process_duration = ( + t_process_end - t_process_start + if t_process_start > 0 and t_process_end > 0 + else -1.0 + ) + print( + f"[Parent {os.getpid()}] Subprocess call duration: {process_duration:.3f} s, " + f"Total detect_objects duration: {t_parent_end - t_parent_start:.3f} s" + ) - annotated_image = result.plot(conf=False, labels=True) - # convert to rgb and then to PIL image - annotated_image = Image.fromarray(annotated_image[..., ::-1]) return {"text_data": dets_string, "image_data": annotated_image} def __del__(self): - """Cleanup method to ensure resources are properly released""" - if hasattr(self, "_thread_local"): - if hasattr(self._thread_local, "model"): - del self._thread_local.model + """Cleanup method - nothing persistent to clean up""" + pass def set_object_detection_tool(tool: ObjectDetectionTool): From 950950cb46f9dbe4ae3f65c6bce32b161367476d Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Wed, 30 Apr 2025 12:07:59 -0700 Subject: [PATCH 10/24] this works but it is stupid slow, trying to move call out of training process via an api --- .../tool_use_aokvqa_env/tool_use_aok_train.py | 11 ++-- src/r1_vlm/tools/object_detection.py | 52 ++++++++++++------- tool_server/training_server.py | 0 3 files changed, 38 insertions(+), 25 deletions(-) create mode 100644 tool_server/training_server.py diff --git a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py index c60d3af6..9f044852 100644 --- a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py +++ b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py @@ -164,14 +164,13 @@ def train(): if __name__ == "__main__": - # --- Set Multiprocessing Start Method --- - # Must be done early, before processes are created, and ideally only once. - # Using force=True might be necessary if it's potentially set elsewhere. + # --- Set Multiprocessing Start Method to FORKSERVER --- + # Offers a potential speedup over 'spawn' while aiming for better + # isolation than 'fork' for CUDA. Still experimental here. try: - multiprocessing.set_start_method("spawn", force=True) - print("Multiprocessing start method set to 'spawn'.") + multiprocessing.set_start_method("forkserver", force=True) + print("Multiprocessing start method set to 'forkserver'.") except RuntimeError as e: - # Handles cases where the start method might have already been set. print(f"Multiprocessing start method already set or error setting it: {e}") # --- End Set Start Method --- diff --git a/src/r1_vlm/tools/object_detection.py b/src/r1_vlm/tools/object_detection.py index cccbe55c..b4aceb91 100644 --- a/src/r1_vlm/tools/object_detection.py +++ b/src/r1_vlm/tools/object_detection.py @@ -10,7 +10,6 @@ from imgcat import imgcat from PIL import Image from tritonclient.http import InferenceServerClient -from ultralytics import YOLO from r1_vlm.environments.tool_vision_env import RawToolArgs, TypedToolArgs @@ -24,24 +23,30 @@ # --- Worker Process Function --- def _yolo_worker(conn, url, image_bytes): - """Runs YOLO detection in a separate process.""" + """Runs YOLO detection in a separate process, forcing CPU.""" + # --- Force CPU for this process --- + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + + # --- End Force CPU --- + try: + # --- Import YOLO *inside* the worker, *after* setting CUDA_VISIBLE_DEVICES --- + from ultralytics import YOLO + # Note: This will also import torch internally here in the child process + # --- End Import --- + t_start = time.time() + # Deserialize image image = pickle.loads(image_bytes) t_deserialized = time.time() - # Create transient model instance INSIDE the subprocess - # Note: Add requests patch here if needed for the subprocess context - # import requests - # _worker_session = requests.Session() - # requests.Session = lambda *args, **kwargs: _worker_session - # requests.session = lambda *args, **kwargs: _worker_session - model = YOLO(f"http://{url}", task="detect") + # Create transient model instance INSIDE the subprocess (will use CPU) + model = YOLO(f"http://{url}", task="detect") # Should now default to CPU t_model_created = time.time() result = model(image, conf=0.3)[0] t_inference_done = time.time() - del model # Explicitly delete + del model # Extract necessary data boxes = [[int(round(x)) for x in box] for box in result.boxes.xyxy.tolist()] @@ -58,18 +63,27 @@ def _yolo_worker(conn, url, image_bytes): serialized_output = pickle.dumps(output) t_results_serialized = time.time() - # Print worker timings + # Print worker timings (original) + # print(...) # Keep this or comment out if too verbose + + t_before_send = time.time() + print(f" [Worker {os.getpid()} CPU] Attempting to send results...") + conn.send(serialized_output) + t_after_send = time.time() + print(f" [Worker {os.getpid()} CPU] Results sent.") + + # Print worker timings (updated with send time) print( - f" [Worker {os.getpid()}] Timings (s): " - f"Deserialize: {t_deserialized - t_start:.3f}, " - f"ModelCreate: {t_model_created - t_deserialized:.3f}, " - f"Inference: {t_inference_done - t_model_created:.3f}, " - f"Extract: {t_results_extracted - t_inference_done:.3f}, " - f"Serialize: {t_results_serialized - t_results_extracted:.3f}, " - f"Total: {t_results_serialized - t_start:.3f}" + f" [Worker {os.getpid()} CPU] Timings (s): " + f"Deserialize: {t_deserialized - t_start:.4f}, " + f"ModelCreate: {t_model_created - t_deserialized:.4f}, " + f"Inference: {t_inference_done - t_model_created:.4f}, " + f"Extract: {t_results_extracted - t_inference_done:.4f}, " + f"Serialize: {t_results_serialized - t_results_extracted:.4f}, " + f"Send: {t_after_send - t_before_send:.4f}, " # Added Send time + f"Total: {t_after_send - t_start:.4f}" # Use t_after_send for total now ) - conn.send(serialized_output) except Exception as e: # Send back exception info if something goes wrong print(f"YOLO Worker Error: {e}") # Log error in worker diff --git a/tool_server/training_server.py b/tool_server/training_server.py new file mode 100644 index 00000000..e69de29b From 4d93020c655f7d35c33f5d2ae70735d012691303 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Wed, 30 Apr 2025 12:34:18 -0700 Subject: [PATCH 11/24] its working! and its fast --- src/r1_vlm/tools/object_detection.py | 322 +++++++++------------------ tool_server/pyproject.toml | 2 + tool_server/training_server.py | 211 ++++++++++++++++++ tool_server/uv.lock | 148 ++++++++++++ 4 files changed, 466 insertions(+), 217 deletions(-) diff --git a/src/r1_vlm/tools/object_detection.py b/src/r1_vlm/tools/object_detection.py index b4aceb91..28c173fc 100644 --- a/src/r1_vlm/tools/object_detection.py +++ b/src/r1_vlm/tools/object_detection.py @@ -1,218 +1,122 @@ -import contextlib +import base64 # For encoding/decoding images +import io # For handling image bytes import json import os -import pickle -import time # Import the time module -from multiprocessing import Pipe, Process +import time -# Add imports for numpy and cv2 +import requests # To make HTTP requests to the API server + +# Remove multiprocessing imports +# from multiprocessing import Pipe, Process +# Remove YOLO import from here +# from ultralytics import YOLO +# Add imports for numpy and cv2 (if still needed for other parts, unlikely now) from dotenv import load_dotenv -from imgcat import imgcat from PIL import Image -from tritonclient.http import InferenceServerClient from r1_vlm.environments.tool_vision_env import RawToolArgs, TypedToolArgs load_dotenv() -API_IP = str(os.getenv("API_IP")) -API_PORT = int(os.getenv("API_PORT")) +# --- Configuration for the Detection API Server --- +# Get the API server's URL from environment variables, default to localhost:8001 +DETECTION_API_HOST = os.getenv("DETECTION_API_HOST", "localhost") +DETECTION_API_PORT = int(os.getenv("DETECTION_API_PORT", 8001)) +DETECTION_API_URL = f"http://{DETECTION_API_HOST}:{DETECTION_API_PORT}/detect" +# --- End Configuration --- _object_detection_tool = None -# --- Worker Process Function --- -def _yolo_worker(conn, url, image_bytes): - """Runs YOLO detection in a separate process, forcing CPU.""" - # --- Force CPU for this process --- - os.environ["CUDA_VISIBLE_DEVICES"] = "-1" - - # --- End Force CPU --- - - try: - # --- Import YOLO *inside* the worker, *after* setting CUDA_VISIBLE_DEVICES --- - from ultralytics import YOLO - # Note: This will also import torch internally here in the child process - # --- End Import --- - - t_start = time.time() - - # Deserialize image - image = pickle.loads(image_bytes) - t_deserialized = time.time() - - # Create transient model instance INSIDE the subprocess (will use CPU) - model = YOLO(f"http://{url}", task="detect") # Should now default to CPU - t_model_created = time.time() - result = model(image, conf=0.3)[0] - t_inference_done = time.time() - del model - - # Extract necessary data - boxes = [[int(round(x)) for x in box] for box in result.boxes.xyxy.tolist()] - labels = [result.names[int(cls)] for cls in result.boxes.cls.tolist()] - plot_img_array = result.plot(conf=False, labels=True) # Get plotted numpy array - t_results_extracted = time.time() - - # Serialize results (only send data, not complex objects) - output = { - "boxes": boxes, - "labels": labels, - "plot_img_array": plot_img_array, - } - serialized_output = pickle.dumps(output) - t_results_serialized = time.time() - - # Print worker timings (original) - # print(...) # Keep this or comment out if too verbose - - t_before_send = time.time() - print(f" [Worker {os.getpid()} CPU] Attempting to send results...") - conn.send(serialized_output) - t_after_send = time.time() - print(f" [Worker {os.getpid()} CPU] Results sent.") - - # Print worker timings (updated with send time) - print( - f" [Worker {os.getpid()} CPU] Timings (s): " - f"Deserialize: {t_deserialized - t_start:.4f}, " - f"ModelCreate: {t_model_created - t_deserialized:.4f}, " - f"Inference: {t_inference_done - t_model_created:.4f}, " - f"Extract: {t_results_extracted - t_inference_done:.4f}, " - f"Serialize: {t_results_serialized - t_results_extracted:.4f}, " - f"Send: {t_after_send - t_before_send:.4f}, " # Added Send time - f"Total: {t_after_send - t_start:.4f}" # Use t_after_send for total now - ) - - except Exception as e: - # Send back exception info if something goes wrong - print(f"YOLO Worker Error: {e}") # Log error in worker - import traceback - - traceback.print_exc() - conn.send(pickle.dumps(e)) - finally: - conn.close() - - -# --- End Worker Process Function --- - - class ObjectDetectionTool: def __init__(self): - self.url = f"{API_IP}:{API_PORT}/yolo" - # Keep triton client for readiness check, but not for inference here - self.triton_client = InferenceServerClient( - url=self.url, verbose=False, ssl=False - ) - - # Wait until model is ready - for _ in range(10): - with contextlib.suppress(Exception): - assert self.triton_client.is_model_ready("yolo") - break - time.sleep(1) + # Store the URL for the detection API server + self.api_url = DETECTION_API_URL def detect_objects(self, image: Image.Image) -> dict: - """Performs object detection using a separate worker process.""" - t_parent_start = time.time() - parent_conn, child_conn = Pipe() - - # Serialize image data - try: - image_bytes = pickle.dumps(image) - except Exception as e: - raise RuntimeError(f"Failed to pickle input image: {e}") from e + """Sends image to detection API server and returns results.""" + t_client_start = time.time() + annotated_image = None # Default + dets_string = "Error: Detection failed." # Default error message - proc = None # Initialize proc to None - t_process_start = 0.0 - t_process_end = 0.0 try: - # Create and start the worker process - proc = Process( - target=_yolo_worker, args=(child_conn, self.url, image_bytes) + # 1. Prepare Image for Sending + buffer = io.BytesIO() + # Save image to buffer in a common format like PNG + image.save(buffer, format="PNG") + img_bytes = buffer.getvalue() + img_base64 = base64.b64encode(img_bytes).decode("utf-8") + t_encoded = time.time() + + # 2. Prepare Request Payload + payload = {"image_base64": img_base64} + + # 3. Call the API Server + response = requests.post( + self.api_url, + json=payload, + headers={"Content-Type": "application/json"}, + timeout=60.0, # Set a reasonable timeout (e.g., 60 seconds) ) - t_process_start = time.time() - proc.start() - child_conn.close() # Close child end in parent immediately after start - - # Wait for result from worker using poll (with timeout) - if parent_conn.poll(timeout=60.0): # Wait up to 60 seconds - result_bytes = parent_conn.recv() - t_process_end = time.time() # Record time when result received + t_responded = time.time() + + # 4. Process Response + if response.status_code == 200: + try: + response_data = response.json() + dets_string = response_data.get( + "text_data", "Error: Missing text_data in response." + ) + image_data_base64 = response_data.get("image_data_base64") + + if image_data_base64: + try: + annotated_bytes = base64.b64decode(image_data_base64) + annotated_image = Image.open(io.BytesIO(annotated_bytes)) + except Exception as img_err: + raise ValueError( + f"Failed to decode/load annotated image from response: {img_err}" + ) + # Keep annotated_image as None + + except json.JSONDecodeError as json_err: + raise ValueError( + f"Failed to decode JSON response from API: {json_err}" + ) + + except Exception as proc_err: # Catch other errors processing response + raise ValueError( + f"Error processing successful API response: {proc_err}" + ) + else: - t_process_end = time.time() # Record time even on timeout - raise TimeoutError("YOLO worker process timed out waiting for result.") + # Handle HTTP errors + error_msg = f"Error from detection API: {response.status_code}" + try: + error_detail = response.json().get("detail", response.text) + error_msg += f" - {error_detail}" + except json.JSONDecodeError: + error_msg += f" - {response.text}" + raise ValueError(error_msg) + + except requests.exceptions.Timeout: + raise ValueError("Request to detection API timed out after 60s.") + except requests.exceptions.RequestException as req_err: + raise ValueError(f"Request to detection API failed: {req_err}") - except EOFError: # Handle case where child exits unexpectedly before sending - result_bytes = pickle.dumps( - RuntimeError("YOLO worker process exited before sending results.") - ) - except Exception as e: # Catch other potential errors during process management - result_bytes = pickle.dumps( - RuntimeError(f"Error managing YOLO worker process: {e}") - ) - finally: - # Ensure process cleanup - if proc is not None: - proc.join(timeout=5.0) # Short wait for graceful exit - if proc.is_alive(): - print("Warning: YOLO worker process unresponsive, terminating.") - proc.terminate() - proc.join(timeout=1.0) # Wait after terminate - # Ensure parent connection is closed - if "parent_conn" in locals() and not parent_conn.closed: - parent_conn.close() - # Ensure child connection is closed (belt and suspenders) - if "child_conn" in locals() and not child_conn.closed: - child_conn.close() - - # Deserialize result - try: - result_data = pickle.loads(result_bytes) except Exception as e: - # If deserialization fails, result_bytes might contain partial/error data - raise RuntimeError( - f"Failed to deserialize result from YOLO worker. Raw data: {result_bytes!r}. Error: {e}" - ) from e - - # Re-raise exception if worker sent one - if isinstance(result_data, Exception): - raise RuntimeError("YOLO worker process failed") from result_data - - # Process results back into the expected format - boxes = result_data["boxes"] - labels = result_data["labels"] - plot_img_array = result_data["plot_img_array"] - - detections = [ - {"bbox_2d": box, "label": label} for box, label in zip(boxes, labels) - ] - - if len(detections) == 0: - dets_string = "No objects detected." - annotated_image = None - else: - dets_string = "" - for index, det in enumerate(detections): - dets_string += f"{index + 1}. {det}" - if index < len(detections) - 1: - dets_string += "\n" - # Convert plotted numpy array back to PIL Image - annotated_image = Image.fromarray( - plot_img_array[..., ::-1] - ) # BGR->RGB for PIL - - t_parent_end = time.time() - # Print parent timings - process_duration = ( - t_process_end - t_process_start - if t_process_start > 0 and t_process_end > 0 - else -1.0 - ) + # Catch-all for other unexpected errors in the client logic + raise ValueError( + f"Unexpected error in detect_objects client: {e}", exc_info=True + ) + + t_client_end = time.time() print( - f"[Parent {os.getpid()}] Subprocess call duration: {process_duration:.3f} s, " - f"Total detect_objects duration: {t_parent_end - t_parent_start:.3f} s" + f"detect_objects client timings (s): " + f"Encode: {t_encoded - t_client_start:.3f}, " + f"API Call: {t_responded - t_encoded:.3f}, " + f"Decode/Process: {t_client_end - t_responded:.3f}, " + f"Total: {t_client_end - t_client_start:.3f}" ) return {"text_data": dets_string, "image_data": annotated_image} @@ -261,12 +165,18 @@ def detect_objects(image_name: str, **kwargs) -> tuple[list[dict], Image.Image]: f"Error: Image {image_name} is not the input_image. This tool can only be called on the input_image." ) - return _object_detection_tool.detect_objects(image) + if _object_detection_tool is None: + raise RuntimeError( + "ObjectDetectionTool not initialized. Call set_object_detection_tool first." + ) + + # Call the method which now calls the API + return _object_detection_tool.detect_objects(image) # Return type is now dict def parse_detect_objects_args(raw_args: RawToolArgs) -> TypedToolArgs: """ - Parses raw string arguments for the detect_objects tool, focusing on type conversion. + Parses raw string arguments for the detect_objects tool. Expects keys: 'name', 'image_name' Detailed validation of values (e.g., 'image_name' validity) @@ -276,11 +186,10 @@ def parse_detect_objects_args(raw_args: RawToolArgs) -> TypedToolArgs: raw_args: Dictionary with string keys and string values from the general parser. Returns: - A dictionary containing the arguments with basic type conversions applied, - ready for the detect_objects function. Keys: 'image_name'. + A dictionary containing the arguments. Keys: 'image_name'. Raises: - ValueError: If required keys are missing or basic type conversion fails + ValueError: If required keys are missing or extra keys are present. """ required_keys = {"name", "image_name"} actual_keys = set(raw_args.keys()) @@ -299,16 +208,12 @@ def parse_detect_objects_args(raw_args: RawToolArgs) -> TypedToolArgs: f"Error: Unexpected arguments for detect_objects tool: {', '.join(sorted(extra_keys))}" ) - # 3. Perform Basic Type Conversions + # 3. Prepare typed args (only image_name needed) typed_args: TypedToolArgs = {} try: # Keep image_name as string typed_args["image_name"] = raw_args["image_name"] - except json.JSONDecodeError: - raise ValueError( - f"Error: Invalid JSON format for 'classes': '{raw_args['classes']}'" - ) except ValueError as e: # Catch the list type error from above raise ValueError(f"Error: processing 'classes': {e}") @@ -317,20 +222,3 @@ def parse_detect_objects_args(raw_args: RawToolArgs) -> TypedToolArgs: raise ValueError(f"Error: Missing key '{e}' during conversion phase.") return typed_args - - -if __name__ == "__main__": - tool = ObjectDetectionTool() - set_object_detection_tool(tool) - image = Image.open( - "/millcreek/home/sunil/r1_vlm_bumbershoot0/r1_vlm/src/r1_vlm/tools/cars.jpeg" - ) - for i in range(10): - detections = detect_objects( - image_name="input_image", images={"input_image": image} - ) - - image_data = detections["image_data"] - imgcat(image_data) - text_data = detections["text_data"] - print(text_data) diff --git a/tool_server/pyproject.toml b/tool_server/pyproject.toml index 371551e9..a2bb6f18 100644 --- a/tool_server/pyproject.toml +++ b/tool_server/pyproject.toml @@ -19,6 +19,8 @@ dependencies = [ "numpy>=1.26.4", "pillow>=11.2.1", "python-dotenv>=1.1.0", + "uvicorn>=0.34.2", + "fastapi>=0.115.12", ] [tool.uv.sources] diff --git a/tool_server/training_server.py b/tool_server/training_server.py index e69de29b..d935d5b3 100644 --- a/tool_server/training_server.py +++ b/tool_server/training_server.py @@ -0,0 +1,211 @@ +import base64 +import io +import logging +import os +import time +from typing import Optional + +# --- Force CPU Usage for this Server --- +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" +logger = logging.getLogger(__name__) # Get logger early for info message +logger.info("CUDA_VISIBLE_DEVICES set to -1. Server will attempt to use CPU for YOLO.") +# --- End Force CPU --- + +import uvicorn +from dotenv import load_dotenv +from fastapi import FastAPI, HTTPException +from fastapi.responses import JSONResponse +from PIL import Image +from pydantic import BaseModel +from ultralytics import YOLO + +# --- Configuration & Initialization --- + +# Load environment variables (pointing to the *actual* YOLO/Triton backend) +load_dotenv() +API_IP = str(os.getenv("API_IP")) +API_PORT = int(os.getenv("API_PORT")) +BACKEND_URL = f"http://{API_IP}:{API_PORT}/yolo" + +# Set up logging +logging.basicConfig(level=logging.INFO) + +# Global variable to hold the loaded YOLO model +yolo_model: Optional[YOLO] = None + +# --- Pydantic Models --- + + +class DetectionRequest(BaseModel): + """Request body for the /detect endpoint.""" + + image_base64: str + + +class DetectionResponse(BaseModel): + """Successful response body for the /detect endpoint.""" + + text_data: str + image_data_base64: Optional[str] = None + + +class ErrorResponse(BaseModel): + """Error response body.""" + + error: str + + +# --- FastAPI App --- + +app = FastAPI(title="YOLO Detection API Server") + + +@app.on_event("startup") +async def startup_event(): + """Load the YOLO model on server startup (will use CPU).""" + global yolo_model + logger.info( + f"Attempting to load YOLO model targeting backend: {BACKEND_URL} (Forcing CPU)" + ) # Added CPU note + start_time = time.time() + try: + # Model will initialize on CPU due to CUDA_VISIBLE_DEVICES=-1 + yolo_model = YOLO(BACKEND_URL, task="detect") + # Perform a dummy inference to ensure connection/initialization on CPU + dummy_img = Image.new("RGB", (64, 64), color="red") + _ = yolo_model(dummy_img, verbose=False) + end_time = time.time() + logger.info( + f"YOLO model loaded successfully on CPU in {end_time - start_time:.2f} seconds." # Added CPU note + ) + except Exception as e: + logger.error(f"Failed to load YOLO model on startup: {e}", exc_info=True) + yolo_model = None + + +@app.post( + "/detect", + response_model=DetectionResponse, + responses={500: {"model": ErrorResponse}}, + summary="Perform object detection on an image", +) +async def detect_objects_api(request: DetectionRequest): + """ + Accepts a base64 encoded image, performs YOLO detection using the + pre-loaded model targeting the backend service, and returns results. + """ + global yolo_model + if yolo_model is None: + logger.error("YOLO model is not loaded. Cannot process request.") + raise HTTPException( + status_code=503, detail="Model service unavailable" + ) # 503 Service Unavailable + + logger.info("Received detection request.") + t_start = time.time() + + try: + # 1. Decode Base64 Image + try: + image_bytes = base64.b64decode(request.image_base64) + image = Image.open(io.BytesIO(image_bytes)).convert("RGB") # Ensure RGB + except Exception as e: + logger.error(f"Failed to decode/load image from base64: {e}") + raise HTTPException(status_code=400, detail=f"Invalid image data: {e}") + t_decoded = time.time() + + # 2. Run YOLO Inference + try: + # Use the globally loaded model + result = yolo_model(image, conf=0.3, verbose=False)[ + 0 + ] # verbose=False is quieter + except Exception as e: + logger.error(f"YOLO inference failed: {e}", exc_info=True) + # Consider more specific error codes if YOLO/Triton provide them + raise HTTPException(status_code=500, detail=f"Inference failed: {e}") + t_inferred = time.time() + + # 3. Process Results + boxes = [[int(round(x)) for x in box] for box in result.boxes.xyxy.tolist()] + labels = [result.names[int(cls)] for cls in result.boxes.cls.tolist()] + detections = [ + {"bbox_2d": box, "label": label} for box, label in zip(boxes, labels) + ] + + # 4. Format Output String + if not detections: + dets_string = "No objects detected." + annotated_image = None + plot_img_array = None + else: + dets_string = "" + for index, det in enumerate(detections): + dets_string += f"{index + 1}. {det}" + if index < len(detections) - 1: + dets_string += "\n" + # Generate annotated image array + plot_img_array = result.plot(conf=False, labels=True) + annotated_image = Image.fromarray(plot_img_array[..., ::-1]) # BGR->RGB + + t_processed = time.time() + + # 5. Encode Annotated Image (if any) + image_data_base64: Optional[str] = None + if annotated_image: + try: + with io.BytesIO() as buffer: + # Save as PNG (generally lossless) + annotated_image.save(buffer, format="PNG") + img_bytes = buffer.getvalue() + image_data_base64 = base64.b64encode(img_bytes).decode("utf-8") + except Exception as e: + logger.error(f"Failed to encode annotated image: {e}") + # Proceed without annotated image if encoding fails + t_encoded = time.time() + + logger.info( + f"Detection successful. Timings (s): " + f"Decode: {t_decoded - t_start:.3f}, " + f"Inference: {t_inferred - t_decoded:.3f}, " + f"Process: {t_processed - t_inferred:.3f}, " + f"Encode: {t_encoded - t_processed:.3f}, " + f"Total: {t_encoded - t_start:.3f}" + ) + + return DetectionResponse( + text_data=dets_string, + image_data_base64=image_data_base64, + ) + + except HTTPException as http_exc: + # Re-raise HTTPExceptions (like 400 Bad Request) + raise http_exc + except Exception as e: + # Catch-all for unexpected server errors during processing + logger.error(f"Unexpected error during detection request: {e}", exc_info=True) + # Return a generic 500 error response + return JSONResponse( + status_code=500, content={"error": f"Internal server error: {e}"} + ) + + +# --- Run Server --- + +if __name__ == "__main__": + # Set default port if not specified in environment + server_port = int(os.getenv("DETECTION_API_PORT", 8001)) + num_workers = int(os.getenv("DETECTION_API_WORKERS", 6)) # Default to 6 workers + logger.info( + f"Starting YOLO detection server on port {server_port} with {num_workers} workers (CPU forced)..." # Added CPU note + ) + # Note: Using uvicorn.run() with workers > 1 might have limitations + # compared to running via the command line with a process manager like gunicorn. + # See Uvicorn documentation for details on multi-process modes. + uvicorn.run( + "training_server:app", # Need to specify app string for reload/workers + host="0.0.0.0", + port=server_port, + workers=num_workers, + # reload=False # Ensure reload is False when using workers programmatically + ) diff --git a/tool_server/uv.lock b/tool_server/uv.lock index 2929c3fa..f440049b 100644 --- a/tool_server/uv.lock +++ b/tool_server/uv.lock @@ -62,6 +62,29 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/6a/bc7e17a3e87a2985d3e8f4da4cd0f481060eb78fb08596c42be62c90a4d9/aiosignal-1.3.2-py2.py3-none-any.whl", hash = "sha256:45cde58e409a301715980c2b01d0c28bdde3770d8290b5eb2173759d9acb31a5", size = 7597 }, ] +[[package]] +name = "annotated-types" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ee/67/531ea369ba64dcff5ec9c3402f9f51bf748cec26dde048a2f973a4eea7f5/annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89", size = 16081 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643 }, +] + +[[package]] +name = "anyio" +version = "4.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "idna" }, + { name = "sniffio" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/95/7d/4c1bd541d4dffa1b52bd83fb8527089e097a106fc90b467a7313b105f840/anyio-4.9.0.tar.gz", hash = "sha256:673c0c244e15788651a4ff38710fea9675823028a6f08a5eda409e0c9840a028", size = 190949 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a1/ee/48ca1a7c89ffec8b6a0c5d02b89c305671d5ffd8d3c94acf8b8c408575bb/anyio-4.9.0-py3-none-any.whl", hash = "sha256:9f76d541cad6e36af7beb62e978876f3b41e3e04f2c1fbf0884604c0a9c4d93c", size = 100916 }, +] + [[package]] name = "asttokens" version = "3.0.0" @@ -150,6 +173,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0e/f6/65ecc6878a89bb1c23a086ea335ad4bf21a588990c3f535a227b9eea9108/charset_normalizer-3.4.1-py3-none-any.whl", hash = "sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85", size = 49767 }, ] +[[package]] +name = "click" +version = "8.1.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/d4/7ebdbd03970677812aac39c869717059dbb71a4cfc033ca6e5221787892c/click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2", size = 98188 }, +] + [[package]] name = "clip" version = "1.0" @@ -255,6 +290,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7b/8f/c4d9bafc34ad7ad5d8dc16dd1347ee0e507a52c3adb6bfa8887e1c6a26ba/executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa", size = 26702 }, ] +[[package]] +name = "fastapi" +version = "0.115.12" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "starlette" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f4/55/ae499352d82338331ca1e28c7f4a63bfd09479b16395dce38cf50a39e2c2/fastapi-0.115.12.tar.gz", hash = "sha256:1e2c2a2646905f9e83d32f04a3f86aff4a286669c6c950ca95b5fd68c2602681", size = 295236 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/50/b3/b51f09c2ba432a576fe63758bddc81f78f0c6309d9e5c10d194313bf021e/fastapi-0.115.12-py3-none-any.whl", hash = "sha256:e94613d6c05e27be7ffebdd6ea5f388112e5e430c8f7d6494a9d1d88d43e814d", size = 95164 }, +] + [[package]] name = "filelock" version = "3.18.0" @@ -420,6 +469,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7d/36/0c03e2d80db69e2472cf81c6123aa7d14741de7cf790117291a703ae6ae1/grpcio-1.67.1-cp312-cp312-win_amd64.whl", hash = "sha256:5721e66a594a6c4204458004852719b38f3d5522082be9061d6510b455c90afc", size = 4346574 }, ] +[[package]] +name = "h11" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515 }, +] + [[package]] name = "huggingface-hub" version = "0.30.2" @@ -1079,6 +1137,46 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/13/a3/a812df4e2dd5696d1f351d58b8fe16a405b234ad2886a0dab9183fb78109/pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc", size = 117552 }, ] +[[package]] +name = "pydantic" +version = "2.11.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-types" }, + { name = "pydantic-core" }, + { name = "typing-extensions" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/77/ab/5250d56ad03884ab5efd07f734203943c8a8ab40d551e208af81d0257bf2/pydantic-2.11.4.tar.gz", hash = "sha256:32738d19d63a226a52eed76645a98ee07c1f410ee41d93b4afbfa85ed8111c2d", size = 786540 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e7/12/46b65f3534d099349e38ef6ec98b1a5a81f42536d17e0ba382c28c67ba67/pydantic-2.11.4-py3-none-any.whl", hash = "sha256:d9615eaa9ac5a063471da949c8fc16376a84afb5024688b3ff885693506764eb", size = 443900 }, +] + +[[package]] +name = "pydantic-core" +version = "2.33.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ad/88/5f2260bdfae97aabf98f1778d43f69574390ad787afb646292a638c923d4/pydantic_core-2.33.2.tar.gz", hash = "sha256:7cb8bc3605c29176e1b105350d2e6474142d7c1bd1d9327c4a9bdb46bf827acc", size = 435195 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/18/8a/2b41c97f554ec8c71f2a8a5f85cb56a8b0956addfe8b0efb5b3d77e8bdc3/pydantic_core-2.33.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:a7ec89dc587667f22b6a0b6579c249fca9026ce7c333fc142ba42411fa243cdc", size = 2009000 }, + { url = "https://files.pythonhosted.org/packages/a1/02/6224312aacb3c8ecbaa959897af57181fb6cf3a3d7917fd44d0f2917e6f2/pydantic_core-2.33.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3c6db6e52c6d70aa0d00d45cdb9b40f0433b96380071ea80b09277dba021ddf7", size = 1847996 }, + { url = "https://files.pythonhosted.org/packages/d6/46/6dcdf084a523dbe0a0be59d054734b86a981726f221f4562aed313dbcb49/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e61206137cbc65e6d5256e1166f88331d3b6238e082d9f74613b9b765fb9025", size = 1880957 }, + { url = "https://files.pythonhosted.org/packages/ec/6b/1ec2c03837ac00886ba8160ce041ce4e325b41d06a034adbef11339ae422/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb8c529b2819c37140eb51b914153063d27ed88e3bdc31b71198a198e921e011", size = 1964199 }, + { url = "https://files.pythonhosted.org/packages/2d/1d/6bf34d6adb9debd9136bd197ca72642203ce9aaaa85cfcbfcf20f9696e83/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c52b02ad8b4e2cf14ca7b3d918f3eb0ee91e63b3167c32591e57c4317e134f8f", size = 2120296 }, + { url = "https://files.pythonhosted.org/packages/e0/94/2bd0aaf5a591e974b32a9f7123f16637776c304471a0ab33cf263cf5591a/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:96081f1605125ba0855dfda83f6f3df5ec90c61195421ba72223de35ccfb2f88", size = 2676109 }, + { url = "https://files.pythonhosted.org/packages/f9/41/4b043778cf9c4285d59742281a769eac371b9e47e35f98ad321349cc5d61/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f57a69461af2a5fa6e6bbd7a5f60d3b7e6cebb687f55106933188e79ad155c1", size = 2002028 }, + { url = "https://files.pythonhosted.org/packages/cb/d5/7bb781bf2748ce3d03af04d5c969fa1308880e1dca35a9bd94e1a96a922e/pydantic_core-2.33.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:572c7e6c8bb4774d2ac88929e3d1f12bc45714ae5ee6d9a788a9fb35e60bb04b", size = 2100044 }, + { url = "https://files.pythonhosted.org/packages/fe/36/def5e53e1eb0ad896785702a5bbfd25eed546cdcf4087ad285021a90ed53/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:db4b41f9bd95fbe5acd76d89920336ba96f03e149097365afe1cb092fceb89a1", size = 2058881 }, + { url = "https://files.pythonhosted.org/packages/01/6c/57f8d70b2ee57fc3dc8b9610315949837fa8c11d86927b9bb044f8705419/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:fa854f5cf7e33842a892e5c73f45327760bc7bc516339fda888c75ae60edaeb6", size = 2227034 }, + { url = "https://files.pythonhosted.org/packages/27/b9/9c17f0396a82b3d5cbea4c24d742083422639e7bb1d5bf600e12cb176a13/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5f483cfb75ff703095c59e365360cb73e00185e01aaea067cd19acffd2ab20ea", size = 2234187 }, + { url = "https://files.pythonhosted.org/packages/b0/6a/adf5734ffd52bf86d865093ad70b2ce543415e0e356f6cacabbc0d9ad910/pydantic_core-2.33.2-cp312-cp312-win32.whl", hash = "sha256:9cb1da0f5a471435a7bc7e439b8a728e8b61e59784b2af70d7c169f8dd8ae290", size = 1892628 }, + { url = "https://files.pythonhosted.org/packages/43/e4/5479fecb3606c1368d496a825d8411e126133c41224c1e7238be58b87d7e/pydantic_core-2.33.2-cp312-cp312-win_amd64.whl", hash = "sha256:f941635f2a3d96b2973e867144fde513665c87f13fe0e193c158ac51bfaaa7b2", size = 1955866 }, + { url = "https://files.pythonhosted.org/packages/0d/24/8b11e8b3e2be9dd82df4b11408a67c61bb4dc4f8e11b5b0fc888b38118b5/pydantic_core-2.33.2-cp312-cp312-win_arm64.whl", hash = "sha256:cca3868ddfaccfbc4bfb1d608e2ccaaebe0ae628e1416aeb9c4d88c001bb45ab", size = 1888894 }, +] + [[package]] name = "pygments" version = "2.19.1" @@ -1295,6 +1393,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050 }, ] +[[package]] +name = "sniffio" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235 }, +] + [[package]] name = "stack-data" version = "0.6.3" @@ -1309,6 +1416,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695", size = 24521 }, ] +[[package]] +name = "starlette" +version = "0.46.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ce/20/08dfcd9c983f6a6f4a1000d934b9e6d626cff8d2eeb77a89a68eef20a2b7/starlette-0.46.2.tar.gz", hash = "sha256:7f7361f34eed179294600af672f565727419830b54b7b084efe44bb82d2fccd5", size = 2580846 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8b/0c/9d30a4ebeb6db2b25a841afbb80f6ef9a854fc3b41be131d249a977b4959/starlette-0.46.2-py3-none-any.whl", hash = "sha256:595633ce89f8ffa71a015caed34a5b2dc1c0cdb3f0f1fbd1e69339cf2abeec35", size = 72037 }, +] + [[package]] name = "sympy" version = "1.13.1" @@ -1381,6 +1500,7 @@ version = "0.1.0" source = { virtual = "." } dependencies = [ { name = "clip" }, + { name = "fastapi" }, { name = "imgcat" }, { name = "ipdb" }, { name = "mobileclip" }, @@ -1394,11 +1514,13 @@ dependencies = [ { name = "tensorrt" }, { name = "tritonclient", extra = ["all"] }, { name = "ultralytics" }, + { name = "uvicorn" }, ] [package.metadata] requires-dist = [ { name = "clip", git = "https://github.com/ultralytics/CLIP.git" }, + { name = "fastapi", specifier = ">=0.115.12" }, { name = "imgcat", specifier = ">=0.6.0" }, { name = "ipdb", specifier = ">=0.13.13" }, { name = "mobileclip", git = "https://github.com/THU-MIG/yoloe.git?subdirectory=third_party%2Fml-mobileclip" }, @@ -1412,6 +1534,7 @@ requires-dist = [ { name = "tensorrt", specifier = ">=10.9.0.34" }, { name = "tritonclient", extras = ["all"], specifier = ">=2.56.0" }, { name = "ultralytics", specifier = ">=8.3.112" }, + { name = "uvicorn", specifier = ">=0.34.2" }, ] [[package]] @@ -1530,6 +1653,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8b/54/b1ae86c0973cc6f0210b53d508ca3641fb6d0c56823f288d108bc7ab3cc8/typing_extensions-4.13.2-py3-none-any.whl", hash = "sha256:a439e7c04b49fec3e5d3e2beaa21755cadbbdc391694e28ccdd36ca4a1408f8c", size = 45806 }, ] +[[package]] +name = "typing-inspection" +version = "0.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/82/5c/e6082df02e215b846b4b8c0b887a64d7d08ffaba30605502639d44c06b82/typing_inspection-0.4.0.tar.gz", hash = "sha256:9765c87de36671694a67904bf2c96e395be9c6439bb6c87b5142569dcdd65122", size = 76222 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/31/08/aa4fdfb71f7de5176385bd9e90852eaf6b5d622735020ad600f2bab54385/typing_inspection-0.4.0-py3-none-any.whl", hash = "sha256:50e72559fcd2a6367a19f7a7e610e6afcb9fac940c650290eed893d61386832f", size = 14125 }, +] + [[package]] name = "tzdata" version = "2025.2" @@ -1587,6 +1722,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6b/11/cc635220681e93a0183390e26485430ca2c7b5f9d33b15c74c2861cb8091/urllib3-2.4.0-py3-none-any.whl", hash = "sha256:4e16665048960a0900c702d4a66415956a584919c03361cac9f1df5c5dd7e813", size = 128680 }, ] +[[package]] +name = "uvicorn" +version = "0.34.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a6/ae/9bbb19b9e1c450cf9ecaef06463e40234d98d95bf572fab11b4f19ae5ded/uvicorn-0.34.2.tar.gz", hash = "sha256:0e929828f6186353a80b58ea719861d2629d766293b6d19baf086ba31d4f3328", size = 76815 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b1/4b/4cef6ce21a2aaca9d852a6e84ef4f135d99fcd74fa75105e2fc0c8308acd/uvicorn-0.34.2-py3-none-any.whl", hash = "sha256:deb49af569084536d269fe0a6d67e3754f104cf03aba7c11c40f01aadf33c403", size = 62483 }, +] + [[package]] name = "wcwidth" version = "0.2.13" From f28e31c28e071c644252a96332e420dd1eb5de47 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Wed, 30 Apr 2025 12:46:52 -0700 Subject: [PATCH 12/24] remvoe fork thing --- .../tool_use_aokvqa_env/tool_use_aok_train.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py index 9f044852..de7c27f9 100644 --- a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py +++ b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py @@ -1,4 +1,3 @@ -import multiprocessing import os import torch @@ -164,16 +163,6 @@ def train(): if __name__ == "__main__": - # --- Set Multiprocessing Start Method to FORKSERVER --- - # Offers a potential speedup over 'spawn' while aiming for better - # isolation than 'fork' for CUDA. Still experimental here. - try: - multiprocessing.set_start_method("forkserver", force=True) - print("Multiprocessing start method set to 'forkserver'.") - except RuntimeError as e: - print(f"Multiprocessing start method already set or error setting it: {e}") - # --- End Set Start Method --- - train() # CUDA_VISIBLE_DEVICES=0,1,2,3 uv run accelerate launch --config_file src/r1_vlm/deepspeed_configs/multi_gpu_3only.yaml src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py From c471d8b914e1d5e759a886485b0c3feace98fb5c Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Wed, 30 Apr 2025 13:33:44 -0700 Subject: [PATCH 13/24] shuffle order of tools in system prompt --- .../tool_use_aokvqa_env.py | 4 ++++ src/r1_vlm/environments/tool_vision_env.py | 23 +++++++++++++------ 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py index 9a1a37a4..a4f5b64f 100644 --- a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py +++ b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py @@ -313,3 +313,7 @@ def correct_answer_reward_func( if __name__ == "__main__": env = AOKVQAToolEnv(processing_class=None) train_dataset, val_dataset, test_dataset = env.get_dataset() + import ipdb + + ipdb.set_trace() + print("hi") diff --git a/src/r1_vlm/environments/tool_vision_env.py b/src/r1_vlm/environments/tool_vision_env.py index f309a2e4..34d3472f 100644 --- a/src/r1_vlm/environments/tool_vision_env.py +++ b/src/r1_vlm/environments/tool_vision_env.py @@ -1,4 +1,5 @@ import inspect +import random import traceback from typing import Any, Callable, Dict, List @@ -147,12 +148,8 @@ def __init__( # Schema inference still uses the tool function's signature/docstring self.tool_schemas.append(infer_schema_from_function(tool_func)) - # Format the system prompt with tool descriptions - tool_descriptions = format_tool_descriptions(self.tool_schemas) - formatted_prompt = tool_prompt_template.format( - tool_descriptions=tool_descriptions - ) - self.formatted_prompt = formatted_prompt + # Store the template for dynamic formatting later + self.tool_prompt_template = tool_prompt_template # Set the general parser (use internal default if none provided) self.general_parser = general_parser or self._general_parse_key_value @@ -188,11 +185,23 @@ def _inject_prompt(examples): if not messages or messages[0]["role"] != "system": raise ValueError("Expected first message to be a system message") + # Create a shuffled copy of tool schemas for this sample + shuffled_schemas = self.tool_schemas[:] # Create a copy + random.shuffle(shuffled_schemas) + + # Format tool descriptions with the shuffled order + tool_descriptions = format_tool_descriptions(shuffled_schemas) + + # Format the prompt template with the randomized descriptions + formatted_prompt = self.tool_prompt_template.format( + tool_descriptions=tool_descriptions + ) + # Replace the content of the system message with the formatted prompt messages[0]["content"] = [ { "type": "text", - "text": self.formatted_prompt, + "text": formatted_prompt, } ] From a74cd682a5a9f8e768531264bef8bb690e50ed63 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Wed, 30 Apr 2025 15:10:07 -0700 Subject: [PATCH 14/24] log metrics callback --- .../environments/multistep_vision_env.py | 87 ++++++++++++------- src/r1_vlm/environments/simple_vision_env.py | 12 ++- .../tool_use_aokvqa_env.py | 39 +++++++++ 3 files changed, 103 insertions(+), 35 deletions(-) diff --git a/src/r1_vlm/environments/multistep_vision_env.py b/src/r1_vlm/environments/multistep_vision_env.py index 24dc3ec5..cb93a967 100644 --- a/src/r1_vlm/environments/multistep_vision_env.py +++ b/src/r1_vlm/environments/multistep_vision_env.py @@ -57,7 +57,7 @@ def is_completed(self, messages: list[dict[str, str]], **kwargs: Any) -> bool: @abstractmethod def env_response( self, messages: list[dict[str, str]], **kwargs: Any - ) -> list[dict[str, Any]]: + ) -> list[dict[str, Any]]: pass def prepare_data(self, *, inputs, processing_class): @@ -96,12 +96,12 @@ def update_state(j, vlm_response): state["prompt_ids"] = vlm_response.prompt_token_ids # update the conversation with the model's response - state["messages"].append({ - "role": "assistant", - "content": [ - {"text": vlm_response.outputs[0].text, "type": "text"} - ] - }) + state["messages"].append( + { + "role": "assistant", + "content": [{"text": vlm_response.outputs[0].text, "type": "text"}], + } + ) # get token lengths of env response and new completion total_prev_len = len(state["prompt_ids"]) + len(state["completion_ids"]) @@ -120,7 +120,7 @@ def update_state(j, vlm_response): ] # if we are done, we mark the state as completed - # we do not want to truncate the completion ids here, + # we do not want to truncate the completion ids here, # because the number of image tokens returned from the tools is variable if ( self.is_completed(state["messages"]) @@ -153,7 +153,7 @@ def update_state(j, vlm_response): for j, state in results: states[j] = state - + return states def generate( @@ -176,7 +176,7 @@ def generate( } for conversation in conversations ] - + # main loop while not all_completed: states = self.step(states, vlm, custom_sp) @@ -190,8 +190,7 @@ def generate( "messages": completion_messages, "mask": completion_mask, } - - + def clean_messages_for_logging(messages): cleaned = [] images = [] @@ -201,7 +200,10 @@ def clean_messages_for_logging(messages): cleaned_content = [] for item in cleaned_message["content"]: cleaned_item = item.copy() - if "image" in cleaned_item and cleaned_item["image"] is not None: + if ( + "image" in cleaned_item + and cleaned_item["image"] is not None + ): images.append(cleaned_item["image"]) cleaned_item["image"] = "" cleaned_content.append(cleaned_item) @@ -212,28 +214,30 @@ def clean_messages_for_logging(messages): cleaned_messages, images = clean_messages_for_logging(states[0]["messages"]) self.logger.info( - "Full conversation 0:\n" - + json.dumps(cleaned_messages, indent=4) + "Full conversation 0:\n" + json.dumps(cleaned_messages, indent=4) ) for image in images: imgcat.imgcat(image) - + return output - + @staticmethod - def preprocess_messages(prompts_messages: list[list[dict[str, Any]]], completions_messages: list[list[dict[str, Any]]]) -> list[list[dict[str, Any]]]: - ''' + def preprocess_messages( + prompts_messages: list[list[dict[str, Any]]], + completions_messages: list[list[dict[str, Any]]], + ) -> list[list[dict[str, Any]]]: + """ 1. Combines prompts and completion messages into full conversations 2. Removes all messages before the first assistant message, leaving only the completion 3. Merges elements of the completion that come from the same source and are text only - + Args: prompts: list of prompt conversations completions_messages: list of completion conversations - + Returns: list of preprocessed completion conversations - ''' + """ # Combine prompts and completions into full conversations combined_messages = [] for prompt_msgs, completion_msgs in zip(prompts_messages, completions_messages): @@ -241,22 +245,29 @@ def preprocess_messages(prompts_messages: list[list[dict[str, Any]]], completion conversation.extend(prompt_msgs) conversation.extend(completion_msgs) combined_messages.append(conversation) - + filtered_messages = [] for completion in combined_messages: # find the index of the first assistant message - assistant_message_index = next((i for i, message in enumerate(completion) if message["role"] == "assistant"), None) - + assistant_message_index = next( + ( + i + for i, message in enumerate(completion) + if message["role"] == "assistant" + ), + None, + ) + if assistant_message_index is not None: # keep only messages from the first assistant message onwards filtered_messages.append(completion[assistant_message_index:]) - + merged_completions = [] - + for completion in filtered_messages: merged_completion = [] current_message = None - + for message in completion: # If message has non-text content, add it as is if any(item["type"] != "text" for item in message["content"]): @@ -265,7 +276,7 @@ def preprocess_messages(prompts_messages: list[list[dict[str, Any]]], completion current_message = None merged_completion.append(message) continue - + # For text messages if current_message and current_message["role"] == message["role"]: # Merge text content @@ -277,11 +288,21 @@ def preprocess_messages(prompts_messages: list[list[dict[str, Any]]], completion merged_completion.append(current_message) current_message = { "role": message["role"], - "content": [{"type": "text", "text": message["content"][0]["text"]}] + "content": [ + {"type": "text", "text": message["content"][0]["text"]} + ], } - + if current_message: merged_completion.append(current_message) merged_completions.append(merged_completion) - - return merged_completions \ No newline at end of file + + return merged_completions + + def log_metrics(self, data): + """ + Callback for logging metrics. Can be implemented by subclasses. + + Should return a dictionary of metrics (key = metric name, value = metric value) + """ + return {} diff --git a/src/r1_vlm/environments/simple_vision_env.py b/src/r1_vlm/environments/simple_vision_env.py index 7e269bd6..5f1465a9 100644 --- a/src/r1_vlm/environments/simple_vision_env.py +++ b/src/r1_vlm/environments/simple_vision_env.py @@ -4,12 +4,12 @@ import imgcat from qwen_vl_utils import process_vision_info +from verifiers import SimpleEnv from vllm import LLM, SamplingParams # type: ignore from r1_vlm.budget_forcing.budget_forcing import ( generate_completions_with_budget_forcing, ) -from verifiers import SimpleEnv class SimpleVisionEnv(SimpleEnv): @@ -93,7 +93,7 @@ def generate( completions = vlm.generate( vlm_inputs, sampling_params=custom_sp, use_tqdm=False ) # type: ignore - + stop_reasons = [c.outputs[0].stop_reason for c in completions] print(f"Stop reasons: {stop_reasons}") @@ -166,6 +166,14 @@ def prepare_data(self, *, inputs, processing_class): return conversations, texts, batch, vllm_inputs + def log_metrics(self, data): + """ + Callback for logging metrics. Can be implemented by subclasses. + + Should return a dictionary of metrics (key = metric name, value = metric value) + """ + return {} + def prepare_inputs_for_env(*, inputs, processing_class): """ diff --git a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py index a4f5b64f..f1f13768 100644 --- a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py +++ b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py @@ -1,3 +1,4 @@ +import re from typing import Any, Callable from datasets import Dataset @@ -309,6 +310,44 @@ def correct_answer_reward_func( correct_answer_reward_func, ] + def log_metrics(self, conversations, completions_text, completion_messages): + # 1. compute how many completions attempt to use any tool + # 2. for each tool, compute how many completions attempt to use it + + completions_with_tool_use = 0 + completions_with_zoom_use = 0 + completions_with_detect_objects_use = 0 + + for completion in completions_text: + print(f"HERE: {completion}") + tool_use_regex = r".*.*.*" + zoom_use_string = "name: zoom" + detect_objects_use_string = "name: detect_objects" + + if re.search(tool_use_regex, completion, re.DOTALL): + completions_with_tool_use += 1 + if zoom_use_string in completion: + completions_with_zoom_use += 1 + if detect_objects_use_string in completion: + completions_with_detect_objects_use += 1 + + print( + f"There are {len(completions_text)} completions, {completions_with_tool_use} of which attempt to use a tool, {completions_with_zoom_use} of which attempt to use zoom, and {completions_with_detect_objects_use} of which attempt to use detect_objects" + ) + + num_completions = len(completions_text) + tool_use_proportion = completions_with_tool_use / num_completions + zoom_use_proportion = completions_with_zoom_use / num_completions + detect_objects_use_proportion = ( + completions_with_detect_objects_use / num_completions + ) + + return { + "tool_use_proportion": tool_use_proportion, + "zoom_use_proportion": zoom_use_proportion, + "detect_objects_use_proportion": detect_objects_use_proportion, + } + if __name__ == "__main__": env = AOKVQAToolEnv(processing_class=None) From c2c702f15af772d422c22641e8c8429d6d87fb62 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Wed, 30 Apr 2025 15:13:27 -0700 Subject: [PATCH 15/24] reset the schedule --- .../environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py index f1f13768..a430ad2b 100644 --- a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py +++ b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py @@ -11,6 +11,7 @@ ) from r1_vlm.datasets.utils import preprocess_r1_dataset from r1_vlm.environments.multistep_vision_env import MultistepVisionEnv +from r1_vlm.environments.reward_schedules import create_linear_decay_schedule from r1_vlm.environments.tool_vision_env import ToolArgParser, ToolVisionEnv from r1_vlm.tools.object_detection import ( ObjectDetectionTool, @@ -115,8 +116,7 @@ def get_reward_weights(self) -> list[float]: reward_weights.append(schedule) elif reward_function.__name__ == "tool_execution_reward_func": # linearly decay from 1.0 to 0.0 over 200 global steps (200 gradient updates) - # schedule = create_linear_decay_schedule(1.0, 0.0, 200) - schedule = 0.0 # restarting the run, past step 200 + schedule = create_linear_decay_schedule(1.0, 0.0, 200) reward_weights.append(schedule) elif reward_function.__name__ == "correct_answer_reward_func": From 77df0b6720fa280b44cfcaf830b96be6f6ed5af2 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Wed, 30 Apr 2025 15:18:05 -0700 Subject: [PATCH 16/24] ready to start training again --- .../tool_use_aokvqa_env/tool_use_aok_train.py | 8 ++------ .../tool_use_aokvqa_env/tool_use_aokvqa_env.py | 1 - 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py index de7c27f9..9107b868 100644 --- a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py +++ b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py @@ -88,12 +88,8 @@ def find_target_linear_names( def train(): - checkpoint = "/millcreek/home/sunil/r1_vlm/vlm-r1-new-od-tool-reward-schedule-for-tools-apr-29-restart/checkpoint-50" - model, peft_config, processor, model_config, gradient_checkpointing = ( - load_model_and_processor( - model_name_or_path=checkpoint, gradient_checkpointing=True, use_peft=False - ) + load_model_and_processor(gradient_checkpointing=True, use_peft=False) ) print("loaded model") @@ -110,7 +106,7 @@ def train(): training_args = GRPOConfig( model_init_kwargs=model_config, # save path on the runpod instance - output_dir="vlm-r1-new-od-tool-reward-schedule-for-tools-apr-29-restart-2", + output_dir="vlm-r1-od-tool-fixed-reward-schedule-for-tools-apr-30", # increase learning rate for PEFT - 1e-4 learning_rate=1e-4 if peft_config is not None else 1e-6, max_grad_norm=1.0, diff --git a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py index a430ad2b..e36866c2 100644 --- a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py +++ b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py @@ -319,7 +319,6 @@ def log_metrics(self, conversations, completions_text, completion_messages): completions_with_detect_objects_use = 0 for completion in completions_text: - print(f"HERE: {completion}") tool_use_regex = r".*.*.*" zoom_use_string = "name: zoom" detect_objects_use_string = "name: detect_objects" From 56f804b465a7b0770915adf6081930a3f4d2f0bd Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Wed, 30 Apr 2025 15:21:27 -0700 Subject: [PATCH 17/24] more robust way of catching --- .../tool_use_aokvqa_env/tool_use_aokvqa_env.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py index e36866c2..d8b8ac46 100644 --- a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py +++ b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py @@ -319,16 +319,18 @@ def log_metrics(self, conversations, completions_text, completion_messages): completions_with_detect_objects_use = 0 for completion in completions_text: - tool_use_regex = r".*.*.*" + tool_use_regex = r"(.*?)" zoom_use_string = "name: zoom" detect_objects_use_string = "name: detect_objects" - if re.search(tool_use_regex, completion, re.DOTALL): + tool_matches = re.findall(tool_use_regex, completion, re.DOTALL) + if tool_matches: completions_with_tool_use += 1 - if zoom_use_string in completion: - completions_with_zoom_use += 1 - if detect_objects_use_string in completion: - completions_with_detect_objects_use += 1 + for tool_content in tool_matches: + if zoom_use_string in tool_content: + completions_with_zoom_use += 1 + if detect_objects_use_string in tool_content: + completions_with_detect_objects_use += 1 print( f"There are {len(completions_text)} completions, {completions_with_tool_use} of which attempt to use a tool, {completions_with_zoom_use} of which attempt to use zoom, and {completions_with_detect_objects_use} of which attempt to use detect_objects" From be43b071fe1b2a6d9c59f904a923bf7e1ebf137f Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Thu, 1 May 2025 10:26:19 -0700 Subject: [PATCH 18/24] generalized the eval script --- .../environments/tool_use_aokvqa_env/eval.py | 172 +++++++++++------- 1 file changed, 104 insertions(+), 68 deletions(-) diff --git a/src/r1_vlm/environments/tool_use_aokvqa_env/eval.py b/src/r1_vlm/environments/tool_use_aokvqa_env/eval.py index 635930e1..55ed613f 100644 --- a/src/r1_vlm/environments/tool_use_aokvqa_env/eval.py +++ b/src/r1_vlm/environments/tool_use_aokvqa_env/eval.py @@ -2,6 +2,7 @@ import os import re +from datasets import Dataset from imgcat import imgcat from tqdm import tqdm from transformers import AutoProcessor @@ -19,78 +20,78 @@ def extract_answer(generation: str): return None -def main(): - checkpoint = "/millcreek/home/sunil/r1_vlm/vlm-r1-zoom-only-reward-refactor-oversampling/checkpoint-700" - processor = AutoProcessor.from_pretrained(checkpoint, padding_side="left") - vf_env = AOKVQAToolEnv(processing_class=processor) - train_dataset, val_dataset, test_dataset = vf_env.get_dataset() - - if not os.path.exists("generations.json"): - vlm = LLM( - model=checkpoint, - gpu_memory_utilization=1.0, - dtype="bfloat16", - tensor_parallel_size=2, - enable_prefix_caching=True, - limit_mm_per_prompt={"image": 2, "video": 0}, - ) +def generate_completions( + checkpoint_path: str, file_path: str, dataset: Dataset, env, processor +): + """ + Generate completions given a checkpoint and a file path to save the generations + """ + if os.path.exists(file_path): + raise ValueError(f"File {file_path} already exists") + + vlm = LLM( + model=checkpoint_path, + gpu_memory_utilization=1.0, + dtype="bfloat16", + tensor_parallel_size=4, + enable_prefix_caching=True, + limit_mm_per_prompt={"image": 2, "video": 0}, + ) + + sampling_params = SamplingParams( + temperature=0.1, + max_tokens=2048, + ) + + batch_size = 24 + batches = [] + + for example in dataset: + if len(batches) == 0: + batches.append([example]) + elif len(batches[-1]) < batch_size: + batches[-1].append(example) + else: + batches.append([example]) - sampling_params = SamplingParams( - temperature=0.1, - max_tokens=2048, + generations = [] + for batch in tqdm(batches, desc="Generating completions"): + conversations, texts, processed_batch, vllm_inputs = env.prepare_data( + inputs=batch, processing_class=processor ) - batch_size = 6 - batches = [] - - for example in val_dataset: - if len(batches) == 0: - batches.append([example]) - elif len(batches[-1]) < batch_size: - batches[-1].append(example) - else: - batches.append([example]) - - generations = [] - for batch in tqdm(batches, desc="Generating completions"): - conversations, texts, processed_batch, vllm_inputs = vf_env.prepare_data( - inputs=batch, processing_class=processor - ) - - completion_ids = vf_env.generate( - conversations=conversations, - vlm_inputs=vllm_inputs, - vlm=vlm, - sampling_params=sampling_params, - ) + completion_ids = env.generate( + conversations=conversations, + vlm_inputs=vllm_inputs, + vlm=vlm, + sampling_params=sampling_params, + ) - generated_texts = processor.batch_decode( - completion_ids["ids"], - skip_special_tokens=False, - clean_up_tokenization_spaces=False, - ) + generated_texts = processor.batch_decode( + completion_ids["ids"], + skip_special_tokens=False, + clean_up_tokenization_spaces=False, + ) - print(generated_texts) + for example, generation in zip(batch, generated_texts): + data = { + "question_id": example["question_id"], + "question": example["question"], + "options": example["choices"], + "rationales": example["rationales"], + "gt_answer": example["multiple_choice_answer"], + "generation": generation, + "model_answer": extract_answer(generation), + } + generations.append(data) - for example, generation in zip(batch, generated_texts): - data = { - "question_id": example["question_id"], - "question": example["question"], - "options": example["choices"], - "rationales": example["rationales"], - "gt_answer": example["multiple_choice_answer"], - "generation": generation, - "model_answer": extract_answer(generation), - } - generations.append(data) + with open(file_path, "w") as f: + json.dump(generations, f, indent=2) - # Save the generations list as a JSON array to a file - with open("generations.json", "w") as f: - json.dump(generations, f, indent=2) # Use indent for readability (optional) - else: - with open("generations.json", "r") as f: - generations = json.load(f) +def evaluate(generations_dict: dict, dataset: Dataset): + with open(generations_dict, "r") as f: + generations = json.load(f) generations_dict = {} for generation in generations: @@ -101,7 +102,8 @@ def main(): total = 0 correct = 0 in_option_set = 0 - for example in val_dataset: + + for example in dataset: question_id = example["question_id"] if question_id not in generations_dict: @@ -131,9 +133,43 @@ def main(): imgcat(example["image"]) print("--------------------------------") - print(f"Accuracy: {correct / total}") - print(f"In option set: {in_option_set / total}") + results = { + "accuracy": correct / total, + "in_option_set": in_option_set / total, + } + + print(f"Accuracy: {results['accuracy']}") + print(f"In option set: {results['in_option_set']}") + + return results if __name__ == "__main__": - main() + checkpoints_folder = "/millcreek/home/sunil/r1_vlm/vlm-r1-od-tool-fixed-reward-schedule-for-tools-apr-30" + + checkpoint_paths = [ + os.path.join(checkpoints_folder, f) + for f in os.listdir(checkpoints_folder) + if os.path.isdir(os.path.join(checkpoints_folder, f)) + ] + + processor = AutoProcessor.from_pretrained(checkpoint_paths[0], padding_side="left") + env = AOKVQAToolEnv(processing_class=processor) + train_dataset, val_dataset, test_dataset = env.get_dataset() + + results_dict = {} + + # we'll save evaluations to the same folder as the checkpoints + for checkpoint_path in checkpoint_paths: + file_path = os.path.join( + checkpoints_folder, f"{checkpoint_path}_generations.json" + ) + if not os.path.exists(file_path): + generate_completions( + checkpoint_path, file_path, val_dataset, env, processor + ) + + results = evaluate(file_path, val_dataset) + results_dict[checkpoint_path] = results + + print(results_dict) From bd99398604158f322f28c6649c0ca5fdbea48548 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Thu, 1 May 2025 11:46:05 -0700 Subject: [PATCH 19/24] better eval script --- .../environments/tool_use_aokvqa_env/eval.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/r1_vlm/environments/tool_use_aokvqa_env/eval.py b/src/r1_vlm/environments/tool_use_aokvqa_env/eval.py index 55ed613f..76ca7032 100644 --- a/src/r1_vlm/environments/tool_use_aokvqa_env/eval.py +++ b/src/r1_vlm/environments/tool_use_aokvqa_env/eval.py @@ -1,6 +1,7 @@ import json import os import re +from copy import deepcopy from datasets import Dataset from imgcat import imgcat @@ -153,6 +154,14 @@ def evaluate(generations_dict: dict, dataset: Dataset): if os.path.isdir(os.path.join(checkpoints_folder, f)) ] + checkpoints_to_eval = ["150", "350", "600"] + + checkpoint_paths = [ + path + for path in checkpoint_paths + if any(num in path for num in checkpoints_to_eval) + ] + processor = AutoProcessor.from_pretrained(checkpoint_paths[0], padding_side="left") env = AOKVQAToolEnv(processing_class=processor) train_dataset, val_dataset, test_dataset = env.get_dataset() @@ -166,10 +175,12 @@ def evaluate(generations_dict: dict, dataset: Dataset): ) if not os.path.exists(file_path): generate_completions( - checkpoint_path, file_path, val_dataset, env, processor + checkpoint_path, file_path, deepcopy(val_dataset), env, processor ) + else: + print(f"Skipping {checkpoint_path} because it already exists") - results = evaluate(file_path, val_dataset) + results = evaluate(file_path, deepcopy(val_dataset)) results_dict[checkpoint_path] = results print(results_dict) From 9eab93e4fa8f2e18b611ac31626c1f1ff3db1147 Mon Sep 17 00:00:00 2001 From: ROIM1998 Date: Thu, 1 May 2025 16:00:45 -0700 Subject: [PATCH 20/24] implement the new combined correctness-and-tool-use reward --- .../tool_use_aokvqa_env.py | 83 +++++++++++++++++-- 1 file changed, 78 insertions(+), 5 deletions(-) diff --git a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py index d8b8ac46..98ede3ee 100644 --- a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py +++ b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py @@ -38,6 +38,7 @@ def __init__( ], max_steps: int = 3, tool_prompt_template: str = SINGLE_OPTIONAL_TOOL_PROMPT_TEMPLATE, + use_combined_tool_correctness_reward: bool = False, ): super().__init__( processing_class=processing_class, @@ -53,6 +54,7 @@ def __init__( ("answer", ["answer"]), ("tool", ["tool"]), ] + self.use_combined_tool_correctness_reward = use_combined_tool_correctness_reward def parse(self, text: str, strip: bool = True): return self.parser.parse(text, strip=strip) @@ -123,6 +125,10 @@ def get_reward_weights(self) -> list[float]: # consistent high reward for getting the answer right schedule = 1.0 reward_weights.append(schedule) + elif reward_function.__name__ == "combined_tool_correctness_reward_func": + # consistent high reward for getting the answer right + schedule = 1.0 + reward_weights.append(schedule) else: raise ValueError( f"Unknown reward function: {reward_function.__name__} encountered in get_reward_weights" @@ -233,6 +239,17 @@ def check_format(trajectory): return [check_format(m) for m in merged_completion_conversations] + def check_tool_use_attempt(conversation) -> bool: + """ + Returns True if the model attempts to use any tool. + """ + for i, message in enumerate(conversation): + if message["role"] == "assistant": + parsed = self.parser.parse(message["content"][0]["text"]) + if hasattr(parsed, "tool") and parsed.tool is not None: + return True + return False + def check_execution(conversation): """ Returns the ratio of successful tool executions to total attempts. @@ -304,11 +321,67 @@ def correct_answer_reward_func( return [1.0 if result else 0.0 for result in correctness_results] - return [ - format_reward_func, - tool_execution_reward_func, - correct_answer_reward_func, - ] + def combined_tool_correctness_reward_func( + prompts, completions, completions_messages, **kwargs + ) -> list[float]: + """ + Reward function that checks if tools were executed successfully only if tool use is necessary to answer the question. + """ + merged_completion_conversations = MultistepVisionEnv.preprocess_messages( + prompts_messages=prompts, completions_messages=completions_messages + ) + + # For each response sampled, check if the completion has the correct answer + correct_answers = kwargs["multiple_choice_answer"] + correctness_results: list[bool] = [ + check_correctness(conv, correct_answer) + for conv, correct_answer in zip( + merged_completion_conversations, correct_answers + ) + ] + # For each response sampled, check if the any tool use is correct + tool_use_correctness: list[bool] = [ + check_execution(conv) > 0.0 for conv in merged_completion_conversations + ] + # For each response sampled, check if the model attempts to use any tool + tool_use_attempts: list[bool] = [ + check_tool_use_attempt(conv) for conv in merged_completion_conversations + ] + + # For all responses sampled, check if there is a completion that has the correct answer successfully using a tool + correct_with_tool = any(correctness_results[i] and tool_use_correctness[i] for i in range(len(correctness_results))) + # For all responses sampled, check if there is a completion that has the correct answer without successfully using a tool + correct_without_tool = any(correctness_results[i] and not tool_use_correctness[i] for i in range(len(correctness_results))) + + if correct_without_tool: + # If the question is answerable without using a tool, the model will be penalized for using a tool + rewards = [ + 0.0 if not correctness_results[i] else + (0.5 if tool_use_attempts[i] else 1.0) + for i in range(len(correctness_results)) + ] + elif correct_with_tool: + # The model is only rewarded if the tool use used correctly, AND the answer is correct + rewards = [ + 1.0 if correctness_results[i] and tool_use_correctness[i] else 0.0 + for i in range(len(correctness_results)) + ] + else: + # The model is not rewarded for any incorrect responses + rewards = [0.0 for _ in range(len(correctness_results))] + return rewards + + if self.use_combined_tool_correctness_reward: + return [ + format_reward_func, + combined_tool_correctness_reward_func, + ] + else: + return [ + format_reward_func, + tool_execution_reward_func, + correct_answer_reward_func, + ] def log_metrics(self, conversations, completions_text, completion_messages): # 1. compute how many completions attempt to use any tool From 1cbdd10367c4a4e83f72201fb2d74a315a834274 Mon Sep 17 00:00:00 2001 From: ROIM1998 Date: Thu, 1 May 2025 17:01:44 -0700 Subject: [PATCH 21/24] always return all 4 rewards but setting the schedules differently based on the flag --- .../tool_use_aokvqa_env.py | 23 ++++++++----------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py index 98ede3ee..8ba1e84f 100644 --- a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py +++ b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py @@ -118,16 +118,16 @@ def get_reward_weights(self) -> list[float]: reward_weights.append(schedule) elif reward_function.__name__ == "tool_execution_reward_func": # linearly decay from 1.0 to 0.0 over 200 global steps (200 gradient updates) - schedule = create_linear_decay_schedule(1.0, 0.0, 200) + schedule = create_linear_decay_schedule(1.0, 0.0, 200) if not self.use_combined_tool_correctness_reward else 0.0 reward_weights.append(schedule) elif reward_function.__name__ == "correct_answer_reward_func": # consistent high reward for getting the answer right - schedule = 1.0 + schedule = 1.0 if not self.use_combined_tool_correctness_reward else 0.0 reward_weights.append(schedule) elif reward_function.__name__ == "combined_tool_correctness_reward_func": # consistent high reward for getting the answer right - schedule = 1.0 + schedule = 1.0 if self.use_combined_tool_correctness_reward else 0.0 reward_weights.append(schedule) else: raise ValueError( @@ -371,17 +371,12 @@ def combined_tool_correctness_reward_func( rewards = [0.0 for _ in range(len(correctness_results))] return rewards - if self.use_combined_tool_correctness_reward: - return [ - format_reward_func, - combined_tool_correctness_reward_func, - ] - else: - return [ - format_reward_func, - tool_execution_reward_func, - correct_answer_reward_func, - ] + return [ + format_reward_func, + tool_execution_reward_func, + correct_answer_reward_func, + combined_tool_correctness_reward_func, + ] def log_metrics(self, conversations, completions_text, completion_messages): # 1. compute how many completions attempt to use any tool From 0ed84296f90c7027e727ca4c8de25ab281a3f996 Mon Sep 17 00:00:00 2001 From: ROIM1998 Date: Thu, 1 May 2025 17:11:28 -0700 Subject: [PATCH 22/24] add num_generations check to make sure the new reward sees the entire sampled group --- .../tool_use_aokvqa_env/tool_use_aok_train.py | 8 +++++++- .../tool_use_aokvqa_env/tool_use_aokvqa_env.py | 4 ++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py index 9107b868..7962e3a5 100644 --- a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py +++ b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py @@ -92,8 +92,14 @@ def train(): load_model_and_processor(gradient_checkpointing=True, use_peft=False) ) print("loaded model") + num_generations = 6 - vf_env = AOKVQAToolEnv(processing_class=processor, max_steps=3) + vf_env = AOKVQAToolEnv( + processing_class=processor, + max_steps=3, + num_generations=num_generations, + use_combined_tool_correctness_reward=True, + ) print("loaded env") diff --git a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py index 8ba1e84f..9a4c85a9 100644 --- a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py +++ b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py @@ -38,6 +38,7 @@ def __init__( ], max_steps: int = 3, tool_prompt_template: str = SINGLE_OPTIONAL_TOOL_PROMPT_TEMPLATE, + num_generations: int = 6, use_combined_tool_correctness_reward: bool = False, ): super().__init__( @@ -54,6 +55,7 @@ def __init__( ("answer", ["answer"]), ("tool", ["tool"]), ] + self.num_generations = num_generations self.use_combined_tool_correctness_reward = use_combined_tool_correctness_reward def parse(self, text: str, strip: bool = True): @@ -327,6 +329,8 @@ def combined_tool_correctness_reward_func( """ Reward function that checks if tools were executed successfully only if tool use is necessary to answer the question. """ + if self.num_generations != len(prompts) or self.num_generations != len(completions_messages): + raise ValueError(f"Expected num_generations to be equal to the number of prompts and completions, but got num_generations={self.num_generations}, len(prompts)={len(prompts)}, len(completions_messages)={len(completions_messages)}") merged_completion_conversations = MultistepVisionEnv.preprocess_messages( prompts_messages=prompts, completions_messages=completions_messages ) From 39f6e9e88827bd7afcbeb758c649ef9b851a0b14 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Thu, 1 May 2025 21:03:23 -0700 Subject: [PATCH 23/24] run name --- .../environments/tool_use_aokvqa_env/tool_use_aok_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py index 7962e3a5..4cafc653 100644 --- a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py +++ b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aok_train.py @@ -112,7 +112,7 @@ def train(): training_args = GRPOConfig( model_init_kwargs=model_config, # save path on the runpod instance - output_dir="vlm-r1-od-tool-fixed-reward-schedule-for-tools-apr-30", + output_dir="vlm-r1-new-fancy-tool-aligned-reward-may1", # increase learning rate for PEFT - 1e-4 learning_rate=1e-4 if peft_config is not None else 1e-6, max_grad_norm=1.0, From d9ddbcd4b57a63e5e315d5e68e4f73a72aa8b94f Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Thu, 1 May 2025 21:46:41 -0700 Subject: [PATCH 24/24] try adding a short term incentive to use tools with new fancy reward --- .../tool_use_aokvqa_env.py | 34 ++++++++++++++----- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py index 9a4c85a9..44085222 100644 --- a/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py +++ b/src/r1_vlm/environments/tool_use_aokvqa_env/tool_use_aokvqa_env.py @@ -19,7 +19,7 @@ parse_detect_objects_args, set_object_detection_tool, ) -from r1_vlm.tools.tool_prompts import SINGLE_OPTIONAL_TOOL_PROMPT_TEMPLATE +from r1_vlm.tools.tool_prompts import SINGLE_TOOL_PROMPT_TEMPLATE from r1_vlm.tools.zoom import parse_zoom_args, zoom # This is a global variable that is used to store the object detection tool. It is accessed by the detect_objects function. @@ -37,7 +37,7 @@ def __init__( (zoom, parse_zoom_args), ], max_steps: int = 3, - tool_prompt_template: str = SINGLE_OPTIONAL_TOOL_PROMPT_TEMPLATE, + tool_prompt_template: str = SINGLE_TOOL_PROMPT_TEMPLATE, num_generations: int = 6, use_combined_tool_correctness_reward: bool = False, ): @@ -120,7 +120,12 @@ def get_reward_weights(self) -> list[float]: reward_weights.append(schedule) elif reward_function.__name__ == "tool_execution_reward_func": # linearly decay from 1.0 to 0.0 over 200 global steps (200 gradient updates) - schedule = create_linear_decay_schedule(1.0, 0.0, 200) if not self.use_combined_tool_correctness_reward else 0.0 + schedule = ( + create_linear_decay_schedule(1.0, 0.0, 200) + if not self.use_combined_tool_correctness_reward + # quick burst of reward for tool use at the beginning to teach the model to use tools + else create_linear_decay_schedule(1.0, 0.0, 50) + ) reward_weights.append(schedule) elif reward_function.__name__ == "correct_answer_reward_func": @@ -329,8 +334,12 @@ def combined_tool_correctness_reward_func( """ Reward function that checks if tools were executed successfully only if tool use is necessary to answer the question. """ - if self.num_generations != len(prompts) or self.num_generations != len(completions_messages): - raise ValueError(f"Expected num_generations to be equal to the number of prompts and completions, but got num_generations={self.num_generations}, len(prompts)={len(prompts)}, len(completions_messages)={len(completions_messages)}") + if self.num_generations != len(prompts) or self.num_generations != len( + completions_messages + ): + raise ValueError( + f"Expected num_generations to be equal to the number of prompts and completions, but got num_generations={self.num_generations}, len(prompts)={len(prompts)}, len(completions_messages)={len(completions_messages)}" + ) merged_completion_conversations = MultistepVisionEnv.preprocess_messages( prompts_messages=prompts, completions_messages=completions_messages ) @@ -353,15 +362,22 @@ def combined_tool_correctness_reward_func( ] # For all responses sampled, check if there is a completion that has the correct answer successfully using a tool - correct_with_tool = any(correctness_results[i] and tool_use_correctness[i] for i in range(len(correctness_results))) + correct_with_tool = any( + correctness_results[i] and tool_use_correctness[i] + for i in range(len(correctness_results)) + ) # For all responses sampled, check if there is a completion that has the correct answer without successfully using a tool - correct_without_tool = any(correctness_results[i] and not tool_use_correctness[i] for i in range(len(correctness_results))) + correct_without_tool = any( + correctness_results[i] and not tool_use_correctness[i] + for i in range(len(correctness_results)) + ) if correct_without_tool: # If the question is answerable without using a tool, the model will be penalized for using a tool rewards = [ - 0.0 if not correctness_results[i] else - (0.5 if tool_use_attempts[i] else 1.0) + 0.0 + if not correctness_results[i] + else (0.5 if tool_use_attempts[i] else 1.0) for i in range(len(correctness_results)) ] elif correct_with_tool: