Skip to content

McGill-NLP/the-markovian-thinker

Repository files navigation

The Markovian Thinker: Architecture-Agnostic Linear Scaling of Reasoning

Milad Aghajohari*, Kamran Chitsaz*, Amirhossein Kazemnejad*,
Sarath Chandar, Alessandro Sordoni†, Aaron Courville†, Siva Reddy†

*Equal Contribution   †Equal Advising

Paper Hugging Face Collection

Table of Contents

Updates

  • October 2025: 🎉 We release our paper, models and codebase.

Links

TL;DR

RL for reasoning LLMs has a trivial underlying RL environment (MDP) that treats the state as the whole prompt plus all past thinking tokens. That state keeps growing, which make the computation cost quadratic.

We propose Markovian Thinking paradigm, where the state size remains bounded/fixed. This by design, and no matter the policy architecture, makes the compute cost linear with the number of thinking tokens, and memory stays flat.

Delethink:

Delethink makes the environment chunked: generation happens in fixed-size chunks. At each boundary we reset the context to a fresh prompt with the original question plus a short carryover. The model learns to carry progress forward in text — becoming a Markovian thinker. LongCoT, in contrast, keeps concatenating tokens so context grows.

  • While only having 8K active context, Delethink 24K matches or beats LongCoT-RL 24K in accuracy while using less compute.
  • Beyond the trained budget, Delethink keeps improving while others plateau. Within the budget, both scale similarly under sequential sampling.
  • Training cost vs. average thinking length on H100s: LongCoT ~ quadratic, Delethink ~ linear.

Scaling Delethink to 96K

GPT-OSS & Qwen3 Thinking show signs of Markovian Thinking

State-of-the-art reasoning LLMs, GPT-OSS-120B and Qwen3-30B-A3B, are capable of Markovian Thinking zero-shot, providing a strong initialization for training, signaling scalability. Delethink closely tracks LongCoT and recovers most of its final accuracy (the curves nearly coincide on Qwen3). Overall, Delethink looks promising for scaling.

Takeaways

  • RLVR is governed by a trivial MDP that everyone has forgotten. We can actually design any MDP we want.
  • By making the state size bounded, we propose the Markovian Thinking paradigm, where the model learn to advance its reasoning by only conditioning on fixed-size states.
  • Delethink is simple and effective: with an 8K fixed state it matches or beats LongCoT-RL and can think up to 128K tokens.
  • GPT-OSS 120B and Qwen3 30B-A3B already show strong signs of Markovian thinking.

If you are interested in more details, please check out our paper!

Usage

Install

Our codebase is built on top of verl and SGLang. We provide a pre-built docker image and an instalation recipe based on uv. For more details, please see INSTALLATION.md.

Quick Start

Use the following script to quickly start a Delethink training on DeepScaleR dataset with 24K response length with linear compute!

#!/bin/bash

NUM_GPUS=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)

CONFIG_NAME="r1d-1.5b_deepscaler_delethink_24k"

export VERL_LOGGING_LEVEL=INFO
export TREETUNEV__EXP_NAME=$CONFIG_NAME
export TREETUNEV__NNODE=1
export TREETUNEV__NUM_GPUS_PER_NODE=$NUM_GPUS # should be set, otherwise defaults to 1

# base model
model_path=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B

# delethink parameters
# Total reasoning budget (C + (I-1) * (C-m)) = 8192 + (5-1) * (8192-8192/2) = 24576
max_response=8192
intermediate_max_new_tokens=4096
keep_head=100
fixed_num_optim_steps=2
multi_turn_max_assistant_turns=5

# rollout and sampling
gpu_memory_utilization=0.8
rollout_n=8
temp=0.6
top_k=-1
top_p=1.0

# validation sampling params
val_top_k=-1
val_top_p=1.0
val_temperature=0.6
val_n=32
val_do_sample=true

# trainer schedule and logging
train_steps=1000
save_freq=10

# checkpoint policy
keep_every_n_saves=5
push_to_hub_freq=20

# actor optimization
clip_ratio_high=0.26
ppo_max_token_len_per_gpu=32768
ppo_epochs=1

# build hydra overrides
overrides=(
  actor_rollout_ref.model.path=$model_path

  data.max_response_length=$max_response
  algorithm.delethink.intermediate_max_new_tokens=$intermediate_max_new_tokens
  algorithm.delethink.keep_head=$keep_head
  algorithm.delethink.fixed_num_optim_steps=$fixed_num_optim_steps
  actor_rollout_ref.rollout.multi_turn.max_assistant_turns=$multi_turn_max_assistant_turns

  actor_rollout_ref.rollout.gpu_memory_utilization=$gpu_memory_utilization

  actor_rollout_ref.actor.loss_agg_mode=seq-mean-token-norm-trace-length
  actor_rollout_ref.actor.clip_ratio_high=$clip_ratio_high
  actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$ppo_max_token_len_per_gpu
  actor_rollout_ref.actor.ppo_epochs=$ppo_epochs

  actor_rollout_ref.actor.checkpoint.keep_every_n_saves=$keep_every_n_saves
  actor_rollout_ref.actor.checkpoint.push_to_hub_freq=$push_to_hub_freq

  actor_rollout_ref.rollout.n=$rollout_n
  actor_rollout_ref.rollout.temperature=$temp
  actor_rollout_ref.rollout.top_k=$top_k
  actor_rollout_ref.rollout.top_p=$top_p

  actor_rollout_ref.rollout.val_kwargs.top_k=$val_top_k
  actor_rollout_ref.rollout.val_kwargs.top_p=$val_top_p
  actor_rollout_ref.rollout.val_kwargs.temperature=$val_temperature
  actor_rollout_ref.rollout.val_kwargs.n=$val_n
  actor_rollout_ref.rollout.val_kwargs.do_sample=$val_do_sample

  trainer.total_training_steps=$train_steps
  trainer.save_freq=$save_freq
)

python -m verl.trainer.main_policy_iteration \
    --config-name=$CONFIG_NAME \
    ${overrides[@]}

For more tunable scripts, please see:

Reproduce RL experiments

To reproduce the RL experiments in the paper, please use the following scripts:

# LongCoT-RL with 8K thinking tokens
sh examples/reproduce_rl_training/longcot_8k.sh

# LongCoT-RL with 24K thinking tokens
sh examples/reproduce_rl_training/longcot_24k.sh

# Delethink-RL with 24K thinking tokens
sh examples/reproduce_rl_training/delethink_24k.sh

# Delethink-RL with 96K thinking tokens
sh examples/reproduce_rl_training/delethink_96k.sh

These scripts are tested on single 8xH100 GPU node. For multi-node training, please see ./examples/multi_node_template.sh.

Evaluation

TBD

Delethink Tracing Demo

We provide a single-file, self-contained Delethink tracing implementation in delethink_tracing_demo.py. You can use the following command to run the demo:

python delething_tracing_demo.py

This will run Delethink tracing on the entire AIME 2024 eval set using SGLang and R1-Distill-Qwen-1.5B. Options:

  • --model (str, default: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B): Model path to use for sglang.Engine
  • --tp (int, default: 1): Tensor model parallel size
  • --dp (int, default: 1): Data parallel size; parallel inference engines
  • --samples (int, default: -1): Number of samples (all if -1)
  • --delethink_context_size (int, default: 8192): Delethink context size (C)
  • --delethink_markovian_size (int, default: 4096): Delethink markovian size (m)
  • --delethink_iteration_cap (int, default: 5): Delethink iteration cap (I)

Delethink Implementation Reference

The core parts of the Delethink implementation are:

  1. Inference Loop (verl calls this "Agent Loop"): verl/experiments/agent_loop/delethink_agent_loop.py - Implements the logic for generating Delethink traces.
  2. Trainer: verl/trainer/delethink_ppo/ray_trainer.py - Implements training using Delethink traces
  3. Configs:

Codebase Overview

Directory Structure

├── examples/                   # Example scripts for training
├── verl/                       # the verl codebase
│   ├── trainer/
│   │   └── config/             # Configuration files for different experiments
├── scripts/                    # Utility scripts
├── delethink_tracing_demo.py   # Single-file, self-contained Delethink tracing implementation
├── INSTALLATION.md             # Installation guide
├── README.md                   # README for the codebase

Important Config Files

Config File Description
verl/trainer/config/r1d-1.5b_deepscaler.yaml The parent config file for all DeepScaleR experiments
verl/trainer/config/r1d-1.5b_deepscaler_longcot_24k.yaml The config file for LongCoT-RL 24K
verl/trainer/config/r1d-1.5b_deepscaler_longcot_8k.yaml The config file for LongCoT-RL 8K
verl/trainer/config/r1d-1.5b_deepscaler_delethink_24k.yaml The config file for Delethink 24K
verl/trainer/config/r1d-1.5b_openmath_delethink_96k.yaml The config file for Delethink 96K

Core Components

Main Entry Point:

Trainers:

Actor (Policy backward):

Rollout (SGLang/Inference Engine):

Dataset:

Reward Function/Mechanism:

Reward managers wrap the reward function and handle some custom logics like length penalty, etc. We mostly use naive.py/delethink.py reward managers which are thin wrappers and don't modify the reward computation.

ActorRolloutRefWorker:

Trajectory Visualization

Use the following command to visualize the LongCoT trajectories

pip install textual==0.52.1
python scripts/rollout_viewer.py experiments/<exp_name>/train_rollouts

To visualize the unflattened Delethink trajectories, use the following command:

python scripts/rollout_viewer.py experiments/<exp_name>/train_rollouts_delethink

FAQ

How to setup the VSCode ray debugger?

Show answer

Follow the guide: VERL Ray Debug Tutorial

Where is the config file for Delethink?

Show answer

The config file is located at: verl/trainer/config/r1d-1.5b_deepscaler_delethink_24k.yaml and verl/trainer/config/r1d-1.5b_openmath_delethink_96k.yaml.

Where is loss aggregation implemented?

Show answer

Loss aggregation is implemented in verl/trainer/ppo/core_algos.py using the agg_loss and agg_loss_with_trace_lengths functions.

To change the aggregation mode, use actor_rollout_ref.actor.loss_agg_mode in the config.

What does all the batch size configuration mean in the VERL universe?

Show answer
  • data.train_batch_size and actor_rollout_ref.actor.ppo_mini_batch_size are with respect to the number of prompts
  • To map to TreeTune Next convention:
    • number_of_episodes_per_iteration = data.train_batch_size * actor_rollout_ref.rollout.n
    • target_batch_size = actor_rollout_ref.actor.ppo_mini_batch_size * actor_rollout_ref.rollout.n
  • The number of gradient updates per iteration is data.train_batch_size / actor_rollout_ref.actor.ppo_mini_batch_size
  • We also have algorithm.delethink.fixed_num_optim_steps which allows us to dynamically adjust the batch size to have a fixed number of gradient updates per iteration

How to use smaller batch sizes for debugging/development?

Show answer

Use configuration overrides like this:

python -m verl.trainer.main_policy_iteration \
    --config-name="<config_name>" \
    trainer.val_before_train=False \
    actor_rollout_ref.rollout.enable_debug=True \
    actor_rollout_ref.rollout.n=2 \
    actor_rollout_ref.actor.ppo_mini_batch_size=4 \
    data.train_batch_size=4

How can I add a new config?

Show answer

Simply define it in YAML format. You can also assign a target config class to perform config validation.

See algorithm.delethink._target_ in verl/trainer/config/r1d-1.5b_deepscaler_delethink_24k.yaml for an example.

Where are all important components in VERL implemented?

Show answer

Main Entry Point:

Trainers:

Actor (Policy backward):

Rollout (SGLang/Inference Engine):

Dataset:

Reward Function/Mechanism:

Reward managers wrap the reward function and handle some custom logics like length penalty, etc. We mostly use naive.py/delethink.py reward managers which are thin wrappers and don't modify the reward computation.

ActorRolloutRefWorker:

Citation

@misc{Aghajohari2025:TheMarkovianThinker,
      title={The Markovian Thinker}, 
      author={Milad Aghajohari and Kamran Chitsaz and Amirhossein Kazemnejad and Sarath Chandar and Alessandro Sordoni and Aaron Courville and Siva Reddy},
      year={2025},
      eprint={2510.06557},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2510.06557}, 
}

Acknowledgement

  • The training infra is built on top of verl, which employs SGLang
  • Parts of this code is borrowed from VinePPO.

About

Code for paper "The Markovian Thinker: Architecture-Agnostic Linear Scaling of Reasoning"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages