Writing GPU kernels with FlyDSL:
@flyc.jit,@flyc.kernel, expression API, launch configuration, shared memory, and synchronization.
API: This guide documents the
@flyc.kernel/@flyc.jitAPI fromflydsl.compilerandflydsl.expr(python/flydsl/).
| Concept | API | Description |
|---|---|---|
| JIT host func | @flyc.jit |
Emit host-side launcher with JIT compilation |
| GPU kernel | @flyc.kernel |
Define GPU kernel function |
| Launch | kernel(...).launch(grid=, block=) |
Configure and emit GPU launch |
| Thread ID | fx.gpu.thread_idx.x |
Get thread index in workgroup |
| Block ID | fx.gpu.block_idx.x |
Get block/workgroup index |
| Block dim | fx.gpu.block_dim.x |
Get block dimension size |
| Compile-time | fx.Constexpr[int] |
Compile-time constant parameter |
| Tensor arg | fx.Tensor |
GPU tensor argument (via DLPack) |
| Stream arg | fx.Stream |
CUDA/HIP stream argument |
| Barrier | fx.gpu.barrier() |
Workgroup synchronization |
| Constants | fx.Int32 / fx.Index / fx.Float32 |
Create typed DSL constants |
| Range loop | range_constexpr(n) |
Compile-time unrolled loop |
| Buffer load | buffer_ops.buffer_load(rsrc, off) |
AMD buffer load intrinsic |
import flydsl.compiler as flyc
import flydsl.expr as fx
from flydsl.expr import gpu
@flyc.kernel
def vec_add_kernel(
A: fx.Tensor,
B: fx.Tensor,
C: fx.Tensor,
N: fx.Constexpr[int],
):
tid = gpu.thread_idx.x
bid = gpu.block_idx.x
idx = bid * 256 + tid
# ... kernel body using fx.*, ArithValue, Vector, and buffer ops ...
@flyc.jit
def vec_add(
A: fx.Tensor,
B: fx.Tensor,
C: fx.Tensor,
N: fx.Constexpr[int],
stream: fx.Stream = fx.Stream(None),
):
vec_add_kernel(A, B, C, N).launch(
grid=(N // 256,),
block=(256,),
stream=stream,
)
# Usage:
import torch
A = torch.randn(1024, device="cuda", dtype=torch.float32)
B = torch.randn(1024, device="cuda", dtype=torch.float32)
C = torch.empty(1024, device="cuda", dtype=torch.float32)
vec_add(A, B, C, 1024)@flyc.kernelwraps the function as aKernelFunction@flyc.jitwraps the function as aJitFunction- On first call,
JitFunction.__call__triggers:- AST rewriting (Python loops/ifs → MLIR scf ops)
- MLIR module creation with
gpu.container_module - Tracing the jit function body to generate MLIR ops
- Calling
vec_add_kernel(...)emits agpu.funcingpu.module .launch()emitsgpu.launch_funcMlirCompiler.compile()runs the full pass pipelineJITCFunctionwraps the resulting ExecutionEngine
- Subsequent calls with the same type signature use the cached binary
Maps a PyTorch tensor to an MLIR memref descriptor via DLPack:
@flyc.kernel
def my_kernel(input: fx.Tensor, output: fx.Tensor):
# input and output are Tensor wrappers around ir.Value (memref)
...At the host boundary, torch.Tensor is automatically converted via TensorAdaptor.
Compile-time constant. Value is embedded directly in the generated IR:
@flyc.kernel
def my_kernel(data: fx.Tensor, N: fx.Constexpr[int], dtype: fx.Constexpr[str]):
for i in range_constexpr(N // 64): # unrolled at compile time
...Different Constexpr values produce different compiled kernels (separate cache entries).
Runtime integer parameter (passed as i32):
@flyc.jit
def launch(data: fx.Tensor, size: fx.Int32, stream: fx.Stream = fx.Stream(None)):
...Python int values are automatically converted to Int32 via the JitArgumentRegistry.
CUDA/HIP stream for asynchronous kernel launch:
@flyc.jit
def launch(data: fx.Tensor, stream: fx.Stream = fx.Stream(None)):
my_kernel(data).launch(grid=(1,), block=(256,), stream=stream)
# Launch on specific stream:
stream = torch.cuda.Stream()
launch(data, stream=fx.Stream(stream))Register new Python types for the JIT boundary:
from flydsl.compiler import JitArgumentRegistry
@JitArgumentRegistry.register(MyCustomType, dsl_type=MyDslType)
class MyCustomAdaptor:
def __init__(self, value: MyCustomType):
self.value = value
def __get_ir_types__(self):
return [...] # MLIR types for this argument
def __get_c_pointers__(self):
return [...] # ctypes pointers for invocationfrom flydsl.expr import gpu
# Thread index within workgroup (returns Int32)
tid_x = gpu.thread_idx.x
tid_y = gpu.thread_idx.y
tid_z = gpu.thread_idx.z
# Block (workgroup) index within grid
bid_x = gpu.block_idx.x
bid_y = gpu.block_idx.y
# Block dimensions
bdim_x = gpu.block_dim.x
# Grid dimensions
gdim_x = gpu.grid_dim.x
# Low-level (returns raw ir.Value)
raw_tid = gpu.thread_id("x")
raw_bid = gpu.block_id("x")import flydsl.expr as fx
# Constants (prefer DSL numeric types)
c42 = fx.Index(42) # index type constant
c3_14 = fx.Float32(3.14) # f32 constant
mask = fx.Int32(0xFF) # i32 constant
# Arithmetic (operator overloading via ArithValue / Numeric)
result = a + b
result = a * 2
result = a // 4
result = a % 16
# Cast (prefer DSL numeric constructors)
idx = fx.Index(int_val) # cast to index type
i32_val = fx.Int32(idx) # cast to i32
# Select
result = cond.select(true_val, false_val) # when cond is an ArithValue
# Bitwise
result = a & b
result = a ^ b
result = a << 4Use direct arith.*FOp(..., fastmath=...) only where explicit fastmath flags are performance-critical.
from flydsl.expr.typing import Vector as Vec
# Build vector from elements
vec = Vec.from_elements([a, b, c, d], fx.Float32)
# Vector store to memref
vec.store(memref, [idx])
# Extract, bitcast, and convert
elem = vec[idx]
as_i32 = vec.bitcast(fx.Int32)
as_bf16 = vec.to(fx.BFloat16)AMD buffer load/store intrinsics for efficient global memory access:
from flydsl.expr import buffer_ops
# Create buffer resource descriptor from memref
rsrc = buffer_ops.create_buffer_resource(memref_value)
# Buffer load (vectorized)
data = buffer_ops.buffer_load(rsrc, byte_offset, vec_width=4)
# Buffer store
buffer_ops.buffer_store(data, rsrc, byte_offset)from flydsl.expr import rocdl
# Buffer tensor — wraps a Tensor with AMD buffer resource descriptor
A_buf = rocdl.make_buffer_tensor(A)
# MFMA MMA atom constructor — returns MmaAtomCDNA3_MFMAType
atom_type = rocdl.MFMA(m=16, n=16, k=32, elem_ty_ab=fx.Float8E4M3FNUZ)
# Buffer copy atom types
copy_op = rocdl.BufferCopy128b() # 128-bit buffer copy
copy_op = rocdl.BufferCopy64b() # 64-bit buffer copy
copy_op = rocdl.BufferCopy32b() # 32-bit buffer copySignature: (result_type, [a, b, c, cbsz, abid, blgp]) — trailing ints default to 0.
result = rocdl.mfma_f32_16x16x16f16(result_type, [a, b, acc])
result = rocdl.mfma_f32_16x16x32_fp8_fp8(result_type, [a, b, acc])
result = rocdl.mfma_i32_16x16x32_i8(result_type, [a, b, acc])
result = rocdl.mfma_f32_16x16x16bf16_1k(result_type, [a, b, acc]) # BF16 1K variant
# GFX950 scaled MFMA (MXFP4/FP6/FP8)
result = rocdl.mfma_scale_f32_16x16x128_f8f6f4(
result_type, [a, b, acc, cbsz, blgp, opselA, scaleA, opselB, scaleB]
)Control instruction scheduling for performance tuning:
rocdl.sched_mfma(cnt) # wait for cnt MFMA instructions to complete
rocdl.sched_vmem(cnt) # wait for cnt VMEM reads to complete
rocdl.sched_dsrd(cnt) # wait for cnt DS (LDS) reads to complete
rocdl.sched_dswr(cnt) # wait for cnt DS (LDS) writes to completeSingle-instruction hardware math (guaranteed 1 VALU cycle, lower precision than math.*):
# Base-2 exponential (v_exp_f32)
result = rocdl.exp2(T.f32, x)
# Reciprocal (v_rcp_f32)
result = rocdl.rcp(T.f32, x)# Warp shuffle
val = rocdl.ds_bpermute(idx, src)
# Buffer load/store (raw)
data = rocdl.raw_ptr_buffer_load(rsrc, offset, soffset, aux)
rocdl.raw_ptr_buffer_store(data, rsrc, offset, soffset, aux)from flydsl.expr import gpu
# Barrier (workgroup synchronization)
gpu.barrier()
# Shared memory address space attribute
addrspace = gpu.smem_space()
addrspace_int = gpu.smem_space(int=True)The ASTRewriter automatically transforms Python for loops:
@flyc.kernel
def my_kernel(data: fx.Tensor, N: fx.Constexpr[int]):
# Compile-time unrolled loop
for i in range_constexpr(N):
# This loop is fully unrolled in the generated IR
...
# Runtime loop (lowered by the AST rewriter)
for i in range(runtime_value):
...Mark a value as compile-time constant:
from flydsl.expr import const_expr
@flyc.kernel
def my_kernel(data: fx.Tensor, N: fx.Constexpr[int]):
tile_size = const_expr(N // 4)
for i in range_constexpr(tile_size):
...from flydsl.utils.smem_allocator import SmemAllocator
from flydsl.expr.typing import T
# Create allocator for target architecture
allocator = SmemAllocator(None, arch="gfx942", global_sym_name="smem0")
# Allocate typed arrays
lds_a = allocator.allocate_array(T.f16, 8192)
lds_b = allocator.allocate_array(T.f16, 8192)
# Inside kernel: get base pointer and typed views
lds_base = allocator.get_base()
lds_a_ptr = lds_a(lds_base) # SmemPtr
lds_b_ptr = lds_b(lds_base) # SmemPtr
# Load/store through SmemPtr
val = lds_a_ptr.load([idx])
lds_b_ptr.store(val, [idx])For @flyc.kernel style kernels, finalize the allocator in the GPU module:
comp_ctx = CompilationContext.get_current()
with ir.InsertionPoint(comp_ctx.gpu_module_body):
allocator.finalize()| Architecture | LDS per CU |
|---|---|
gfx942 (MI300X) |
64 KB |
gfx950 (MI350/MI355X) |
160 KB |
gfx1201 (Radeon AI PRO R9700) |
64 KB |
gfx1250 |
320 KB |
@flyc.jit
def launch(data: fx.Tensor, stream: fx.Stream = fx.Stream(None)):
my_kernel(data).launch(
grid=(num_blocks_x, num_blocks_y, num_blocks_z),
block=(threads_x, threads_y, threads_z),
smem=shared_mem_bytes, # dynamic shared memory
stream=stream, # CUDA/HIP stream
)Grid and block dimensions accept:
int— static valueir.Value— dynamic MLIR value- Tuple of 1–3 values — missing dimensions default to 1
@flyc.jit
def launch(data: fx.Tensor, M: fx.Int32, stream: fx.Stream = fx.Stream(None)):
grid_x = M // 256
my_kernel(data, M).launch(
grid=(grid_x, 1, 1),
block=(256, 1, 1),
stream=stream,
)from flydsl.expr import gpu
# Workgroup barrier (s_barrier)
gpu.barrier()JIT-compiled functions are cached automatically:
- In-memory cache — keyed by argument type signature
- Disk cache — stored in
~/.flydsl/cache/(configurable viaFLYDSL_RUNTIME_CACHE_DIR) - Cache key includes: source code hash, dependency sources, closure values, FlyDSL version, LLVM version
Cache is invalidated when:
- Source code of the function or its dependencies changes
- Argument types change (different tensor shapes/dtypes)
Constexprvalues change- FlyDSL or LLVM version changes
The JIT disk cache auto-invalidates when kernel source code or closure values change. Set FLYDSL_RUNTIME_ENABLE_CACHE=0 only when modifying C++ passes or non-closure helper functions:
FLYDSL_RUNTIME_ENABLE_CACHE=0 python my_script.py # or: rm -rf ~/.flydsl/cacheCOMPILE_ONLY=1 python my_script.pyFLYDSL_DUMP_IR=1 FLYDSL_DUMP_DIR=./my_dumps python my_script.py# After compilation, access IR from the compiled function:
result = launch(A, B, C, 1024)
# Or use JITCFunction directly:
compiled_func.print_ir() # compiled MLIR IR
compiled_func.print_ir(compiled=False) # original IR before passesFLYDSL_DEBUG_AST_DIFF=1 python my_script.pyShows the diff between original and rewritten AST for debugging control flow transformations.
From kernels/preshuffle_gemm.py:
import flydsl.compiler as flyc
import flydsl.expr as fx
from flydsl.expr import gpu, buffer_ops, rocdl, range_constexpr
from flydsl.expr.typing import T
from flydsl.utils.smem_allocator import SmemAllocator
def compile_preshuffle_gemm_a8(*, M, N, K, tile_m, tile_n, tile_k,
in_dtype="fp8", lds_stage=2, ...):
allocator = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem0")
lds_a = allocator.allocate_array(T.i8, tile_m * tile_k)
# ... more allocations ...
@flyc.kernel
def gemm_kernel(
arg_c: fx.Tensor, arg_a: fx.Tensor, arg_b: fx.Tensor,
arg_scale_a: fx.Tensor, arg_scale_b: fx.Tensor,
m_in: fx.Int32, n_in: fx.Int32,
):
tid = gpu.thread_idx.x
bid = gpu.block_idx.x
# ... complex GEMM implementation using MFMA, LDS, tiling ...
@flyc.jit
def launch_fn(
arg_c: fx.Tensor, arg_a: fx.Tensor, arg_b: fx.Tensor,
arg_scale_a: fx.Tensor, arg_scale_b: fx.Tensor,
M_val: fx.Int32, N_val: fx.Int32,
stream: fx.Stream = fx.Stream(None),
):
gemm_kernel(arg_c, arg_a, arg_b, arg_scale_a, arg_scale_b,
M_val, N_val).launch(
grid=(grid_x, grid_y), block=(256,),
smem=smem_bytes, stream=stream,
)
return launch_fnWriting a new kernel?
│
├── Simple element-wise?
│ ├── Use @flyc.kernel + @flyc.jit
│ ├── fx.gpu.thread_idx.x for thread indexing
│ └── See tests/kernels/test_vec_add.py
│
├── Reduction (norm, softmax)?
│ ├── Use warp_reduce / block_reduce from kernels/reduce.py
│ └── See kernels/layernorm_kernel.py, kernels/softmax_kernel.py
│
├── Matrix multiply (GEMM)?
│ ├── Use @flyc.kernel + SmemAllocator + MFMA
│ ├── B-preshuffle layout from mfma_preshuffle_pipeline.py
│ └── See kernels/preshuffle_gemm.py
│
├── Need shared memory?
│ ├── Use SmemAllocator with target arch
│ ├── Call finalize() in GPU module body
│ └── Call get_base() inside @kernel
│
└── Need compile-time specialization?
├── Use Constexpr[T] parameters
└── Use range_constexpr() for unrolled loops
| File | Description |
|---|---|
python/flydsl/compiler/__init__.py |
Public API: jit, kernel, from_dlpack |
python/flydsl/compiler/jit_function.py |
@jit decorator, MlirCompiler, JitCacheManager |
python/flydsl/compiler/kernel_function.py |
@kernel decorator, KernelFunction, KernelLauncher |
python/flydsl/compiler/jit_executor.py |
JITCFunction (ExecutionEngine wrapper) |
python/flydsl/compiler/jit_argument.py |
JitArgumentRegistry, TensorAdaptor |
python/flydsl/compiler/ast_rewriter.py |
ASTRewriter — Python AST → MLIR control flow |
python/flydsl/expr/typing.py |
Types (T), Tensor, Stream, Constexpr |
python/flydsl/expr/arith.py |
Arithmetic operations |
python/flydsl/expr/vector.py |
Vector dialect operations |
python/flydsl/expr/gpu.py |
GPU operations (thread_id, barrier, ...) |
python/flydsl/expr/buffer_ops.py |
AMD buffer load/store operations |
python/flydsl/expr/rocdl/ |
ROCm dialect intrinsics (MFMA/WMMA, buffer, TDM, cluster) |
python/flydsl/expr/primitive.py |
Layout algebra primitives (make_shape, crd2idx, etc.) |
python/flydsl/utils/smem_allocator.py |
SmemAllocator, SmemPtr, LDS management |
kernels/preshuffle_gemm.py |
Preshuffle GEMM kernel example |
kernels/reduce.py |
Warp/block reduction primitives |
tests/kernels/test_vec_add.py |
Vector add kernel test |
tests/kernels/test_preshuffle_gemm.py |
Preshuffle GEMM test |