Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions orby/scripts/eval_only_screenspot.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
set -e

# Default values
DATASET_VERSION="screenspot"
MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct
REWARD_FILE=orby/reward/screenspot.py
REWARD_FN=reward_func
OUTPUT_FILE=responses.parquet
PROMPT_FORMAT="qwen"
DATA_PATH=~/data/$DATASET_VERSION/

echo "Using dataset version: $DATASET_VERSION"
echo "Data path: $DATA_PATH"

# Evaluation
python3 -m orby.trainer.main_eval \
data.path=$DATA_PATH/$OUTPUT_FILE \
data.prompt_key=prompt \
data.response_key=responses \
custom_reward_function.path=$REWARD_FILE \
custom_reward_function.name=$REWARD_FN
36 changes: 33 additions & 3 deletions orby/scripts/init_interactive.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,41 @@ apt install -y awscli
pip install 'urllib3<2'
pip install parquet-tools

# Download model.
# Install Miniconda
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
bash Miniconda3-latest-Linux-x86_64.sh -b -p $HOME/miniconda
source $HOME/miniconda/bin/activate

# Create and setup conda environment
conda create -n verl python=3.12 -y
conda activate verl

# Clone VERL repository
git clone [email protected]:orby-ai-engineering/verl.git && cd verl
# Initialize git submodules if needed
git submodule update --init --recursive
# Install dependencies
pip install -e .[vllm]

# Handle Flash Attention installation
pip uninstall -y flash-attn
FLASH_ATTENTION_FORCE_BUILD=TRUE pip install flash-attn --no-build-isolation

# Install additional utilities
pip install qwen_vl_utils
pip install qwen_agent
pip install hf_transfer

# Download model, verify transformers installation
python3 -c "import transformers; transformers.pipeline(model='Qwen/Qwen2.5-VL-7B-Instruct')"

# Install verl lib: https://verl.readthedocs.io/en/latest/start/install.html
pip3 install -e .[vllm]
# Install internal packages
conda deactivate
conda create -n digital-agent python=3.12 -y
git clone [email protected]:orby-ai-engineering/digital-agent.git & cd digital-agent
git submodule update --init --recursive
pip install -r requirements.txt


# Download and convert action description dev set
# mkdir -p ~/data/action_description/raw/
Expand Down
143 changes: 143 additions & 0 deletions orby/trainer/vllm_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import os
import json
import torch
from typing import List, Dict, Any
from openai import OpenAI
from dataclasses import dataclass
from pathlib import Path
from tqdm import tqdm
import pandas as pd
import base64
from io import BytesIO
import ray
from ray.util.multiprocessing import Pool

@dataclass
class Batch:
"""Class to hold a batch of data"""
prompts: List[str]
ground_truth: List[Dict[str, Any]]

class DataLoader:
"""Data loader for loading and batching data"""
def __init__(self, data_path: str, batch_size: int = 128):
self.data_path = Path(data_path)
self.batch_size = batch_size
self.data = self._load_data()

def _load_data(self) -> pd.DataFrame:
"""Load data from the parquet file"""
return pd.read_parquet(self.data_path)

def _create_chat_messages(self, row: pd.Series) -> str:
"""Create chat template for the prompt"""
prompt = row['prompt']

# Convert image bytes to base64 encoding
if len(row['images']) != 1:
raise ValueError("More than 1 image")
image_bytes = BytesIO(row['images'][0]['bytes']).getvalue()
base64_image = base64.b64encode(image_bytes).decode('utf-8')
messages = [
{
"role": "system",
"content": [
{
"type": "text",
"text": prompt[0]["content"]
}
]
},
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt[1]["content"]
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image}"
}
}
]
}
]

return messages

def get_batches(self):
"""Generate batches of data"""
for i in range(0, len(self.data), self.batch_size):
batch_data = self.data.iloc[i:i + self.batch_size]
prompts = [self._create_chat_messages(row) for _, row in batch_data.iterrows()]
ground_truth = batch_data['reward_model'].tolist()
yield Batch(prompts=prompts, ground_truth=ground_truth)

@ray.remote
class VLLMClient:
"""Client for interacting with VLLM server"""
def __init__(self, server_url: str):
self.client = OpenAI(
base_url=server_url,
api_key="not-needed" # VLLM server doesn't require API key
)

def generate(self, prompt: Any) -> str:
"""Generate response for a single prompt"""
completion = self.client.chat.completions.create(
model="qwen25vl7b-9", # Model name not needed as it's configured on server
messages=prompt,
temperature=0,
# max_tokens=2048,
top_p=1.0,
# frequency_penalty=0.0,
# presence_penalty=0.0
)
return [completion.choices[0].message.content]

def process_batch(batch: Batch, client_pool: List[Any]) -> List[str]:
"""Process a batch of prompts using the existing client pool"""
# Generate responses in parallel using the client pool
futures = [client.generate.remote(prompt) for client, prompt in zip(client_pool, batch.prompts)]
responses = ray.get(futures)

return responses

def main():
# Initialize Ray
ray.init()

# Configuration
data_path = "~/data/screenspot/test.parquet" # Replace with your data path
server_url = "http://model.orbyapi.com/v1"
batch_size = 16
output_file = os.path.join(os.path.dirname(data_path), "responses.parquet")

# Create a pool of VLLM clients
client_pool = [VLLMClient.remote(server_url) for _ in range(batch_size)]

# Initialize components
data_loader = DataLoader(data_path, batch_size)

# Process batches and collect responses
all_responses = []

for batch in tqdm(data_loader.get_batches(), desc="Processing batches"):
# Generate responses from VLLM server in parallel using the client pool
responses = process_batch(batch, client_pool)
all_responses.extend(responses)

# Add responses to the original dataset
data_loader.data['responses'] = all_responses

# Save the updated dataset
data_loader.data.to_parquet(output_file, index=False)
print(f"Saved dataset with {len(all_responses)} responses to {output_file}")

# Shutdown Ray
ray.shutdown()

if __name__ == "__main__":
main()
128 changes: 128 additions & 0 deletions orby/trainer/vllm_generation_without_ray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import os
import json
import torch
from typing import List, Dict, Any
from openai import OpenAI
from dataclasses import dataclass
from pathlib import Path
from tqdm import tqdm
import pandas as pd
import base64
from io import BytesIO

@dataclass
class Batch:
"""Class to hold a batch of data"""
prompts: List[str]
ground_truth: List[Dict[str, Any]]

class DataLoader:
"""Data loader for loading and batching data"""
def __init__(self, data_path: str, batch_size: int = 128):
self.data_path = Path(data_path)
self.batch_size = batch_size
self.data = self._load_data()

def _load_data(self) -> pd.DataFrame:
"""Load data from the parquet file"""
return pd.read_parquet(self.data_path)

def _create_chat_messages(self, row: pd.Series) -> str:
"""Create chat template for the prompt"""
prompt = row['prompt']

# Convert image bytes to base64 encoding
if len(row['images']) != 1:
raise ValueError("More than 1 image")
image_bytes = BytesIO(row['images'][0]['bytes']).getvalue()
base64_image = base64.b64encode(image_bytes).decode('utf-8')
messages = [
{
"role": "system",
"content": [
{
"type": "text",
"text": prompt[0]["content"]
}
]
},
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt[1]["content"].replace("<image>", "")
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image}"
}
}
]
}
]

return messages

def get_batches(self):
"""Generate batches of data"""
for i in range(0, len(self.data), self.batch_size):
batch_data = self.data.iloc[i:i + self.batch_size]
prompts = [self._create_chat_messages(row) for _, row in batch_data.iterrows()]
ground_truth = batch_data['reward_model'].tolist()
yield Batch(prompts=prompts, ground_truth=ground_truth)

class VLLMClient:
"""Client for interacting with VLLM server"""
def __init__(self, server_url: str):
self.client = OpenAI(
base_url=server_url,
api_key="not-needed" # VLLM server doesn't require API key
)

def generate(self, prompts: List[Any]) -> List[str]:
"""Generate responses for a batch of prompts"""
completion = self.client.chat.completions.create(
model="qwen25vl7b-9", # Model name not needed as it's configured on server
messages=prompts,
temperature=0,
top_p=1.0,
)
return [choice.message.content for choice in completion.choices]

def process_batch(batch: Batch, client: VLLMClient) -> List[str]:
"""Process a batch of prompts using a single client"""
responses = client.generate(batch.prompts)
return responses

def main():
# Configuration
data_path = "~/data/screenspot/test.parquet" # Replace with your data path
server_url = "http://model.orbyapi.com/v1"
batch_size = 16
output_file = os.path.join(os.path.dirname(data_path), "responses.parquet")

# Create a single VLLM client
client = VLLMClient(server_url)

# Initialize components
data_loader = DataLoader(data_path, batch_size)

# Process batches and collect responses
all_responses = []

for batch in tqdm(data_loader.get_batches(), desc="Processing batches"):
# Generate responses from VLLM server sequentially
responses = process_batch(batch, client)
all_responses.extend(responses)

# Add responses to the original dataset
data_loader.data['responses'] = all_responses

# Save the updated dataset
data_loader.data.to_parquet(output_file, index=False)
print(f"Saved dataset with {len(all_responses)} responses to {output_file}")

if __name__ == "__main__":
main()