Skip to content
Merged
Changes from 8 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
6c25d22
docs: refresh README for amd-flashinfer library consumers
demandal25 Apr 23, 2026
3b7de13
docs: align README with amd-integration and refresh feature matrix
demandal25 May 21, 2026
1256e66
docs: clarify HIP vs AITER backends and auto-routing in feature matrix
demandal25 May 21, 2026
efb7ee1
docs: fold fp8 status into per-row notes in feature matrix
demandal25 May 21, 2026
4650cf5
docs: fix torch.compile env var name and document runtime env vars / …
demandal25 May 21, 2026
9f650bf
docs: collapse docker image table to the latest tag
demandal25 May 21, 2026
7c122a6
docs: drop manual micromamba activate from docker verify step
demandal25 May 21, 2026
012c053
docs: use concrete image tag and container name in docker run
demandal25 May 21, 2026
4cdf823
docs(readme): simplify "Trying the Examples" to point at examples/
demandal25 May 21, 2026
7da63b9
docs(readme): drop redundant Single Prefill Example from AITER section
demandal25 May 21, 2026
7b3bcdd
docs(readme): rename "Build from Source" to "Install from Source"
demandal25 May 21, 2026
9f13ae1
docs(readme): link CDNA3 / CDNA4 to their architecture references
demandal25 May 21, 2026
86448ff
docs(readme): point CDNA3 / CDNA4 links to official whitepapers
demandal25 May 21, 2026
57c8cc8
docs(readme): tighten intro and call out the HIP + AITER split
demandal25 May 21, 2026
df1cb84
docs(readme): move Basic Usage to the end of the README
demandal25 May 21, 2026
dfa7a0f
docs(readme): proofread, dedupe, and clarify after fact-check
demandal25 May 21, 2026
2d1d0f9
feat(mla): accept backend="auto" on ROCm as an alias for "aiter"
demandal25 May 21, 2026
713aeca
docs(readme): clarify the dev-container "Environment name" note
demandal25 May 21, 2026
16afd18
docs(readme): drop redundant Docker-tag note from dev-container section
demandal25 May 21, 2026
6140da4
docs(readme): address Copilot review feedback
demandal25 May 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 182 additions & 65 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
# FlashInfer+ROCm: An AMD ROCm port of FlashInfer

FlashInfer+ROCm is a port of the [FlashInfer](https://github.com/flashinfer-ai/flashinfer) library
that adds support for AMD Instinct GPUs. The project is in active development with current focus on
porting attention kernels to ROCm.

**Versioning:** The release tag format `<upstream_version>+amd` ties each FlashInfer+ROCm release
to its corresponding upstream tag (e.g., `0.2.5+amd.2` is second release of amd-flashinfer based on upstream version `v0.2.5`).
FlashInfer+ROCm is an AMD ROCm port of the
[FlashInfer](https://github.com/flashinfer-ai/flashinfer) attention,
RoPE, normalization, sampling, and logits-processor kernels for LLM
inference on AMD Instinct GPUs. The port targets CDNA3 (gfx942 —
MI300X / MI325X) and CDNA4 (gfx950 — MI355X), and is aimed at developers
embedding FlashInfer kernels into their own training or serving stack.

The project is in active development with the primary focus on attention
(single and batch prefill / decode) and the surrounding KV-cache, RoPE,
and normalization kernels. See [CHANGELOG.md](CHANGELOG.md) for the
full release history.

**Versioning:** The release tag format `<upstream_version>+amd.<n>` ties
each FlashInfer+ROCm release to its corresponding upstream tag (e.g.
`0.5.3+amd.1` is the first AMD release based on upstream `v0.5.3`).

## Table of Contents

* [Basic Usage](#basic-usage)
* [Feature Support Matrix](#feature-support-matrix)
* [GPU and ROCm Support](#gpu-and-rocm-support)
* [GPU, ROCm, and PyTorch Support](#gpu-rocm-and-pytorch-support)
* [Getting Started](#getting-started)
* [Option 1: Get a Pre-built Docker Image](#option-1-get-a-pre-built-docker-image)
* [Option 2: Install from a Wheel Package](#option-2-install-from-a-wheel-package)
Expand All @@ -20,67 +30,121 @@ to its corresponding upstream tag (e.g., `0.2.5+amd.2` is second release of amd-
* [Building and Installing a Wheel Package](#building-and-installing-a-wheel-package)
* [Running Tests](#running-tests)
* [AITER Support](#aiter-support)
* [Single Prefill AITER example](#single-prefill-example)
* [Install AITER from source](#install-aiter-from-source)
* [Install AITER wheel package](#install-aiter-wheel-package)
* [Known Limitations](#known-limitations)
* [Single Prefill Example](#single-prefill-example)
* [Environment Variables](#environment-variables)
* [Runtime Helpers](#runtime-helpers)
* [License and Acknowledgements](#license-and-acknowledgements)

## Feature Support Matrix
## Basic Usage

| Kernel Type | FP16 / BF16 | FP8 (E4M3, E5M2) | Has AITER backend | Notes |
| :--- | :---: | :---: | :---: | :--- |
| **Decode Attention** | ✅ | ✅ | No | Supports MHA, GQA, and MQA |
| **Prefill Attention** | ✅ | WIP | ✅ | Supports MHA, GQA, and MQA |
| **Cascade Attention** | TBD | TBD | No | Not Yet Ported |
| **MLA** | TBD | TBD | No | Not Yet Ported |
| **POD** | TBD | TBD | No | Not Yet Ported |
| **Positional Encoding** | TBD | TBD | No | Not Yet Ported |
| **Sampling** | ✅ | TBD | No | Supports Top-K/Top-P Sampling/OnlineSoftmax/SamplingFromLogits |
| **Logits Processor** | ✅ | TBD | No | |
| **Normalization** | ✅ | TBD | No | Supports RMS-Norm/Layer-Norm |
```python
import torch
import flashinfer

## GPU and ROCm Support
# PyTorch+ROCm still uses device="cuda" for AMD GPUs.
q = torch.randn(1024, 32, 128, dtype=torch.float16, device="cuda")
k = torch.randn(1024, 8, 128, dtype=torch.float16, device="cuda") # GQA 4:1
v = torch.randn(1024, 8, 128, dtype=torch.float16, device="cuda")

**Supported GPU:** gfx942 (CDNA3 architecture), gfx950 (CDNA4 architecture)
# backend="auto" (default) routes to AITER when supported on gfx942/gfx950
# and falls back to the in-tree fa2 HIP kernel otherwise.
output = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True)
```

**Supported ROCm versions:** 7.0.2, 7.1.1, 7.2
See [`examples/`](examples/) for batch prefill, batch decode, and a
Jupyter tutorial that walks through the full public API on ROCm.

## Torch Version Support
## Feature Support Matrix

**Torch+ROCm:** 2.8.0, 2.9.1
Most kernels ship with an in-tree HIP implementation. A subset also has
an [AITER](https://github.com/ROCm/aiter) backend; for those, the
`backend="auto"` default picks AITER when its compatibility conditions
hold and transparently falls back to HIP otherwise. AITER-only kernels
(currently MLA) require an explicit `backend="aiter"`.

Legend: **HIP** = in-tree FlashInfer+ROCm kernel (the historical `fa2`
HIP port, or the `native` JIT kernel for non-attention ops). **AITER** =
ROCm AITER backend.

| Kernel | HIP | AITER | `backend="auto"` resolves to | Notes |
| :--- | :---: | :---: | :--- | :--- |
| **Single decode attention** | ✅ `fa2` | — | HIP | MHA / GQA / MQA |
| **Batch decode attention (paged)** | ✅ `fa2` | ✅ | **AITER** when `fp16/bf16` + `NHD` + no CUDA-graph + `use_tensor_cores=False`; else **HIP** | MHA / GQA / MQA; **fp8 KV-cache (E4M3FNUZ)** on the HIP path; sliding-window on the AITER path; CUDA-graph auto-routes back to HIP |
Comment thread
demandal25 marked this conversation as resolved.
Outdated
Comment thread
demandal25 marked this conversation as resolved.
Outdated
Comment thread
demandal25 marked this conversation as resolved.
Outdated
Comment thread
demandal25 marked this conversation as resolved.
Outdated
| **Single prefill attention** | ✅ `fa2` | ✅ | **AITER** when `fp16/bf16` + `NHD` + no custom mask + equal Q/KV dtypes & head dims + `pos_encoding_mode="NONE"`; else **HIP** | MHA / GQA / MQA; fp8 WIP |
| **Batch prefill attention (paged + ragged)** | ✅ `fa2` | ✅ | Same auto criteria as single prefill | MHA / GQA / MQA; fp8 WIP. AITER native page sizes: `{16, 1024}` (`{128, 256, 1024}` on `amd-aiter==0.1.10`); other sizes go through a gather on the AITER path |
| **Cascade attention** | ✅ | — | HIP | Two-level shared-prefix attention; a fused single-kernel HIP variant is gated behind `FLASHINFER_HIP_FUSED_CASCADE=1` |
| **MLA (Multi-Latent Attention)** | — | ✅ | **AITER only** (no HIP fallback) | DeepSeek-style 192/128 head-dim split; bf16 + `page_size=1`; must pass `backend="aiter"` explicitly |
| **POD attention** | TBD | — | n/a | Code present; **not yet validated on ROCm** |
| **RoPE (positional encoding)** | ✅ | — | HIP | LLaMA-style + LLaMA 3.1 scaling; fused RoPE + fp8 quant + paged-KV append (E4M3FNUZ, E5M2FNUZ) |
| **Paged KV-cache append** | ✅ `native` | ✅ | **AITER** when `fp16/bf16` + `NHD` + AITER importable; else **HIP `native`** | `append_paged_kv_cache`; fp8 KV-cache supported on the HIP path |
Comment thread
demandal25 marked this conversation as resolved.
Outdated
| **RMSNorm** | ✅ `native` | ✅ | **HIP `native`** (auto stays on HIP — AITER is opt-in via `backend="aiter"`) | AITER path is fp16/bf16, 2-D only; slightly lower precision at `hidden_size >= 1024` |
| **LayerNorm / Gemma RMSNorm** | ✅ | — | HIP | |
| **Sampling** | ✅ | — | HIP | Top-K / Top-P / Min-P / OnlineSoftmax / SamplingFromLogits |
| **Logits processor** | ✅ | — | HIP | Composable processor pipeline (cap, mask, temperature, …) |
| **Activation** | ✅ | — | HIP | SiLU / GELU with fused gating |
| **Quantization** | ✅ | — | HIP | `packbits`, `segment_packbits` |
| **`torch.compile`** | ✅ (opt-in) | n/a | n/a | Set `FLASHINFER_USE_TORCH_CUSTOM_OPS=1` **before** importing `flashinfer`; requires PyTorch ≥ 2.4. Without it, `torch.compile` raises a clear error if it traces into a flashinfer op |

Every ✅ row above is exercised by a matching `tests/rocm_tests/test_*_hip.py`.
The full set of conditions that cause AITER auto-routing to fall back to
HIP is documented in [Known Limitations](#known-limitations) below.

Comment thread
demandal25 marked this conversation as resolved.
## GPU, ROCm, and PyTorch Support

**Supported GPUs:** gfx942 (CDNA3 — MI300X, MI325X), gfx950 (CDNA4 — MI355X).

**Supported ROCm versions:** 7.0.2, 7.1.1, 7.2.

**Supported PyTorch+ROCm versions:** 2.8.0, 2.9.1.

Comment thread
demandal25 marked this conversation as resolved.
Install the matching ROCm-enabled PyTorch wheel from
<https://repo.radeon.com>:

**Note**: Other versions may work but have not been tested. Refer to <https://repo.radeon.com/rocm/manylinux/rocm-rel-{rocm-version}/> (replacing `{rocm-version}` with the desired ROCm version, e.g., `7.0.2`) for available versions.
```bash
pip install torch==2.9.1 --index-url https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/
```

Other versions may work but have not been tested. Replace `7.2` with the
ROCm version you need; refer to
<https://repo.radeon.com/rocm/manylinux/rocm-rel-{rocm-version}/> for
available wheels.

## Getting Started

### Option 1: Get a Pre-built Docker Image

AMD validates and publishes [FlashInfer images](https://hub.docker.com/r/rocm/flashinfer/tags)
with ROCm backends on Docker Hub. The following Docker image tag and associated
inventories represent the latest available FlashInfer version from the official Docker Hub.
AMD validates and publishes FlashInfer images with ROCm backends on
Docker Hub. The latest validated tag is:

| Docker image | ROCm | FlashInfer | PyTorch | Ubuntu | Python | GPU |
| ------------ | ---- | ---------- | ------- | ------ | ------ | --- |
| rocm/flashinfer:flashinfer-0.5.3.amd1_rocm7.2_ubuntu24.04_py3.12_pytorch2.9.1 |7.2.0 | v0.5.3 | 2.9.1 | 24.04 | 3.12 | MI355x, MI325X, MI300X |
| rocm/flashinfer:flashinfer-0.5.3.amd1_rocm7.0.2_ubuntu24.04_py3.12_pytorch2.9.1 | 7.0.2 | v0.5.3 | 2.9.1 | 24.04 | 3.12 | MI355x, MI325X, MI300X |
| rocm/flashinfer:flashinfer-0.2.5.amd2_rocm7.1.1_ubuntu24.04_py3.12_pytorch2.8 | 7.1.1 | v0.2.5 | 2.8.0 | 24.04 | 3.12 | MI325X, MI300X |
| `rocm/flashinfer:flashinfer-0.5.3.amd1_rocm7.2_ubuntu24.04_py3.12_pytorch2.9.1` | 7.2.0 | v0.5.3 | 2.9.1 | 24.04 | 3.12 | MI355X, MI325X, MI300X |
Comment thread
demandal25 marked this conversation as resolved.

For older releases (earlier ROCm / PyTorch / FlashInfer combinations),
see the full tag list at
<https://hub.docker.com/r/rocm/flashinfer/tags>.
Comment thread
demandal25 marked this conversation as resolved.

**Start a container:**
Comment thread
demandal25 marked this conversation as resolved.

```bash
docker run -it --privileged --network=host --device=/dev/kfd --device=/dev/dri \
--group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--ipc=host --shm-size 128G --name=<container-name> <docker-image-tag>
--ipc=host --shm-size 128G --name=flashinfer-rocm \
rocm/flashinfer:flashinfer-0.5.3.amd1_rocm7.2_ubuntu24.04_py3.12_pytorch2.9.1
```

**Activate the environment and verify:**
**Verify the installation:**

```bash
# Activate micromamba environment (Note: env name may vary based on the image)
micromamba activate base

# Verify installation
python -c "import flashinfer; print(flashinfer.__version__)"
```

Expected output: `0.5.3+amd.1` (with a possible JIT backend message)
Expected output: `0.5.3+amd.1` (with a possible JIT backend message).
The container's micromamba environment is activated automatically on
shell start — no manual `micromamba activate` is required.

### Option 2: Install from a Wheel Package

Expand All @@ -90,14 +154,14 @@ Install from AMD's package repository:
pip install amd-flashinfer --index-url https://pypi.amd.com/simple/
```

Install the needed ROCm-enabled torch package from <https://repo.radeon.com>:
Install the matching ROCm-enabled torch package from <https://repo.radeon.com>:

```bash
pip install torch==2.9.1 -f https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2
pip install torch==2.9.1 --index-url https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/
```

**NOTE**: The torch version should be exactly as available on repo.radeon.com otherwise a non-ROCm
torch version will get installed from pypi.
**NOTE:** Use `--index-url` (not `-f`) so pip cannot silently fall back
to a CPU-only PyPI wheel.

### Trying the Examples

Expand All @@ -116,11 +180,14 @@ done

**Available examples:**

* `single_prefill_example.py` - Single-sequence prefill attention
* `batch_prefill_example.py` - Batched prefill attention
* `batch_decode_example.py` - Batched decode attention
* `examples/amd_flashinfer_rocm_tutorial.ipynb` - Jupyter tutorial: environment verification (`hip_utils`), AITER-backed prefill examples, and `logits_processor` on ROCm
* `examples/run_jupyter_server.sh` - Start JupyterLab from the repo root (run inside your ROCm/FlashInfer environment or Docker container)
* `single_prefill_example.py` — single-sequence prefill attention
* `batch_prefill_example.py` — batched prefill attention
* `batch_decode_example.py` — batched decode attention
* `examples/amd_flashinfer_rocm_tutorial.ipynb` — Jupyter tutorial:
environment verification (`hip_utils`), AITER-backed prefill examples,
and `logits_processor` on ROCm
* `examples/run_jupyter_server.sh` — start JupyterLab from the repo root
(run inside your ROCm/FlashInfer environment or Docker container)

## Build from Source

Expand Down Expand Up @@ -184,7 +251,8 @@ docker run -it \
</details>
<!-- markdownlint-enable MD033 -->

**Note:** Environment name varies based on Python, PyTorch, and ROCm versions.
**Note:** Environment name varies based on Python, PyTorch, and ROCm
versions.

### Building and Installing a Wheel Package

Expand All @@ -198,12 +266,12 @@ cd dist && pip install amd_flashinfer-*.whl
**Editable install for development:**

```bash
python -m pip install --no-build-isolation -ve.
python -m pip install --no-build-isolation -ve .
```

**Note:** The `--no-deps` flag assumes dependencies are pre-installed. Omit it
to download dependencies during build. AOT builds take longer and use more disk
space but avoid JIT compilation at runtime.
**Note:** The `--no-deps` flag assumes dependencies are pre-installed.
Omit it to download dependencies during build. AOT builds take longer
and use more disk space but avoid JIT compilation at runtime.

### Running Tests

Expand All @@ -214,20 +282,21 @@ The Python tests suite can be run with pytest:
pytest

# Run specific test file
pytest tests/test_decode_kernels_hip.py
pytest tests/rocm_tests/test_batch_decode_kernels_hip.py

# Run with pattern matching
pytest -k "test_decode_kernels_hip"
pytest -k "test_batch_decode_kernels_hip"

# Verbose output
pytest -v

# To run tests parallely on multiple GPUs
pytest -n auto # Uses all available GPUs
pytest -n 2 # Use only two GPUs
# Run tests in parallel across multiple GPUs
pytest -n auto # Uses all available GPUs
pytest -n 2 # Use only two GPUs
```

The default test configuration is specified in [pyproject.toml](pyproject.toml) under the `testpaths` setting.
The default test configuration is specified in [pyproject.toml](pyproject.toml)
under the `testpaths` setting.

#### Recommended invocation on AMD CPX systems

Expand Down Expand Up @@ -256,7 +325,10 @@ pytest -n auto --reruns 2 -m "slow"
## AITER Support

FlashInfer+ROCm supports the use of [AITER](https://github.com/ROCm/aiter) as a
backend. The `aiter` backend is enabled for the `single_prefill` and `batch_prefill` kernels.
backend. The `aiter` backend is enabled for the `single_prefill`,
`batch_prefill` (paged and ragged), `batch_decode`, `append_paged_kv_cache`,
`rmsnorm`, and `MLA` paths. MLA on ROCm is **only** available via AITER —
there is no in-tree HIP MLA kernel yet.
Comment thread
demandal25 marked this conversation as resolved.
Outdated
Comment thread
demandal25 marked this conversation as resolved.
Outdated

**On gfx942/gfx950 GPUs, `backend="auto"` (the default) automatically selects the AITER backend**
when the call parameters are compatible (fp16/bf16, NHD layout, no custom mask, equal Q/K/V
Expand Down Expand Up @@ -316,10 +388,12 @@ results — pass `backend="fa2"` explicitly if you need any of these):
`{16, 1024}` (or `{128, 256, 1024}` on `amd-aiter==0.1.10`). Other page
sizes still work but go through an extra GPU gather to flatten paged KV
before the AITER call.
* Ragged (non-paged) KV is not yet implemented on the AITER batch-prefill
path. `BatchPrefillWithRaggedKVCacheWrapper` therefore forces the backend
to `fa2` regardless of whether you pass `backend="auto"` or
`backend="aiter"` (a warning is logged in the latter case).
* Ragged (non-paged) batch prefill via AITER is supported through
`BatchPrefillWithRaggedKVCacheWrapper`. The wrapper auto-routes to
AITER under `backend="auto"` when the standard AITER compatibility
conditions are met and falls back to `fa2` otherwise.
* MLA on ROCm currently supports only `bfloat16` and `page_size=1`
through the AITER backend.

### Single Prefill Example

Expand All @@ -340,8 +414,51 @@ q = torch.randn(seq_len, num_qo_heads, head_dim, dtype=torch.float16, device="cu
k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda")
v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda")

# Run single prefill attention with causal masking
# Run single prefill attention with causal masking.
# On gfx942/gfx950, backend="auto" (default) routes to AITER automatically.
# Pass backend="aiter" to require AITER explicitly, or backend="fa2" to skip it.
output = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True, backend="auto")
```

## Environment Variables

FlashInfer+ROCm reads the following environment variables at runtime
or import time. Build-time variables (`FLASHINFER_ROCM_ARCH_LIST`,
`FLASHINFER_JIT_VERBOSE`, `FLASHINFER_JIT_DEBUG`, `MAX_JOBS`, …) are
documented in [CLAUDE.md](CLAUDE.md).

Comment thread
demandal25 marked this conversation as resolved.
| Variable | Default | Purpose |
| :--- | :--- | :--- |
| `FLASHINFER_USE_TORCH_CUSTOM_OPS` | `0` | Set to `1` **before** importing `flashinfer` to wrap kernels in `torch.library.custom_op` so `torch.compile` / Dynamo can trace them. Requires PyTorch ≥ 2.4. Adds a small per-call dispatch overhead. |
Comment thread
demandal25 marked this conversation as resolved.
| `FLASHINFER_HIP_FUSED_CASCADE` | `0` | Set to `1` to use a fused single-kernel HIP cascade attention path instead of the default two-level merge-based path. Experimental on ROCm. |
| `FLASHINFER_LOGGING_LEVEL` | `INFO` | Logger verbosity (e.g. `DEBUG`, `INFO`, `WARNING`). Affects AITER auto-fallback warnings and JIT build messages. |
| `FLASHINFER_DISABLE_JIT` | unset | Set to any non-empty value to skip JIT compilation. Useful when running an AOT-built wheel and you want to fail loudly on missing kernels rather than trigger a build. |
| `ROCM_PATH` / `ROCM_HOME` | `/opt/rocm` | Used by `flashinfer.hip_utils` to locate the ROCm install. Override only for non-standard ROCm layouts. |

## Runtime Helpers

`flashinfer` ships a few ROCm-specific helpers that are useful when
guarding code paths or diagnosing setup issues:

```python
from flashinfer.aiter_utils import is_aiter_supported
from flashinfer.hip_utils import (
check_torch_rocm_compatibility,
validate_flashinfer_rocm_arch,
)

# Returns True only on gfx942/gfx950 with the aiter package importable.
if is_aiter_supported(torch.device("cuda")):
...
Comment thread
demandal25 marked this conversation as resolved.

# Raises a clear error if PyTorch + ROCm versions are incompatible
# (e.g. a CPU-only torch wheel was picked up from PyPI).
check_torch_rocm_compatibility()
```
Comment thread
demandal25 marked this conversation as resolved.

## License and Acknowledgements

FlashInfer+ROCm is released under the Apache-2.0 License — see
[LICENSE](LICENSE) and [NOTICE](NOTICE). Upstream project:
[flashinfer-ai/flashinfer](https://github.com/flashinfer-ai/flashinfer).
Run `pre-commit run -a` and `pytest` before opening a PR.
Loading