diff --git a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton_sep.py b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton_sep.py new file mode 100644 index 00000000000..d0945388979 --- /dev/null +++ b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton_sep.py @@ -0,0 +1,818 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py +import argparse +import json +import os +import time +from contextlib import nullcontext +from datetime import datetime +from typing import Any, Dict, List, Tuple, TypedDict + +import ray +import torch +import triton +import triton.language as tl +from ray.experimental.tqdm_ray import tqdm +from sgl_kernel import silu_and_mul +from transformers import AutoConfig + +from sglang.srt.layers.moe.fused_moe_triton import override_config +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( + get_config_dtype_str, + invoke_fused_moe_kernel, + moe_align_block_size, +) +from sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config import ( + get_config_file_name, +) +from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig +from sglang.srt.layers.moe.topk import TopKConfig, select_experts +from sglang.srt.utils import is_hip + +_is_hip = is_hip() + + +class BenchmarkConfig(TypedDict): + BLOCK_SIZE_M: int + BLOCK_SIZE_N: int + BLOCK_SIZE_K: int + GROUP_SIZE_M: int + num_warps: int + num_stages: int + + +def benchmark_config( + config: BenchmarkConfig, + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_int8_w8a16: bool, + topk_ids_dir: str, + block_shape: List[int] = None, + num_iters: int = 100, +) -> float: + ncu_enable = os.getenv("NCU_ENABLE", "0") == "1" + if ncu_enable: + num_iters = 1 + init_dtype = torch.float16 if use_fp8_w8a8 else dtype + hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) + if use_int8_w8a16 or use_int8_w8a8: + w1 = torch.randint( + -127, + 127, + ( + num_experts, + shard_intermediate_size, + hidden_size, + ), + dtype=torch.int8, + ) + w2 = torch.randint( + -127, + 127, + ( + num_experts, + hidden_size, + shard_intermediate_size // 2, + ), + dtype=torch.int8, + ) + else: + w1 = torch.randn( + num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype + ) + w2 = torch.randn( + num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype + ) + gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32) + + w1_scale = None + w2_scale = None + a1_scale = None + a2_scale = None + if use_int8_w8a16: + w1_scale = torch.randn( + (num_experts, 2 * shard_intermediate_size), dtype=torch.float32 + ) + w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32) + if use_fp8_w8a8 or use_int8_w8a8: + if use_int8_w8a8 and block_shape is None: + w1_scale = torch.randn( + num_experts, shard_intermediate_size, dtype=torch.float32 + ) + w2_scale = torch.randn(num_experts, hidden_size, dtype=torch.float32) + elif block_shape is None: + w1_scale = torch.randn(num_experts, dtype=torch.float32) + w2_scale = torch.randn(num_experts, dtype=torch.float32) + a1_scale = torch.randn(1, dtype=torch.float32) + a2_scale = torch.randn(1, dtype=torch.float32) + else: + block_n, block_k = block_shape[0], block_shape[1] + n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n + n_tiles_w2 = (hidden_size + block_n - 1) // block_n + k_tiles_w1 = (hidden_size + block_k - 1) // block_k + k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k + w1_scale = torch.rand( + (num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32 + ) + w2_scale = torch.rand( + (num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32 + ) + + if use_fp8_w8a8: + w1 = w1.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn) + w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn) + + input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) + topk_config = TopKConfig( + top_k=topk, + renormalize=True, + ) + topk_output = select_experts(hidden_states, input_gating, topk_config) + + def prepare(i: int): + input_gating = gating_output[i] + topk_ids = torch.load(f"{topk_ids_dir}/topk_ids_layer{i%58+3}_idx{i//58}.pt") + new_topk_output = select_experts(hidden_states, input_gating, topk_config) + topk_output.topk_weights.copy_(new_topk_output.topk_weights) + tokens, _topk = topk_output.topk_ids.shape + topk_output.topk_ids.copy_(topk_ids[:tokens, :_topk]) + topk_output.router_logits.copy_(new_topk_output.router_logits) + + moe_use_tma = False + + def run(): + moe_runner_config = MoeRunnerConfig( + inplace=True, + ) + topk_weights, topk_ids, _ = topk_output + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, config["BLOCK_SIZE_M"], num_experts + ) + M = hidden_states.shape[0] + E, N, _ = w1.shape + + topk = topk_ids.shape[1] + padded_tokens = ( + min(M * topk, E + 1) * (config["BLOCK_SIZE_M"] - 1) if moe_use_tma else 0 + ) + total_tokens = M * topk + padded_tokens + cache = torch.empty( + total_tokens * max(N, w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache1 = cache[: total_tokens * N].view( + (total_tokens, N), + ) + intermediate_cache2 = torch.empty( + (total_tokens, N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache3 = cache[: M * topk * w2.shape[1]].view( + (M, topk, w2.shape[1]), + ) + + compute_type = ( + tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 + ) + apply_router_weight_on_input = moe_runner_config.apply_router_weight_on_input + + with override_config(config): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start_event.record() + for _ in range(10 if not ncu_enable else 1): + invoke_fused_moe_kernel( + hidden_states, + w1, + None, + intermediate_cache1, + None, + w1_scale, + None, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + apply_router_weight_on_input, + topk_ids.shape[1], + config, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + per_channel_quant=False, + block_shape=block_shape, + b_use_tma=moe_use_tma, + c_sorted=moe_use_tma, + filter_expert=False, + ) + end_event.record() + end_event.synchronize() + time_cost0 = start_event.elapsed_time(end_event) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start_event.record() + + silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) + for _ in range(10 if not ncu_enable else 1): + invoke_fused_moe_kernel( + intermediate_cache2, + w2, + None, + intermediate_cache3, + a2_scale, + w2_scale, + None, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + not apply_router_weight_on_input, + 1, + config, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + per_channel_quant=False, + block_shape=block_shape, + a_use_tma=moe_use_tma, + b_use_tma=moe_use_tma, + filter_expert=False, + ) + end_event.record() + end_event.synchronize() + time_cost1 = start_event.elapsed_time(end_event) + return time_cost0, time_cost1 + + # JIT compilation & warmup + if not ncu_enable: + moe_use_tma = False + run() + moe_use_tma = True + run() + latencies: List[float] = [] + latencies1: List[float] = [] + latencies_tma: List[float] = [] + latencies1_tma: List[float] = [] + + for i in range(num_iters): + prepare(i) + torch.cuda.synchronize() + moe_use_tma = False + t0, t1 = run() + torch.cuda.synchronize() + latencies.append(t0) + latencies1.append(t1) + + moe_use_tma = True + t0, t1 = run() + torch.cuda.synchronize() + latencies_tma.append(t0) + latencies1_tma.append(t1) + + avg = sum(latencies) / (num_iters * 10) * 1000 # us + avg_tma = sum(latencies_tma) / (num_iters * 10) * 1000 # us + avg1 = sum(latencies1) / (num_iters * 10) * 1000 # us + avg1_tma = sum(latencies1_tma) / (num_iters * 10) * 1000 # us + + return avg, avg_tma, avg1, avg1_tma + + +def get_rocm_configs_compute_bound() -> List[Dict[str, int]]: + configs: List[BenchmarkConfig] = [] + waves_per_eu_range = 0 + for block_m in [32, 64, 128, 256]: + for block_k in [32, 64, 128, 256]: + for block_n in [16, 32, 64, 128, 256]: + for num_stages in [2]: + for num_warps in [1, 2, 4, 8]: + for group_size in [1, 4, 8, 16, 32]: + configs.append( + { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_size, + "num_warps": num_warps, + "num_stages": num_stages, + "waves_per_eu": waves_per_eu_range, + } + ) + return configs + + +def get_configs_compute_bound() -> List[Dict[str, int]]: + # Reduced search space for faster tuning. + # TODO(woosuk): Increase the search space and use a performance model to + # prune the search space. + configs: List[BenchmarkConfig] = [] + if _is_hip: + configs = get_rocm_configs_compute_bound() + else: + for block_m in [16, 32, 64, 128, 256]: + for block_k in [32, 64, 128, 256]: + for block_n in [32, 64, 128, 256]: + for num_stages in [2, 3, 4, 5]: + for num_warps in [4, 8]: + for group_size in [1, 16, 32, 64]: + configs.append( + { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_size, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) + return configs + + +class BestConfigTrace: + def __init__(self, name): + self.name = name + self.config = None + self.time_cost = float("inf") + self.time_cost_all = None # kernel0 without tma,, kernel0 with tma, kernel1 without tma, kernel1 with tma + + def update(self, config, time_cost, time_cost_all): + if time_cost < self.time_cost: + print( + f"New best config for {self.name}: {config}, {time_cost=}, {time_cost_all=}, org: {self.config}, {self.time_cost_all}", + flush=True, + ) + self.config = config + self.time_cost = time_cost + self.time_cost_all = time_cost_all + + @property + def total_time(self): + return self.time_cost_all[0] + min(self.time_cost_all[2], self.time_cost_all[3]) + + def config_dict(self, down_moe=False): + if not down_moe: + return self.config + else: + return { + **self.config, + "USE_TMA": self.time_cost_all[2] > self.time_cost_all[3], + } + + +class BenchmarkWorker: + + def __init__(self, seed: int) -> None: + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + self.seed = seed + # Get the device ID to allocate tensors and kernels + # on the respective GPU. + self.device_id = 0 # int(ray.get_gpu_ids()[0]) + + def benchmark( + self, + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_int8_w8a16: bool, + block_shape: List[int], + cfg: Dict[str, int], + topk_ids_dir: str, + ) -> Tuple[Dict[str, int], float]: + torch.cuda.manual_seed_all(0) + dtype_str = get_config_dtype_str( + dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 + ) + # NOTE(woosuk): The current naming convention uses w2.shape[2], which + # is the intermediate size after silu_and_mul. + block_n = block_shape[0] if block_shape else 0 + block_k = block_shape[1] if block_shape else 0 + with torch.cuda.device(self.device_id) if is_hip() else nullcontext(): + kernel_time = benchmark_config( + cfg, + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a8, + use_int8_w8a16, + topk_ids_dir, + block_shape, + ) + return cfg, kernel_time + + def tune( + self, + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_int8_w8a16: bool, + block_shape: List[int], + search_space: List[Dict[str, int]], + topk_ids_dir: str, + ) -> Dict[str, int]: + trace0 = BestConfigTrace("kernel0") + trace1 = BestConfigTrace("kernel1") + trace2 = BestConfigTrace("kernel all") + + with torch.cuda.device(self.device_id) if is_hip() else nullcontext(): + for config in tqdm(search_space): + try: + kt0_no_tma, kt0_tma, kt1_no_tma, kt1_tma = benchmark_config( + config, + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a8, + use_int8_w8a16, + topk_ids_dir, + block_shape, + num_iters=10, + ) + except triton.runtime.autotuner.OutOfResources: + # Some configurations may be invalid and fail to compile. + continue + kt0 = kt0_no_tma + kt1 = min(kt1_no_tma, kt1_tma) + trace0.update( + config, + kt0, + (kt0_no_tma, kt0_tma, kt1_no_tma, kt1_tma), + ) + trace1.update( + config, + kt1, + (kt0_no_tma, kt0_tma, kt1_no_tma, kt1_tma), + ) + trace2.update( + config, + kt0 + kt1, + (kt0_no_tma, kt0_tma, kt1_no_tma, kt1_tma), + ) + + now = datetime.now() + print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}") + assert trace0.config is not None + assert trace1.config is not None + print( + f"{num_tokens=}, {trace0.config=}, {trace0.time_cost_all=}, {trace1.config=}, {trace1.time_cost_all=}" + ) + if trace0.config["BLOCK_SIZE_M"] != trace1.config["BLOCK_SIZE_M"]: + best_trace = trace0 if trace0.total_time < trace1.total_time else trace1 + best_trace = ( + best_trace if best_trace.total_time < trace2.total_time else trace2 + ) + return ( + best_trace.config_dict(), + best_trace.config_dict(True), + best_trace.time_cost_all, + best_trace.time_cost_all, + ) + return ( + trace0.config_dict(), + trace1.config_dict(True), + trace0.time_cost_all, + trace1.time_cost_all, + ) + + +def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: + return { + "BLOCK_SIZE_M": config["BLOCK_SIZE_M"], + "BLOCK_SIZE_N": config["BLOCK_SIZE_N"], + "BLOCK_SIZE_K": config["BLOCK_SIZE_K"], + "GROUP_SIZE_M": config["GROUP_SIZE_M"], + "num_warps": config["num_warps"], + "num_stages": config["num_stages"], + **( + {"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {} + ), + **({"USE_TMA": config["USE_TMA"]} if "USE_TMA" in config else {}), + } + + +def save_configs( + configs: Dict[int, BenchmarkConfig], + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_int8_w8a16: bool, + block_shape: List[int], + down_moe: bool = False, +) -> None: + dtype_str = get_config_dtype_str( + dtype, + use_int8_w8a16=use_int8_w8a16, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + ) + + # NOTE(woosuk): The current naming convention uses w2.shape[2], which + # is the intermediate size after silu_and_mul. + filename = get_config_file_name( + num_experts, + shard_intermediate_size // 2, + dtype_str, + block_shape, + down_moe=down_moe, + ) + + print(f"Writing best config to {filename}...") + with open(filename, "w") as f: + json.dump(configs, f, indent=4) + f.write("\n") + + +def main(args: argparse.Namespace): + print(args) + + config = AutoConfig.from_pretrained(args.model, trust_remote_code=True) + if config.architectures[0] == "DbrxForCausalLM": + E = config.ffn_config.moe_num_experts + topk = config.ffn_config.moe_top_k + intermediate_size = config.ffn_config.ffn_hidden_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size + elif config.architectures[0] == "JambaForCausalLM": + E = config.num_experts + topk = config.num_experts_per_tok + intermediate_size = config.intermediate_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size + elif config.architectures[0] in ["Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"]: + E = config.num_experts + topk = config.num_experts_per_tok + intermediate_size = config.moe_intermediate_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size + elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: + E = ( + config.n_routed_experts + (0 if args.disable_shared_experts_fusion else 1) + if config.architectures[0] in ["DeepseekV3ForCausalLM"] + else config.n_routed_experts + ) + topk = ( + config.num_experts_per_tok + + (0 if args.disable_shared_experts_fusion else 1) + if config.architectures[0] in ["DeepseekV3ForCausalLM"] + else config.num_experts_per_tok + ) + intermediate_size = config.moe_intermediate_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size + elif config.architectures[0] == "Llama4ForConditionalGeneration": + E = config.text_config.num_local_experts + ( + 0 if args.disable_shared_experts_fusion else 1 + ) + topk = config.text_config.num_experts_per_tok + intermediate_size = config.text_config.intermediate_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size + elif config.architectures[0] in [ + "Grok1ForCausalLM", + "Grok1ImgGen", + "Grok1AForCausalLM", + ]: + E = config.num_local_experts + topk = config.num_experts_per_tok + intermediate_size = config.moe_intermediate_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size + elif config.architectures[0] in ["Glm4MoeForCausalLM"]: + E = config.n_routed_experts + topk = config.num_experts_per_tok + intermediate_size = config.moe_intermediate_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size + else: + # Default: Mixtral + E = config.num_local_experts + topk = config.num_experts_per_tok + intermediate_size = config.intermediate_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size + + hidden_size = getattr(config, "hidden_size", None) or config.text_config.hidden_size + dtype = config.torch_dtype + use_fp8_w8a8 = args.dtype == "fp8_w8a8" + use_int8_w8a8 = args.dtype == "int8_w8a8" + use_int8_w8a16 = args.dtype == "int8_w8a16" + block_shape = None + if ( + hasattr(config, "quantization_config") + and "weight_block_size" in config.quantization_config + ): + block_shape = config.quantization_config["weight_block_size"] + assert len(block_shape) == 2 + + topk_ids_dir = args.topk_ids_dir + if args.batch_size is None: + batch_sizes = [ + 1, + 2, + 4, + 8, + 16, + 24, + 32, + 48, + 64, + 96, + 128, + 256, + 512, + 1024, + 1536, + 2048, + 3072, + 4096, + 8192, + ] + batch_sizes.reverse() + else: + batch_sizes = [args.batch_size] + if len(batch_sizes) == 1: + worker = BenchmarkWorker(args.seed) + if args.tune: + search_space = get_configs_compute_bound() + worker.tune( + batch_sizes[0], + E, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a8, + use_int8_w8a16, + block_shape, + search_space, + topk_ids_dir, + ) + else: + cfg = { + "BLOCK_SIZE_M": args.configs[0], + "BLOCK_SIZE_N": args.configs[1], + "BLOCK_SIZE_K": args.configs[2], + "GROUP_SIZE_M": args.configs[3], + "num_warps": args.configs[4], + "num_stages": args.configs[5], + } + + _, (t0, t0_tma, t1, t1_tma) = worker.benchmark( + args.batch_size, + E, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a8, + use_int8_w8a16, + block_shape, + cfg, + topk_ids_dir, + ) + print(f"{t0=}, {t0_tma=}, {t1=}, {t1_tma=}") + return + + assert args.tune + + ray.init() + num_gpus = int(ray.available_resources()["GPU"]) + workers = [ + ray.remote(num_gpus=1)(BenchmarkWorker).remote(args.seed) + for _ in range(num_gpus) + ] + + def _distribute(method: str, inputs: List[Any]) -> List[Any]: + outputs = [] + worker_idx = 0 + for input_args in inputs: + worker = workers[worker_idx] + worker_method = getattr(worker, method) + output = worker_method.remote(*input_args) + outputs.append(output) + worker_idx = (worker_idx + 1) % num_gpus + return ray.get(outputs) + + search_space = get_configs_compute_bound() + if block_shape is not None: + block_n, block_k = block_shape[0], block_shape[1] + search_space = [ + config for config in search_space if block_k % config["BLOCK_SIZE_K"] == 0 + ] + print(f"Start tuning over {len(search_space)} configurations...") + + start = time.perf_counter() + configs = _distribute( + "tune", + [ + ( + batch_size, + E, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a8, + use_int8_w8a16, + block_shape, + search_space, + topk_ids_dir, + ) + for batch_size in batch_sizes + ], + ) + print(f"{configs=}", flush=True) + cur_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + with open(f"tuning_result_{cur_time}.txt", "w") as f: + print(configs, file=f) + batch_sizes.reverse() + configs0 = [config[0] for config in configs] + configs1 = [config[1] for config in configs] + configs0.reverse() + configs1.reverse() + best_configs0 = {M: sort_config(config) for M, config in zip(batch_sizes, configs0)} + save_configs( + best_configs0, + E, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a8, + use_int8_w8a16, + block_shape, + ) + + best_configs1 = {M: sort_config(config) for M, config in zip(batch_sizes, configs1)} + save_configs( + best_configs1, + E, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a8, + use_int8_w8a16, + block_shape, + down_moe=True, + ) + end = time.perf_counter() + print(f"Tuning took {end - start:.2f} seconds") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" + ) + parser.add_argument("--tp-size", "--tp", type=int, default=2) + parser.add_argument( + "--dtype", + type=str, + choices=["auto", "fp8_w8a8", "int8_w8a16", "int8_w8a8"], + default="auto", + ) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--batch-size", type=int, required=False) + parser.add_argument("--tune", action="store_true") + parser.add_argument("--disable-shared-experts-fusion", action="store_true") + parser.add_argument("--configs", type=int, nargs="+", required=False) + parser.add_argument("--topk-ids-dir", type=str, required=True) + args = parser.parse_args() + + main(args) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 8d8f789d079..4eafb2013ee 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -855,14 +855,24 @@ def forward_extend( ) else: # MHA for extend part of sequence without attending prefix kv cache + cu_seqlens_k = ( + metadata.cu_seqlens_q + if not forward_batch.mha_one_shot + else metadata.cu_seqlens_k + ) + max_seqlen_k = ( + metadata.max_seq_len_q + if not forward_batch.mha_one_shot + else metadata.max_seq_len_k + ) output = flash_attn_varlen_func( q=q.view(-1, layer.tp_q_head_num, layer.head_dim), k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype), v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype), cu_seqlens_q=metadata.cu_seqlens_q, - cu_seqlens_k=metadata.cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, max_seqlen_q=metadata.max_seq_len_q, - max_seqlen_k=metadata.max_seq_len_q, + max_seqlen_k=max_seqlen_k, softmax_scale=layer.scaling, causal=True, return_softmax_lse=forward_batch.mha_return_lse, diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index ad9cbfd44b9..a9de5577fb8 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -80,6 +80,7 @@ def __init__( # Buffers and wrappers self.qo_indptr = attn_backend.qo_indptr + self.kv_indptr = attn_backend.kv_indptr self.workspace_buffer = attn_backend.workspace_buffer self.fmha_backend = attn_backend.fmha_backend @@ -130,9 +131,14 @@ def update_wrapper( ) # ragged prefill if not disable_flashinfer_ragged: + kv_indptr = ( + qo_indptr + if not forward_batch.mha_one_shot + else self.kv_indptr[: bs + 1] + ) self.ragged_wrapper.begin_forward( qo_indptr=qo_indptr, - kv_indptr=qo_indptr, + kv_indptr=kv_indptr, num_qo_heads=self.num_local_heads, num_kv_heads=self.num_local_heads, head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim, @@ -154,7 +160,7 @@ def forward( chunk_idx = forward_batch.prefix_chunk_idx assert chunk_idx >= 0 wrapper = self.chunk_ragged_wrappers[chunk_idx] - o1, s1 = wrapper.forward_return_lse( + o = wrapper.forward_return_lse( q.view(-1, layer.tp_q_head_num, layer.head_dim), k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype), v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype), @@ -163,7 +169,12 @@ def forward( logits_soft_cap=logits_soft_cap, ) else: - o1, s1 = self.ragged_wrapper.forward_return_lse( + forward = ( + self.ragged_wrapper.forward_return_lse + if forward_batch.mha_return_lse + else self.ragged_wrapper.forward + ) + o = forward( q.view(-1, layer.tp_q_head_num, layer.head_dim), k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype), v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype), @@ -171,8 +182,7 @@ def forward( sm_scale=layer.scaling, logits_soft_cap=logits_soft_cap, ) - - return o1, s1 + return o class FlashInferMLAAttnBackend(AttentionBackend): @@ -510,15 +520,13 @@ def forward_extend( q_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None, ): - if ( - forward_batch.attn_attend_prefix_cache is not None - and forward_batch.mha_return_lse + if forward_batch.attn_attend_prefix_cache is not None and any( + forward_batch.extend_prefix_lens_cpu ): # MHA Chunk assert self.enable_chunk_kv assert q_rope is None assert k_rope is None - o1, s1 = self.mha_chunk_kv_cache.forward(q, k, v, layer, forward_batch) - return o1, s1 + return self.mha_chunk_kv_cache.forward(q, k, v, layer, forward_batch) cache_loc = forward_batch.out_cache_loc logits_soft_cap = layer.logit_cap diff --git a/python/sglang/srt/layers/attention/utils.py b/python/sglang/srt/layers/attention/utils.py index e8cd2e1580a..5481d7ec7bb 100644 --- a/python/sglang/srt/layers/attention/utils.py +++ b/python/sglang/srt/layers/attention/utils.py @@ -1,3 +1,4 @@ +import torch import triton import triton.language as tl @@ -97,3 +98,80 @@ def create_flashmla_kv_indices_triton( data // PAGED_SIZE, mask=mask_out, ) + + +@triton.jit +def concat_and_cast_mha_k_kernel( + k_ptr, + k_nope_ptr, + k_rope_ptr, + head_cnt: tl.constexpr, + k_stride0: tl.constexpr, + k_stride1: tl.constexpr, + nope_stride0: tl.constexpr, + nope_stride1: tl.constexpr, + rope_stride0: tl.constexpr, + nope_dim: tl.constexpr, + rope_dim: tl.constexpr, +): + pid_loc = tl.program_id(0) + head_range = tl.arange(0, head_cnt) + + k_head_ptr = k_ptr + pid_loc * k_stride0 + head_range[:, None] * k_stride1 + + nope_offs = tl.arange(0, nope_dim) + + src_nope_ptr = ( + k_nope_ptr + + pid_loc * nope_stride0 + + head_range[:, None] * nope_stride1 + + nope_offs[None, :] + ) + dst_nope_ptr = k_head_ptr + nope_offs[None, :] + + src_nope = tl.load(src_nope_ptr) + tl.store(dst_nope_ptr, src_nope) + + rope_offs = tl.arange(0, rope_dim) + src_rope_ptr = k_rope_ptr + pid_loc * rope_stride0 + rope_offs[None, :] + dst_rope_ptr = k_head_ptr + nope_dim + rope_offs[None, :] + src_rope = tl.load(src_rope_ptr) + tl.store(dst_rope_ptr, src_rope) + + +def concat_and_cast_mha_k_triton( + k: torch.Tensor, + k_nope: torch.Tensor, + k_rope: torch.Tensor, +): + # The source data type will be implicitly converted to the target data type. + assert ( + len(k.shape) == 3 and len(k_nope.shape) == 3 and len(k_rope.shape) == 3 + ), f"shape should be 3d, but got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}" + assert ( + k.shape[0] == k_nope.shape[0] and k.shape[0] == k_rope.shape[0] + ), f"invalid shape, got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}" + assert ( + k.shape[1] == k_nope.shape[1] and 1 == k_rope.shape[1] + ), f"invalid shape, got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}" + assert ( + k.shape[-1] == k_nope.shape[-1] + k_rope.shape[-1] + ), f"invalid shape, got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}" + + nope_dim = k_nope.shape[-1] + rope_dim = k_rope.shape[-1] + grid = (k.shape[0],) + + concat_and_cast_mha_k_kernel[grid]( + k, + k_nope, + k_rope, + k.shape[1], + k.stride(0), + k.stride(1), + k_nope.stride(0), + k_nope.stride(1), + k_rope.stride(0), + nope_dim, + rope_dim, + ) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..f05e32900c5 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json new file mode 100644 index 00000000000..2d674e9ebb9 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "USE_TMA": false + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5, + "USE_TMA": false + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3, + "USE_TMA": false + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "USE_TMA": false + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + "USE_TMA": false + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, + "USE_TMA": false + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5, + "USE_TMA": false + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4, + "USE_TMA": false + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, + "USE_TMA": false + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "USE_TMA": false + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "USE_TMA": false + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3, + "USE_TMA": true + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3, + "USE_TMA": true + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3, + "USE_TMA": true + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3, + "USE_TMA": true + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, + "USE_TMA": true + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3, + "USE_TMA": true + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3, + "USE_TMA": true + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 6d3fb53b051..bfbb4c0190a 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -23,7 +23,11 @@ ) from .fused_moe_triton_config import get_config_dtype_str, try_get_optimal_moe_config -from .fused_moe_triton_kernels import invoke_fused_moe_kernel, moe_sum_reduce_triton +from .fused_moe_triton_kernels import ( + invoke_fused_moe_kernel, + moe_sum_reduce_triton, + support_tensor_descriptor, +) from .moe_align_block_size import moe_align_block_size if TYPE_CHECKING: @@ -78,6 +82,7 @@ def inplace_fused_experts( routed_scaling_factor: Optional[float] = None, gemm1_alpha: Optional[float] = None, gemm1_limit: Optional[float] = None, + filter_expert: bool = True, ) -> None: fused_experts_impl( hidden_states, @@ -106,6 +111,7 @@ def inplace_fused_experts( routed_scaling_factor, gemm1_alpha, gemm1_limit, + filter_expert, ) @@ -134,6 +140,7 @@ def inplace_fused_experts_fake( routed_scaling_factor: Optional[float] = None, gemm1_alpha: Optional[float] = None, gemm1_limit: Optional[float] = None, + filter_expert: bool = True, ) -> None: pass @@ -172,6 +179,7 @@ def outplace_fused_experts( routed_scaling_factor: Optional[float] = None, gemm1_alpha: Optional[float] = None, gemm1_limit: Optional[float] = None, + filter_expert: bool = True, ) -> torch.Tensor: return fused_experts_impl( hidden_states, @@ -200,6 +208,7 @@ def outplace_fused_experts( routed_scaling_factor=routed_scaling_factor, gemm1_alpha=gemm1_alpha, gemm1_limit=gemm1_limit, + filter_expert=filter_expert, ) @@ -229,6 +238,7 @@ def outplace_fused_experts_fake( routed_scaling_factor: Optional[float] = None, gemm1_alpha: Optional[float] = None, gemm1_limit: Optional[float] = None, + filter_expert: bool = True, ) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -263,6 +273,10 @@ def fused_experts( block_shape: Optional[List[int]] = None, ): topk_weights, topk_ids, _ = topk_output + filter_expert = ( + moe_runner_config.num_experts is None + or moe_runner_config.num_experts != moe_runner_config.num_local_experts + ) if moe_runner_config.inplace: assert not moe_runner_config.no_combine, "no combine + inplace makes no sense" torch.ops.sglang.inplace_fused_experts( @@ -290,6 +304,7 @@ def fused_experts( moe_runner_config.routed_scaling_factor, moe_runner_config.gemm1_alpha, moe_runner_config.gemm1_clamp_limit, + filter_expert, ) return hidden_states else: @@ -319,6 +334,7 @@ def fused_experts( routed_scaling_factor=moe_runner_config.routed_scaling_factor, gemm1_alpha=moe_runner_config.gemm1_alpha, gemm1_limit=moe_runner_config.gemm1_clamp_limit, + filter_expert=filter_expert, ) @@ -336,6 +352,11 @@ def swiglu_with_alpha_and_limit(x, gemm1_alpha, gemm1_limit): return gate * torch.sigmoid(gate * gemm1_alpha) * (up + 1) +@functools.lru_cache() +def _down_moe_use_tma(): + return support_tensor_descriptor() + + def fused_experts_impl( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -363,6 +384,7 @@ def fused_experts_impl( routed_scaling_factor: Optional[float] = None, gemm1_alpha: Optional[float] = None, gemm1_limit: Optional[float] = None, + filter_expert: bool = True, ): padded_size = padding_size if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter: @@ -402,25 +424,27 @@ def fused_experts_impl( topk_ids.shape[1], config_dtype, block_shape=block_shape, + return_down_config=True, ) - config = get_config_func(M) - - cache = torch.empty( - M * topk_ids.shape[1] * max(N, w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype, + config, (down_config, max_block_m) = get_config_func(M) + down_moe_use_tma = ( + _down_moe_use_tma() + and down_config is not None + and down_config.pop("USE_TMA", False) ) - intermediate_cache1 = cache[: M * topk_ids.shape[1] * N].view( - (M, topk_ids.shape[1], N), + topk = topk_ids.shape[1] + max_padded_tokens = ( + min(M * topk, E + 1) * (max_block_m - 1) if down_moe_use_tma else 0 ) - intermediate_cache2 = torch.empty( - (M * topk_ids.shape[1], N // 2), + total_tokens = M * topk + max_padded_tokens + cache = torch.empty( + total_tokens * max(N, w2.shape[1]), device=hidden_states.device, dtype=hidden_states.dtype, ) - intermediate_cache3 = cache[: M * topk_ids.shape[1] * w2.shape[1]].view( - (M, topk_ids.shape[1], w2.shape[1]), + intermediate_cache3 = cache[: M * topk * w2.shape[1]].view( + (M, topk, w2.shape[1]), ) compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 @@ -428,7 +452,7 @@ def fused_experts_impl( if no_combine: assert not inplace out_hidden_states = torch.empty( - (num_tokens, topk_ids.shape[1], w2.shape[1]), + (num_tokens, topk, w2.shape[1]), device=hidden_states.device, dtype=hidden_states.dtype, ) @@ -453,12 +477,28 @@ def fused_experts_impl( # chunk. Note that in most cases we only have one chunk # so the cache size and config are already set correctly and # do not need to be adjusted. - intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] - intermediate_cache2 = intermediate_cache2[ - : tokens_in_chunk * topk_ids.shape[1] - ] + config, (down_config, _) = get_config_func(tokens_in_chunk) + down_moe_use_tma = ( + _down_moe_use_tma() + and down_config is not None + and down_config.pop("USE_TMA", False) + ) intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] - config = get_config_func(tokens_in_chunk) + + padded_tokens = ( + min(tokens_in_chunk * topk, E + 1) * (config["BLOCK_SIZE_M"] - 1) + if down_moe_use_tma + else 0 + ) + total_tokens = tokens_in_chunk * topk + padded_tokens + intermediate_cache1 = cache[: total_tokens * N].view( + (total_tokens, N), + ) + intermediate_cache2 = torch.empty( + (total_tokens, N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] @@ -490,6 +530,8 @@ def fused_experts_impl( use_int4_w4a16=use_int4_w4a16, per_channel_quant=per_channel_quant, block_shape=block_shape, + c_sorted=down_moe_use_tma, + filter_expert=filter_expert, ) if activation == "silu": if gemm1_alpha is not None: @@ -536,7 +578,7 @@ def fused_experts_impl( num_tokens_post_padded, not apply_router_weight_on_input, 1, - config, + down_config or config, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, @@ -544,6 +586,9 @@ def fused_experts_impl( use_int4_w4a16=use_int4_w4a16, per_channel_quant=per_channel_quant, block_shape=block_shape, + a_use_tma=down_moe_use_tma, + b_use_tma=down_moe_use_tma, + filter_expert=filter_expert, ) if routed_scaling_factor is None: diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py index 64f225ee631..12b7aec244b 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py @@ -21,6 +21,7 @@ def get_config_file_name( dtype: Optional[str], block_shape: Optional[int] = None, per_channel_quant: bool = False, + down_moe: bool = False, ) -> str: device_name = get_device_name().replace(" ", "_") dtype_selector = "" if not dtype else f",dtype={dtype}" @@ -28,7 +29,8 @@ def get_config_file_name( "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}" ) per_channel_quant_selector = ",per_channel_quant=True" if per_channel_quant else "" - return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}{per_channel_quant_selector}.json" + down_moe_selector = "_down" if down_moe else "" + return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}{per_channel_quant_selector}{down_moe_selector}.json" @functools.lru_cache @@ -39,6 +41,7 @@ def get_moe_configs( block_n: Optional[int] = 0, block_k: Optional[int] = 0, per_channel_quant: bool = False, + down_moe: bool = False, ) -> Optional[Dict[int, Any]]: """ Return optimized configurations for the fused MoE kernel. @@ -54,7 +57,12 @@ def get_moe_configs( # First look up if an optimized configuration is available in the configs # directory json_file_name = get_config_file_name( - E, N, dtype, [block_n, block_k], per_channel_quant + E, + N, + dtype, + [block_n, block_k], + per_channel_quant, + down_moe=down_moe, ) # We found that using the fused_moe_kernel config from Triton 3.1.0 with Triton 3.2.0 results in negative performance gains, @@ -177,9 +185,12 @@ def try_get_optimal_moe_config( M: int, is_marlin: bool = False, block_shape: Optional[List[int]] = None, + return_down_config: bool = False, ): from sglang.srt.layers.moe.fused_moe_triton import get_config + down_config = None + max_block_m = None override_config = get_config() if override_config: config = override_config @@ -188,7 +199,7 @@ def try_get_optimal_moe_config( E, _, N = w2_shape block_n = block_shape[0] if block_shape else 0 block_k = block_shape[1] if block_shape else 0 - configs = get_moe_configs(E, N, dtype, block_n, block_k) + configs = get_moe_configs(E, N, dtype, block_n, block_k, down_moe=False) if configs: # If an optimal configuration map has been found, look up the @@ -199,6 +210,21 @@ def try_get_optimal_moe_config( config = get_default_config( M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape ) + if return_down_config: + down_configs = get_moe_configs(E, N, dtype, block_n, block_k, down_moe=True) + if down_configs: + down_config = down_configs[ + min(down_configs.keys(), key=lambda x: abs(x - M)) + ] + down_config = dict(**down_config) + max_block_m = max( + [cfg["BLOCK_SIZE_M"] for cfg in down_configs.values()] + ) + if return_down_config: + assert ( + down_config is None or config["BLOCK_SIZE_M"] == down_config["BLOCK_SIZE_M"] + ) + return config, (down_config, max_block_m) return config diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py index 6a7229a9b1f..11b555a7833 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py @@ -25,6 +25,13 @@ is_hip, ) +try: + from triton.tools.tensor_descriptor import TensorDescriptor + + _support_tensor_descriptor = True +except: + _support_tensor_descriptor = False + _is_hip = is_hip() _is_cuda = is_cuda() _is_cpu_amx_available = cpu_has_amx_support() @@ -41,6 +48,10 @@ padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0 +def support_tensor_descriptor(): + return _support_tensor_descriptor + + @triton.jit def write_zeros_to_output( c_ptr, @@ -108,6 +119,7 @@ def fused_moe_kernel_gptq_awq( use_int4_w4a16: tl.constexpr, use_int8_w8a16: tl.constexpr, even_Ks: tl.constexpr, + filter_expert: tl.constexpr, ): """ Implements the fused computation for a Mixture of Experts (MOE) using @@ -161,7 +173,7 @@ def fused_moe_kernel_gptq_awq( token_mask = offs_token < num_valid_tokens off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) - if off_experts == -1: + if filter_expert and off_experts == -1: # ----------------------------------------------------------- # Write back zeros to the output when the expert is not # in the current expert parallel rank. @@ -296,7 +308,9 @@ def fused_moe_kernel_gptq_awq( def fused_moe_kernel( # Pointers to matrices a_ptr, + a_desc, b_ptr, + b_desc, bias_ptr, c_ptr, a_scale_ptr, @@ -344,6 +358,8 @@ def fused_moe_kernel( use_int8_w8a16: tl.constexpr, per_channel_quant: tl.constexpr, even_Ks: tl.constexpr, + c_sorted: tl.constexpr, + filter_expert: tl.constexpr, ): """ Implements the fused computation for a Mixture of Experts (MOE) using @@ -399,9 +415,10 @@ def fused_moe_kernel( offs_token = offs_token.to(tl.int64) token_mask = offs_token < num_valid_tokens - off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + off_experts_i32 = tl.load(expert_ids_ptr + pid_m) + off_experts = off_experts_i32.to(tl.int64) - if off_experts == -1: + if filter_expert and off_experts == -1: # ----------------------------------------------------------- # Write back zeros to the output when the expert is not # in the current expert parallel rank. @@ -421,15 +438,23 @@ def fused_moe_kernel( offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + ( - offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak - ) + if a_desc is not None: + assert use_fp8_w8a8 and group_n > 0 and group_k > 0 + start_offs_m = pid_m * BLOCK_SIZE_M + else: + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) + + if b_desc is not None: + start_offs_n = pid_n * BLOCK_SIZE_N + else: + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + ) - b_ptrs = ( - b_ptr - + off_experts * stride_be - + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - ) if bias_ptr is not None: bias = tl.load( bias_ptr + off_experts * stride_bias_e + offs_bn[None, :] * stride_bias_n @@ -443,8 +468,14 @@ def fused_moe_kernel( if use_fp8_w8a8 or use_int8_w8a8: # block-wise if group_k > 0 and group_n > 0: - a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm - offs_bsn = offs_bn // group_n + if a_desc is not None: + a_scale_ptrs = a_scale_ptr + offs_token_id * stride_asm + else: + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + if BLOCK_SIZE_N > group_n: + offs_bsn = offs_bn // group_n + else: + offs_bsn = pid_n * BLOCK_SIZE_N // group_n b_scale_ptrs = ( b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn ) @@ -469,37 +500,49 @@ def fused_moe_kernel( # `accumulator` will be converted back to fp16 after the loop. accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + for k_start in range(0, K, BLOCK_SIZE_K): # Load the next block of A and B, generate a mask by checking the # K dimension. - if even_Ks: + if a_desc is not None: + a = a_desc.load([start_offs_m, k_start]) + elif even_Ks: a = tl.load( a_ptrs, mask=token_mask[:, None], other=0.0, ) - b = tl.load(b_ptrs) else: a = tl.load( a_ptrs, - mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + mask=token_mask[:, None] & (offs_k[None, :] < K - k_start), other=0.0, ) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + + if b_desc is not None: + b = ( + b_desc.load([off_experts_i32, start_offs_n, k_start]) + .reshape(BLOCK_SIZE_N, BLOCK_SIZE_K) + .T + ) + elif even_Ks: + b = tl.load(b_ptrs) + else: + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k_start, other=0.0) # We accumulate along the K dimension. if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) elif use_fp8_w8a8 or use_int8_w8a8: if group_k > 0 and group_n > 0: - k_start = k * BLOCK_SIZE_K offs_ks = k_start // group_k a_scale = tl.load( a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 ) b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) - - accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] + if BLOCK_SIZE_N > group_n: + accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] + else: + accumulator += tl.dot(a, b) * (a_scale[:, None] * b_scale) else: if use_fp8_w8a8: accumulator = tl.dot(a, b, acc=accumulator) @@ -508,8 +551,10 @@ def fused_moe_kernel( else: accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk + if a_desc is None: + a_ptrs += BLOCK_SIZE_K * stride_ak + if b_desc is None: + b_ptrs += BLOCK_SIZE_K * stride_bk if use_int8_w8a16: accumulator *= b_scale @@ -528,7 +573,12 @@ def fused_moe_kernel( # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + if c_sorted: + c_ptrs = ( + c_ptr + stride_cm * offs_token_id[:, None] + stride_cn * offs_cn[None, :] + ) + else: + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) @@ -557,6 +607,10 @@ def invoke_fused_moe_kernel( per_channel_quant: bool, block_shape: Optional[List[int]] = None, no_combine: bool = False, + a_use_tma: bool = False, + b_use_tma: bool = False, + c_sorted: bool = False, + filter_expert: bool = True, ) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 @@ -662,14 +716,38 @@ def invoke_fused_moe_kernel( use_int4_w4a16=use_int4_w4a16, use_int8_w8a16=use_int8_w8a16, even_Ks=even_Ks, + filter_expert=filter_expert, **config, ) else: + if a_use_tma or b_use_tma: + # TMA descriptors require a global memory allocation + def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + if a_use_tma: + a_desc = TensorDescriptor( + A, A.shape, A.stride(), [config["BLOCK_SIZE_M"], config["BLOCK_SIZE_K"]] + ) + else: + a_desc = None + if b_use_tma: + b_desc = TensorDescriptor( + B, + B.shape, + B.stride(), + [1, config["BLOCK_SIZE_N"], config["BLOCK_SIZE_K"]], + ) + else: + b_desc = None fused_moe_kernel[grid]( A, + a_desc, B, + b_desc, bias, C, A_scale, @@ -689,8 +767,8 @@ def invoke_fused_moe_kernel( B.stride(1), bias.stride(0) if bias is not None else 0, bias.stride(1) if bias is not None else 0, - C.stride(1), - C.stride(2), + C.stride(-2), + C.stride(-1), A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, @@ -706,6 +784,8 @@ def invoke_fused_moe_kernel( use_int8_w8a16=use_int8_w8a16, per_channel_quant=per_channel_quant, even_Ks=even_Ks, + c_sorted=c_sorted, + filter_expert=filter_expert, **config, ) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index c468269f3ff..1e22400b112 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -1211,6 +1211,65 @@ def set_mla_kv_buffer_triton( ) +@triton.jit +def get_mla_kv_buffer_kernel( + kv_buffer_ptr, + cache_k_nope_ptr, + cache_k_rope_ptr, + loc_ptr, + buffer_stride: tl.constexpr, + nope_stride: tl.constexpr, + rope_stride: tl.constexpr, + nope_dim: tl.constexpr, + rope_dim: tl.constexpr, +): + pid_loc = tl.program_id(0) + loc = tl.load(loc_ptr + pid_loc) + loc_src_ptr = kv_buffer_ptr + loc * buffer_stride + + nope_offs = tl.arange(0, nope_dim) + nope_src_ptr = loc_src_ptr + nope_offs + nope_src = tl.load(nope_src_ptr) + + tl.store( + cache_k_nope_ptr + pid_loc * nope_stride + nope_offs, + nope_src, + ) + + rope_offs = tl.arange(0, rope_dim) + rope_src_ptr = loc_src_ptr + nope_dim + rope_offs + rope_src = tl.load(rope_src_ptr) + tl.store( + cache_k_rope_ptr + pid_loc * rope_stride + rope_offs, + rope_src, + ) + + +def get_mla_kv_buffer_triton( + kv_buffer: torch.Tensor, + loc: torch.Tensor, + cache_k_nope: torch.Tensor, + cache_k_rope: torch.Tensor, +): + # The source data type will be implicitly converted to the target data type. + nope_dim = cache_k_nope.shape[-1] # 512 + rope_dim = cache_k_rope.shape[-1] # 64 + n_loc = loc.numel() + grid = (n_loc,) + + get_mla_kv_buffer_kernel[grid]( + kv_buffer, + cache_k_nope, + cache_k_rope, + loc, + kv_buffer.stride(0), + cache_k_nope.stride(0), + cache_k_rope.stride(0), + nope_dim, + rope_dim, + ) + + class MLATokenToKVPool(KVCache): def __init__( self, @@ -1361,6 +1420,29 @@ def set_mla_kv_buffer( cache_k_rope, ) + def get_mla_kv_buffer( + self, + layer: RadixAttention, + loc: torch.Tensor, + dst_dtype: Optional[torch.dtype] = None, + ): + # get k nope and k rope from the kv buffer, and optionally cast them to dst_dtype. + layer_id = layer.layer_id + kv_buffer = self.get_key_buffer(layer_id) + dst_dtype = dst_dtype or self.dtype + cache_k_nope = torch.empty( + (loc.shape[0], 1, self.kv_lora_rank), + dtype=dst_dtype, + device=kv_buffer.device, + ) + cache_k_rope = torch.empty( + (loc.shape[0], 1, self.qk_rope_head_dim), + dtype=dst_dtype, + device=kv_buffer.device, + ) + get_mla_kv_buffer_triton(kv_buffer, loc, cache_k_nope, cache_k_rope) + return cache_k_nope, cache_k_rope + def get_cpu_copy(self, indices): torch.cuda.synchronize() kv_cache_cpu = [] diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index f34f36d7085..2fc6de5a19f 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -39,6 +39,7 @@ import triton.language as tl from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size +from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.dp_attention import ( DpPaddingMode, get_attention_dp_rank, @@ -240,6 +241,8 @@ class ForwardBatch: # For MLA chunked prefix cache used in chunked prefill # Tell attention backend whether lse needs to be returned mha_return_lse: Optional[bool] = None + mha_one_shot_kv_indices: Optional[torch.Tensor] = None + mha_one_shot: Optional[bool] = None # For multimodal mm_inputs: Optional[List[MultimodalInputs]] = None @@ -852,6 +855,10 @@ def prepare_chunked_prefix_cache_info(self, device: torch.device): self.token_to_kv_pool, MLATokenToKVPool ), "Currently chunked prefix cache can only be used by Deepseek models" + if not any(self.extend_prefix_lens_cpu): + self.num_prefix_chunks = 0 + return + if self.prefix_chunk_len is not None: # Chunked kv cache info already prepared by prior modules return @@ -906,6 +913,34 @@ def prepare_chunked_prefix_cache_info(self, device: torch.device): def can_run_tbo(self): return self.tbo_split_seq_index is not None + def fetch_mha_one_shot_kv_indices(self): + if self.mha_one_shot_kv_indices is not None: + return self.mha_one_shot_kv_indices + batch_size = self.batch_size + paged_kernel_lens_sum = sum(self.seq_lens_cpu) + kv_indices = torch.empty( + paged_kernel_lens_sum, + dtype=torch.int32, + device=self.req_pool_indices.device, + ) + kv_indptr = torch.zeros( + batch_size + 1, + dtype=torch.int32, + device=self.req_pool_indices.device, + ) + kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0) + create_flashinfer_kv_indices_triton[(self.batch_size,)]( + self.req_to_token_pool.req_to_token, + self.req_pool_indices, + self.seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token_pool.req_to_token.shape[1], + ) + self.mha_one_shot_kv_indices = kv_indices + return kv_indices + def enable_num_token_non_padded(server_args): return get_moe_expert_parallel_world_size() > 1 diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index fb9cd4f6c9f..01e4aa3b053 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -55,6 +55,7 @@ is_mla_preprocess_enabled, ) from sglang.srt.layers.attention.nsa.nsa_indexer import Indexer +from sglang.srt.layers.attention.utils import concat_and_cast_mha_k_triton from sglang.srt.layers.communicator import ( LayerCommunicator, LayerScatterModes, @@ -226,6 +227,10 @@ class AttnForwardMethod(IntEnum): # This method can avoid OOM when prefix lengths are long. MHA_CHUNKED_KV = auto() + # Use multi-head attention, execute the MHA for prefix and extended kv in one shot + # when the sequence lengths are below the threshold. + MHA_ONE_SHOT = auto() + # Use MLA but with fused RoPE MLA_FUSED_ROPE = auto() @@ -291,6 +296,14 @@ def _is_extend_without_speculative(forward_batch): ) +def _support_mha_one_shot(attn: DeepseekV2AttentionMLA, forward_batch, backend_name): + attn_supported = backend_name in ["fa3", "flashinfer", "flashmla"] + sum_seq_lens = ( + sum(forward_batch.seq_lens_cpu) if forward_batch.seq_lens_cpu is not None else 0 + ) + return attn_supported and sum_seq_lens <= forward_batch.get_max_chunk_capacity() + + def _handle_attention_backend( attn: DeepseekV2AttentionMLA, forward_batch, backend_name ): @@ -310,6 +323,8 @@ def _handle_attention_backend( or sum_extend_prefix_lens == 0 ) ): + if _support_mha_one_shot(attn, forward_batch, backend_name): + return AttnForwardMethod.MHA_ONE_SHOT return AttnForwardMethod.MHA_CHUNKED_KV else: return _dispatch_mla_subtype(attn, forward_batch) @@ -1037,6 +1052,7 @@ def __init__( self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings + self.kv_cache_dtype = get_global_server_args().kv_cache_dtype # NOTE modification to rope_scaling must be done early enough, b/c e.g. Indexer needs it if rope_scaling: @@ -1334,6 +1350,10 @@ def forward_prepare( inner_state = self.forward_normal_chunked_kv_prepare( positions, hidden_states, forward_batch, zero_allocator ) + elif attn_forward_method == AttnForwardMethod.MHA_ONE_SHOT: + inner_state = self.forward_normal_one_shot_prepare( + positions, hidden_states, forward_batch, zero_allocator + ) elif attn_forward_method == AttnForwardMethod.MLA: if not self.is_mla_preprocess_enabled: inner_state = self.forward_absorb_prepare( @@ -1385,6 +1405,8 @@ def forward_core(self, intermediate_state): return self.forward_normal_core(*inner_state) elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV: return self.forward_normal_chunked_kv_core(*inner_state) + elif attn_forward_method == AttnForwardMethod.MHA_ONE_SHOT: + return self.forward_normal_one_shot_core(*inner_state) elif attn_forward_method == AttnForwardMethod.MLA: return self.forward_absorb_core(*inner_state) elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE: @@ -1419,41 +1441,24 @@ def forward_normal_prepare( kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) latent_cache = latent_cache.unsqueeze(1) kv_a = self.kv_a_layernorm(kv_a) - kv = self.kv_b_proj(kv_a)[0] - kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope = kv[..., : self.qk_nope_head_dim] - v = kv[..., self.qk_nope_head_dim :] k_pe = latent_cache[:, :, self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) q[..., self.qk_nope_head_dim :] = q_pe - k = torch.empty_like(q) - # Temporary for DeepSeek V3/R1 only, but can generalize if needed + self._set_mla_kv_buffer(latent_cache, kv_a, k_pe, forward_batch) if ( - _is_cuda - and (self.num_local_heads == 128) - and (self.qk_nope_head_dim == 128) - and (self.qk_rope_head_dim == 64) + forward_batch.mha_one_shot + and sum(forward_batch.extend_prefix_lens_cpu) != 0 ): - concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe) - else: - k[..., : self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim :] = k_pe - - if not _is_npu: - latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1) - latent_cache[:, :, self.kv_lora_rank :] = k_pe - - # Save latent cache - forward_batch.token_to_kv_pool.set_kv_buffer( - self.attn_mha, forward_batch.out_cache_loc, latent_cache, None - ) - else: - # To reduce a time-costing split operation - forward_batch.token_to_kv_pool.set_kv_buffer( - self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe + kv_a, k_pe = self._get_mla_kv_buffer( + forward_batch.fetch_mha_one_shot_kv_indices(), q.dtype, forward_batch ) + kv = self.kv_b_proj(kv_a)[0] + kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope = kv[..., : self.qk_nope_head_dim] + v = kv[..., self.qk_nope_head_dim :] + k = self._concat_and_cast_mha_k(k_nope, k_pe, forward_batch) return q, k, v, forward_batch def forward_normal_core(self, q, k, v, forward_batch): @@ -2263,20 +2268,11 @@ def _chunked_prefix_attn_mha( for i in range(forward_batch.num_prefix_chunks): forward_batch.set_prefix_chunk_idx(i) + kv_indices = forward_batch.prefix_chunk_kv_indices[i] # Fetch latent cache from memory pool with precomputed chunked kv indices - latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer( - self.attn_mha.layer_id - ) - latent_cache = ( - latent_cache_buf[forward_batch.prefix_chunk_kv_indices[i]] - .contiguous() - .to(q.dtype) - ) - - kv_a_normed, k_pe = latent_cache.split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + kv_a_normed, k_pe = self._get_mla_kv_buffer( + kv_indices, q.dtype, forward_batch ) - kv_a_normed = kv_a_normed.squeeze(1).contiguous() kv = self.kv_b_proj(kv_a_normed)[0] kv = kv.view( -1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim @@ -2351,6 +2347,107 @@ def forward_normal_chunked_kv_core(self, q, k, v, forward_batch): output, _ = self.o_proj(attn_output) return output + def forward_normal_one_shot_prepare( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + zero_allocator: BumpAllocator, + ): + forward_batch.mha_one_shot = True + return self.forward_normal_prepare( + positions, hidden_states, forward_batch, zero_allocator + ) + + def forward_normal_one_shot_core(self, q, k, v, forward_batch): + has_extend_prefix = any(forward_batch.extend_prefix_lens_cpu) + # Only initialize the info once + if has_extend_prefix and forward_batch.num_prefix_chunks is None: + forward_batch.num_prefix_chunks = 0 + if hasattr(forward_batch.attn_backend, "init_mha_chunk_metadata"): + forward_batch.attn_backend.init_mha_chunk_metadata(forward_batch) + forward_batch.mha_return_lse = False + # Do mha for extended part without prefix + forward_batch.set_attn_attend_prefix_cache(False) + return self.forward_normal_core(q, k, v, forward_batch) + + def _set_mla_kv_buffer( + self, + latent_cache: torch.Tensor, + kv_a: torch.Tensor, + k_pe: torch.Tensor, + forward_batch: ForwardBatch, + ): + if _is_cuda: + # Save latent cache + forward_batch.token_to_kv_pool.set_mla_kv_buffer( + self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe + ) + elif _is_npu: + # To reduce a time-costing split operation + forward_batch.token_to_kv_pool.set_kv_buffer( + self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe + ) + else: + latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1) + latent_cache[:, :, self.kv_lora_rank :] = k_pe + + # Save latent cache + forward_batch.token_to_kv_pool.set_kv_buffer( + self.attn_mha, forward_batch.out_cache_loc, latent_cache, None + ) + + def _get_mla_kv_buffer( + self, + kv_indices: torch.Tensor, + dst_dtype: torch.dtype, + forward_batch: ForwardBatch, + ): + if _is_cuda: + kv_a, k_pe = forward_batch.token_to_kv_pool.get_mla_kv_buffer( + self.attn_mha, kv_indices, dst_dtype + ) + kv_a = kv_a.squeeze(1) + else: + latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer( + self.attn_mha.layer_id + ) + latent_cache = latent_cache_buf[kv_indices].contiguous().to(dst_dtype) + + kv_a, k_pe = latent_cache.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + kv_a = kv_a.squeeze(1).contiguous() + return kv_a, k_pe + + def _concat_and_cast_mha_k(self, k_nope, k_pe, forward_batch): + # Temporary for DeepSeek V3/R1 only, but can generalize if needed + k_shape = (k_nope.shape[0], self.num_local_heads, self.qk_head_dim) + if ( + _is_cuda + and (self.num_local_heads == 128) + and (self.qk_nope_head_dim == 128) + and (self.qk_rope_head_dim == 64) + ): + k = k_nope.new_empty(*k_shape) + concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe) + elif _is_cuda: + # fa3 mha support fp8 inputs + if ( + self.current_attention_backend == "fa3" + and self.kv_cache_dtype != "auto" + ): + attn_dtype = forward_batch.token_to_kv_pool.dtype + else: + attn_dtype = k_nope.dtype + k = k_nope.new_empty(*k_shape, dtype=attn_dtype) + concat_and_cast_mha_k_triton(k, k_nope, k_pe) + else: + k = k_nope.new_empty(*k_shape) + k[..., : self.qk_nope_head_dim] = k_nope + k[..., self.qk_nope_head_dim :] = k_pe + return k + class DeepseekV2DecoderLayer(nn.Module):