diff --git a/ds4_metal.m b/ds4_metal.m index 7c3556b57..668329da3 100644 --- a/ds4_metal.m +++ b/ds4_metal.m @@ -833,6 +833,21 @@ static int ds4_gpu_use_mpp_attn_out_low_matmul(void) { return ds4_gpu_mpp_available(); } +static int ds4_gpu_use_decode_moe_minigemm(void) { + static int initialized; + static int enabled; + if (!initialized) { + enabled = ds4_gpu_env_bool("DS4_METAL_DECODE_MOE_MINIGEMM") > 0; + if (enabled) { + fprintf(stderr, + "ds4: experimental Metal decode MoE mini-GEMM enabled by " + "DS4_METAL_DECODE_MOE_MINIGEMM\n"); + } + initialized = 1; + } + return enabled; +} + enum { DS4_METAL_ATTN_OUT_MPP_TILE_N = 64, }; @@ -2584,6 +2599,18 @@ static int ds4_gpu_encode_mul_mm_id_mapped_tile( NSUInteger src1_off, id dst, NSUInteger dst_off); +static int ds4_gpu_encode_mul_mm_selected_sum( + id cb, + id mm_pipeline, + const ds4_gpu_mul_mm_id_args *mm_args, + id src0, + NSUInteger src0_off, + id src1, + NSUInteger src1_off, + id ids, + NSUInteger ids_off, + id dst, + NSUInteger dst_off); typedef struct { int32_t ne11; @@ -3002,6 +3029,28 @@ static int ds4_gpu_encode_rope_tail_inplace( uint64_t dst_token_stride; } ds4_gpu_dsv4_moe_sum6_args; +static int ds4_gpu_encode_mul_mm_selected_pair_swiglu( + id cb, + id mm_pipeline, + const ds4_gpu_mul_mm_id_args *mm_args, + const ds4_gpu_dsv4_moe_swiglu_weight_args *act_args, + id src0_gate, + NSUInteger src0_gate_off, + id src0_up, + NSUInteger src0_up_off, + id src1, + NSUInteger src1_off, + id dst_gate, + NSUInteger dst_gate_off, + id dst_up, + NSUInteger dst_up_off, + id dst_mid, + NSUInteger dst_mid_off, + id ids, + NSUInteger ids_off, + id weights, + NSUInteger weights_off); + /* Compile the single in-repo Metal source and create the pipelines that every * session uses. Shape-dependent kernels with function constants are built * lazily by the small ds4_gpu_get_* caches, so startup stays predictable @@ -12745,6 +12794,48 @@ static NSUInteger ds4_gpu_routed_mv_smem(uint32_t type) { } } +static id ds4_gpu_routed_mm_selected_sum_pipeline( + uint32_t type, + bool rhs_f16) { + switch (type) { + case DS4_METAL_TENSOR_IQ2_XXS: + return ds4_gpu_get_mul_mm_id_pipeline(rhs_f16 ? + "kernel_mul_mm_selected_sum_iq2_xxs_f16" : + "kernel_mul_mm_selected_sum_iq2_xxs_f32", false); + case DS4_METAL_TENSOR_Q2_K: + return ds4_gpu_get_mul_mm_id_pipeline(rhs_f16 ? + "kernel_mul_mm_selected_sum_q2_K_f16" : + "kernel_mul_mm_selected_sum_q2_K_f32", false); + case DS4_METAL_TENSOR_Q4_K: + return ds4_gpu_get_mul_mm_id_pipeline(rhs_f16 ? + "kernel_mul_mm_selected_sum_q4_K_f16" : + "kernel_mul_mm_selected_sum_q4_K_f32", false); + default: + return nil; + } +} + +static id ds4_gpu_routed_mm_selected_pair_swiglu_pipeline( + uint32_t type, + bool mid_f16) { + switch (type) { + case DS4_METAL_TENSOR_IQ2_XXS: + return ds4_gpu_get_mul_mm_id_pipeline(mid_f16 ? + "kernel_mul_mm_selected_iq2_xxs_pair_swiglu_f16" : + "kernel_mul_mm_selected_iq2_xxs_pair_swiglu_f32", false); + case DS4_METAL_TENSOR_Q2_K: + return ds4_gpu_get_mul_mm_id_pipeline(mid_f16 ? + "kernel_mul_mm_selected_q2_K_pair_swiglu_f16" : + "kernel_mul_mm_selected_q2_K_pair_swiglu_f32", false); + case DS4_METAL_TENSOR_Q4_K: + return ds4_gpu_get_mul_mm_id_pipeline(mid_f16 ? + "kernel_mul_mm_selected_q4_K_pair_swiglu_f16" : + "kernel_mul_mm_selected_q4_K_pair_swiglu_f32", false); + default: + return nil; + } +} + static int ds4_gpu_encode_mul_mv_id( id cb, id pipeline, @@ -13104,6 +13195,92 @@ static int ds4_gpu_encode_mul_mm_id_mapped( dst_off); } +static int ds4_gpu_encode_mul_mm_selected_sum( + id cb, + id mm_pipeline, + const ds4_gpu_mul_mm_id_args *mm_args, + id src0, + NSUInteger src0_off, + id src1, + NSUInteger src1_off, + id ids, + NSUInteger ids_off, + id dst, + NSUInteger dst_off) { + if (!cb || !mm_pipeline || !mm_args || !src0 || !src1 || !ids || !dst || + mm_args->ne00 <= 0 || mm_args->ne0 <= 0 || + mm_args->ne20 <= 0 || mm_args->ne21 <= 0 || mm_args->ne02 <= 0) { + return 0; + } + + const NSUInteger tile_n = 8u; + id enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:mm_pipeline]; + [enc setBytes:mm_args length:sizeof(*mm_args) atIndex:0]; + [enc setBuffer:src0 offset:src0_off atIndex:1]; + [enc setBuffer:src1 offset:src1_off atIndex:2]; + [enc setBuffer:ids offset:ids_off atIndex:3]; + [enc setBuffer:dst offset:dst_off atIndex:4]; + [enc setThreadgroupMemoryLength:8192u atIndex:0]; + [enc dispatchThreadgroups:MTLSizeMake(((NSUInteger)mm_args->ne21 + tile_n - 1u) / tile_n, + ((NSUInteger)mm_args->ne0 + 63u) / 64u, + 1) + threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, enc); + return 1; +} + +static int ds4_gpu_encode_mul_mm_selected_pair_swiglu( + id cb, + id mm_pipeline, + const ds4_gpu_mul_mm_id_args *mm_args, + const ds4_gpu_dsv4_moe_swiglu_weight_args *act_args, + id src0_gate, + NSUInteger src0_gate_off, + id src0_up, + NSUInteger src0_up_off, + id src1, + NSUInteger src1_off, + id dst_gate, + NSUInteger dst_gate_off, + id dst_up, + NSUInteger dst_up_off, + id dst_mid, + NSUInteger dst_mid_off, + id ids, + NSUInteger ids_off, + id weights, + NSUInteger weights_off) { + if (!cb || !mm_pipeline || !mm_args || !act_args || + !src0_gate || !src0_up || !src1 || !dst_gate || !dst_up || + !dst_mid || !ids || !weights || + mm_args->ne00 <= 0 || mm_args->ne0 <= 0 || + mm_args->ne20 <= 0 || mm_args->ne21 <= 0 || mm_args->ne02 <= 0) { + return 0; + } + + const NSUInteger tile_n = 8u; + id enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:mm_pipeline]; + [enc setBytes:mm_args length:sizeof(*mm_args) atIndex:0]; + [enc setBytes:act_args length:sizeof(*act_args) atIndex:1]; + [enc setBuffer:src0_gate offset:src0_gate_off atIndex:2]; + [enc setBuffer:src0_up offset:src0_up_off atIndex:3]; + [enc setBuffer:src1 offset:src1_off atIndex:4]; + [enc setBuffer:dst_gate offset:dst_gate_off atIndex:5]; + [enc setBuffer:dst_up offset:dst_up_off atIndex:6]; + [enc setBuffer:dst_mid offset:dst_mid_off atIndex:7]; + [enc setBuffer:ids offset:ids_off atIndex:8]; + [enc setBuffer:weights offset:weights_off atIndex:9]; + [enc setThreadgroupMemoryLength:16384u atIndex:0]; + [enc dispatchThreadgroups:MTLSizeMake(((NSUInteger)mm_args->ne21 + tile_n - 1u) / tile_n, + ((NSUInteger)mm_args->ne0 + 63u) / 64u, + (NSUInteger)mm_args->ne20) + threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, enc); + return 1; +} + static int ds4_gpu_encode_attn_out_low_q8_mpp( id cb, id pipeline, @@ -14074,6 +14251,8 @@ int ds4_gpu_routed_moe_one_tensor( const uint32_t down_nr0 = ds4_gpu_routed_mv_nr0(down_type); id gate_mv_pipeline = ds4_gpu_routed_mv_pipeline(gate_type); id down_mv_pipeline = ds4_gpu_routed_mv_pipeline(down_type); + id gate_up_mm_pipeline = nil; + id down_mm_pipeline = nil; if (gate_nr0 == 0 || down_nr0 == 0 || !gate_mv_pipeline || !down_mv_pipeline) { fprintf(stderr, "ds4: unsupported Metal routed MoE quant types gate=%u down=%u\n", gate_type, down_type); @@ -14088,6 +14267,32 @@ int ds4_gpu_routed_moe_one_tensor( ds4_gpu_make_mul_mv_id_args(expert_mid_dim, out_dim, n_total_expert, down_row_bytes, down_expert_bytes, n_expert, n_expert, n_tokens, down_nr0); + const bool request_mid_f16 = !g_quality_mode; + bool use_minigemm = + !g_quality_mode && + ds4_gpu_use_decode_moe_minigemm() && + !getenv("DS4_METAL_MOE_WRITE_CLAMPED_ACT") && + !getenv("DS4_METAL_GRAPH_DUMP_PREFIX"); + ds4_gpu_mul_mm_id_args gate_mm_args = { 0 }; + ds4_gpu_mul_mm_id_args down_mm_args = { 0 }; + if (use_minigemm) { + gate_mm_args = + ds4_gpu_make_mul_mm_id_args(expert_in_dim, expert_mid_dim, n_total_expert, + gate_row_bytes, gate_expert_bytes, + 1, n_expert, n_tokens); + down_mm_args = + ds4_gpu_make_mul_mm_id_args_src1_size(expert_mid_dim, out_dim, n_total_expert, + down_row_bytes, down_expert_bytes, + n_expert, n_expert, n_tokens, + request_mid_f16 ? sizeof(uint16_t) : sizeof(float)); + gate_up_mm_pipeline = + ds4_gpu_routed_mm_selected_pair_swiglu_pipeline(gate_type, request_mid_f16); + down_mm_pipeline = + ds4_gpu_routed_mm_selected_sum_pipeline(down_type, request_mid_f16); + if (!gate_up_mm_pipeline || !down_mm_pipeline) { + use_minigemm = false; + } + } int owned = 0; id cb = ds4_gpu_command_buffer(&owned); @@ -14098,6 +14303,50 @@ int ds4_gpu_routed_moe_one_tensor( int ok = 1; const bool write_clamped_moe = getenv("DS4_METAL_MOE_WRITE_CLAMPED_ACT") != NULL; + const bool moe_stage_profile = + getenv("DS4_METAL_MOE_STAGE_PROFILE") != NULL && g_batch_cb != nil; + const char *moe_stage_filter = getenv("DS4_METAL_MOE_STAGE_PROFILE_FILTER"); + const char *moe_path = use_minigemm ? "mini_mma" : "mv"; + double moe_stage_t0 = moe_stage_profile ? ds4_gpu_now_ms() : 0.0; + if (moe_stage_profile) { + if (ds4_gpu_end_commands() == 0 || ds4_gpu_begin_commands() == 0) { + return 0; + } + cb = ds4_gpu_command_buffer(&owned); + if (!cb) return 0; + moe_stage_t0 = ds4_gpu_now_ms(); + } +#define DS4_METAL_PROFILE_DECODE_MOE_STAGE(name) do { \ + if (ok && moe_stage_profile) { \ + if (ds4_gpu_end_commands() == 0) { \ + ok = 0; \ + } else { \ + const char *stage_name = (name); \ + const double now_ms = ds4_gpu_now_ms(); \ + const int print_stage = \ + !moe_stage_filter || !moe_stage_filter[0] || \ + strstr(stage_name, moe_stage_filter) != NULL; \ + if (print_stage) { \ + fprintf(stderr, \ + "ds4: Metal decode routed MoE stage tokens=%u pairs=%u experts=%u " \ + "gate=%s down=%s path=%s mid=%s %s=%.3f ms\n", \ + n_tokens, pair_rows, n_expert, \ + ds4_gpu_metal_tensor_type_name(gate_type), \ + ds4_gpu_metal_tensor_type_name(down_type), \ + moe_path, \ + (use_minigemm && request_mid_f16) ? "f16" : "f32", \ + stage_name, now_ms - moe_stage_t0); \ + } \ + moe_stage_t0 = now_ms; \ + if (ds4_gpu_begin_commands() == 0) { \ + ok = 0; \ + } else { \ + cb = ds4_gpu_command_buffer(&owned); \ + if (!cb) ok = 0; \ + } \ + } \ + } \ + } while (0) id pair_swiglu_pipeline = nil; if (gate_type == DS4_METAL_TENSOR_IQ2_XXS) { pair_swiglu_pipeline = g_moe_mul_mv_id_iq2_xxs_pair_swiglu_pipeline; @@ -14106,10 +14355,43 @@ int ds4_gpu_routed_moe_one_tensor( } const bool fuse_pair_swiglu = !g_quality_mode && + !use_minigemm && !write_clamped_moe && getenv("DS4_METAL_DISABLE_ROUTED_PAIR_SWIGLU_FUSION") == NULL && pair_swiglu_pipeline != nil; - if (fuse_pair_swiglu) { + if (use_minigemm) { + ds4_gpu_dsv4_moe_swiglu_weight_args act_args = { + .width = expert_mid_dim, + .rows = pair_rows, + .gate_row_stride = (uint64_t)expert_mid_dim * sizeof(float), + .up_row_stride = (uint64_t)expert_mid_dim * sizeof(float), + .mid_row_stride = (uint64_t)expert_mid_dim * + (request_mid_f16 ? sizeof(uint16_t) : sizeof(float)), + .weight_stride = sizeof(float), + .write_clamped = 0, + .clamp_value = clamp, + }; + ok = ds4_gpu_encode_mul_mm_selected_pair_swiglu(cb, + gate_up_mm_pipeline, + &gate_mm_args, + &act_args, + gate_buf, + (NSUInteger)gate_inner, + up_buf, + (NSUInteger)up_inner, + xbuf, + ds4_gpu_tensor_offset(x), + gatebuf, + ds4_gpu_tensor_offset(gate), + upbuf, + ds4_gpu_tensor_offset(up), + midbuf, + ds4_gpu_tensor_offset(mid), + selectedbuf, + ds4_gpu_tensor_offset(selected), + weightsbuf, + ds4_gpu_tensor_offset(weights)); + } else if (fuse_pair_swiglu) { ds4_gpu_dsv4_moe_swiglu_weight_args act_args = { .width = expert_mid_dim, .rows = pair_rows, @@ -14215,7 +14497,8 @@ int ds4_gpu_routed_moe_one_tensor( 2, false); } - if (ok && !fuse_pair_swiglu) { + DS4_METAL_PROFILE_DECODE_MOE_STAGE("gate_up"); + if (ok && !fuse_pair_swiglu && !use_minigemm) { ok = ds4_gpu_encode_moe_swiglu_weight(cb, gatebuf, ds4_gpu_tensor_offset(gate), @@ -14228,8 +14511,9 @@ int ds4_gpu_routed_moe_one_tensor( expert_mid_dim, pair_rows, clamp, - false); + use_minigemm && request_mid_f16); } + DS4_METAL_PROFILE_DECODE_MOE_STAGE("activation_weight"); id down_dst = n_expert == 1 ? outbuf : (expertsbuf ? expertsbuf : g_moe_down_scratch_buffer); NSUInteger down_dst_off = n_expert == 1 ? ds4_gpu_tensor_offset(out) : @@ -14242,6 +14526,7 @@ int ds4_gpu_routed_moe_one_tensor( } const bool direct_down_sum = !g_quality_mode && + !use_minigemm && n_expert == 6 && n_tokens == 1 && down_sum6_pipeline != nil; @@ -14259,6 +14544,18 @@ int ds4_gpu_routed_moe_one_tensor( ds4_gpu_tensor_offset(selected), down_smem, 2); + } else if (ok && use_minigemm) { + ok = ds4_gpu_encode_mul_mm_selected_sum(cb, + down_mm_pipeline, + &down_mm_args, + down_buf, + (NSUInteger)down_inner, + midbuf, + ds4_gpu_tensor_offset(mid), + selectedbuf, + ds4_gpu_tensor_offset(selected), + outbuf, + ds4_gpu_tensor_offset(out)); } else if (ok) { ok = ds4_gpu_encode_mul_mv_id(cb, down_mv_pipeline, @@ -14275,7 +14572,8 @@ int ds4_gpu_routed_moe_one_tensor( 2, false); } - if (ok && n_expert > 1 && !direct_down_sum) { + DS4_METAL_PROFILE_DECODE_MOE_STAGE("down"); + if (ok && n_expert > 1 && !direct_down_sum && !use_minigemm) { ok = ds4_gpu_encode_moe_sum_experts(cb, down_dst, down_dst_off, @@ -14285,9 +14583,11 @@ int ds4_gpu_routed_moe_one_tensor( n_expert, n_tokens); } + DS4_METAL_PROFILE_DECODE_MOE_STAGE("sum"); if (!ok) return 0; if (!ds4_gpu_finish_command_buffer(cb, owned, "routed tensor MoE")) return 0; +#undef DS4_METAL_PROFILE_DECODE_MOE_STAGE } return 1; diff --git a/metal/moe.metal b/metal/moe.metal index 884e05265..dc780f108 100644 --- a/metal/moe.metal +++ b/metal/moe.metal @@ -128,6 +128,25 @@ struct ds4_metal_dsv4_moe_sum6_args { uint64_t dst_token_stride; }; +struct ds4_metal_args_mul_mm_selected { + int32_t ne00; + int32_t ne02; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne20; + int32_t ne21; + int32_t ne0; + int32_t ne1; + int16_t r2; + int16_t r3; +}; + // Routed-MoE activation for the selected experts: // clamp(gate), clamp(up), silu(gate) * up * route_weight. Normal inference // does not consume gate/up after this point, so the fast path avoids writing the @@ -1759,6 +1778,398 @@ template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id_f16 template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, half, half4x4, half, half2x4>; template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id_f16_rhs kernel_mul_mm_id<32, half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, half, half4x4, half, half2x4>; +// Decode-time selected-expert mini-GEMM for gate+up followed by routed SwiGLU. +// This keeps the heavy projections on the simdgroup MMA path while sharing the +// RHS activation tile between gate and up. It is intentionally selected-slot +// based like kernel_mul_mm_selected, not all-expert grouped like prefill. +template +kernel void kernel_mul_mm_selected_pair_swiglu( + constant ds4_metal_args_mul_mm_selected & args, + constant ds4_metal_dsv4_moe_swiglu_weight_args & act, + device const char * src0_gate, + device const char * src0_up, + device const char * src1, + device char * dst_gate, + device char * dst_up, + device char * dst_mid, + device const char * ids, + device const char * weights, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + threadgroup S0 * sa = (threadgroup S0 *)(shmem); + threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); + + constexpr int NR0 = 64; + constexpr int NR1 = 8; + constexpr int NK = 32; + constexpr int NL0 = NK/16; + constexpr int NL1 = NK/8; + + const int slot = tgpig.z; + const int r0 = tgpig.y*NR0; + const int r1 = tgpig.x*NR1; + + if (slot >= args.ne20 || r1 >= args.ne21) { + return; + } + + const int32_t expert = ((device const int32_t *)ids)[slot]; + if (expert < 0 || expert >= args.ne02) { + return; + } + + const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0; + const short nr1 = (args.ne21 - r1 < NR1) ? (args.ne21 - r1) : NR1; + + const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; + const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; + const short il0 = tiitg % NL0; + + short il = il0; + const uint64_t offset0 = (uint64_t)expert*args.nb02; + const short offset1 = il0/nl; + device const block_q * xg = + (device const block_q *)(src0_gate + args.nb01*(r0 + lr0) + offset0) + offset1; + device const block_q * xu = + (device const block_q *)(src0_up + args.nb01*(r0 + lr0) + offset0) + offset1; + + const short iy = 8*(tiitg % NL1); + const short i11 = slot % args.ne11; + const short i12 = r1 + lr1; + device const T1 * y = (device const T1 *)(src1 + + args.nb12*i12 + + args.nb11*i11 + + args.nb10*iy); + + S0_8x8 ma[4]; + S1_8x8 mb[2]; + simdgroup_float8x8 mc_gate[8]; + simdgroup_float8x8 mc_up[8]; + + for (short i = 0; i < 8; i++) { + mc_gate[i] = make_filled_simdgroup_matrix(0.f); + mc_up[i] = make_filled_simdgroup_matrix(0.f); + } + + for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) { + S0_4x4 temp_gate; + dequantize_func(xg, il, temp_gate); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + FOR_UNROLL (short i = 0; i < 16; i++) { + const short sx = 2*il0 + i/8; + const short sy = (tiitg/NL0)/8; + const short lx = (tiitg/NL0)%8; + const short ly = i%8; + const short ib = 8*sx + sy; + *(sa + 64*ib + 8*ly + lx) = temp_gate[i/4][i%4]; + } + + const short sx = tiitg%NL1; + const short sy = (tiitg/NL1)/8; + const short ly = (tiitg/NL1)%8; + const short ib = 4*sx + sy; + *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *)y)); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + threadgroup const S0 * lsma = sa + 4*64*(sgitg%2); + threadgroup const S1 * lsmb = sb + 2*64*(sgitg/2); + + FOR_UNROLL (short ik = 0; ik < NK/8; ik++) { + simdgroup_barrier(mem_flags::mem_none); + + FOR_UNROLL (short i = 0; i < 4; i++) { + simdgroup_load(ma[i], lsma + 64*i, 8, 0, false); + } + + simdgroup_barrier(mem_flags::mem_none); + + FOR_UNROLL (short i = 0; i < 2; i++) { + simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false); + } + + simdgroup_barrier(mem_flags::mem_none); + + FOR_UNROLL (short i = 0; i < 8; i++) { + simdgroup_multiply_accumulate(mc_gate[i], mb[i/4], ma[i%4], mc_gate[i]); + } + + lsma += 8*64; + lsmb += 4*64; + } + + S0_4x4 temp_up; + dequantize_func(xu, il, temp_up); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + FOR_UNROLL (short i = 0; i < 16; i++) { + const short sx = 2*il0 + i/8; + const short sy = (tiitg/NL0)/8; + const short lx = (tiitg/NL0)%8; + const short ly = i%8; + const short ib = 8*sx + sy; + *(sa + 64*ib + 8*ly + lx) = temp_up[i/4][i%4]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + lsma = sa + 4*64*(sgitg%2); + lsmb = sb + 2*64*(sgitg/2); + + FOR_UNROLL (short ik = 0; ik < NK/8; ik++) { + simdgroup_barrier(mem_flags::mem_none); + + FOR_UNROLL (short i = 0; i < 4; i++) { + simdgroup_load(ma[i], lsma + 64*i, 8, 0, false); + } + + simdgroup_barrier(mem_flags::mem_none); + + FOR_UNROLL (short i = 0; i < 2; i++) { + simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false); + } + + simdgroup_barrier(mem_flags::mem_none); + + FOR_UNROLL (short i = 0; i < 8; i++) { + simdgroup_multiply_accumulate(mc_up[i], mb[i/4], ma[i%4], mc_up[i]); + } + + lsma += 8*64; + lsmb += 4*64; + } + + il = (il + 2 < nl) ? il + 2 : il % 2; + xg = (il < 2) ? xg + (2 + nl - 1)/nl : xg; + xu = (il < 2) ? xu + (2 + nl - 1)/nl : xu; + y += NK; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + threadgroup float * temp_gate_base = (threadgroup float *)shmem; + threadgroup float * temp_up_base = (threadgroup float *)(shmem + 8192); + threadgroup float * temp_gate = + temp_gate_base + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0; + threadgroup float * temp_up = + temp_up_base + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0; + + for (short i = 0; i < 8; i++) { + simdgroup_store(mc_gate[i], temp_gate + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false); + simdgroup_store(mc_up[i], temp_up + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short j = sgitg; j < nr1; j += 4) { + const uint64_t row_index = (uint64_t)(r1 + j)*args.ne1 + (uint64_t)slot; + device float *G = act.write_clamped != 0 ? + (device float *)dst_gate + r0 + (uint64_t)slot*args.ne0 + + (uint64_t)(r1 + j)*args.ne1*args.ne0 : nullptr; + device float *U = act.write_clamped != 0 ? + (device float *)dst_up + r0 + (uint64_t)slot*args.ne0 + + (uint64_t)(r1 + j)*args.ne1*args.ne0 : nullptr; + device MidT *M = (device MidT *)(dst_mid + row_index*act.mid_row_stride); + device const float *route_w = + (device const float *)(weights + row_index*act.weight_stride); + const float route_weight = route_w[0]; + const float c = act.clamp_value; + + threadgroup float *Cg = temp_gate_base + j*NR0; + threadgroup float *Cu = temp_up_base + j*NR0; + for (int i = tiisg; i < nr0; i += 32) { + float g = Cg[i]; + float u = Cu[i]; + float gw = g; + float uw = u; + if (c > 1.0e-6f) { + gw = min(gw, c); + uw = clamp(uw, -c, c); + } + if (act.write_clamped != 0) { + G[i] = gw; + U[i] = uw; + } + const float silu = gw / (1.0f + exp(-gw)); + M[r0 + i] = (MidT)(silu * uw * route_weight); + } + } +} + +// Decode-time selected-expert down projection that directly accumulates all +// selected experts into the final token row. This is the mini-GEMM counterpart +// of the scalar `*_sum6` kernels: avoid writing per-expert down rows and avoid a +// separate sum kernel after the MMA path. +template +kernel void kernel_mul_mm_selected_sum( + constant ds4_metal_args_mul_mm_selected & args, + device const char * src0, + device const char * src1, + device const char * ids, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + threadgroup S0 * sa = (threadgroup S0 *)(shmem); + threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); + + constexpr int NR0 = 64; + constexpr int NR1 = 8; + constexpr int NK = 32; + constexpr int NL0 = NK/16; + constexpr int NL1 = NK/8; + + const int r0 = tgpig.y*NR0; + const int r1 = tgpig.x*NR1; + + if (r1 >= args.ne21) { + return; + } + + const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0; + const short nr1 = (args.ne21 - r1 < NR1) ? (args.ne21 - r1) : NR1; + + const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; + const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; + const short il0 = tiitg % NL0; + const short iy = 8*(tiitg % NL1); + + S0_8x8 ma[4]; + S1_8x8 mb[2]; + simdgroup_float8x8 mc[8]; + + for (short i = 0; i < 8; i++) { + mc[i] = make_filled_simdgroup_matrix(0.f); + } + + for (int slot = 0; slot < args.ne20; slot++) { + const int32_t expert = ((device const int32_t *)ids)[slot]; + if (expert < 0 || expert >= args.ne02) { + continue; + } + + short il = il0; + const uint64_t offset0 = (uint64_t)expert*args.nb02; + const short offset1 = il0/nl; + device const block_q * x = + (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1; + device const T1 * y = (device const T1 *)(src1 + + args.nb12*(r1 + lr1) + + args.nb11*slot + + args.nb10*iy); + + for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) { + S0_4x4 temp_a; + dequantize_func(x, il, temp_a); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + FOR_UNROLL (short i = 0; i < 16; i++) { + const short sx = 2*il0 + i/8; + const short sy = (tiitg/NL0)/8; + const short lx = (tiitg/NL0)%8; + const short ly = i%8; + const short ib = 8*sx + sy; + *(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4]; + } + + const short sx = tiitg%NL1; + const short sy = (tiitg/NL1)/8; + const short ly = (tiitg/NL1)%8; + const short ib = 4*sx + sy; + *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = + (S1_2x4)(*((device T1_2x4 *)y)); + + il = (il + 2 < nl) ? il + 2 : il % 2; + x = (il < 2) ? x + (2 + nl - 1)/nl : x; + y += NK; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + threadgroup const S0 * lsma = sa + 4*64*(sgitg%2); + threadgroup const S1 * lsmb = sb + 2*64*(sgitg/2); + + FOR_UNROLL (short ik = 0; ik < NK/8; ik++) { + simdgroup_barrier(mem_flags::mem_none); + + FOR_UNROLL (short i = 0; i < 4; i++) { + simdgroup_load(ma[i], lsma + 64*i, 8, 0, false); + } + + simdgroup_barrier(mem_flags::mem_none); + + FOR_UNROLL (short i = 0; i < 2; i++) { + simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false); + } + + simdgroup_barrier(mem_flags::mem_none); + + FOR_UNROLL (short i = 0; i < 8; i++) { + simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); + } + + lsma += 8*64; + lsmb += 4*64; + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + threadgroup float * temp_str = + ((threadgroup float *)shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0; + + for (short i = 0; i < 8; i++) { + simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short j = sgitg; j < nr1; j += 4) { + device float *D = (device float *)dst + r0 + (uint64_t)(r1 + j)*args.ne0; + device float4 *D4 = (device float4 *)D; + + threadgroup float *C = (threadgroup float *)shmem + j*NR0; + threadgroup float4 *C4 = (threadgroup float4 *)C; + + int i = tiisg; + for (; i < nr0/4; i += 32) { + *(D4 + i) = *(C4 + i); + } + + i = (4*(nr0/4)) + tiisg; + for (; i < nr0; i += 32) { + *(D + i) = *(C + i); + } + } +} + +typedef decltype(kernel_mul_mm_selected_pair_swiglu) mul_mm_selected_pair_swiglu; +typedef decltype(kernel_mul_mm_selected_pair_swiglu) mul_mm_selected_pair_swiglu_f16_mid; +typedef decltype(kernel_mul_mm_selected_sum) mul_mm_selected_sum; +typedef decltype(kernel_mul_mm_selected_sum) mul_mm_selected_sum_f16_rhs; + +template [[host_name("kernel_mul_mm_selected_q2_K_pair_swiglu_f32")]] kernel mul_mm_selected_pair_swiglu kernel_mul_mm_selected_pair_swiglu; +template [[host_name("kernel_mul_mm_selected_q4_K_pair_swiglu_f32")]] kernel mul_mm_selected_pair_swiglu kernel_mul_mm_selected_pair_swiglu; +template [[host_name("kernel_mul_mm_selected_iq2_xxs_pair_swiglu_f32")]] kernel mul_mm_selected_pair_swiglu kernel_mul_mm_selected_pair_swiglu; +template [[host_name("kernel_mul_mm_selected_q2_K_pair_swiglu_f16")]] kernel mul_mm_selected_pair_swiglu_f16_mid kernel_mul_mm_selected_pair_swiglu; +template [[host_name("kernel_mul_mm_selected_q4_K_pair_swiglu_f16")]] kernel mul_mm_selected_pair_swiglu_f16_mid kernel_mul_mm_selected_pair_swiglu; +template [[host_name("kernel_mul_mm_selected_iq2_xxs_pair_swiglu_f16")]] kernel mul_mm_selected_pair_swiglu_f16_mid kernel_mul_mm_selected_pair_swiglu; +template [[host_name("kernel_mul_mm_selected_sum_q2_K_f32")]] kernel mul_mm_selected_sum kernel_mul_mm_selected_sum; +template [[host_name("kernel_mul_mm_selected_sum_q4_K_f32")]] kernel mul_mm_selected_sum kernel_mul_mm_selected_sum; +template [[host_name("kernel_mul_mm_selected_sum_iq2_xxs_f32")]] kernel mul_mm_selected_sum kernel_mul_mm_selected_sum; +template [[host_name("kernel_mul_mm_selected_sum_q2_K_f16")]] kernel mul_mm_selected_sum_f16_rhs kernel_mul_mm_selected_sum; +template [[host_name("kernel_mul_mm_selected_sum_q4_K_f16")]] kernel mul_mm_selected_sum_f16_rhs kernel_mul_mm_selected_sum; +template [[host_name("kernel_mul_mm_selected_sum_iq2_xxs_f16")]] kernel mul_mm_selected_sum_f16_rhs kernel_mul_mm_selected_sum; + #ifdef DS4_METAL_HAS_TENSOR // Attention-output low-rank projection retained for Metal4 prefill. It uses // the same direct-RHS idea as dense matmul: dequantize the Q8_0 low projection diff --git a/tools/validate_metal.sh b/tools/validate_metal.sh new file mode 100755 index 000000000..bd82d98c0 --- /dev/null +++ b/tools/validate_metal.sh @@ -0,0 +1,87 @@ +#!/bin/bash +# Validate ds4 Metal shaders offline using xcrun metal. +# Requires: Metal Toolchain (xcodebuild -downloadComponent MetalToolchain) +# Usage: tools/validate_metal.sh + +set -euo pipefail + +if ! xcrun --sdk macosx metal --version &>/dev/null; then + echo "Error: Metal compiler not found. Install with:" + echo " xcodebuild -downloadComponent MetalToolchain" + exit 1 +fi + +cd "$(dirname "$0")/.." +OUT=/tmp/ds4_shader_check.air + +# Base header matching ds4_gpu_source in ds4_metal.m +cat > /tmp/ds4_metal_base.h << 'EOF' +#include +using namespace metal; + +#define MAX(x, y) ((x) > (y) ? (x) : (y)) +#define MIN(x, y) ((x) < (y) ? (x) : (y)) +#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; } +#define QK8_0 32 +#define N_SIMDWIDTH 32 +#define N_R0_Q8_0 2 +#define N_SG_Q8_0 4 +#define FC_MUL_MV 600 +#define FC_MUL_MM 700 +#define FC_BIN 1300 +#define FOR_UNROLL(x) _Pragma("clang loop unroll(full)") for (x) + +enum ds4_sort_order { + DS4_SORT_ORDER_ASC, + DS4_SORT_ORDER_DESC, +}; + +struct block_q8_0 { + half d; + int8_t qs[QK8_0]; +}; +EOF + +echo "// === base header ===" > /tmp/ds4_full.metal +cat /tmp/ds4_metal_base.h >> /tmp/ds4_full.metal + +for f in \ + metal/flash_attn.metal \ + metal/dense.metal \ + metal/moe.metal \ + metal/dsv4_hc.metal \ + metal/unary.metal \ + metal/dsv4_kv.metal \ + metal/dsv4_rope.metal \ + metal/dsv4_misc.metal \ + metal/argsort.metal \ + metal/cpy.metal \ + metal/concat.metal \ + metal/get_rows.metal \ + metal/sum_rows.metal \ + metal/softmax.metal \ + metal/repeat.metal \ + metal/glu.metal \ + metal/norm.metal \ + metal/bin.metal \ + metal/set_rows.metal; do + echo "" >> /tmp/ds4_full.metal + echo "// === $f ===" >> /tmp/ds4_full.metal + cat "$f" >> /tmp/ds4_full.metal +done + +LINES=$(wc -l < /tmp/ds4_full.metal) +echo "Compiling $LINES lines of Metal shader source..." + +WARNINGS=0 +if xcrun -sdk macosx metal -std=metal3.1 -c /tmp/ds4_full.metal -o "$OUT" 2>&1 | tee /tmp/ds4_metal_warnings.txt; then + WARNINGS=$(grep -c "warning:" /tmp/ds4_metal_warnings.txt || true) + echo "" + echo "SUCCESS: All Metal shaders compiled ($WARNINGS warnings)" + ls -la "$OUT" + rm -f /tmp/ds4_metal_base.h /tmp/ds4_full.metal /tmp/ds4_metal_warnings.txt +else + echo "" + echo "FAILED: See errors above" + exit 1 +fi