diff --git a/orby/data/analyse_uground.py b/orby/data/analyse_uground.py
new file mode 100644
index 00000000000..4f1b0fe2112
--- /dev/null
+++ b/orby/data/analyse_uground.py
@@ -0,0 +1,550 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Preprocess the Uground dataset to parquet format
+"""
+
+import argparse
+import io
+import json
+import os
+import math
+
+import datasets
+from datasets import Sequence
+from datasets import Image as ImageData
+from PIL import Image, ImageDraw
+from transformers import AutoProcessor
+from qwen_vl_utils import smart_resize
+from datasets import Dataset
+
+from verl.utils.hdfs_io import copy, makedirs
+from qwen_agent.llm.fncall_prompts.nous_fncall_prompt import (
+ NousFnCallPrompt,
+ Message,
+ ContentItem,
+)
+from orby.utils.dataset.qwen_agent_function_call import ComputerUse
+from orby.data.prompts import get_subtask_messages
+
+
+MODEL_PATH = "Qwen/Qwen2.5-VL-7B-Instruct"
+PROCESSOR = AutoProcessor.from_pretrained(MODEL_PATH, use_fast=True)
+
+def to_rgb(pil_image: Image.Image) -> Image.Image:
+ if pil_image.mode == 'RGBA':
+ white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
+ white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
+ return white_background
+ else:
+ return pil_image.convert("RGB")
+
+def get_resized_hw(image, max_pixels=None):
+ """
+ Get the resized width and height of the image.
+ """
+ # if max_pixels is not set, use the max pixels of the image processor
+ if not max_pixels:
+ print("Max pixels not set, using the max pixels of the image processor", flush=True)
+ max_pixels = PROCESSOR.image_processor.max_pixels
+
+ resized_height, resized_width = smart_resize(
+ height=image.height,
+ width=image.width,
+ factor=PROCESSOR.image_processor.patch_size
+ * PROCESSOR.image_processor.merge_size,
+ min_pixels=PROCESSOR.image_processor.min_pixels,
+ max_pixels=max_pixels,
+ )
+
+ return resized_height, resized_width
+
+
+def save_in_chunks(
+ all_data, output_dir, prefix, start_file_counter=0
+):
+ """Save processed data in multiple parquet files"""
+ os.makedirs(output_dir, exist_ok=True)
+
+ file_counter = start_file_counter
+
+ # If all_data is a single dataset, convert to list
+ if not isinstance(all_data, list):
+ all_data = [all_data]
+
+ # Process each dataset chunk immediately
+ for dataset_chunk in all_data:
+ if len(dataset_chunk) == 0:
+ continue
+
+ # Remove width and height columns if they exist
+ columns_to_remove = []
+ if "width" in dataset_chunk.column_names:
+ columns_to_remove.append("width")
+ if "height" in dataset_chunk.column_names:
+ columns_to_remove.append("height")
+
+ if columns_to_remove:
+ dataset_chunk = dataset_chunk.remove_columns(columns_to_remove)
+ print(f"Removed columns: {columns_to_remove}", flush=True)
+
+ # Save the chunk as-is (remove the splitting logic)
+ output_file = os.path.join(
+ output_dir, f"{prefix}_part_{file_counter:04d}.parquet"
+ )
+ dataset_chunk.to_parquet(output_file)
+ print(f"✓ Saved {len(dataset_chunk)} examples to {output_file}", flush=True)
+ file_counter += 1
+
+ return file_counter
+
+
+def process_in_chunks(streaming_dataset, chunk_size):
+ """Process streaming dataset in chunks with immediate saving capability"""
+ chunk = []
+ total_processed = 0
+
+
+ for i, example in enumerate(streaming_dataset):
+ if (
+ hasattr(process_in_chunks, "max_examples")
+ and total_processed >= process_in_chunks.max_examples
+ ):
+ break
+
+ chunk.append(example)
+
+ if len(chunk) >= chunk_size:
+ print(
+ f"Processing chunk {total_processed//chunk_size + 1}, examples {total_processed}-{total_processed + len(chunk)}",
+ flush=True,
+ )
+
+ # Convert chunk to Dataset for processing
+ chunk_dataset = Dataset.from_list(chunk)
+
+ # Process the chunk
+ processed_chunk = chunk_dataset.map(
+ function=process_in_chunks.map_fn,
+ with_indices=True,
+ num_proc=4, # Reduced from 16 to manage memory
+ )
+ processed_chunk = processed_chunk.cast_column(
+ "images", Sequence(ImageData())
+ )
+
+ yield processed_chunk, total_processed
+
+ total_processed += len(chunk)
+ chunk = []
+
+ # Process remaining examples
+ if chunk:
+ print(
+ f"Processing final chunk, examples {total_processed}-{total_processed + len(chunk)}",
+ flush=True,
+ )
+ chunk_dataset = Dataset.from_list(chunk)
+ processed_chunk = chunk_dataset.map(
+ function=process_in_chunks.map_fn, with_indices=True, num_proc=4
+ )
+ processed_chunk = processed_chunk.cast_column("images", Sequence(ImageData()))
+ yield processed_chunk, total_processed
+
+
+def check_shard_loading_pattern():
+ import os
+ from PIL import Image
+
+ data_source = "osunlp/UGround-V1-Data-Box"
+
+ # Create base directory for saving images
+ base_save_dir = "uground_images"
+ os.makedirs(base_save_dir, exist_ok=True)
+
+ # Iterate through 10 shards
+ for shard_idx in range(100):
+ data_files = f"shard_{shard_idx:04d}.parquet"
+ print(f"\n=== Processing {data_files} ===")
+
+ # Create shard-specific directory
+ shard_dir = os.path.join(base_save_dir, f"shard_{shard_idx:04d}")
+ os.makedirs(shard_dir, exist_ok=True)
+
+ try:
+ # Load dataset info to see available shards
+ dataset_info = datasets.get_dataset_config_info(data_source)
+ print(f"Dataset info: {dataset_info}")
+
+ # Load in streaming mode with detailed logging
+ dataset = datasets.load_dataset(
+ data_source,
+ data_files=data_files,
+ streaming=True,
+ download_mode="force_redownload"
+ )
+
+ # Check which files are being loaded
+ print(f"Dataset files: {dataset['train']._ex_iterable}")
+
+ # Sample first few examples to see which shard they come from
+ count = 0
+ total_x = 0
+ total_y = 0
+
+ import random
+
+ # Collect all examples first to enable random sampling
+ all_examples = []
+ for example in dataset['train']:
+ all_examples.append(example)
+
+
+ # Randomly sample 15 examples
+ sample_size = min(15, len(all_examples))
+ # sampled_examples = random.sample(all_examples, sample_size)
+ sampled_examples = all_examples[:sample_size]
+ for i, example in enumerate(sampled_examples):
+ # Extract and save image
+ image = example.pop("image")
+
+ image_filename = f"image_{i:06d}.jpg"
+ image_path = os.path.join(shard_dir, image_filename)
+
+ # Save image if it's a PIL Image
+ pil_image = Image.open(io.BytesIO(image))
+
+ pil_image = to_rgb(pil_image)
+ # Get the resized width and height of the pil_image.
+ resized_height, resized_width = get_resized_hw(pil_image, args.max_pixels)
+
+ pil_image = pil_image.resize((resized_width, resized_height))
+
+
+
+ conversation = example.pop("conversations").strip()
+ # Use the first message for now. Uground has multiple grounding
+ # commands / groundtruths in the conversation.
+ command, label = json.loads(conversation)[:2]
+ assert command["from"] == "human" and label["from"] == "gpt"
+ instruction = command["value"]
+ label_text = label["value"]
+
+ # Parse the label text as "(x1, y1, x2, y2)" format
+ label_text = label_text.strip("()")
+ bbox = [int(x.strip()) for x in label_text.split(",")]
+ assert len(bbox) == 4, f"Expected 4 coordinates, got {len(bbox)}"
+ bbox = [
+ bbox[0] * resized_width / 1000.0,
+ bbox[1] * resized_height / 1000.0,
+ bbox[2] * resized_width / 1000.0,
+ bbox[3] * resized_height / 1000.0,
+ ]
+ print(f"Bounding box: {bbox}")
+ print(f"Instruction: {instruction}")
+ draw = ImageDraw.Draw(pil_image)
+
+
+ x1, y1, x2, y2 = bbox
+ draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
+ # Draw GT center point
+ center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
+ # draw.ellipse([center_x-3, center_y-3, center_x+3, center_y+3], fill="red")
+ # draw.text((center_x+10, center_y-10), "GT", fill="red")
+
+ if instruction:
+ # Position text outside the bounding box and within image bounds
+ # Try positioning above the bbox first
+ text_x = max(10, min(x1, resized_width - 200)) # Keep text within image width
+ text_y = max(10, y1 - 50) # Position above bbox, but not above image
+
+ # If there's not enough space above, position below the bbox
+ if text_y < 10:
+ text_y = min(y2 + 10, resized_height - 30) # Position below bbox, within image
+
+ # Draw text background for better visibility
+ text_bbox = draw.textbbox((text_x, text_y), instruction)
+ # Ensure text box doesn't go outside image bounds
+ text_bbox = (
+ max(0, text_bbox[0]-5),
+ max(0, text_bbox[1]-5),
+ min(resized_width, text_bbox[2]+5),
+ min(resized_height, text_bbox[3]+5)
+ )
+ draw.rectangle(text_bbox, fill="yellow", outline="black")
+ draw.text((text_x, text_y), instruction, fill="black")
+ pil_image.save(image_path, 'JPEG')
+ # Calculate mean x and y coordinates of bbox and accumulate
+ mean_x = (x1 + x2) / 2
+ mean_y = (y1 + y2) / 2
+ total_x += mean_x
+ total_y += mean_y
+
+ count += 1
+
+ print(f"Shard {shard_idx}: Total examples: {count}")
+ if count > 0:
+ overall_mean_x = total_x / count
+ overall_mean_y = total_y / count
+ print(f"Shard {shard_idx}: Overall mean coordinates: ({overall_mean_x}, {overall_mean_y})")
+ print(f"Shard {shard_idx}: Images saved to {shard_dir}")
+
+ except Exception as e:
+ print(f"Error processing shard {shard_idx}: {e}")
+ continue
+
+ print(f"\nAll images saved to {base_save_dir} in hierarchical shard folders")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--local_dir", default="~/data/uground/")
+ parser.add_argument("--hdfs_dir", default=None)
+ parser.add_argument("--data_files", default="shard_*.parquet")
+ parser.add_argument("--output_filename", default="train")
+ parser.add_argument(
+ "--prompt_format",
+ choices=["qwen", "thinking", "subtask", "sft"],
+ default="subtask",
+ help="Select prompt format: 'qwen' or 'thinking' or 'subtask' or 'sft'",
+ )
+ parser.add_argument(
+ "--chunk_size",
+ type=int,
+ default=5000,
+ help="Number of examples per chunk",
+ )
+ parser.add_argument(
+ "--max_examples",
+ type=int,
+ default=100000,
+ help="Maximum number of examples to process (for testing)",
+ )
+
+ parser.add_argument(
+ "--max_pixels",
+ type=int,
+ default=None,
+ help="Maximum number of pixels in the image",
+ )
+
+
+ args = parser.parse_args()
+
+ # if args.max_examples and args.max_examples > 5000:
+ # print(f"⚠️ WARNING: You've set max_examples to {args.max_examples:,}, which is quite large.")
+ # print(" This will process a lot of data and may take a long time.")
+ # response = input(" Are you sure you want to continue? (y/N): ")
+ # if response.lower() not in ['y', 'yes']:
+ # print(" Exiting...")
+ # exit(0)
+ # print(" Continuing with processing...")
+
+
+ data_source = "osunlp/UGround-V1-Data-Box"
+ print(
+ f"Loading the {data_source} dataset from huggingface in streaming mode...",
+ flush=True,
+ )
+
+
+
+ check_shard_loading_pattern()
+
+ exit(0)
+
+
+
+ # Load in streaming mode
+ dataset = datasets.load_dataset(
+ data_source, data_files=args.data_files, streaming=True
+ )
+ dataset = dataset["train"]
+
+ def make_map_fn(split):
+ def process_fn(example, idx):
+ image = example.pop("image")
+ conversation = example.pop("conversations").strip()
+ # Use the first message for now. Uground has multiple grounding
+ # commands / groundtruths in the conversation.
+ command, label = json.loads(conversation)[:2]
+ assert command["from"] == "human" and label["from"] == "gpt"
+ instruction = command["value"]
+ label_text = label["value"]
+
+ # Parse the label text as "(x1, y1, x2, y2)" format
+ label_text = label_text.strip("()")
+ bbox = [int(x.strip()) for x in label_text.split(",")]
+ assert len(bbox) == 4, f"Expected 4 coordinates, got {len(bbox)}"
+
+ # Get image and resize ratios
+ if isinstance(image, bytes):
+ image = Image.open(io.BytesIO(image))
+ # Convert image to RGB if it's RGBA
+ image = to_rgb(image)
+ # Get the resized width and height of the image.
+ resized_height, resized_width = get_resized_hw(image, args.max_pixels)
+ image = image.resize((resized_width, resized_height))
+ # Adjust bbox based on resize ratios. Uground labels range from
+ # [0, 999]
+ bbox = [
+ bbox[0] * resized_width / 1000.0,
+ bbox[1] * resized_height / 1000.0,
+ bbox[2] * resized_width / 1000.0,
+ bbox[3] * resized_height / 1000.0,
+ ]
+
+ ground_truth = {
+ "bbox": bbox,
+ }
+
+ center_x = (bbox[0] + bbox[2]) / 2
+ center_y = (bbox[1] + bbox[3]) / 2
+
+ answer = [
+ {"role": "assistant", "content": f"click({center_x:.0f}, {center_y:.0f})"}
+ ]
+
+ data = {
+ "data_source": "uground",
+ "images": [image],
+ "ability": "vision-grounding",
+ "reward_model": {
+ "style": "rule",
+ "ground_truth": ground_truth,
+ },
+ "extra_info": {
+ "split": split,
+ "index": idx,
+ "question": instruction,
+ "bounding_box": bbox,
+ "max_pixels": args.max_pixels,
+ },
+ "response": answer
+ }
+
+ # Create prompt based on selected format
+
+ if args.prompt_format == "thinking":
+ data["prompt"] = [
+ {
+ "role": "user",
+ "content": (
+ "Map the user instruction to the coordinates in the UI image. "
+ "Think step by step before you answer. The reasoning process MUST BE enclosed within tags. "
+ "The coordinate x and y MUST BE put in tags, separeted by space. "
+ " Instruction: " + instruction
+ ),
+ },
+ ]
+ elif args.prompt_format == "sft":
+ data["prompt"] = [
+ {
+ "role": "user",
+ "content": (" Instruction: " + instruction),
+ },
+ ]
+ elif args.prompt_format == "subtask":
+ prompt = get_subtask_messages(instruction)
+ data["prompt"] = prompt
+ elif args.prompt_format == "qwen":
+ prompt = NousFnCallPrompt().preprocess_fncall_messages(
+ messages=[
+ Message(
+ role="system",
+ content=[ContentItem(text="You are a helpful assistant.")],
+ ),
+ Message(
+ role="user",
+ content=[
+ ContentItem(text=instruction + ""),
+ ],
+ ),
+ ],
+ functions=[
+ ComputerUse(
+ cfg={
+ "display_width_px": resized_width,
+ "display_height_px": resized_height,
+ }
+ ).function
+ ],
+ lang=None,
+ )
+
+ prompt = [msg.model_dump() for msg in prompt]
+ for message in prompt:
+ # Replace the list of content to a string.
+ content = "".join(m["text"] for m in message["content"])
+ message["content"] = content
+
+ data["prompt"] = prompt
+
+
+
+ return data
+
+ return process_fn
+
+ local_dir = os.path.expanduser(args.local_dir)
+ local_dir = os.path.join(local_dir, args.prompt_format)
+ if args.max_examples % 1000 == 0:
+ folder_name = f"{args.max_examples//1000}k"
+ else:
+ folder_name = f"{args.max_examples/1000:.2f}k"
+ local_dir = os.path.join(local_dir, folder_name)
+ print(f"Saving to {local_dir}...", flush=True)
+ os.makedirs(local_dir, exist_ok=True)
+
+ # Initialize counters and directories
+ train_file_counter = 0
+ test_file_counter = 0
+ total_processed = 0
+
+ train_dir = os.path.join(local_dir, "train")
+ test_dir = os.path.join(local_dir, "test")
+
+ # Set up the map function for process_in_chunks
+ process_in_chunks.map_fn = make_map_fn("train")
+ process_in_chunks.max_examples = args.max_examples
+
+ for chunk_dataset, chunk_start in process_in_chunks(dataset, args.chunk_size):
+ # Split each chunk into train/test
+ chunk_split = chunk_dataset.train_test_split(train_size=0.95, seed=42)
+ train_chunk = chunk_split["train"]
+ test_chunk = chunk_split["test"]
+
+ # Save train data
+ train_file_counter = save_in_chunks(
+ [train_chunk],
+ train_dir,
+ "train",
+ train_file_counter,
+ )
+
+ # Save test data
+ test_file_counter = save_in_chunks(
+ [test_chunk],
+ test_dir,
+ "test",
+ test_file_counter,
+ )
+
+ total_processed += len(chunk_dataset)
+
+ print(f"Processing completed! {total_processed} examples processed")
+
+ if args.hdfs_dir is not None:
+ makedirs(args.hdfs_dir)
+ copy(src=local_dir, dst=args.hdfs_dir)
diff --git a/orby/data/convert_osatlas.py b/orby/data/convert_osatlas.py
index 59bf0865c4f..1eca16bb677 100644
--- a/orby/data/convert_osatlas.py
+++ b/orby/data/convert_osatlas.py
@@ -70,17 +70,31 @@
PROCESSOR = AutoProcessor.from_pretrained(MODEL_PATH, use_fast=True)
-def get_resized_wh(image):
+def to_rgb(pil_image: Image.Image) -> Image.Image:
+ if pil_image.mode == 'RGBA':
+ white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
+ white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
+ return white_background
+ else:
+ return pil_image.convert("RGB")
+
+def get_resized_hw(image, max_pixels=None):
"""
Get the resized width and height of the image.
"""
+
+ # if max_pixels is not set, use the max pixels of the image processor
+ if not max_pixels:
+ print("Max pixels not set, using the max pixels of the image processor", flush=True)
+ max_pixels = PROCESSOR.image_processor.max_pixels
+
resized_height, resized_width = smart_resize(
- image.height,
- image.width,
+ height=image.height,
+ width=image.width,
factor=PROCESSOR.image_processor.patch_size
* PROCESSOR.image_processor.merge_size,
min_pixels=PROCESSOR.image_processor.min_pixels,
- max_pixels=PROCESSOR.image_processor.max_pixels,
+ max_pixels=max_pixels,
)
return resized_height, resized_width
@@ -102,6 +116,17 @@ def save_in_chunks(
for dataset_chunk in all_data:
if len(dataset_chunk) == 0:
continue
+
+ # Remove width and height columns if they exist
+ columns_to_remove = []
+ if "width" in dataset_chunk.column_names:
+ columns_to_remove.append("width")
+ if "height" in dataset_chunk.column_names:
+ columns_to_remove.append("height")
+
+ if columns_to_remove:
+ dataset_chunk = dataset_chunk.remove_columns(columns_to_remove)
+ print(f"Removed columns: {columns_to_remove}", flush=True)
# Save the chunk as-is (remove the splitting logic)
output_file = os.path.join(
@@ -245,6 +270,12 @@ def process_in_chunks(dataset, chunk_size):
parser.add_argument(
"--image_dir", default="/root/data/os_atlas/desktop_domain/merged_images/", help="Path to the directory containing images"
)
+ parser.add_argument(
+ "--max_pixels",
+ type=int,
+ default=None,
+ help="Maximum number of pixels in the image",
+ )
args = parser.parse_args()
@@ -274,7 +305,10 @@ def process_fn(example, idx):
# Get image and resize ratios
if isinstance(image, bytes):
image = Image.open(io.BytesIO(image))
- resized_height, resized_width = get_resized_wh(image)
+ # Convert image to RGB if it's RGBA
+ image = to_rgb(image)
+ # Get the resized width and height of the image.
+ resized_height, resized_width = get_resized_hw(image, args.max_pixels)
bbox = [
@@ -309,6 +343,7 @@ def process_fn(example, idx):
"index": idx,
"question": instruction,
"bounding_box": bbox,
+ "max_pixels": args.max_pixels,
},
"response": answer
}
diff --git a/orby/data/convert_screenspot.py b/orby/data/convert_screenspot.py
index 04b44abedb2..966ad19b0ed 100644
--- a/orby/data/convert_screenspot.py
+++ b/orby/data/convert_screenspot.py
@@ -45,6 +45,14 @@
"macos": "desktop",
}
+def to_rgb(pil_image: Image.Image) -> Image.Image:
+ if pil_image.mode == 'RGBA':
+ white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
+ white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
+ return white_background
+ else:
+ return pil_image.convert("RGB")
+
def get_resized_wh(image):
"""
@@ -56,7 +64,7 @@ def get_resized_wh(image):
factor=PROCESSOR.image_processor.patch_size
* PROCESSOR.image_processor.merge_size,
min_pixels=PROCESSOR.image_processor.min_pixels,
- max_pixels=PROCESSOR.image_processor.max_pixels,
+ max_pixels= 1e6,#PROCESSOR.image_processor.max_pixels,
)
return resized_height, resized_width
@@ -100,8 +108,11 @@ def process_fn(example, idx):
# Get image and resize ratios
if isinstance(image, bytes):
image = Image.open(io.BytesIO(image))
+ image = to_rgb(image)
resized_height, resized_width = get_resized_wh(image)
+ image = image.resize((resized_width, resized_height))
+
# Adjust bbox based on resize ratios
bbox = [
bbox[0] * resized_width,
diff --git a/orby/data/convert_screenspot_pro.py b/orby/data/convert_screenspot_pro.py
index 26af3d04c40..5065b3c2b0f 100644
--- a/orby/data/convert_screenspot_pro.py
+++ b/orby/data/convert_screenspot_pro.py
@@ -66,7 +66,7 @@ def get_resized_ratio(image):
factor=PROCESSOR.image_processor.patch_size
* PROCESSOR.image_processor.merge_size,
min_pixels=PROCESSOR.image_processor.min_pixels,
- max_pixels=PROCESSOR.image_processor.max_pixels,
+ max_pixels= 1e6,#PROCESSOR.image_processor.max_pixels,
)
height_ratio = resized_height / image.height
@@ -74,6 +74,13 @@ def get_resized_ratio(image):
return height_ratio, width_ratio
+def to_rgb(pil_image: Image.Image) -> Image.Image:
+ if pil_image.mode == 'RGBA':
+ white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
+ white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
+ return white_background
+ else:
+ return pil_image.convert("RGB")
def process_json_file(json_path, image_dir, split, prompt_format="thinking"):
"""
@@ -99,6 +106,10 @@ def process_json_file(json_path, image_dir, split, prompt_format="thinking"):
image = Image.open(img_path)
# Convert PIL Image to bytes
img_byte_arr = io.BytesIO()
+ image = to_rgb(image)
+ height_ratio, width_ratio = get_resized_ratio(image)
+ resized_height, resized_width = image.height * height_ratio, image.width * width_ratio
+ image = image.resize((int(resized_width), int(resized_height))) # Convert to integers
image.save(img_byte_arr, format=image.format or "PNG")
img_byte_arr = img_byte_arr.getvalue()
except Exception as e:
@@ -106,7 +117,7 @@ def process_json_file(json_path, image_dir, split, prompt_format="thinking"):
continue
# Get image resize ratios
- height_ratio, width_ratio = get_resized_ratio(image)
+
# Adjust bbox based on resize ratios
bbox = example["bbox"]
diff --git a/orby/data/convert_uground.py b/orby/data/convert_uground.py
index 074e91b1af8..94ffec64739 100644
--- a/orby/data/convert_uground.py
+++ b/orby/data/convert_uground.py
@@ -42,18 +42,30 @@
MODEL_PATH = "Qwen/Qwen2.5-VL-7B-Instruct"
PROCESSOR = AutoProcessor.from_pretrained(MODEL_PATH, use_fast=True)
+def to_rgb(pil_image: Image.Image) -> Image.Image:
+ if pil_image.mode == 'RGBA':
+ white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
+ white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
+ return white_background
+ else:
+ return pil_image.convert("RGB")
-def get_resized_wh(image):
+def get_resized_hw(image, max_pixels=None):
"""
Get the resized width and height of the image.
"""
+ # if max_pixels is not set, use the max pixels of the image processor
+ if not max_pixels:
+ print("Max pixels not set, using the max pixels of the image processor", flush=True)
+ max_pixels = PROCESSOR.image_processor.max_pixels
+
resized_height, resized_width = smart_resize(
- image.height,
- image.width,
+ height=image.height,
+ width=image.width,
factor=PROCESSOR.image_processor.patch_size
* PROCESSOR.image_processor.merge_size,
min_pixels=PROCESSOR.image_processor.min_pixels,
- max_pixels=PROCESSOR.image_processor.max_pixels,
+ max_pixels=max_pixels,
)
return resized_height, resized_width
@@ -76,6 +88,17 @@ def save_in_chunks(
if len(dataset_chunk) == 0:
continue
+ # Remove width and height columns if they exist
+ columns_to_remove = []
+ if "width" in dataset_chunk.column_names:
+ columns_to_remove.append("width")
+ if "height" in dataset_chunk.column_names:
+ columns_to_remove.append("height")
+
+ if columns_to_remove:
+ dataset_chunk = dataset_chunk.remove_columns(columns_to_remove)
+ print(f"Removed columns: {columns_to_remove}", flush=True)
+
# Save the chunk as-is (remove the splitting logic)
output_file = os.path.join(
output_dir, f"{prefix}_part_{file_counter:04d}.parquet"
@@ -165,6 +188,13 @@ def process_in_chunks(streaming_dataset, chunk_size):
help="Maximum number of examples to process (for testing)",
)
+ parser.add_argument(
+ "--max_pixels",
+ type=int,
+ default=None,
+ help="Maximum number of pixels in the image",
+ )
+
args = parser.parse_args()
@@ -209,9 +239,13 @@ def process_fn(example, idx):
# Get image and resize ratios
if isinstance(image, bytes):
image = Image.open(io.BytesIO(image))
- resized_height, resized_width = get_resized_wh(image)
+ # Convert image to RGB if it's RGBA
+ image = to_rgb(image)
+ # Get the resized width and height of the image.
+ resized_height, resized_width = get_resized_hw(image, args.max_pixels)
- # Adjust bbox based on resize ratios. Uground labels range from
+ image = image.resize((resized_width, resized_height))
+ # Adjust bbox based on resize ratios. Uground labels range from
# [0, 999]
bbox = [
bbox[0] * resized_width / 1000.0,
@@ -244,6 +278,7 @@ def process_fn(example, idx):
"index": idx,
"question": instruction,
"bounding_box": bbox,
+ "max_pixels": args.max_pixels,
},
"response": answer
}
diff --git a/orby/data/debug_images.py b/orby/data/debug_images.py
new file mode 100644
index 00000000000..f43ec857f85
--- /dev/null
+++ b/orby/data/debug_images.py
@@ -0,0 +1,241 @@
+import os
+import pandas as pd
+import argparse
+from PIL import Image, ImageDraw
+import io
+
+# # Set pandas display options to show full content
+# # pd.set_option('display.max_columns', None)
+# # pd.set_option('display.max_colwidth', None)
+# # pd.set_option('display.width', None)
+# # pd.set_option('display.max_rows', None)
+
+# def extract_image_from_message(message):
+# """Extract image data from a message if it contains an image."""
+# if isinstance(message, dict) and 'content' in message:
+# content = message['content']
+# if isinstance(content, list):
+# for item in content:
+# if isinstance(item, dict) and item.get('type') == 'image':
+# return item.get('image')
+# return None
+
+# def save_image_with_bbox(image_data, bbox, output_dir, filename, instruction=None, response=None):
+# """Save image with bounding box overlay and predicted click point."""
+# # Handle dict format for image data
+# if isinstance(image_data, dict):
+# actual_bytes = image_data.get('bytes') or image_data.get('data') or image_data.get('image')
+# if actual_bytes:
+# image_data = actual_bytes
+
+# # Create PIL Image from bytes
+# image = Image.open(io.BytesIO(image_data))
+
+# # Create a copy for drawing
+# img_with_bbox = image.copy()
+# draw = ImageDraw.Draw(img_with_bbox)
+
+# # Parse bounding box coordinates - handle numpy array
+# if bbox is not None and hasattr(bbox, '__len__') and len(bbox) >= 4:
+# # Convert to list if it's a numpy array
+# bbox_list = list(bbox) if hasattr(bbox, 'tolist') else bbox
+# x1, y1, x2, y2 = bbox_list[:4]
+
+# # Draw ground truth bounding box in red
+# draw.rectangle([x1, y1, x2, y2], outline="red", width=5)
+
+# # Calculate and draw ground truth center
+# center_x = (x1 + x2) / 2
+# center_y = (y1 + y2) / 2
+# # Draw center point as a small circle
+# radius = 2
+# draw.ellipse([center_x-radius, center_y-radius, center_x+radius, center_y+radius],
+# fill="red", outline="darkred", width=1)
+
+# # Add "GT" label near the center
+# draw.text((center_x + 15, center_y - 10), "GT", fill="red", font=None)
+
+# # Extract and draw predicted click coordinates from response
+# if response:
+# import re
+# # Convert numpy array to regular list
+# if hasattr(response, 'tolist'):
+# response = response.tolist()
+
+# response_content = ""
+# print(f"Response value: {response}")
+
+# if isinstance(response, list) and len(response) > 0:
+# response_content = str(response[0]['content']) # Force string conversion
+# elif isinstance(response, str):
+# response_content = response
+
+# # Extract coordinates from response using regex
+# match = re.search(r'click\(([0-9.]+),\s*([0-9.]+)\)', response_content)
+# if match:
+# print(match.groups())
+# pred_x, pred_y = float(match.group(1)), float(match.group(2))
+
+# # Draw predicted click point in blue
+# radius = 2
+# draw.ellipse([pred_x-radius, pred_y-radius, pred_x+radius, pred_y+radius],
+# fill="blue", outline="darkblue", width=2)
+
+# # Add "PRED" label near the predicted point
+# draw.text((pred_x + 15, pred_y - 10), "PRED", fill="blue", font=None)
+
+# # Draw a line connecting GT center to predicted point
+# if bbox is not None and hasattr(bbox, '__len__') and len(bbox) >= 4:
+# draw.line([center_x, center_y, pred_x, pred_y], fill="yellow", width=3)
+
+# print(f"Response center: ({pred_x:.0f}, {pred_y:.0f})")
+
+# # Add instruction text if provided
+# if instruction:
+# # Try to use a larger font
+# try:
+# from PIL import ImageFont
+# font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16)
+# except:
+# font = None
+
+# # Position text at top of image
+# text_x = 10
+# text_y = 10
+
+# # Draw text background for better visibility
+# if font:
+# text_bbox = draw.textbbox((text_x, text_y), instruction, font=font)
+# else:
+# text_bbox = draw.textbbox((text_x, text_y), instruction)
+# draw.rectangle(text_bbox, fill="yellow", outline="black")
+# draw.text((text_x, text_y), instruction, fill="black", font=font)
+
+# # Save the image with bounding box
+# os.makedirs(output_dir, exist_ok=True)
+# img_with_bbox.save(os.path.join(output_dir, filename))
+
+# return img_with_bbox
+
+# def save_original_image(image_data, output_dir, filename):
+# """Save original image without modifications."""
+# # Handle dict format (likely has 'bytes' or 'data' key)
+# if isinstance(image_data, dict):
+# # Try common keys for image bytes
+# actual_bytes = image_data.get('bytes') or image_data.get('data') or image_data.get('image')
+# if actual_bytes:
+# image_data = actual_bytes
+
+# image = Image.open(io.BytesIO(image_data))
+# os.makedirs(output_dir, exist_ok=True)
+# image.save(os.path.join(output_dir, filename))
+# return image
+
+# def print_image_info(image, filename, response):
+# """Print detailed information about the image."""
+# print(f"\n--- IMAGE INFO: {filename} ---")
+# print(f"Size: {image.size} (width x height)")
+# print(f"Mode: {image.mode}")
+# print(f"Response: {response}")
+
+# def visualize_parquet(parquet_file, save_images=False, original_dir="original_images", bbox_dir="bbox_images", num_examples=20):
+# # Load the parquet file
+# df = pd.read_parquet(parquet_file)
+# print(f"\nParquet file contents ({len(df)} rows):")
+# print("\nColumns:", df.columns.tolist())
+
+# # Print first few rows
+# print("\nFirst few rows:")
+# print(df.head())
+
+
+
+# idx = 0
+
+# for idx in range(min(len(df), num_examples)):
+# # Get the row and extract image data
+# row = df.iloc[idx]
+# image_data = None
+
+# # Try to get image from 'images' column
+# if 'images' in row and row['images'] is not None:
+# images = row['images']
+# if isinstance(images, list) and len(images) > 0:
+# image_data = images[0] # Take first image
+# elif hasattr(images, '__len__') and len(images) > 0:
+# image_data = images[0] # Handle numpy array
+# elif isinstance(images, bytes):
+# image_data = images
+# else:
+# print(f"Images column is None or missing")
+
+# if image_data is None:
+# print(f"No image data found for example {idx+1}")
+
+
+# # Save images if requested and image data is available
+# if image_data is not None:
+# filename = f"example_{idx+1}.png"
+
+# # Save original image
+# original_image = save_original_image(image_data, original_dir, filename)
+# # print_image_info(original_image, f"Original - {filename}", row.get('response', ''))
+
+# # Save image with bounding box
+# extra_info = row.get('extra_info', {})
+# bbox = extra_info.get('bounding_box')
+# print(f"Bounding box: {bbox}")
+# print(f"Center of bbox: {(bbox[0] + bbox[2]) / 2}, {(bbox[1] + bbox[3]) / 2}")
+# if bbox is not None and len(bbox) > 0:
+# # Get instruction from extra_info
+# instruction = extra_info.get('question', '') or extra_info.get('answer', '')
+# print("="*100)
+# bbox_image = save_image_with_bbox(image_data, bbox, bbox_dir, filename, instruction, row.get('response'))
+# print_image_info(bbox_image, f"With BBox - {filename}", row.get('response', ''))
+# else:
+# print("No bounding box found for this example")
+
+# if __name__ == "__main__":
+# parser = argparse.ArgumentParser()
+# parser.add_argument("--parquet_file",
+# default="~/data/uground/subtask/0.50k/train/train_part_0000.parquet",
+# help="Path to parquet file to visualize")
+# parser.add_argument("--save_images", action="store_true", help="Save images to directories")
+# parser.add_argument("--original_dir", default="/workspace/verl/orby/data/uground/original_images/", help="Directory to save original images")
+# parser.add_argument("--bbox_dir", default="/workspace/verl/orby/data/uground/bbox_images/", help="Directory to save images with bounding boxes")
+# parser.add_argument("--num_examples", type=int, default=20, help="Number of examples to visualize")
+# args = parser.parse_args()
+
+# visualize_parquet(args.parquet_file, args.save_images, args.original_dir, args.bbox_dir, args.num_examples)
+
+
+def visualize_parquet(parquet_file):
+ # Load the parquet file
+ df = pd.read_parquet(parquet_file)
+ print(f"\nParquet file contents ({len(df)} rows):")
+ print("\nColumns:", df.columns.tolist())
+
+
+
+ for i in range(5):
+ print(f"\nExample {i+1}:")
+ print(f"Bounding box: {df.iloc[i]['extra_info']['bounding_box']}")
+ # print(f"Image size: {df.iloc[i]['width']}x{df.iloc[i]['height']}")
+ # Convert image bytes to PIL Image
+ image_data = df.iloc[i]['images'][0]
+ if isinstance(image_data, dict):
+ image_bytes = image_data.get('bytes') or image_data.get('data') or image_data.get('image')
+ else:
+ image_bytes = image_data
+ image = Image.open(io.BytesIO(image_bytes))
+ width, height = image.size
+ print(f"PIL Image size: {width}x{height}")
+ x1, y1, x2, y2 = df.iloc[i]['extra_info']['bounding_box']
+ print(f"Center of bbox: {(x1 + x2) / 2:.0f}, {(y1 + y2) / 2:.0f}")
+ print(f"Response: {df.iloc[i]['response']}")
+ print("-" * 80)
+
+
+if __name__ == "__main__":
+ parquet_file = "/root/data/uground/subtask/0.10k/test/test_part_0000.parquet"
+ visualize_parquet(parquet_file)
\ No newline at end of file
diff --git a/orby/data/visualize_screenspot.py b/orby/data/visualize_screenspot.py
new file mode 100644
index 00000000000..dd4d38ae9c8
--- /dev/null
+++ b/orby/data/visualize_screenspot.py
@@ -0,0 +1,90 @@
+import pyarrow.parquet as pq
+import io
+from PIL import Image, ImageDraw
+import os
+import re
+
+def extract_images(parquet_file, num_images=10000, output_dir="/workspace/verl/orby/data/screenspot_v1/"):
+ """
+ Extract images from ScreenSpot parquet file and draw bounding boxes with coordinates
+ """
+ parquet_file_expanded = parquet_file.replace('~', '/root')
+
+ # Create output directory
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Open the parquet file and iterate through batches
+ pf = pq.ParquetFile(parquet_file_expanded)
+ batch_reader = pf.iter_batches(batch_size=1, columns=['images', 'responses', 'extra_info'])
+
+ extracted_count = 0
+
+ for i, batch in enumerate(batch_reader):
+ if extracted_count >= num_images:
+ break
+
+ images_col = batch.column('images')
+ responses_col = batch.column('responses')
+ extra_info_col = batch.column('extra_info')
+
+ if len(images_col) > 0 and images_col[0].is_valid:
+ image_data = images_col[0].as_py()
+ responses = responses_col[0].as_py() if len(responses_col) > 0 and responses_col[0].is_valid else None
+ extra_info = extra_info_col[0].as_py() if len(extra_info_col) > 0 and extra_info_col[0].is_valid else None
+
+ if image_data and isinstance(image_data, list) and len(image_data) > 0:
+ first_item = image_data[0]
+
+ if isinstance(first_item, dict) and 'bytes' in first_item:
+ image_bytes = first_item['bytes']
+
+ try:
+ image = Image.open(io.BytesIO(image_bytes))
+ draw = ImageDraw.Draw(image)
+
+ # Draw ground truth bounding box in red
+ if extra_info and 'bounding_box' in extra_info:
+ bbox = extra_info['bounding_box']
+ x1, y1, x2, y2 = bbox
+ draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
+ # Draw GT center point
+ center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
+ draw.ellipse([center_x-3, center_y-3, center_x+3, center_y+3], fill="red")
+ draw.text((center_x+10, center_y-10), "GT", fill="red")
+
+ # Extract and draw predicted coordinates in blue
+ if responses and len(responses) > 0:
+ response_text = responses[0]
+ match = re.search(r'click\((\d+\.?\d*),\s*(\d+\.?\d*)\)', response_text)
+ if match:
+ pred_x, pred_y = float(match.group(1)), float(match.group(2))
+ draw.ellipse([pred_x-3, pred_y-3, pred_x+3, pred_y+3], fill="blue")
+ draw.text((pred_x+10, pred_y-10), "PRED", fill="blue")
+
+ # Draw question text near the bounding box
+ if extra_info and 'question' in extra_info and 'bounding_box' in extra_info:
+ question = extra_info['question']
+ bbox = extra_info['bounding_box']
+ x1, y1, x2, y2 = bbox
+ # Position text above the bounding box
+ text_x = x1 + 30 # 30 pixels right of the bounding box
+ text_y = y1 + 30 # 30 pixels below the bounding box
+ # Draw text background for better visibility
+ text_bbox = draw.textbbox((text_x, text_y), question)
+ draw.rectangle([text_bbox[0]-5, text_bbox[1]-5, text_bbox[2]+5, text_bbox[3]+5],
+ fill="yellow", outline="black")
+ draw.text((text_x, text_y), question, fill="black")
+ output_path = os.path.join(output_dir, f"image_{i:04d}.png")
+ image.save(output_path)
+ extracted_count += 1
+ print(f"Saved image {extracted_count}: {output_path}")
+
+ except Exception as e:
+ print(f"Failed to process image {i}: {e}")
+
+ print(f"\nExtracted {extracted_count} images to {output_dir}")
+ return extracted_count
+
+if __name__ == "__main__":
+ parquet_file = "~/data/screenspot_subtask/result-test-output-1.parquet"
+ extract_images(parquet_file, num_images=100)
\ No newline at end of file
diff --git a/orby/data/visualize_uground.py b/orby/data/visualize_uground.py
index 1feff0c51fc..d28cb6a538b 100644
--- a/orby/data/visualize_uground.py
+++ b/orby/data/visualize_uground.py
@@ -57,8 +57,8 @@ def visualize_parquet(parquet_file):
extra_info = df.iloc[idx]['extra_info']
print(f"\n--- EXTRA INFO ---")
print(f"Question: {extra_info['question']}")
- print(f"Answer: {extra_info['answer']}")
print(f"Bounding Box: {extra_info['bounding_box']}")
+ print(f"Max Pixels: {extra_info['max_pixels']}")
print("\n" + "="*60)
@@ -66,7 +66,7 @@ def visualize_parquet(parquet_file):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--parquet_file",
- default="~/data/executor-reward/train/part-00000-tid-3712964276653840281-af2210b2-e910-4427-aa16-9f2a2cfdae0a-844-1-c000.snappy.parquet",
+ default="~/data/uground/subtask/0.50k/train/train_part_0000.parquet",
help="Path to parquet file to visualize")
args = parser.parse_args()
diff --git a/orby/mcli/sft_qwen2_5_vl_7b_grounding.yaml b/orby/mcli/sft_qwen2_5_vl_7b_grounding.yaml
index df4f427adf4..095274df810 100644
--- a/orby/mcli/sft_qwen2_5_vl_7b_grounding.yaml
+++ b/orby/mcli/sft_qwen2_5_vl_7b_grounding.yaml
@@ -4,7 +4,7 @@ image: whatcanyousee/verl:ngc-cu124-vllm0.8.3-sglang0.4.5-mcore0.12.0-te2.2
integrations:
- integration_type: git_repo
git_repo: orby-ai-engineering/verl
- git_branch: main # TODO: Change this according to your experiment!
+ git_branch: rishu/grounding-1k-pixel # TODO: Change this according to your experiment!
pip_install: .
ssh_clone: true
@@ -18,15 +18,15 @@ command: |
export NUM_NODES=2
export MODEL_NAME=Qwen/Qwen2.5-VL-7B-Instruct
export PROJECT_NAME=verl_sft_grounding
- export DATASET_VERSION=os_atlas
- export EXPERIMENT_NAME=$MODEL_NAME-$DATASET_VERSION-sft
- export DATA_SPLIT=5k # "Set the data split here (example 100k, 5k, 0.05k, etc.)"
+ export DATASET_VERSION=os_atlas_uground
+ export EXPERIMENT_NAME=$MODEL_NAME-$DATASET_VERSION-sft-1k-pixel
+ export DATA_SPLIT=50k # "Set the data split here (example 100k, 5k, 0.05k, etc.)"
export S3_CHECKPOINT_DIR=s3://orby-osu-va/verl-checkpoints/$PROJECT_NAME/$EXPERIMENT_NAME/$DATA_SPLIT
export TRAIN_BATCH_SIZE=32
export MICRO_BATCH_SIZE_PER_GPU=2
export FILTER_OVERLONG_PROMPTS_WORKERS=24 # (24 seems to work well for OSAtlas + Uground data)
- export TRAIN_DIR=s3://orby-osu-va/Rishu-SFT-Dataset/os_atlas/subtask/$DATA_SPLIT/train/
- export TEST_DIR=s3://orby-osu-va/Rishu-SFT-Dataset/os_atlas/subtask/$DATA_SPLIT/test/
+ export TRAIN_DIR=s3://orby-osu-va/rishu_test/1kpixel/50k/train/
+ export TEST_DIR=s3://orby-osu-va/rishu_test/1kpixel/50k/test/
export MAX_PROMPT_LENGTH=7100
cd /workspace/verl
@@ -91,6 +91,7 @@ command: |
+data.max_prompt_length=$MAX_PROMPT_LENGTH \
+processor.use_fast=true \
+processor.trust_remote_code=true \
+ +processor.max_pixels=1000000 \
optim.lr=1e-6 \
model.partial_pretrain=$MODEL_NAME \
model.fsdp_config.cpu_offload=true \
diff --git a/orby/scripts/eval_screenspot.sh b/orby/scripts/eval_screenspot.sh
index aa68bd7f427..d9565694f0f 100644
--- a/orby/scripts/eval_screenspot.sh
+++ b/orby/scripts/eval_screenspot.sh
@@ -24,7 +24,7 @@ set -e
# Default values
DATASET_VERSION="screenspot"
-MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct
+MODEL_PATH=/root/experiment/uground-osatlas-1kpixel/subtask/100k/global_step_300/
MODEL_SIZE=7
REWARD_FILE=orby/reward/screenspot.py
REWARD_FN=reward_func
diff --git a/orby/trainer/fsdp_sft_trainer.py b/orby/trainer/fsdp_sft_trainer.py
index d60858fe9ea..1ef830e265a 100644
--- a/orby/trainer/fsdp_sft_trainer.py
+++ b/orby/trainer/fsdp_sft_trainer.py
@@ -1001,6 +1001,11 @@ def main(config):
local_model_path, trust_remote_code=config.model.trust_remote_code
)
processor = hf_processor(local_model_path, **config.get("processor", {}))
+ if rank == 0:
+ print("Processor parameters:")
+ for name, param in processor.__dict__.items():
+ if not name.startswith('_'):
+ print(f" {name}: {param}")
train_dataset = create_sft_dataset(
config.data.train_files, config.data, tokenizer, processor
)
diff --git a/orby/trainer/main_generation.py b/orby/trainer/main_generation.py
index 2848fac28c9..322370ace4b 100644
--- a/orby/trainer/main_generation.py
+++ b/orby/trainer/main_generation.py
@@ -101,7 +101,7 @@ def main_task(config):
trust_remote_code = config.data.get("trust_remote_code", False)
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
processor = hf_processor(
- local_path, use_fast=True
+ local_path, use_fast=True,max_pixels=1e6
) # used for multimodal LLM, could be none
dataset = _create_dataloader(config, tokenizer, processor)