diff --git a/CLAUDE.md b/CLAUDE.md index 0a622166ad..05383b2c70 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -27,8 +27,9 @@ pip install torch== \ --index-url https://repo.radeon.com/rocm/manylinux/rocm-rel-/ ``` -See the [GPU and ROCm Support](README.md#gpu-and-rocm-support) table in -`README.md` for current `` and `` values. +See the [GPU, ROCm, and PyTorch Support](README.md#gpu-rocm-and-pytorch-support) +table in `README.md` for current `` and `` +values. ## Non-Obvious Gotchas @@ -90,6 +91,8 @@ gh api repos/ROCm/flashinfer/pulls/ --method PATCH --field body="" gh api repos/ROCm/flashinfer/pulls/ --method PATCH --field body="$(cat /tmp/pr_body.md)" ``` +Ask to push to remote. + ## PR Description **Body** — include sections that apply, skip the rest: diff --git a/README.md b/README.md index 1446c2f4d9..30f7fa4096 100644 --- a/README.md +++ b/README.md @@ -1,86 +1,133 @@ # FlashInfer+ROCm: An AMD ROCm port of FlashInfer -FlashInfer+ROCm is a port of the [FlashInfer](https://github.com/flashinfer-ai/flashinfer) library -that adds support for AMD Instinct GPUs. The project is in active development with current focus on -porting attention kernels to ROCm. - -**Versioning:** The release tag format `+amd` ties each FlashInfer+ROCm release -to its corresponding upstream tag (e.g., `0.2.5+amd.2` is second release of amd-flashinfer based on upstream version `v0.2.5`). +FlashInfer+ROCm brings the +[FlashInfer](https://github.com/flashinfer-ai/flashinfer) inference +kernel library to AMD Instinct GPUs — currently +[CDNA3](https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/white-papers/amd-cdna-3-white-paper.pdf) +(gfx942 — MI300X / MI325X) and +[CDNA4](https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/white-papers/amd-cdna-4-architecture-whitepaper.pdf) +(gfx950 — MI355X). It ships in-tree HIP ports of the attention, +KV-cache, RoPE, normalization, sampling, and logits-processor kernels, +and transparently dispatches a subset of attention paths to AMD's +[AITER](https://github.com/ROCm/aiter) backend when its compatibility +conditions hold (see [Feature Support Matrix](#feature-support-matrix)). + +The port is in active development and is aimed at developers embedding +FlashInfer kernels into their own training or serving stack. See +[CHANGELOG.md](CHANGELOG.md) for the full release history. + +**Versioning:** The release tag format `+amd.` ties +each FlashInfer+ROCm release to its corresponding upstream tag (e.g. +`0.5.3+amd.1` is the first AMD release based on upstream `v0.5.3`). ## Table of Contents * [Feature Support Matrix](#feature-support-matrix) -* [GPU and ROCm Support](#gpu-and-rocm-support) +* [GPU, ROCm, and PyTorch Support](#gpu-rocm-and-pytorch-support) * [Getting Started](#getting-started) * [Option 1: Get a Pre-built Docker Image](#option-1-get-a-pre-built-docker-image) * [Option 2: Install from a Wheel Package](#option-2-install-from-a-wheel-package) * [Trying the Examples](#trying-the-examples) -* [Build from Source](#build-from-source) +* [Install from Source](#install-from-source) * [Setting up a Development Environment](#setting-up-a-development-environment) * [Building and Installing a Wheel Package](#building-and-installing-a-wheel-package) * [Running Tests](#running-tests) * [AITER Support](#aiter-support) - * [Single Prefill AITER example](#single-prefill-example) + * [Install AITER from source](#install-aiter-from-source) + * [Install AITER wheel package](#install-aiter-wheel-package) + * [Known Limitations](#known-limitations) +* [Environment Variables](#environment-variables) +* [Runtime Helpers](#runtime-helpers) +* [Basic Usage](#basic-usage) +* [License and Acknowledgements](#license-and-acknowledgements) ## Feature Support Matrix -| Kernel Type | FP16 / BF16 | FP8 (E4M3, E5M2) | Has AITER backend | Notes | -| :--- | :---: | :---: | :---: | :--- | -| **Decode Attention** | ✅ | ✅ | No | Supports MHA, GQA, and MQA | -| **Prefill Attention** | ✅ | WIP | ✅ | Supports MHA, GQA, and MQA | -| **Cascade Attention** | TBD | TBD | No | Not Yet Ported | -| **MLA** | TBD | TBD | No | Not Yet Ported | -| **POD** | TBD | TBD | No | Not Yet Ported | -| **Positional Encoding** | TBD | TBD | No | Not Yet Ported | -| **Sampling** | ✅ | TBD | No | Supports Top-K/Top-P Sampling/OnlineSoftmax/SamplingFromLogits | -| **Logits Processor** | ✅ | TBD | No | | -| **Normalization** | ✅ | TBD | No | Supports RMS-Norm/Layer-Norm | - -## GPU and ROCm Support - -**Supported GPU:** gfx942 (CDNA3 architecture), gfx950 (CDNA4 architecture) - -**Supported ROCm versions:** 7.0.2, 7.1.1, 7.2 +Most kernels ship with an in-tree HIP implementation. A subset also has +an AITER backend; for those, `backend="auto"` picks AITER when its +compatibility conditions hold and falls back to HIP otherwise. The one +AITER-only kernel today (MLA) has no HIP path — `backend="auto"` +resolves directly to `"aiter"`. + +Legend: **HIP** = in-tree kernel (`fa2` for attention, `native` JIT +kernel for non-attention ops). **AITER** = ROCm AITER backend. + +| Kernel | HIP | AITER | `backend="auto"` resolves to | Notes | +| :--- | :---: | :---: | :--- | :--- | +| **Single decode attention** | ✅ `fa2` | — | HIP | MHA / GQA / MQA | +| **Batch decode attention (paged)** | ✅ `fa2` | ✅ | **AITER** when `fp16/bf16` + `NHD` + `pos_encoding_mode="NONE"` + no CUDA-graph + `use_tensor_cores=False`; else **HIP** | MHA / GQA / MQA; **fp8 KV-cache (E4M3FNUZ)** on the HIP path; sliding-window on the AITER path; CUDA-graph auto-routes back to HIP | +| **Single prefill attention** | ✅ `fa2` | ✅ | **AITER** when `fp16/bf16` + `NHD` + no custom mask + equal Q/KV dtypes & head dims + `pos_encoding_mode="NONE"`; else **HIP** | MHA / GQA / MQA; fp8 WIP | +| **Batch prefill attention (paged + ragged)** | ✅ `fa2` | ✅ | Same auto criteria as single prefill | MHA / GQA / MQA; fp8 WIP. AITER native page sizes: `{16, 1024}` (`{128, 256, 1024}` on `amd-aiter==0.1.10`); other sizes go through a gather on the AITER path | +| **Cascade attention** | ✅ | — | HIP | Two-level shared-prefix attention; a fused single-kernel HIP variant is gated behind `FLASHINFER_HIP_FUSED_CASCADE=1` | +| **MLA (Multi-Latent Attention)** | — | ✅ | **AITER** (no HIP fallback) | DeepSeek-style 192/128 head-dim split; bf16 + `page_size=1`; `backend="auto"` (default) resolves to `"aiter"` | +| **POD attention** | TBD | — | n/a | Code present; **not yet validated on ROCm** | +| **RoPE (positional encoding)** | ✅ | — | HIP | LLaMA-style + LLaMA 3.1 scaling; fused RoPE + fp8 quant + paged-KV append (E4M3FNUZ, E5M2FNUZ) | +| **Paged KV-cache append** | ✅ `native` | ✅ | **AITER** when `fp16/bf16` + `NHD` + gfx942/gfx950 + AITER importable; else **HIP `native`** | `append_paged_kv_cache`; fp8 KV-cache supported on the HIP path | +| **RMSNorm** | ✅ `native` | ✅ | **HIP `native`** (auto stays on HIP — AITER is opt-in via `backend="aiter"`) | AITER path is fp16/bf16, 2-D only; slightly lower precision at `hidden_size >= 1024` | +| **LayerNorm / Gemma RMSNorm** | ✅ | — | HIP | | +| **Sampling** | ✅ | — | HIP | Top-K / Top-P / Min-P / OnlineSoftmax / SamplingFromLogits | +| **Logits processor** | ✅ | — | HIP | Composable processor pipeline (cap, mask, temperature, …) | +| **Activation** | ✅ | — | HIP | SiLU / GELU with fused gating | +| **Quantization** | ✅ | — | HIP | `packbits`, `segment_packbits` | +| **`torch.compile`** | ✅ (opt-in) | n/a | n/a | Set `FLASHINFER_USE_TORCH_CUSTOM_OPS=1` **before** importing `flashinfer`; requires PyTorch ≥ 2.4. Without it, `torch.compile` raises a clear error if it traces into a flashinfer op | + +Every ✅ row above is exercised by a matching `tests/rocm_tests/test_*_hip.py`. +The full set of conditions that cause AITER auto-routing to fall back to +HIP is documented in [Known Limitations](#known-limitations) below. + +## GPU, ROCm, and PyTorch Support + +**Supported GPUs:** gfx942 (CDNA3 — MI300X, MI325X), gfx950 (CDNA4 — MI355X). + +**Supported ROCm versions:** 7.0.2, 7.1.1, 7.2. + +**Supported PyTorch+ROCm versions:** 2.8.0, 2.9.1. + +Install the matching ROCm-enabled PyTorch wheel from +: -## Torch Version Support - -**Torch+ROCm:** 2.8.0, 2.9.1 +```bash +pip install torch==2.9.1 --index-url https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/ +``` -**Note**: Other versions may work but have not been tested. Refer to (replacing `{rocm-version}` with the desired ROCm version, e.g., `7.0.2`) for available versions. +Other versions may work but have not been tested. Replace `7.2` with the +ROCm version you need; refer to + for +available wheels. ## Getting Started ### Option 1: Get a Pre-built Docker Image -AMD validates and publishes [FlashInfer images](https://hub.docker.com/r/rocm/flashinfer/tags) -with ROCm backends on Docker Hub. The following Docker image tag and associated -inventories represent the latest available FlashInfer version from the official Docker Hub. +AMD validates and publishes FlashInfer images with ROCm backends on +Docker Hub. The latest validated tag is: | Docker image | ROCm | FlashInfer | PyTorch | Ubuntu | Python | GPU | | ------------ | ---- | ---------- | ------- | ------ | ------ | --- | -| rocm/flashinfer:flashinfer-0.5.3.amd1_rocm7.2_ubuntu24.04_py3.12_pytorch2.9.1 |7.2.0 | v0.5.3 | 2.9.1 | 24.04 | 3.12 | MI355x, MI325X, MI300X | -| rocm/flashinfer:flashinfer-0.5.3.amd1_rocm7.0.2_ubuntu24.04_py3.12_pytorch2.9.1 | 7.0.2 | v0.5.3 | 2.9.1 | 24.04 | 3.12 | MI355x, MI325X, MI300X | -| rocm/flashinfer:flashinfer-0.2.5.amd2_rocm7.1.1_ubuntu24.04_py3.12_pytorch2.8 | 7.1.1 | v0.2.5 | 2.8.0 | 24.04 | 3.12 | MI325X, MI300X | +| `rocm/flashinfer:flashinfer-0.5.3.amd1_rocm7.2_ubuntu24.04_py3.12_pytorch2.9.1` | 7.2.0 | v0.5.3 | 2.9.1 | 24.04 | 3.12 | MI355X, MI325X, MI300X | + +For older releases (earlier ROCm / PyTorch / FlashInfer combinations), +see the full tag list at +. **Start a container:** ```bash docker run -it --privileged --network=host --device=/dev/kfd --device=/dev/dri \ --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ - --ipc=host --shm-size 128G --name= + --ipc=host --shm-size 128G --name=flashinfer-rocm \ + rocm/flashinfer:flashinfer-0.5.3.amd1_rocm7.2_ubuntu24.04_py3.12_pytorch2.9.1 ``` -**Activate the environment and verify:** +**Verify the installation:** ```bash -# Activate micromamba environment (Note: env name may vary based on the image) -micromamba activate base - -# Verify installation python -c "import flashinfer; print(flashinfer.__version__)" ``` -Expected output: `0.5.3+amd.1` (with a possible JIT backend message) +Expected output: `0.5.3+amd.1` (with a possible JIT backend message). +The container's micromamba environment is activated automatically on +shell start — no manual `micromamba activate` is required. ### Option 2: Install from a Wheel Package @@ -90,39 +137,27 @@ Install from AMD's package repository: pip install amd-flashinfer --index-url https://pypi.amd.com/simple/ ``` -Install the needed ROCm-enabled torch package from : +Install the matching ROCm-enabled torch package from : ```bash -pip install torch==2.9.1 -f https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2 +pip install torch==2.9.1 --index-url https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/ ``` -**NOTE**: The torch version should be exactly as available on repo.radeon.com otherwise a non-ROCm -torch version will get installed from pypi. +**NOTE:** Use `--index-url` (not `-f`) so pip cannot silently fall back +to a CPU-only PyPI wheel. ### Trying the Examples -Download and run example scripts from the repository: +Runnable scripts live in the [`examples/`](examples/) directory of this +repository (single/batch prefill, batch decode, plus an +`amd_flashinfer_rocm_tutorial.ipynb` Jupyter notebook). After cloning, +run any of them directly, for example: ```bash -# Download a single example -wget https://raw.githubusercontent.com/ROCm/flashinfer/amd-integration/examples/single_prefill_example.py -python single_prefill_example.py - -# Download all examples -for example in single_prefill_example.py batch_prefill_example.py batch_decode_example.py; do - wget https://raw.githubusercontent.com/ROCm/flashinfer/amd-integration/examples/$example -done +python examples/single_prefill_example.py ``` -**Available examples:** - -* `single_prefill_example.py` - Single-sequence prefill attention -* `batch_prefill_example.py` - Batched prefill attention -* `batch_decode_example.py` - Batched decode attention -* `examples/amd_flashinfer_rocm_tutorial.ipynb` - Jupyter tutorial: environment verification (`hip_utils`), AITER-backed prefill examples, and `logits_processor` on ROCm -* `examples/run_jupyter_server.sh` - Start JupyterLab from the repo root (run inside your ROCm/FlashInfer environment or Docker container) - -## Build from Source +## Install from Source ### Setting up a Development Environment @@ -184,8 +219,6 @@ docker run -it \ -**Note:** Environment name varies based on Python, PyTorch, and ROCm versions. - ### Building and Installing a Wheel Package **Build with JIT (Just-in-Time) compilation only:** @@ -198,36 +231,37 @@ cd dist && pip install amd_flashinfer-*.whl **Editable install for development:** ```bash -python -m pip install --no-build-isolation -ve. +python -m pip install --no-build-isolation -ve . ``` -**Note:** The `--no-deps` flag assumes dependencies are pre-installed. Omit it -to download dependencies during build. AOT builds take longer and use more disk -space but avoid JIT compilation at runtime. +**Note:** The `--no-deps` flag assumes dependencies are pre-installed. +Omit it to download dependencies during build. AOT builds take longer +and use more disk space but avoid JIT compilation at runtime. ### Running Tests -The Python tests suite can be run with pytest: +Run the Python test suite with pytest: ```bash # Run default tests (configured in pyproject.toml) pytest # Run specific test file -pytest tests/test_decode_kernels_hip.py +pytest tests/rocm_tests/test_batch_decode_kernels_hip.py # Run with pattern matching -pytest -k "test_decode_kernels_hip" +pytest -k "test_batch_decode_kernels_hip" # Verbose output pytest -v -# To run tests parallely on multiple GPUs -pytest -n auto # Uses all available GPUs -pytest -n 2 # Use only two GPUs +# Run tests in parallel across multiple GPUs +pytest -n auto # Uses all available GPUs +pytest -n 2 # Use only two GPUs ``` -The default test configuration is specified in [pyproject.toml](pyproject.toml) under the `testpaths` setting. +The default test configuration is specified in [pyproject.toml](pyproject.toml) +under the `testpaths` setting. #### Recommended invocation on AMD CPX systems @@ -248,22 +282,55 @@ pytest -n auto --reruns 2 -m "slow" **Notes** -* `pytest -n auto` for the `tests/rocm_tests/` suite spawns **half as many xdist workers as physical AMD cards** (e.g. 4 workers on a CPX-mode 8-card MI308X / MI325X host). One worker per physical card was tried first but produced sporadic failures across rope, single_prefill, and logits_cap under residual concurrent load; halving the count produces reliable green runs. Each worker is pinned to its card via `HIP_VISIBLE_DEVICES`. On non-CPX systems the helper applies the same halving; users who want every device used can pass an explicit `-n N`. -* `--reruns 2` (from `pytest-rerunfailures`) absorbs the residual ~0.01 % of transient HIP runtime crashes (HSA exceptions, HIPBLAS handle-pool exhaustion, intermittent generator non-determinism) that worker pinning cannot fully eliminate. Successful tests are not duplicated; only failed tests are retried. -* The `slow` marker is registered in [pyproject.toml](pyproject.toml). It tags the 1M-trial sampling-frequency tests, the 4 GB-tensor speculative-sampling cases, and the entire `TestLogitsPipeCompilationHIP` class (every test there runs the sampling kernel twice per case for compile=True/False). -* The reference attention helper in `tests/attention_reference.py` wraps `torch.matmul` in a `_hipblas_safe_matmul` retry helper that catches `HIPBLAS_STATUS_ALLOC_FAILED` and retries with a short back-off — needed under heavy concurrent xdist load. +* **Worker count.** `pytest -n auto` for the `tests/rocm_tests/` suite + spawns **half as many xdist workers as physical AMD cards** (e.g. 4 + workers on a CPX-mode 8-card MI308X / MI325X host) and pins each + worker to its card via `HIP_VISIBLE_DEVICES`. One worker per physical + card was tried first but produced sporadic failures across rope, + single_prefill, and logits_cap under residual concurrent load. + Pass an explicit `-n N` to override the halving. +* **Reruns.** `--reruns 2` (from `pytest-rerunfailures`) absorbs the + residual ~0.01 % of transient HIP runtime crashes (HSA exceptions, + HIPBLAS handle-pool exhaustion, intermittent generator + non-determinism) that worker pinning cannot fully eliminate. Only + failed tests are retried. +* **`slow` marker.** Registered in [pyproject.toml](pyproject.toml). It + tags the 1M-trial sampling-frequency tests, the 4 GB-tensor + speculative-sampling cases, and the entire `TestLogitsPipeCompilationHIP` + class (each test runs the sampling kernel twice for compile=True/False). +* **HIPBLAS retry.** The reference attention helper in + `tests/attention_reference.py` wraps `torch.matmul` in a + `_hipblas_safe_matmul` retry that catches `HIPBLAS_STATUS_ALLOC_FAILED` + and retries with a short back-off — needed under heavy concurrent + xdist load. ## AITER Support -FlashInfer+ROCm supports the use of [AITER](https://github.com/ROCm/aiter) as a -backend. The `aiter` backend is enabled for the `single_prefill` and `batch_prefill` kernels. - -**On gfx942/gfx950 GPUs, `backend="auto"` (the default) automatically selects the AITER backend** -when the call parameters are compatible (fp16/bf16, NHD layout, no custom mask, equal Q/K/V -dtypes and head dims, `pos_encoding_mode="NONE"`). It falls back to `fa2` with a one-time -`logger.warning` when any condition is not met. You can also pass `backend="aiter"` explicitly. - -Unless you are using the prebuilt docker image, AITER must also be installed on your system. You may follow one of the following ways to do so. +FlashInfer+ROCm can dispatch the `single_prefill`, `batch_prefill` +(paged and ragged), `batch_decode`, `append_paged_kv_cache`, `rmsnorm`, +and `MLA` paths to [AITER](https://github.com/ROCm/aiter). MLA on ROCm +is **AITER-only** — there is no in-tree HIP MLA kernel yet, so +`backend="auto"` (the default for the MLA wrapper) resolves directly +to `"aiter"`. + +On gfx942/gfx950, `backend="auto"` (the default) selects AITER when the +call is compatible (see [Known Limitations](#known-limitations) for the +full list) and otherwise falls back to the in-tree HIP kernel, emitting +a one-time `logger.warning`. Pass `backend="aiter"` to require AITER +explicitly, or pass the in-tree backend string to skip it: +`backend="fa2"` for the attention wrappers (single/batch +prefill/decode), `backend="native"` for non-attention ops +(`append_paged_kv_cache`, `rmsnorm`). Two backend-specific exceptions +to "auto picks AITER when supported": + +* `rmsnorm`: `backend="auto"` stays on the HIP `native` kernel; the + AITER path is opt-in via `backend="aiter"`. +* `batch_decode`: `use_cuda_graph=True` or `use_tensor_cores=True` + force `auto` back to `fa2` (AITER decode does not support either), + and `pos_encoding_mode != "NONE"` raises under `backend="aiter"`. + +Unless you are using the prebuilt Docker image, install AITER separately +via one of the options below. ### Install AITER from source @@ -283,11 +350,15 @@ pip install amd-aiter --index-url https://pypi.amd.com/simple/ ### Known Limitations -The AITER backend has the following constraints. With `backend="aiter"` the -call will error on the first group of conditions, or for the second group, -run but silently ignore the unsupported argument. +AITER constraints fall into two groups: hard incompatibilities (the call +errors with `backend="aiter"` and triggers fallback under +`backend="auto"`), and silently-ignored kwargs (the call runs but the +flag has no effect on AITER — pass the in-tree backend explicitly if +you need any of them: `backend="fa2"` for attention wrappers, or +`backend="native"` for `append_paged_kv_cache` / `rmsnorm`). -**Conditions that fall back to `fa2` under `backend="auto"`:** +**Conditions that fall back to the in-tree HIP kernel under +`backend="auto"`** (and raise under `backend="aiter"`): * GPU is not gfx942 or gfx950 * `kv_layout` is not `NHD` @@ -295,15 +366,19 @@ run but silently ignore the unsupported argument. * `q_dtype` is not `float16` / `bfloat16` (no fp32, fp8, or int8) * `q_dtype != kv_dtype` (mixed-precision Q/KV is unsupported) * `head_dim_qk != head_dim_vo` (e.g. DeepSeek-style MLA with 192/128 head dims) +* `pos_encoding_mode != "NONE"` (AITER attention paths only support `"NONE"`) +* batch decode: `use_cuda_graph=True` or `use_tensor_cores=True` * the `aiter` Python package is not importable -**Features silently ignored on the AITER path** (the kwargs are accepted by -the FlashInfer wrapper but not forwarded to AITER, which can produce wrong -results — pass `backend="fa2"` explicitly if you need any of these): +**Features silently ignored on the AITER path** (kwargs are accepted by +the FlashInfer wrapper but not forwarded to AITER, which can produce +wrong results): * ALiBi slopes (`maybe_alibi_slopes`) -* in-kernel positional encoding modes (`pos_encoding_mode`, `rope_scale`, - `rope_theta`) +* RoPE scaling kwargs (`rope_scale`, `rope_theta`) — these are only + consumed alongside `pos_encoding_mode != "NONE"`, which AITER + attention rejects outright; the kwargs themselves pass through + silently when the mode is `"NONE"` * attention sinks (`sinks`) * multi-modal / prefix-cache helpers (`maybe_prefix_len_ptr`, `maybe_token_pos_in_items_ptr`, `maybe_max_item_len_ptr`) @@ -316,32 +391,79 @@ results — pass `backend="fa2"` explicitly if you need any of these): `{16, 1024}` (or `{128, 256, 1024}` on `amd-aiter==0.1.10`). Other page sizes still work but go through an extra GPU gather to flatten paged KV before the AITER call. -* Ragged (non-paged) KV is not yet implemented on the AITER batch-prefill - path. `BatchPrefillWithRaggedKVCacheWrapper` therefore forces the backend - to `fa2` regardless of whether you pass `backend="auto"` or - `backend="aiter"` (a warning is logged in the latter case). +* Ragged (non-paged) batch prefill via AITER is supported through + `BatchPrefillWithRaggedKVCacheWrapper`. The wrapper auto-routes to + AITER under `backend="auto"` when the standard AITER compatibility + conditions are met and falls back to `fa2` otherwise. +* MLA on ROCm currently supports only `bfloat16` and `page_size=1` + through the AITER backend. + +## Environment Variables + +FlashInfer+ROCm reads the following environment variables at runtime +or import time. Build-time variables (`FLASHINFER_ROCM_ARCH_LIST`, +`FLASHINFER_JIT_VERBOSE`, `FLASHINFER_JIT_DEBUG`, `MAX_JOBS`, …) are +documented in [CLAUDE.md](CLAUDE.md). -### Single Prefill Example +| Variable | Default | Purpose | +| :--- | :--- | :--- | +| `FLASHINFER_USE_TORCH_CUSTOM_OPS` | `0` | Set to `1` **before** importing `flashinfer` to wrap kernels in `torch.library.custom_op` so `torch.compile` / Dynamo can trace them. Requires PyTorch ≥ 2.4. Adds a small per-call dispatch overhead. | +| `FLASHINFER_HIP_FUSED_CASCADE` | `0` | Set to `1` to use a fused single-kernel HIP cascade attention path instead of the default two-level merge-based path. Experimental on ROCm. | +| `FLASHINFER_LOGGING_LEVEL` | `INFO` | Logger verbosity (e.g. `DEBUG`, `INFO`, `WARNING`). Affects AITER auto-fallback warnings and JIT build messages. | +| `FLASHINFER_DISABLE_JIT` | unset | Set to any non-empty value to skip JIT compilation. Useful when running an AOT-built wheel and you want to fail loudly on missing kernels rather than trigger a build. | +| `ROCM_PATH` / `ROCM_HOME` | `/opt/rocm` | Used by `flashinfer.hip_utils` to locate the ROCm install. Override only for non-standard ROCm layouts. | -This section provides an example on how to use Single Prefill with AITER. +## Runtime Helpers + +`flashinfer` ships a few ROCm-specific helpers that are useful when +guarding code paths or diagnosing setup issues: + +```python +import torch + +from flashinfer.aiter_utils import is_aiter_supported +from flashinfer.hip_utils import check_torch_rocm_compatibility + +# True on gfx942/gfx950 (a ROCm build + supported GPU arch). Does *not* +# verify the `aiter` Python package is importable — wrap the actual +# AITER call in a try/except ImportError if you need that guarantee. +if is_aiter_supported(torch.device("cuda")): + ... + +# Raises a clear error if PyTorch + ROCm versions are incompatible +# (e.g. a CPU-only torch wheel was picked up from PyPI). +check_torch_rocm_compatibility() +``` + +`flashinfer.hip_utils.validate_flashinfer_rocm_arch` is a related +build-time validator used by `setup.py` to cross-check +`FLASHINFER_ROCM_ARCH_LIST` against ROCm and PyTorch — not typically +called from application code. + +## Basic Usage ```python import torch import flashinfer -# Configuration -seq_len = 1024 # Prompt length -num_qo_heads = 32 # Number of query/output heads -num_kv_heads = 8 # Number of KV heads (GQA with 4:1 ratio) -head_dim = 128 - -# Create Q, K, V tensors (NHD layout: sequence, heads, dimension) -q = torch.randn(seq_len, num_qo_heads, head_dim, dtype=torch.float16, device="cuda") -k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda") -v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda") - -# Run single prefill attention with causal masking -# On gfx942/gfx950, backend="auto" (default) routes to AITER automatically. -# Pass backend="aiter" to require AITER explicitly, or backend="fa2" to skip it. -output = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True, backend="auto") +# PyTorch+ROCm still uses device="cuda" for AMD GPUs. +q = torch.randn(1024, 32, 128, dtype=torch.float16, device="cuda") +k = torch.randn(1024, 8, 128, dtype=torch.float16, device="cuda") # GQA 4:1 +v = torch.randn(1024, 8, 128, dtype=torch.float16, device="cuda") + +# backend="auto" (default) routes to AITER when supported on gfx942/gfx950 +# and falls back to the in-tree fa2 HIP kernel otherwise. +output = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True) ``` + +See [`examples/`](examples/) for batch prefill, batch decode, and a +Jupyter tutorial that walks through the full public API on ROCm. + +## License and Acknowledgements + +FlashInfer+ROCm is released under the Apache-2.0 License — see +[LICENSE](LICENSE) and [NOTICE](NOTICE). Upstream project: +[flashinfer-ai/flashinfer](https://github.com/flashinfer-ai/flashinfer). + +Contributions are welcome. Please run `pre-commit run -a` and the +relevant `pytest` selection before opening a PR. diff --git a/flashinfer/mla_rocm.py b/flashinfer/mla_rocm.py index ec00503ae2..5cb3e1e211 100644 --- a/flashinfer/mla_rocm.py +++ b/flashinfer/mla_rocm.py @@ -99,18 +99,21 @@ class BatchMLAPagedAttentionWrapper: float_workspace_buffer : torch.Tensor Reserved workspace. Size is ignored; only the device is used. backend : str - Must be ``"aiter"`` (only supported backend on ROCm). + Either ``"auto"`` (the default, resolves to ``"aiter"`` on ROCm) + or ``"aiter"``. Any other value raises ``ValueError``. """ def __init__( self, float_workspace_buffer: torch.Tensor, - backend: str = "aiter", + backend: str = "auto", ) -> None: - if backend != "aiter": + if backend not in ("auto", "aiter"): raise ValueError( - f"Only backend='aiter' is supported on ROCm; got {backend!r}." + f"Only backend='aiter' (or 'auto', which resolves to " + f"'aiter') is supported on ROCm; got {backend!r}." ) + backend = "aiter" self.device = float_workspace_buffer.device _require_aiter_mla(self.device) diff --git a/tests/rocm_tests/test_mla_aiter_hip.py b/tests/rocm_tests/test_mla_aiter_hip.py index 7f2f90785c..b0b84c6111 100644 --- a/tests/rocm_tests/test_mla_aiter_hip.py +++ b/tests/rocm_tests/test_mla_aiter_hip.py @@ -326,3 +326,36 @@ def test_mla_run_before_plan_raises(): torch.zeros(4, 16, 512, dtype=torch.float16, device=device), torch.zeros(4, 16, 64, dtype=torch.float16, device=device), ) + + +@pytest.mark.skipif( + not is_aiter_supported(torch.device("cuda:0")), + reason="AITER backend requires gfx942/gfx950", +) +@pytest.mark.parametrize("backend", ["auto", "aiter"]) +def test_mla_backend_accepts_auto_and_aiter(backend): + """The ROCm MLA wrapper accepts both 'auto' (default) and 'aiter'. + + 'auto' resolves to 'aiter' since there is no HIP MLA kernel. + """ + from flashinfer.mla_rocm import BatchMLAPagedAttentionWrapper + + device = torch.device("cuda:0") + ws = torch.empty(1, dtype=torch.float32, device=device) + BatchMLAPagedAttentionWrapper(ws, backend=backend) + + +def test_mla_backend_rejects_unsupported(): + """Any backend other than 'auto'/'aiter' raises ValueError. + + The check fires before the AITER-availability probe, so this test + runs on any host (no GPU / no AITER required). + """ + from flashinfer.mla_rocm import BatchMLAPagedAttentionWrapper + + device = ( + torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + ) + ws = torch.empty(1, dtype=torch.float32, device=device) + with pytest.raises(ValueError, match="aiter.*auto"): + BatchMLAPagedAttentionWrapper(ws, backend="fa2")