Skip to content

Commit

Permalink
Fix AMD moe_align and triton stage config
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Dec 27, 2024
1 parent d315402 commit 57a5006
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 81 deletions.
2 changes: 1 addition & 1 deletion docker/Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ ENV SGLANG_SET_CPU_AFFINITY=1
ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1
ENV NCCL_MIN_NCHANNELS=112

ENV MOE_PADDING=0
ENV MOE_PADDING=1
ENV VLLM_FP8_PADDING=1
ENV VLLM_FP8_ACT_PADDING=1
ENV VLLM_FP8_WEIGHT_PADDING=1
Expand Down
16 changes: 8 additions & 8 deletions python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
import torch
import triton
import triton.language as tl
from sglang.srt.utils import is_hip
if not is_hip():
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
from vllm import _custom_ops as ops

from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.utils import direct_register_custom_op, get_device_name
from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip

is_hip_ = is_hip()

logger = logging.getLogger(__name__)
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
Expand Down Expand Up @@ -270,7 +270,7 @@ def moe_align_block_size(
)
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
# FIXME(zhyncs)
if num_experts >= 256 and not is_hip():
if num_experts >= 256:
sgl_moe_align_block_size(
topk_ids,
num_experts,
Expand Down Expand Up @@ -432,14 +432,14 @@ def get_default_config(
dtype: Optional[str],
is_marlin: bool,
) -> Dict[str, int]:
if dtype == "fp8_w8a8" and not is_hip():
if dtype == "fp8_w8a8":
config = {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4,
"num_stages": 2 if is_hip_ else 4,
}
if M <= E:
config = {
Expand All @@ -448,7 +448,7 @@ def get_default_config(
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4,
"num_stages": 2 if is_hip_ 4,
}
else:
config = {
Expand Down
184 changes: 112 additions & 72 deletions sgl-kernel/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
import shutil
import zipfile
from pathlib import Path
import torch
def is_hip() -> bool:
"""Return whether it is HIP on the AMD ROCm platform."""
return torch.version.hip is not None

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
Expand Down Expand Up @@ -58,80 +62,116 @@ def update_wheel_platform_tag():
old_wheel.rename(new_wheel)


setup(
name="sgl-kernel",
version=get_version(),
packages=["sgl_kernel"],
package_dir={"": "src"},
ext_modules=[
CUDAExtension(
"sgl_kernel.ops.warp_reduce_cuda",
[
"src/sgl-kernel/csrc/warp_reduce.cc",
"src/sgl-kernel/csrc/warp_reduce_kernel.cu",
],
extra_compile_args={
"nvcc": [
"-O3",
"-Xcompiler",
"-fPIC",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",

if not is_hip():
setup(
name="sgl-kernel",
version=get_version(),
packages=["sgl_kernel"],
package_dir={"": "src"},
ext_modules=[
CUDAExtension(
"sgl_kernel.ops.warp_reduce_cuda",
[
"src/sgl-kernel/csrc/warp_reduce.cc",
"src/sgl-kernel/csrc/warp_reduce_kernel.cu",
],
extra_compile_args={
"nvcc": [
"-O3",
"-Xcompiler",
"-fPIC",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",
],
"cxx": ["-O3"],
},
libraries=["c10", "torch", "torch_python"],
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
),
CUDAExtension(
"sgl_kernel.ops.custom_reduce_cuda",
[
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/trt_reduce.cc",
],
"cxx": ["-O3"],
},
libraries=["c10", "torch", "torch_python"],
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
),
CUDAExtension(
"sgl_kernel.ops.custom_reduce_cuda",
[
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/trt_reduce.cc",
],
extra_compile_args={
"nvcc": [
"-O3",
"-Xcompiler",
"-fPIC",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
extra_compile_args={
"nvcc": [
"-O3",
"-Xcompiler",
"-fPIC",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
],
"cxx": ["-O3"],
},
libraries=["c10", "torch", "torch_python"],
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
),
CUDAExtension(
"sgl_kernel.ops.moe_align_block_size",
[
"src/sgl-kernel/csrc/moe_align_kernel.cu",
],
"cxx": ["-O3"],
},
libraries=["c10", "torch", "torch_python"],
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
),
CUDAExtension(
"sgl_kernel.ops.moe_align_block_size",
[
"src/sgl-kernel/csrc/moe_align_kernel.cu",
],
extra_compile_args={
"nvcc": [
"-O3",
"-Xcompiler",
"-fPIC",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",
extra_compile_args={
"nvcc": [
"-O3",
"-Xcompiler",
"-fPIC",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",
],
"cxx": ["-O3"],
},
libraries=["c10", "torch", "torch_python"],
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
),
],
cmdclass={"build_ext": BuildExtension},
install_requires=["torch"],
)
else:
hipcc_flags = [
"-D__HIP_PLATFORM_AMD__=1",
"--amdgpu-target=gfx942",
"-DENABLE_BF16", # Enable BF16 for cuda_version >= 11
"-DENABLE_FP8", # Enable FP8 for cuda_version >= 11.8
]
setup(
name="sgl-kernel",
version=get_version(),
packages=["sgl_kernel"],
package_dir={"": "src"},
ext_modules=[
CUDAExtension(
"sgl_kernel.ops.moe_align_block_size",
[
"src/sgl-kernel/csrc/moe_align_kernel.cu",
],
"cxx": ["-O3"],
},
libraries=["c10", "torch", "torch_python"],
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
),
],
cmdclass={"build_ext": BuildExtension},
install_requires=["torch"],
)
extra_compile_args={
"nvcc": hipcc_flags + [
"-O3",
"-Xcompiler",
"-fPIC",
],
"cxx": ["-O3"],
},
libraries=["hiprtc", "amdhip64", "c10", "torch", "torch_python"],
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
),
],
cmdclass={"build_ext": BuildExtension},
install_requires=["torch"],
)

update_wheel_platform_tag()

0 comments on commit 57a5006

Please sign in to comment.