diff --git a/orby/scripts/eval_only_screenspot.sh b/orby/scripts/eval_only_screenspot.sh new file mode 100644 index 00000000000..8dc94c116f2 --- /dev/null +++ b/orby/scripts/eval_only_screenspot.sh @@ -0,0 +1,21 @@ +set -e + +# Default values +DATASET_VERSION="screenspot" +MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct +REWARD_FILE=orby/reward/screenspot.py +REWARD_FN=reward_func +OUTPUT_FILE=responses.parquet +PROMPT_FORMAT="qwen" +DATA_PATH=~/data/$DATASET_VERSION/ + +echo "Using dataset version: $DATASET_VERSION" +echo "Data path: $DATA_PATH" + +# Evaluation +python3 -m orby.trainer.main_eval \ + data.path=$DATA_PATH/$OUTPUT_FILE \ + data.prompt_key=prompt \ + data.response_key=responses \ + custom_reward_function.path=$REWARD_FILE \ + custom_reward_function.name=$REWARD_FN diff --git a/orby/scripts/init_interactive.sh b/orby/scripts/init_interactive.sh index eb07705e120..184c58d191b 100644 --- a/orby/scripts/init_interactive.sh +++ b/orby/scripts/init_interactive.sh @@ -18,11 +18,41 @@ apt install -y awscli pip install 'urllib3<2' pip install parquet-tools -# Download model. +# Install Miniconda +wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh +bash Miniconda3-latest-Linux-x86_64.sh -b -p $HOME/miniconda +source $HOME/miniconda/bin/activate + +# Create and setup conda environment +conda create -n verl python=3.12 -y +conda activate verl + +# Clone VERL repository +git clone git@github.com:orby-ai-engineering/verl.git && cd verl +# Initialize git submodules if needed +git submodule update --init --recursive +# Install dependencies +pip install -e .[vllm] + +# Handle Flash Attention installation +pip uninstall -y flash-attn +FLASH_ATTENTION_FORCE_BUILD=TRUE pip install flash-attn --no-build-isolation + +# Install additional utilities +pip install qwen_vl_utils +pip install qwen_agent +pip install hf_transfer + +# Download model, verify transformers installation python3 -c "import transformers; transformers.pipeline(model='Qwen/Qwen2.5-VL-7B-Instruct')" -# Install verl lib: https://verl.readthedocs.io/en/latest/start/install.html -pip3 install -e .[vllm] +# Install internal packages +conda deactivate +conda create -n digital-agent python=3.12 -y +git clone git@github.com:orby-ai-engineering/digital-agent.git & cd digital-agent +git submodule update --init --recursive +pip install -r requirements.txt + # Download and convert action description dev set # mkdir -p ~/data/action_description/raw/ diff --git a/orby/trainer/vllm_generation.py b/orby/trainer/vllm_generation.py new file mode 100644 index 00000000000..94b9eb0b9ab --- /dev/null +++ b/orby/trainer/vllm_generation.py @@ -0,0 +1,143 @@ +import os +import json +import torch +from typing import List, Dict, Any +from openai import OpenAI +from dataclasses import dataclass +from pathlib import Path +from tqdm import tqdm +import pandas as pd +import base64 +from io import BytesIO +import ray +from ray.util.multiprocessing import Pool + +@dataclass +class Batch: + """Class to hold a batch of data""" + prompts: List[str] + ground_truth: List[Dict[str, Any]] + +class DataLoader: + """Data loader for loading and batching data""" + def __init__(self, data_path: str, batch_size: int = 128): + self.data_path = Path(data_path) + self.batch_size = batch_size + self.data = self._load_data() + + def _load_data(self) -> pd.DataFrame: + """Load data from the parquet file""" + return pd.read_parquet(self.data_path) + + def _create_chat_messages(self, row: pd.Series) -> str: + """Create chat template for the prompt""" + prompt = row['prompt'] + + # Convert image bytes to base64 encoding + if len(row['images']) != 1: + raise ValueError("More than 1 image") + image_bytes = BytesIO(row['images'][0]['bytes']).getvalue() + base64_image = base64.b64encode(image_bytes).decode('utf-8') + messages = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": prompt[0]["content"] + } + ] + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt[1]["content"] + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{base64_image}" + } + } + ] + } + ] + + return messages + + def get_batches(self): + """Generate batches of data""" + for i in range(0, len(self.data), self.batch_size): + batch_data = self.data.iloc[i:i + self.batch_size] + prompts = [self._create_chat_messages(row) for _, row in batch_data.iterrows()] + ground_truth = batch_data['reward_model'].tolist() + yield Batch(prompts=prompts, ground_truth=ground_truth) + +@ray.remote +class VLLMClient: + """Client for interacting with VLLM server""" + def __init__(self, server_url: str): + self.client = OpenAI( + base_url=server_url, + api_key="not-needed" # VLLM server doesn't require API key + ) + + def generate(self, prompt: Any) -> str: + """Generate response for a single prompt""" + completion = self.client.chat.completions.create( + model="qwen25vl7b-9", # Model name not needed as it's configured on server + messages=prompt, + temperature=0, + # max_tokens=2048, + top_p=1.0, + # frequency_penalty=0.0, + # presence_penalty=0.0 + ) + return [completion.choices[0].message.content] + +def process_batch(batch: Batch, client_pool: List[Any]) -> List[str]: + """Process a batch of prompts using the existing client pool""" + # Generate responses in parallel using the client pool + futures = [client.generate.remote(prompt) for client, prompt in zip(client_pool, batch.prompts)] + responses = ray.get(futures) + + return responses + +def main(): + # Initialize Ray + ray.init() + + # Configuration + data_path = "~/data/screenspot/test.parquet" # Replace with your data path + server_url = "http://model.orbyapi.com/v1" + batch_size = 16 + output_file = os.path.join(os.path.dirname(data_path), "responses.parquet") + + # Create a pool of VLLM clients + client_pool = [VLLMClient.remote(server_url) for _ in range(batch_size)] + + # Initialize components + data_loader = DataLoader(data_path, batch_size) + + # Process batches and collect responses + all_responses = [] + + for batch in tqdm(data_loader.get_batches(), desc="Processing batches"): + # Generate responses from VLLM server in parallel using the client pool + responses = process_batch(batch, client_pool) + all_responses.extend(responses) + + # Add responses to the original dataset + data_loader.data['responses'] = all_responses + + # Save the updated dataset + data_loader.data.to_parquet(output_file, index=False) + print(f"Saved dataset with {len(all_responses)} responses to {output_file}") + + # Shutdown Ray + ray.shutdown() + +if __name__ == "__main__": + main() diff --git a/orby/trainer/vllm_generation_without_ray.py b/orby/trainer/vllm_generation_without_ray.py new file mode 100644 index 00000000000..02d4635e6f5 --- /dev/null +++ b/orby/trainer/vllm_generation_without_ray.py @@ -0,0 +1,128 @@ +import os +import json +import torch +from typing import List, Dict, Any +from openai import OpenAI +from dataclasses import dataclass +from pathlib import Path +from tqdm import tqdm +import pandas as pd +import base64 +from io import BytesIO + +@dataclass +class Batch: + """Class to hold a batch of data""" + prompts: List[str] + ground_truth: List[Dict[str, Any]] + +class DataLoader: + """Data loader for loading and batching data""" + def __init__(self, data_path: str, batch_size: int = 128): + self.data_path = Path(data_path) + self.batch_size = batch_size + self.data = self._load_data() + + def _load_data(self) -> pd.DataFrame: + """Load data from the parquet file""" + return pd.read_parquet(self.data_path) + + def _create_chat_messages(self, row: pd.Series) -> str: + """Create chat template for the prompt""" + prompt = row['prompt'] + + # Convert image bytes to base64 encoding + if len(row['images']) != 1: + raise ValueError("More than 1 image") + image_bytes = BytesIO(row['images'][0]['bytes']).getvalue() + base64_image = base64.b64encode(image_bytes).decode('utf-8') + messages = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": prompt[0]["content"] + } + ] + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt[1]["content"].replace("", "") + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{base64_image}" + } + } + ] + } + ] + + return messages + + def get_batches(self): + """Generate batches of data""" + for i in range(0, len(self.data), self.batch_size): + batch_data = self.data.iloc[i:i + self.batch_size] + prompts = [self._create_chat_messages(row) for _, row in batch_data.iterrows()] + ground_truth = batch_data['reward_model'].tolist() + yield Batch(prompts=prompts, ground_truth=ground_truth) + +class VLLMClient: + """Client for interacting with VLLM server""" + def __init__(self, server_url: str): + self.client = OpenAI( + base_url=server_url, + api_key="not-needed" # VLLM server doesn't require API key + ) + + def generate(self, prompts: List[Any]) -> List[str]: + """Generate responses for a batch of prompts""" + completion = self.client.chat.completions.create( + model="qwen25vl7b-9", # Model name not needed as it's configured on server + messages=prompts, + temperature=0, + top_p=1.0, + ) + return [choice.message.content for choice in completion.choices] + +def process_batch(batch: Batch, client: VLLMClient) -> List[str]: + """Process a batch of prompts using a single client""" + responses = client.generate(batch.prompts) + return responses + +def main(): + # Configuration + data_path = "~/data/screenspot/test.parquet" # Replace with your data path + server_url = "http://model.orbyapi.com/v1" + batch_size = 16 + output_file = os.path.join(os.path.dirname(data_path), "responses.parquet") + + # Create a single VLLM client + client = VLLMClient(server_url) + + # Initialize components + data_loader = DataLoader(data_path, batch_size) + + # Process batches and collect responses + all_responses = [] + + for batch in tqdm(data_loader.get_batches(), desc="Processing batches"): + # Generate responses from VLLM server sequentially + responses = process_batch(batch, client) + all_responses.extend(responses) + + # Add responses to the original dataset + data_loader.data['responses'] = all_responses + + # Save the updated dataset + data_loader.data.to_parquet(output_file, index=False) + print(f"Saved dataset with {len(all_responses)} responses to {output_file}") + +if __name__ == "__main__": + main() \ No newline at end of file