Skip to content

Quantize kernel with the layout that deepgemm wants #97

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: varun/masked-kernels
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 167 additions & 7 deletions vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import importlib.util
from typing import Optional

import triton
import triton.language as tl
import torch

import vllm.model_executor.layers.fused_moe.modular_kernel as mk
Expand All @@ -15,6 +17,166 @@
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None




@triton.jit
def _per_token_group_quant_fp8_3d(
# Pointers ------------------------------------------------------------
y_ptr, # FP16 activations (E, T, H)
y_q_ptr, # FP8 quantized activations (E, T, H)

y_s_ptr, # FP32 scales (E, T, G)
counts_ptr, # INT32 number of tokens per expert (E)

# Sizes ---------------------------------------------------------------
E: tl.constexpr, # num_experts
T: tl.constexpr, # max_num_tokens
H: tl.constexpr, # hidden dimension
GROUP_SIZE: tl.constexpr, # elements per group (usually 128)

# Strides for y (elements) -------------------------------------------
stride_y_e,
stride_y_t,
stride_y_h,

# Strides for y_q (elements) -----------------------------------------
stride_yq_e,
stride_yq_t,
stride_yq_h,


# Strides for y_s (elements) -----------------------------------------
stride_ys_e,
stride_ys_t,
stride_ys_g,

# Stride for counts (elements)
stride_counts_e,

# Numeric params ------------------------------------------------------
eps: tl.constexpr,
fp8_min: tl.constexpr,
fp8_max: tl.constexpr,

# Meta ---------------------------------------------------------------
BLOCK: tl.constexpr,
):
"""Dynamic FP8 quantisation over a 3‑D tensor laid out **(E, T, H)**.

* Each program instance handles **one** `GROUP_SIZE`‑length slice along H
for a single (expert *e*, token *t*).
* Scales are produced with shape **(E, T, G)** where
`G = H // GROUP_SIZE` and with *element* strides
`(T*G, 1, T)` so that the *token* dimension is the fastest‑varying in
memory – matching the downstream reshape you showed.
* All strides are expressed **in elements**, not bytes.
"""

G = H // GROUP_SIZE # groups per hidden dim

# ----------------------- map program id -> (e, g) --------------------
pid = tl.program_id(0)
e = pid // G
g = pid % G

# number of valid tokens for this expert
n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int32)

# block for H dimension
cols = tl.arange(0, BLOCK)

Check failure on line 86 in vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py:86:81: E501 Line too long (85 > 80)
mask_h = cols < BLOCK

Check failure on line 87 in vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py:87:81: E501 Line too long (89 > 80)

# iterate over tokens for this (expert, group)
t = tl.zeros([], tl.int32)
while t < n_tokens:
base_y_offset = e * stride_y_e + t * stride_y_t + g * GROUP_SIZE * stride_y_h
base_yq_offset = e * stride_yq_e + t * stride_yq_t + g * GROUP_SIZE * stride_yq_h
base_ys_offset = e * stride_ys_e + t * stride_ys_t + g * stride_ys_g

mask = mask_h
y = tl.load(y_ptr + base_y_offset + cols * stride_y_h,
mask=mask, other=0.0).to(tl.float32)

_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max

y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)

tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h,
y_q, mask=mask)
tl.store(y_s_ptr + base_ys_offset, y_s)

t += 1


def quant_fp8_3d(
y: torch.Tensor, # (E, T, H)
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
group_size: int = 128,
fp8_dtype = torch.float8_e4m3fn,
eps: float = 1e-6,
):
"""Quantize y into FP8 with per‑(expert, token, group) scales.

Only the first `tokens_per_expert[e]` tokens are quantized per expert;
the remaining positions in each (E, T, H) slice are treated as padding.

Returns `(y_q, y_s)` where
* `y_q` is the FP8 tensor, same shape and **standard PyTorch order** as *y*.
* `y_s` has shape `(E, T, H // group_size)` and element strides
`(T * G, 1, T)` so that the *token* dimension is contiguous.
"""

assert y.ndim == 3, "y must be (E, T, H)"
E, T, H = y.shape
G = H // group_size
assert H % group_size == 0, "H must be divisible by group_size"
assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, \
"tokens_per_expert must be shape (E,)"
tokens_per_expert = tokens_per_expert.to(device=y.device, dtype=torch.int32)

# ---------------- allocate outputs ----------------------------------
y_q = torch.empty_like(y, dtype=fp8_dtype)

# desired scale strides (elements): (T*G, 1, T)
stride_ys_e = T * G
stride_ys_t = 1
stride_ys_g = T

# allocate scale buffer with proper shape and stride
y_s = torch.empty_strided((E, T, G), (stride_ys_e, stride_ys_t, stride_ys_g),
dtype=torch.float32, device=y.device)

# ---------------- stride bookkeeping (elements, not bytes) ----------
stride_y_e, stride_y_t, stride_y_h = y.stride()

stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride()


Check failure on line 155 in vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py:155:81: E501 Line too long (82 > 80)
# stride for tokens_per_expert (elements)
stride_cnt_e = tokens_per_expert.stride()[0]

# static grid over experts and H-groups; tokens loop is internal to the kernel
grid = (E * G,)

f_info = torch.finfo(fp8_dtype)
fp8_max = f_info.max
fp8_min = -f_info.max

_per_token_group_quant_fp8_3d[grid](
y, y_q, y_s, tokens_per_expert,
E, T, H, group_size,
stride_y_e, stride_y_t, stride_y_h,
stride_yq_e, stride_yq_t, stride_yq_h,
stride_ys_e, stride_ys_t, stride_ys_g,
stride_cnt_e,
eps, fp8_min, fp8_max,
BLOCK=group_size,
num_warps=4,
)

return y_q, y_s

class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):

# The Deep Gemm kernels only support block size of 128
Expand Down Expand Up @@ -87,6 +249,8 @@
):
import deep_gemm as dg
assert hidden_states.ndim == 3
assert(w1_zp is None and w2_zp is None)
assert(a2_scale is None)

a1q = hidden_states
_, N, K = w1.size()
Expand All @@ -113,13 +277,9 @@
self.masked_activation(activation, workspace2, workspace1,
expert_num_tokens)

# TODO (varun) : Pass in an output tensor derived from workspace
# as a memory optimization.
a2q, a2q_scale = masked_per_token_group_quant_fp8(
x=workspace2,
valid_tokens_array=expert_num_tokens,
group_size=self.block_shape[1],
column_major_scales=False)
a2q, a2q_scale = quant_fp8_3d(workspace2,
tokens_per_expert=expert_num_tokens,
group_size=self.block_shape[1])

dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale),
(w2, w2_scale),
Expand Down
Loading