This repository implements Owen-Shapley Policy Optimization (OSPO) for training large language models on the Amazon ESCI (Shopping Queries) dataset. OSPO addresses the credit assignment gap in reinforcement learning for search query generation by redistributing sequence-level rewards based on tokens' marginal contributions to retrieval outcomes.
Large language models are increasingly trained via reinforcement learning for personalized recommendation tasks, but standard methods like GRPO rely on sparse, sequence-level rewards. These obscure which tokens actually contribute to high-quality outputs, creating a credit assignment gap. This gap is especially problematic when models must infer latent user intent from under-specified language without ground truth labels, which is a reasoning pattern rarely seen during pretraining but commonly required in deployment. We introduce Owen-Shapley Policy Optimization (OSPO), a framework that redistributes sequence-level advantages based on tokens' marginal contributions to outcomes. OSPO transforms task feedback into potential-based reward shaping via Shapley-Owen attributions to assign segment-level credit while preserving the optimal policy—all without parametric value models. By forming coalitions of semantically coherent units (e.g., phrases describing product attributes or sentences capturing preferences), OSPO identifies which response parts drive performance. Experiments on Amazon ESCI and H&M Fashion datasets show consistent gains over baselines and notable test-time robustness to out-of-distribution retrievers unseen during training.
- Python 3.10+
- CUDA-capable GPU(s)
- Clone the repository:
git clone <repository-url>
cd LLM-Seq-Shapley-Owen-PO- Install dependencies:
pip install -r requirements.txtKey dependencies include:
transformers,accelerate,trl(LLM training)sentence-transformers(embedding generation)faiss-gpu(vector search)datasets(HuggingFace datasets)torch,numpy,pandas
llm_shapley_owen_code/
├── data/ # Datasets and indices
│ └── esci/ # ESCI product search data
│ ├── metadata/
│ │ └── item_catalog.jsonl # ASIN→metadata mapping (~3M items, 1.8GB)
│ ├── embeddings/
│ │ ├── all-mpnet-base-v2.npy # Dense embeddings (3M × 768, 10GB)
│ │ └── all-mpnet-base-v2_asin_mapping.json
│ ├── index/
│ │ ├── all-mpnet-base-v2_faiss.bin # FAISS HNSW index (10GB)
│ │ ├── all-mpnet-base-v2_asin_mapping.json
│ │ └── all-mpnet-base-v2_faiss_metadata.json
│ ├── rl_dataset/ # RL training dataset
│ ├── sft_dataset.jsonl # Generated SFT training data
│ └── dpo_dataset.jsonl # Generated DPO preference pairs
│
├── src/esci_search/ # Main source code
│ ├── configs/ # Training configurations
│ │ ├── grpo_config.yaml # GRPO baseline config
│ │ ├── ospo_prop_config.yaml # OSPO proportional
│ │ ├── ospo_rank_config.yaml # OSPO rank-based
│ │ └── ospo_prop_no_clip.yaml # OSPO without gradient clipping
│ │
│ ├── data_processing/ # Data preparation pipeline
│ │ ├── build_index.py # Build FAISS search index
│ │ ├── create_item_metadata.py
│ │ ├── generate_embeddings.py
│ │ ├── generate_rl_data.py # Create RL training dataset
│ │ ├── generate_dpo_sft_data.py # Generate model trajectories
│ │ ├── process_sft_dpo_data.py # Process into SFT/DPO format
│ │ ├── run_sft_dpo_pipeline.sh # End-to-end dataset generation
│ │ └── sample_queries.py
│ │
│ ├── evals/ # Evaluation scripts
│ │ ├── generate_test.py # Run inference on test set
│ │ ├── score_search_only.py # Compute retrieval metrics
│ │ └── test_generations_sft_dpo/ # Generated trajectory CSVs
│ │
│ └── trainers/ # Training implementations
│ ├── dense_search/ # Dense retrieval (FAISS)
│ │ └── search.py
│ ├── train_sft.py # Supervised fine-tuning
│ ├── train_dpo.py # Direct preference optimization
│ ├── train_grpo.py # Group relative policy optimization
│ ├── train_ospo.py # Owen-Shapley policy optimization
│ ├── ospo_utils.py # OSPO-specific utilities
│ ├── reward_utils.py # Reward computation
│ └── generation_utils.py # Text generation helpers
│
├── outputs/ # Training outputs (created at runtime)
│ ├── ospo_ablations_esci/ # OSPO/GRPO checkpoints
│ ├── sft_models/ # SFT checkpoints
│ └── dpo_models/ # DPO checkpoints
├── ospo_code_final_review.pdf
├── pyproject.toml
├── README.md
├── README_old.md
├── requirements.txt
└── setup.py
-
data/esci/: ESCI product search dataset and artifactsmetadata/: Product catalog with ASIN metadata (~3M items, 1.8GB)embeddings/: Pre-computed dense embeddings (all-mpnet-base-v2, 10GB)index/: FAISS HNSW search index for dense retrieval (10GB)rl_dataset/: Prepared dataset for RL training (queries + candidate pools)sft_dataset.jsonl: High-quality samples for supervised fine-tuningdpo_dataset.jsonl: Preference pairs for direct preference optimization
-
src/esci_search/: Core codebase for the ESCI product search taskconfigs/: YAML configurations for different training methods (OSPO, GRPO)data_processing/: Scripts to prepare datasets, indices, embeddings, and SFT/DPO dataevals/: Inference and evaluation utilities, trajectory generation outputstrainers/: Training implementations (SFT, DPO, GRPO, OSPO) and search utilities
-
outputs/: All model checkpoints organized by training method (git-ignored)ospo_ablations_esci/: OSPO and GRPO model checkpointssft_models/: Supervised fine-tuning checkpointsdpo_models/: Direct preference optimization checkpoints
The pipeline prepares the ESCI Shopping Queries dataset for OSPO training by creating item metadata, dense embeddings, search indices, ground truth queries, and RL-formatted training data.
Run all 5 steps sequentially from the data processing directory:
cd src/esci_search/data_processing
bash data_process.shThis executes all stages in order and stops automatically if any step fails.
Run each step individually for debugging or customization:
Script: create_item_metadata.py
Extracts product metadata (titles + descriptions) from Amazon Reviews dataset for all ASINs in ESCI.
What it does:
- Downloads ESCI dataset from HuggingFace
- Filters for US locale, exact matches, small version
- Samples 50 negatives per query from same category
- Streams metadata from Amazon Reviews JSONL files
- Creates ASIN→metadata mapping
Output: data/esci/metadata/item_catalog.jsonl (~3M items, 1.8GB)
python create_item_metadata.pyScript: generate_embeddings.py
Creates dense vector embeddings for all products using sentence-transformers.
What it does:
- Loads
item_catalog.jsonl - Encodes metadata using
all-mpnet-base-v2(768-dim) - Multi-GPU acceleration (distributes across available GPUs)
- Saves embeddings as binary numpy array
- Creates index→ASIN mapping for FAISS lookups
Output:
data/esci/embeddings/all-mpnet-base-v2.npydata/esci/embeddings/all-mpnet-base-v2_asin_mapping.json
Runtime: ~2-3 hours for 3M items on 2× RTX A6000
python generate_embeddings.py \
--model_name sentence-transformers/all-mpnet-base-v2 \
--batch_size 1024Script: build_index.py
Builds HNSW index for fast approximate nearest neighbor search.
What it does:
- Loads embeddings from Step 2
- Normalizes vectors for cosine similarity
- Builds FAISS HNSW index (M=32, ef_construction=200)
- Copies ASIN mapping to index directory
- Saves index metadata (dimensions, build time, etc.)
Output:
data/esci/index/all-mpnet-base-v2_faiss.bin(160MB)data/esci/index/all-mpnet-base-v2_asin_mapping.json(1.2MB)data/esci/index/all-mpnet-base-v2_faiss_metadata.json(152B)
Runtime: ~2-5 minutes
python build_index.pyScript: sample_queries.py
Curates high-quality queries with multiple verified exact matches for evaluation.
What it does:
- Processes ESCI to group exact matches by query
- Filters for queries with 3-10 exact matches
- Validates semantic relevance (≥30% token overlap with titles)
- Samples 40,000 queries with fixed random seed
- Deduplicates product matches per query
Output: data/esci/ground_truth/query_ground_truth_40000.json
python sample_queries.py \
--num_queries 40000 \
--min_matches 3 \
--max_matches 10 \
--seed 42Script: generate_rl_data.py
Creates TRL-compatible RL dataset with candidate pools for OSPO/GRPO training.
What it does:
- Loads ground truth queries from Step 4
- Loads ASIN mapping from index to filter valid items
- Samples diverse negatives per query:
- 50 from same category (hard negatives)
- 30 from related categories (medium)
- 20 random from catalog (easy)
- Creates candidate pools (1 target + 99 negatives)
- Formats prompts with query expansion system instructions
- Splits 90/10 into train/test
Output: data/esci/rl_dataset/ (HuggingFace DatasetDict)
- Train: 9,339 samples
- Test: 1,038 samples
python generate_rl_data.pyWe provide training scripts for OSPO and GRPO baselines. All experiments use Qwen/Qwen2.5-7B-Instruct as the base model with LoRA fine-tuning.
Training configs are located in src/esci_search/configs/:
ospo_prop_config.yaml- OSPO with proportional redistributionospo_rank_config.yaml- OSPO with rank-based redistributionospo_prop_no_clip.yaml- OSPO without advantage clippinggrpo_config.yaml- GRPO baseline
# Model Parameters
model_name_or_path: "Qwen/Qwen2.5-7B-Instruct"
use_peft: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
load_in_4bit: true
torch_dtype: "bfloat16"
# Training Hyperparameters
learning_rate: 5.0e-6
lr_scheduler_type: "cosine"
warmup_steps: 100
weight_decay: 0.05
optimizer_type: "paged_adamw_32bit"
max_steps: 2000
seed: 0
# Batch & Distribution
per_device_train_batch_size: 8
per_device_eval_batch_size: 8
gradient_accumulation_steps: 2
gradient_checkpointing: true
num_generations: 8 # MCMC samples for OSPO
max_prompt_length: 256
max_completion_length: 356
#search reward specific
faiss_topk: 1000
# Data Paths
dataset_path: "data/esci/rl_dataset"
faiss_index_path: "data/esci/index/all-mpnet-base-v2_faiss.bin"
asin_mapping_path: "data/esci/index/all-mpnet-base-v2_asin_mapping.json"
metadata_path: "data/esci/metadata/item_catalog.jsonl"
embedding_model: "sentence-transformers/all-mpnet-base-v2"
# OSPO / Search Specific
owen_max_width: 32
owen_max_permutations: 64
redistribution_mode: "owen_weights"
clip_ospo_advantages: true
faiss_topk: 1000
reward_topk: 100
# Logging & Eval
output_dir: "./outputs/ospo_ablations_esci"
logging_steps: 1
save_steps: 500
eval_steps: 200
eval_strategy: "steps"
report_to: "wandb"
log_completions: true
sanity_check: falseNote: Token length parameters (256 for prompts, 356 for completions) are shorter than H&M dataset since ESCI prompts are based on conventional query rewrites which are naturally shorter. Verify these values with your tokenizer capture >99% of the dataset.
First, generate training datasets from model trajectories:
cd src/esci_search/data_processing
./run_sft_dpo_pipeline.shThis creates:
data/esci/sft_dataset.jsonl- High-quality samples (NDCG > 0.3)data/esci/dpo_dataset.jsonl- Preference pairs (chosen vs rejected)- Generation CSVs in
src/esci_search/evals/test_generations_sft_dpo/
Train with supervised fine-tuning for better initialization:
# Basic SFT
python src/esci_search/trainers/train_sft.py --model qwen-7b
# With accelerate
accelerate launch src/esci_search/trainers/train_sft.py --model qwen-7bCheckpoints saved to: ./sft_trained_models/qwen-7b/checkpoint-{step}/
Train with preference optimization (optionally from SFT checkpoint):
# DPO from base model
python src/esci_search/trainers/train_dpo.py --model qwen-7b
# DPO from SFT checkpoint (recommended)
python src/esci_search/trainers/train_dpo.py \
--model qwen-7b \
--sft_checkpoint_path sft_trained_models/qwen-7b/checkpoint-150
# With accelerate
accelerate launch src/esci_search/trainers/train_dpo.py --model qwen-7bCheckpoints saved to: ./trained_dpo_models/qwen-7b/checkpoint-{step}/
# OSPO Proportional
accelerate launch src/esci_search/trainers/train_ospo.py \
--config src/esci_search/configs/ospo_prop_config.yaml
# OSPO Rank-based
accelerate launch src/esci_search/trainers/train_ospo.py \
--config src/esci_search/configs/ospo_rank_config.yaml
# OSPO without clipping
accelerate launch src/esci_search/trainers/train_ospo.py \
--config src/esci_search/configs/ospo_prop_no_clip.yamlaccelerate launch src/esci_search/trainers/train_grpo.py \
--config src/esci_search/configs/grpo_config.yamlSFT/DPO Checkpoints:
./sft_trained_models/qwen-7b/
├── checkpoint-50/
├── checkpoint-100/
└── checkpoint-150/
./trained_dpo_models/qwen-7b/
├── checkpoint-15/
└── checkpoint-30/
OSPO/GRPO Checkpoints:
./outputs/ospo_ablations_esci/<MODEL_TAG>/<RUN_NAME>/
├── checkpoint-500/
├── checkpoint-1000/
├── checkpoint-1500/
└── checkpoint-2000/
Run inference over the evaluation dataset using stored checkpoints:
For OSPO/GRPO models:
cd src/esci_search/evals
python generate_test.py \
--base_model "Qwen/Qwen2.5-7B-Instruct" \
--checkpoint_base "<PATH_TO_RUN_DIR_WITH_CHECKPOINTS>" \
--dataset_path "../../data/esci/rl_dataset" \
--output_dir "./test_generations" \
--temperature 1.0 \
--batch_size 1024 \
--max_new_tokens 512For SFT/DPO models:
python generate_test.py \
--base_model "Qwen/Qwen2.5-7B-Instruct" \
--checkpoint_base "sft_trained_models/qwen-7b" \
--dataset_path "../../data/esci/rl_dataset" \
--output_dir "./test_generations_sft" \
--temperature 1.0 \
--batch_size 1024For base model baseline:
python generate_test.py \
--base_model "Qwen/Qwen2.5-7B-Instruct" \
--dataset_path "../../data/esci/rl_dataset" \
--output_dir "./test_generations_base" \
--temperature 1.0 \
--batch_size 1024Notes:
--dataset_pathmust contain a "test" split with messages in chat format--checkpoint_baseshould containcheckpoint-*subdirectories- Reduce
--batch_sizeto 256/128 if OOM occurs
We compute retrieval-based metrics (NDCG, AP, MRR) using FAISS-based evaluation. We report two variants:
This variant restricts scoring to each query's candidate pool (recommended for pooled evaluation).
python score_search_only.py \
--input_csv_glob "./test_generations/*.csv" \
--output_dir "./test_generations_with_pools" \
--fixed_k 10,20,50,100,1000 \
--text_cols "expanded_query,response" \
--batch_size 512 \
--top_k 1000 \
--device cuda \
--faiss_index_path "../../data/esci/index/all-mpnet-base-v2_faiss.bin" \
--mapping_path "../../data/esci/index/all-mpnet-base-v2_asin_mapping.json" \
--use_pools true \
--model_name "sentence-transformers/all-mpnet-base-v2"This variant scores against raw FAISS top-K retrieval without candidate filtering.
python score_search_only.py \
--input_csv_glob "./test_generations/*.csv" \
--output_dir "./test_generations_without_pools" \
--fixed_k 10,20,50,100,1000 \
--text_cols "expanded_query,response" \
--batch_size 512 \
--top_k 1000 \
--device cuda \
--faiss_index_path "../../data/esci/index/all-mpnet-base-v2_faiss.bin" \
--mapping_path "../../data/esci/index/all-mpnet-base-v2_asin_mapping.json" \
--use_pools false \
--model_name "sentence-transformers/all-mpnet-base-v2"All commands assume you're in the appropriate subdirectory:
- Data processing:
cd src/esci_search/data_processing - Training:
cd src/esci_search/trainers - Evaluation:
cd src/esci_search/evals
Important Note: While this pipeline references the Amazon Reviews 2023 dataset, it does NOT use customer review data. We only extract:
- ASIN-to-category mappings (
asin2category.json) - Publicly available product taxonomy - Product metadata (
meta_*.jsonlfiles) - Product titles and descriptions from the catalog
No customer reviews, ratings, or review text are accessed or processed. The Amazon Reviews 2023 dataset serves as a source of product catalog information only, which is necessary to enrich ESCI product ASINs with textual metadata for embedding generation.
If you cannot access the Amazon Reviews 2023 dataset, you can alternatively use the ESCI product titles directly, though this may reduce retrieval quality due to less rich metadata.
If you use this code or OSPO in your research, please cite:
@misc{nath2026owenshapleypolicyoptimizationospo,
title={Owen-Shapley Policy Optimization (OSPO): A Principled RL Algorithm for Generative Search LLMs},
author={Abhijnan Nath and Alireza Bagheri Garakani and Tianchen Zhou and Fan Yang and Nikhil Krishnaswamy},
year={2026},
eprint={2601.08403},
archivePrefix={arXiv},
primaryClass={cs.AI},
url={https://arxiv.org/abs/2601.08403},
}CC-BY-NC-4.0 license
For questions or issues, please open a GitHub issue or contact abhijnan.nath@colostate.edu.