diff --git a/examples/93_hopper_dual_gemm/93_hopper_dual_gemm.cu b/examples/93_hopper_dual_gemm/93_hopper_dual_gemm.cu new file mode 100644 index 0000000000..77fec0a568 --- /dev/null +++ b/examples/93_hopper_dual_gemm/93_hopper_dual_gemm.cu @@ -0,0 +1,603 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper Dual-GEMM example using CUTLASS 3.0 APIs for NVIDIA Hopper architecture + + This example is based on example 48_hopper_warp_specialized_gemm.cu but for Dual-GEMM with the exception of using TmaWarpSpecializedCooperative + instead of TmaWarpSpecializedAdapter from original + +``` +D0 = epilogue0(X @ B0, C0) +D1 = epilogue1(X @ B1, C1) +D2 = element_wise(D0, D1) +``` + Usage: + + $ ./examples/93_hopper_dual_gemm/93_hopper_dual_gemm --m=2048 --n=2048 --k=2048 --rasterization=N --swizzle=2 +*/ + +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "helper.h" + +#include "collective/dispatch_policy_extra.hpp" +#include "collective/builder.hpp" +#include "collective/builder_epilogue.hpp" +#include "kernel/sm90_gemm_tma_warpspecialized_cooperative_dual.hpp" +#include "device/gemm_universal_adapter.h" +#include "thread/left_silu_and_mul.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = float; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = float; // Element type for B matrix operand +using LayoutB0 = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +using LayoutB1 = cutlass::layout::ColumnMajor; // Layout type for B matrix operand + +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = float; // Element type for D matrix operand +using ElementD = float; // Element type for C matrix operand +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operand +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operand +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_128,_32>; // Threadblock-level tile size +using ClusterShape = Shape<_4,_2,_1>; // Shape of the threadblocks in a cluster +using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size +using KernelSchedule = cutlass::gemm::DualKernelTmaWarpSpecializedCooperative; +using EpilogueSchedule = cutlass::epilogue::DualTmaWarpSpecializedCooperative; +using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + +using OpLeft = cutlass::epilogue::fusion::DualLinearCombination; +using OpRight = cutlass::epilogue::fusion::DualLinearCombination; +using DualPairOp = cutlass::epilogue::fusion::DualOpPair; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::DualCollectiveBuilder< + ArchTag, OperatorClass, + TileShape, ClusterShape, + EpilogueTileType, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule, + DualPairOp + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::DualCollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB0, LayoutB1, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCount<3>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Reference device GEMM implementation type +using DeviceGemmReferenceB0 = cutlass::reference::device::Gemm< + ElementA, + LayoutA, + ElementB, + LayoutB0, + ElementC, + LayoutC, + ElementAccumulator, + ElementAccumulator>; + +using DeviceGemmReferenceB1 = cutlass::reference::device::Gemm< + ElementA, + LayoutA, + ElementB, + LayoutB1, + ElementC, + LayoutC, + ElementAccumulator, + ElementAccumulator>; + +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB0 = typename Gemm::GemmKernel::StrideB0; +using StrideB1 = typename Gemm::GemmKernel::StrideB1; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +// +// Data members +// + +/// Initialization +StrideA stride_A; +StrideB0 stride_B0; +StrideB1 stride_B1; +StrideC stride_C; +StrideD stride_D; +uint64_t seed; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B0; +cutlass::DeviceAllocation block_B1; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; +cutlass::DeviceAllocation block_ref_D0; +cutlass::DeviceAllocation block_ref_D1; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions; + +// Command line options parsing +struct Options { + + bool help; + + float alpha, beta; + int iterations; + int m, n, k; + RasterOrderOptions raster; + int swizzle; + + Options(): + help(false), + m(5120), n(4096), k(4096), + alpha(1.f), beta(0.f), + iterations(1000), + raster(RasterOrderOptions::Heuristic), + swizzle(1) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + + char raster_char; + cmd.get_cmd_line_argument("raster", raster_char); + + if (raster_char == 'N' || raster_char == 'n') { + raster = RasterOrderOptions::AlongN; + } + else if (raster_char == 'M' || raster_char == 'm') { + raster = RasterOrderOptions::AlongM; + } + else if (raster_char == 'H' || raster_char == 'h') { + raster = RasterOrderOptions::Heuristic; + } + + cmd.get_cmd_line_argument("swizzle", swizzle, 1); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "93_hopper_dual_gemm\n\n" + << " Hopper FP32 Dual-GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --raster= CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n" + << " --swizzle= CTA Rasterization swizzle\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "93_hopper_dual_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add x2 for dual gemm + uint64_t flop = uint64_t(4) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = Element(2); + scope_min = Element(0); + } else if (bits_input <= 8) { + scope_max = Element(2); + scope_min = Element(-2); + } else { + scope_max = Element(8); + scope_min = Element(-8); + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1}); + stride_B0 = cutlass::make_cute_packed_stride(StrideB0{}, {options.k, options.n, 1}); + stride_B1 = cutlass::make_cute_packed_stride(StrideB1{}, {options.k, options.n, 1}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1}); + + block_A.reset(options.m * options.k); + block_B0.reset(options.k * options.n); + block_B1.reset(options.k * options.n); + block_C.reset(options.m * options.n); + block_D.reset(options.m * options.n); + block_ref_D.reset(options.m * options.n); + block_ref_D0.reset(options.m * options.n); + block_ref_D1.reset(options.m * options.n); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B0, seed + 2022); + initialize_block(block_B1, seed + 3022); + initialize_block(block_C, seed + 2021); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + int device_id = 0; + cutlass::KernelHardwareInfo kernel_hw_info = cutlass::KernelHardwareInfo::make_kernel_hardware_info(device_id); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k}, + {block_A.get(), stride_A, block_B0.get(), stride_B0, block_B1.get(), stride_B1}, + {{{options.alpha, options.beta}, {options.alpha, options.beta}}, block_C.get(), stride_C, block_D.get(), stride_D}, + kernel_hw_info + }; + + arguments.scheduler.raster_order = options.raster; + // The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8) + arguments.scheduler.max_swizzle_size = options.swizzle; + + return arguments; +} + +bool verify(const Options &options) { + cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({options.m, options.k})); + cutlass::TensorRef ref_B0(block_B0.get(), Gemm::LayoutB0::packed({options.k, options.n})); + cutlass::TensorRef ref_B1(block_B1.get(), Gemm::LayoutB1::packed({options.k, options.n})); + cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({options.m, options.n})); + cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({options.m, options.n})); + cutlass::TensorRef ref_D0(block_ref_D0.get(), Gemm::LayoutD::packed({options.m, options.n})); + cutlass::TensorRef ref_D1(block_ref_D1.get(), Gemm::LayoutD::packed({options.m, options.n})); + + // + // Compute reference output + // + + DeviceGemmReferenceB0 gemm_ref_b0; + gemm_ref_b0( + {options.m, options.n, options.k}, + ElementAccumulator(options.alpha), + ref_A, + ref_B0, + ElementAccumulator(options.beta), + ref_C, + ref_D0); + + DeviceGemmReferenceB1 gemm_ref_b1; + gemm_ref_b1( + {options.m, options.n, options.k}, + ElementAccumulator(options.alpha), + ref_A, + ref_B1, + ElementAccumulator(options.beta), + ref_C, + ref_D1); + + // D = SiLU(D0) * D1 on host + size_t n_elem = size_t(options.m) * size_t(options.n); + std::vector h_D0(n_elem), h_D1(n_elem), h_D_ref(n_elem), h_D_dev(n_elem); + CUDA_CHECK(cudaMemcpy(h_D0.data(), block_ref_D0.get(), n_elem * sizeof(ElementD), cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy(h_D1.data(), block_ref_D1.get(), n_elem * sizeof(ElementD), cudaMemcpyDeviceToHost)); + + using LeftOp = cutlass::epilogue::thread::LeftSiLUAndMul; + LeftOp leftop(typename LeftOp::Params{}); + for (size_t i = 0; i < n_elem; ++i) { + ElementAccumulator lhs = ElementAccumulator(h_D0[i]); + ElementAccumulator rhs = ElementAccumulator(h_D1[i]); + h_D_ref[i] = leftop(lhs, rhs); + } + CUDA_CHECK(cudaMemcpy(block_ref_D.get(), h_D_ref.data(), n_elem * sizeof(ElementD), cudaMemcpyHostToDevice)); + + // Copy device kernel output to host + CUDA_CHECK(cudaMemcpy(h_D_dev.data(), block_D.get(), n_elem * sizeof(ElementD), cudaMemcpyDeviceToHost)); + + // Tolerant compare (TF32 math on device vs FP32 host) + double max_abs_err = 0.0, max_rel_err = 0.0; + size_t mismatches = 0; + const double rel_eps = 2e-2; + const double abs_eps = 3e-3; + + for (size_t i = 0; i < n_elem; ++i) { + double ref = static_cast(h_D_ref[i]); + double got = static_cast(h_D_dev[i]); + double abs_err = std::abs(got - ref); + double denom = std::max(1e-6, std::abs(ref)); + double rel_err = abs_err / denom; + + max_abs_err = std::max(max_abs_err, abs_err); + max_rel_err = std::max(max_rel_err, rel_err); + + if (rel_err > rel_eps && abs_err > abs_eps) { + ++mismatches; + // Optionally print a few first mismatches + if (mismatches <= 5) { + size_t r = i % options.m; + size_t c = i / options.m; + std::cout << "mismatch at (" << r << "," << c << "): ref=" << ref << " got=" << got + << " abs=" << abs_err << " rel=" << rel_err << std::endl; + } + } + } + + // Wait for kernel to finish + CUDA_CHECK(cudaDeviceSynchronize()); + + bool passed = (mismatches == 0); + std::cout << " Max abs err: " << max_abs_err + << " Max rel err: " << max_rel_err + << " Mismatches: " << mismatches << std::endl; + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + cutlass::Status init_status = gemm.initialize(arguments, workspace.get()); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::string raster = "Heuristic"; + + if (options.raster == RasterOrderOptions::AlongN) { + raster = "Along N"; + } + else if (options.raster == RasterOrderOptions::AlongM) { + raster = "Along M"; + } + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA 12 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major != 9 || props.minor != 0) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + run(options); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/93_hopper_dual_gemm/CMakeLists.txt b/examples/93_hopper_dual_gemm/CMakeLists.txt new file mode 100644 index 0000000000..03cd12684d --- /dev/null +++ b/examples/93_hopper_dual_gemm/CMakeLists.txt @@ -0,0 +1,32 @@ +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_example_add_executable( + 93_hopper_dual_gemm + 93_hopper_dual_gemm.cu +) diff --git a/examples/93_hopper_dual_gemm/collective/builder.hpp b/examples/93_hopper_dual_gemm/collective/builder.hpp new file mode 100644 index 0000000000..c7e0b9978b --- /dev/null +++ b/examples/93_hopper_dual_gemm/collective/builder.hpp @@ -0,0 +1,202 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/gemm/collective/builders/sm90_common.inl" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/pipeline/sm90_pipeline.hpp" +#include "cutlass/gemm/collective/collective_mma_decl.hpp" +#include "cutlass/gemm/collective/collective_builder_decl.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/tensor.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "dispatch_policy_extra.hpp" +#include "sm90_mma_tma_gmma_ss_warpspecialized_dual.hpp" + +// SM90 Collective Builders should be used only starting CUDA 12.0 +#if (__CUDACC_VER_MAJOR__ >= 12) +#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +template < + class ArchTag, + class OpClass, + class ElementA, + class GmemLayoutA, + int AlignmentA, + class ElementB, + class GmemLayoutB0, + class GmemLayoutB1, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType, + class Enable = void +> +struct DualCollectiveBuilder { + static_assert(sizeof(ElementA) == 0, + "DualCollectiveBuilder: unsupported configuration."); +}; + +// GMMA_TMA_WS_SS +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutB0Tag, + class GmemLayoutB1Tag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct DualCollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutB0Tag, + GmemLayoutB1Tag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + cute::is_same_v> +> { + static_assert(is_static::value); + static_assert(is_static::value); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + static_assert(detail::is_aligned(), + "Should meet TMA alignment requirement\n"); + + static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v); + static constexpr bool IsFP8Input = detail::is_input_fp8(); + + // For fp32 types, map to tf32 MMA value type + using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; + using ElementBMma = cute::conditional_t, tfloat32_t, ElementB>; + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB0 = detail::gmma_ss_tag_to_major_B(); + static constexpr cute::GMMA::Major GmmaMajorB1 = detail::gmma_ss_tag_to_major_B(); + + static constexpr bool IsCooperative = cute::is_any_of_v; + using AtomLayoutMNK = cute::conditional_t>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< + ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB0>(), AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB0 = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + using GmemTiledCopyB1 = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector< + GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB0 = decltype(detail::ss_smem_selector< + GmmaMajorB0, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB1 = decltype(detail::ss_smem_selector< + GmmaMajorB1, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0; + static constexpr size_t SchedulerPipelineStorage = cute::is_pointer_v> ? + sizeof(cutlass::PipelineDetail::PipelineAsyncSharedStorage<8>) : 0; + static constexpr int KernelSmemCarveout = static_cast(TensorMapStorage + SchedulerPipelineStorage); + static constexpr int Sm90ReducedSmemCapacityBytes = detail::sm90_smem_capacity_bytes - KernelSmemCarveout; + + static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); + /* For FP8 use a separate mainloop compared to other datatypes */ + using DispatchPolicy = cute::conditional_t, + MainloopSm90ArrayTmaGmmaWarpSpecialized + >, + cute::conditional_t, + DualMainloopSm90TmaGmmaWarpSpecialized + > + >; + + using SmemCopyAtomA = void; + using SmemCopyAtomB0 = void; + using SmemCopyAtomB1 = void; + + using CollectiveOp = DualCollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + TagToStrideA_t, + ElementB, + TagToStrideB_t, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB0, + SmemLayoutAtomB0, + SmemCopyAtomB0, + cute::identity, + GmemTiledCopyB1, + SmemLayoutAtomB1, + SmemCopyAtomB1, + cute::identity + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/93_hopper_dual_gemm/collective/builder_epilogue.hpp b/examples/93_hopper_dual_gemm/collective/builder_epilogue.hpp new file mode 100644 index 0000000000..b50c8f3f84 --- /dev/null +++ b/examples/93_hopper_dual_gemm/collective/builder_epilogue.hpp @@ -0,0 +1,525 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/atom/mma_traits_sm90.hpp" +#include "cute/atom/mma_traits_sm90_gmma.hpp" +#include "cute/atom/copy_traits_sm90.hpp" + +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/detail/layout.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/builders/sm90_common.inl" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/collective/builders/sm90_common.inl" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" + +#include "dispatch_policy_extra.hpp" +#include "sm90_epilogue_tma_warpspecialized_dual.hpp" +#include "../fusion/callbacks.hpp" +#include "../fusion/sm90_callbacks_tma_warpspecialized_dual.hpp" + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective { + +/////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +static constexpr bool dual_sm90_is_cooperative_v = + cute::is_base_of_v || + sm90_is_ptr_array_tma_cooperative_v; + +// Returns the parameterized dispatch policy for the TMA epilogue +template +constexpr auto +dual_sm90_get_tma_dispatch_policy() { + using namespace cute; + + constexpr int EpiTiles = size(shape_div(take<0,2>(TileShapeMNK{}), EpilogueTileMN{})); + constexpr int FragmentSize = size(EpilogueTileMN{}) / (detail::dual_sm90_is_cooperative_v ? 256 : 128); + // 8b residuals load fast and consume little smem, so the perf cost of waiting on stores to finish outweighs the cost of extra allocation + constexpr bool ReuseSmem = (sizeof_bits_v == sizeof_bits_v) && (sizeof_bits_v > 8); + // TMA store delay performs worse with residual loads and compilicates tensormap updates for Ptr-Array GEMMs + constexpr bool DelayTmaStore = is_void_v && !detail::sm90_is_ptr_array_tma_v; + constexpr int StagesD = cute::min(EpiTiles, 2); + constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 4), StagesD+1) + : cute::min(EpiTiles, 4); + + if constexpr (detail::sm90_is_ptr_array_tma_v) { + return Sm90PtrArrayTmaWarpSpecialized{}; + } + else { + return DualSm90TmaWarpSpecialized{}; + } +} + +// Returns the smem layout atom to be used for C or D matrix +template +constexpr auto +dual_sm90_get_epilogue_smem_swizzle_layout_atom() { + using namespace cute; + + // ColMajor C/D (M-major) + if constexpr (cutlass::gemm::detail::is_major<0>(GmemStrideType{})) { + return cutlass::gemm::collective::detail::ss_smem_selector< + cute::GMMA::Major::MN, Element, decltype(get<0>(EpilogueTile_MN{})), decltype(get<1>(EpilogueTile_MN{})) + >(); + } + // RowMajor C/D (N-major) + else if constexpr (cutlass::gemm::detail::is_major<1>(GmemStrideType{})) { + return cutlass::gemm::collective::detail::ss_smem_selector< + cute::GMMA::Major::K , Element, decltype(get<0>(EpilogueTile_MN{})), decltype(get<1>(EpilogueTile_MN{})) + >(); + } + else { + static_assert(cutlass::detail::dependent_false, "Unsupported gmem layout."); + } +} + +template +static constexpr bool dual_sm90_is_warp_specialized_v = + (!sm90_is_ptr_array_tma_cooperative_v && sm90_is_ptr_array_tma_v) || + cute::is_base_of_v; + +// Attempts to compute a reasonable epilogue tile based on block tile shape or allows the user to provide one. +template +constexpr auto +dual_sm90_compute_tile_shape_or_override() { + if constexpr (cute::is_same_v) { + auto epi_tile = [&] () { + if constexpr (detail::dual_sm90_is_cooperative_v) { + + auto tile_m = cute::min(_128{}, size<0>(TileShape_MNK{})); + auto tile_n = cute::gcd(cute::min(_32{}, size<1>(TileShape_MNK{})), size<1>(TileShape_MNK{})); + return make_shape(tile_m, tile_n); + } + else if constexpr (detail::dual_sm90_is_warp_specialized_v) { + constexpr int N_perf = (sizeof_bits_v == 8) && (size<1>(TileShape_MNK{}) % 64 == 0) ? 64 : 32; + auto tile_m = cute::min(_64{}, size<0>(TileShape_MNK{})); + auto tile_n = cute::gcd(cute::min(Int{}, size<1>(TileShape_MNK{})), size<1>(TileShape_MNK{})); + return make_shape(tile_m, tile_n); + } + else { + static_assert(cutlass::detail::dependent_false, "Unsupported schedule."); + } + }(); + + return cute::transform(epi_tile, seq<0,1>{}, + [] (auto epi_tiler, auto I) { + auto cta_tiler = make_layout(get(TileShape_MNK{})); + // This is a multimodal CTA tiler, transform before returning + if constexpr (depth(cta_tiler) > 0) { + // This is an implicit multimodal tiler, match profile and return + if constexpr (tuple_size_v == 1) { + return make_tile(epi_tiler); + } + // This is an explicit multimodal tiler, compose out epi tiler + else { + return composition(cta_tiler, epi_tiler); + } + } + // This is a flat CTA tiler, no need for transformation + else { + return epi_tiler; + } + }); + } + else if constexpr (cute::is_tuple::value) { + EpilogueTileType epi_tile; + constexpr int M = size<0>(shape(epi_tile)); + constexpr int N = size<1>(shape(epi_tile)); + + static_assert(!is_layout::value, "EpilogueTile must be a cute::Tile or cute::Shape"); + static_assert(M == 64 && detail::dual_sm90_is_warp_specialized_v || + M == 128 && detail::dual_sm90_is_cooperative_v, "Unsupported tile shape"); + static_assert(N % 16 == 0, "Unsupported tile shape"); + + return epi_tile; + } + else { + static_assert(cutlass::detail::dependent_false, "Invalid type for EpilogueTileType."); + } +} +// callbacks builder with operation tag +template< + class DispatchPolicy, + class FusionOp, + class TileShape_MNK, + class EpilogueTile_MN, + class ElementAccumulator, + class AccLoadOp = cute::DefaultCopy, + class = void +> +struct DualCallbacksBuilder { + using Callbacks = fusion::DualFusionCallbacks; +}; + +// aux fusion callbacks builder for sm90 tma epilogue +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class FusionOp, + class TileShape_MNK, + class EpilogueTile_MN, + class AccLoadOp, + class ElementAccumulator +> +struct DualCallbacksBuilder< + DualSm90TmaWarpSpecialized, + FusionOp, + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + AccLoadOp, + cute::enable_if_t<(FusionOp::IsAuxOutSupported ^ FusionOp::IsAuxInSupported) // only one aux tensor + && not cute::is_subbyte_v> // aux subbyte tensor doesn't use smem +> { + using GmemStrideTypeAux = gemm::TagToStrideC_t; + using SmemLayoutAtomAux = decltype(detail::dual_sm90_get_epilogue_smem_swizzle_layout_atom< + GmemStrideTypeAux, typename FusionOp::ElementAux, EpilogueTile_MN>()); + using CopyOpR2S = decltype(detail::sm90_get_smem_store_op_for_accumulator< + GmemStrideTypeAux, typename FusionOp::ElementAux, EpilogueTile_MN>()); + using CopyOpS2R = decltype(detail::sm90_get_smem_load_op_for_source< + GmemStrideTypeAux, typename FusionOp::ElementAux, EpilogueTile_MN>()); + using SmemCopyOpAux = cute::conditional_t; + + using Callbacks = fusion::DualFusionCallbacks< + DualSm90TmaWarpSpecialized, + FusionOp, TileShape_MNK, EpilogueTile_MN, + SmemLayoutAtomAux, SmemCopyOpAux + >; +}; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class FusionOp, + class TileShape_MNK, + class EpilogueTile_MN, + class AccLoadOp, + class ElementAccumulator +> +struct DualCallbacksBuilder< + DualSm90TmaWarpSpecialized, + FusionOp, + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + AccLoadOp, + cute::enable_if_t<(FusionOp::IsAuxOutSupported ^ FusionOp::IsAuxInSupported) // only one aux tensor + && sizeof_bits_v == 1> +> { + using Callbacks = fusion::DualFusionCallbacks< + DualSm90TmaWarpSpecialized, + FusionOp, TileShape_MNK, EpilogueTile_MN, + Layout<_1,_0>, DefaultCopy // aux bit tensor doesn't use smem + >; +}; + +// Helper for building TMA warp-specialized collective epilogues, specialized by +// the fusion operation performed and the dispatch policy to use. +template < + class TileShape_MNK, + class EpilogueTile_MN, + class ElementAccumulator, + class ElementCompute, + class ElementC_, + class GmemLayoutTagC_, + int AlignmentC, + class ElementD_, + class GmemLayoutTagD, + int AlignmentD, + class FusionOpOrCallbacks, + class DispatchPolicy +> +struct DualSm90TmaBuilderImpl { + // C/D should meet TMA alignment requirement if not void + static_assert(detail::is_aligned(), + "C/D Should meet TMA alignment requirement\n"); + // Passing void D disables destination store + smem allocation + using ElementD = cute::conditional_t, + fusion::get_element_aux_t, ElementD_>; + + // Passing void C disables source load + smem allocation + using ElementC = cute::conditional_t,ElementD,ElementC_>; // prevents void ref breakages + using GmemLayoutTagC = cute::conditional_t,GmemLayoutTagD,GmemLayoutTagC_>; + + using GmemStrideTypeC = cutlass::detail::TagToStrideC_t; + using GmemStrideTypeD = cutlass::detail::TagToStrideC_t; + + using UnderlyingGmemStrideTypeC = cute::remove_pointer_t; + using UnderlyingGmemStrideTypeD = cute::remove_pointer_t; + + using CopyOpS2G = cute::conditional_t, + SM90_TMA_STORE_IM2COL, + SM90_TMA_STORE + >; + using CopyOpG2S = cute::conditional_t, + SM90_TMA_LOAD_IM2COL, + SM90_TMA_LOAD + >; + + // Get the smallest tiled copy we can use to retile the accumulators + // using CopyAtomC = Copy_Atom; + using CopyAtomC = cute::conditional_t< + size<1>(EpilogueTile_MN{}) % 16 == 0, + Copy_Atom, + cute::conditional_t< + size<1>(EpilogueTile_MN{}) % 8 == 0, + Copy_Atom, + void + > + >; + static_assert(!cute::is_same_v, "CopyAtomC can't be void, divisiblity check for EpilogueTile_MN failed"); + // Get register to register tiled copy that happen before shared memory store. + // Apply void as no register transform op needed currently. + using CopyOpR2R = void; + + // TMA builder allows for passing callbacks directly, which is either a fusion::FusionCallbacks + // instance or a direct visitor implementation, e.g. fusion::Sm90LinearCombination + using DualFusionCallbacks = + typename DualCallbacksBuilder< + DispatchPolicy, + FusionOpOrCallbacks, + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator + >::Callbacks; + + using CollectiveOp = cutlass::epilogue::collective::DualCollectiveEpilogue< + DispatchPolicy, + TileShape_MNK, + EpilogueTile_MN, + ElementC_, // Need to pass void through to expose via GemmUniversal + GmemStrideTypeC, + ElementD_, + GmemStrideTypeD, + DualFusionCallbacks, + CopyOpG2S, + decltype(detail::dual_sm90_get_epilogue_smem_swizzle_layout_atom()), + decltype(detail::sm90_get_smem_load_op_for_source()), + CopyOpS2G, + decltype(detail::dual_sm90_get_epilogue_smem_swizzle_layout_atom()), + decltype(detail::sm90_get_smem_store_op_for_accumulator()), + CopyAtomC, + CopyOpR2R + >; +}; + +/////////////////////////////////////////////////////////////////////////////// +// Descriptor classes for defining EVT nodes +// Some of the epilogue visitor nodes require non-intuitive template arguments +// such as CopyOpS2R for AuxLoad node. Traditionaly, these are resolved by the +// builder classes. Here we provide a set of descriptor classes that resolve +// these template arguments from more intuitive types such as Stride, Layout + +// Get TileShape, EpilogueTile, Dispatch Policy, StagesC, and STagesD +template< + typename TileShape_MNK, + typename EpilogueTileType, + typename ElementC, + typename ElementD, + typename Schedule +> +struct DualEpilogueDescriptor { + using TileShape = TileShape_MNK; + using EpilogueTile = + decltype( + detail::dual_sm90_compute_tile_shape_or_override< + ElementD, EpilogueTileType, Schedule, TileShape_MNK + >() + ); + using DispatchPolicy = + decltype( + detail::dual_sm90_get_tma_dispatch_policy< + TileShape_MNK, EpilogueTile, + ElementC, ElementD, Schedule + >() + ); + constexpr static int StagesC = DispatchPolicy::StagesC; + constexpr static int StagesD = DispatchPolicy::StagesD; +}; + +// Get Stride, SmemLayout, and CopyOpS2R for AuxLoad node +template< + typename EpilogueDescriptor, + typename StrideOrLayoutTag, + typename ElementAux +> +struct DualAuxLoadDescriptor { + constexpr static int Stages = EpilogueDescriptor::StagesC; + using EpilogueTile = typename EpilogueDescriptor::EpilogueTile; + using Element = ElementAux; + using Stride = cutlass::detail::TagToStrideC_t; + using SmemLayoutAtom = + decltype( + detail::dual_sm90_get_epilogue_smem_swizzle_layout_atom< + Stride, ElementAux, typename EpilogueDescriptor::EpilogueTile + >() + ); + using CopyOpS2R = + decltype(detail::sm90_get_smem_load_op_for_source()); +}; + +// Get Stride, SmemLayout, and CopyOpS2R for AuxStore node +template< + typename EpilogueDescriptor, + typename StrideOrLayoutTag, + typename ElementAux +> +struct DualAuxStoreDescriptor { + constexpr static int Stages = EpilogueDescriptor::StagesD; + using EpilogueTile = typename EpilogueDescriptor::EpilogueTile; + using Element = ElementAux; + using Stride = cutlass::detail::TagToStrideC_t; + using SmemLayoutAtom = + decltype( + detail::dual_sm90_get_epilogue_smem_swizzle_layout_atom< + Stride, ElementAux, typename EpilogueDescriptor::EpilogueTile + >() + ); + using CopyOpR2S = + decltype(detail::sm90_get_smem_store_op_for_accumulator()); +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ArchTag, + class OpClass, + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC, + class GmemLayoutTagC, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class EpilogueScheduleType, + class FusionOpOrCallbacks = cutlass::epilogue::fusion::DualLinearCombination, + class Enable = void +> +struct DualCollectiveBuilder { + static_assert(cutlass::detail::dependent_false, + "Could not build a collective epilogue for given parameters."); +}; + +// Tma warp-specialized builder +template < + class OpClass, + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC, + class GmemLayoutTagC, + int AlignmentC, + class ElementD_, + class GmemLayoutTagD, + int AlignmentD, + class Schedule, + class FusionOperation +> +struct DualCollectiveBuilder< + arch::Sm90, + OpClass, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD_, + GmemLayoutTagD, + AlignmentD, + Schedule, + FusionOperation, + cute::enable_if_t< + cute::is_same_v + >> { +private: + using ElementD = cute::conditional_t, + fusion::dual_get_element_aux_t, ElementD_>; + using EpilogueTile_MN = + decltype(detail::dual_sm90_compute_tile_shape_or_override()); + using DispatchPolicy = + decltype(detail::dual_sm90_get_tma_dispatch_policy()); + +public: + using CollectiveOp = + typename detail::DualSm90TmaBuilderImpl< + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD_, + GmemLayoutTagD, + AlignmentD, + FusionOperation, + DispatchPolicy + >::CollectiveOp; +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::collective diff --git a/examples/93_hopper_dual_gemm/collective/dispatch_policy_extra.hpp b/examples/93_hopper_dual_gemm/collective/dispatch_policy_extra.hpp new file mode 100644 index 0000000000..dbae061e66 --- /dev/null +++ b/examples/93_hopper_dual_gemm/collective/dispatch_policy_extra.hpp @@ -0,0 +1,81 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/arch/arch.h" +#include "cute/layout.hpp" +#include "cute/numeric/integral_constant.hpp" // cute::false_type + +namespace cutlass::gemm { +using namespace cute; + +struct DualKernelTmaWarpSpecializedCooperative { + static constexpr int SchedulerPipelineStageCount = 0; +}; + +// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule +template< + int Stages_, + class ClusterShape_ = Shape<_1,_1,_1>, + class KernelSchedule = DualKernelTmaWarpSpecializedCooperative +> +struct DualMainloopSm90TmaGmmaWarpSpecialized { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm90; + using Schedule = KernelSchedule; +}; + +} // namespace cutlass::gemm + + +namespace cutlass::epilogue { + +struct DualTmaWarpSpecialized {}; +struct DualTmaWarpSpecializedCooperative : DualTmaWarpSpecialized {}; + +template< + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_ +> +struct DualSm90TmaWarpSpecialized { + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static int FragmentSize = FragmentSize_; + constexpr static bool ReuseSmemC = ReuseSmemC_; + constexpr static bool DelayTmaStore = DelayTmaStore_; +}; + +} // namespace cutlass::epilogue diff --git a/examples/93_hopper_dual_gemm/collective/sm90_epilogue_tma_warpspecialized_dual.hpp b/examples/93_hopper_dual_gemm/collective/sm90_epilogue_tma_warpspecialized_dual.hpp new file mode 100644 index 0000000000..0bcbcd42c0 --- /dev/null +++ b/examples/93_hopper_dual_gemm/collective/sm90_epilogue_tma_warpspecialized_dual.hpp @@ -0,0 +1,1015 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/thread/scale_type.h" + +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/layout.hpp" +#include "cutlass/detail/helper_macros.hpp" +#include "cutlass/trace.h" + +#include "cute/tensor.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +#include "../thread/left_silu_and_mul.h" + +#include "dispatch_policy_extra.hpp" +#include "../fusion/callbacks.hpp" +#include "../fusion/sm90_callbacks_tma_warpspecialized_dual.hpp" +#include "../fusion/operations.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class DispatchPolicy, + class... Args +> +class DualCollectiveEpilogue { + static_assert(cutlass::detail::dependent_false, "Could not find an epilogue specialization."); +}; + +template < + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_, + class CtaTileMNK_, // (CTA_M,CTA_N,CTA_K) + class EpilogueTile_, // (EPI_TILE_M,EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class DualFusionCallbacks_, + class CopyOpG2S_, + class SmemLayoutAtomC_, + class CopyOpS2R_, + class CopyOpS2G_, + class SmemLayoutAtomD_, + class CopyOpR2S_, + class CopyAtomC_, + class CopyOpR2R_ +> +class DualCollectiveEpilogue< + DualSm90TmaWarpSpecialized, + CtaTileMNK_, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + DualFusionCallbacks_, + CopyOpG2S_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpS2G_, + SmemLayoutAtomD_, + CopyOpR2S_, + CopyAtomC_, + CopyOpR2R_ +> { +public: + // + // Type Aliases + // + using DispatchPolicy = DualSm90TmaWarpSpecialized; + using CtaTileMNK = CtaTileMNK_; + using EpilogueTile = EpilogueTile_; + using DualFusionCallbacks = DualFusionCallbacks_; + using ElementC = ElementC_; + using StrideC = StrideC_; + using ElementD = ElementD_; + using StrideD = StrideD_; + using CopyOpG2S = CopyOpG2S_; + using SmemLayoutAtomC = SmemLayoutAtomC_; + using CopyOpS2R = CopyOpS2R_; + using CopyOpS2G = CopyOpS2G_; + using SmemLayoutAtomD = SmemLayoutAtomD_; + using CopyOpR2S = CopyOpR2S_; + using CopyAtomC = CopyAtomC_; + using CopyOpR2R = CopyOpR2R_; + + using ThreadEpilogueOp = typename epilogue::fusion::DualFusionCallbacksTraits::Operation; + using GmemTiledCopyC = CopyOpG2S; + using GmemTiledCopyD = CopyOpS2G; + + static_assert(!is_layout::value && is_tuple::value, "EpilogueTile must be a cute::Tile or cute::Shape"); + static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); + static_assert(size<0>(CtaTileMNK{}) % size<0>(shape(EpilogueTile{})) == 0, "EPI_TILE_M must divide CTA_M"); + static_assert(size<1>(CtaTileMNK{}) % size<1>(shape(EpilogueTile{})) == 0, "EPI_TILE_N must divide CTA_N"); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); + +private: + constexpr static bool is_source_supported = not cute::is_void_v; + constexpr static bool is_destination_supported = not cute::is_void_v; + using NonVoidElementD = cute::conditional_t, ElementD>; + static_assert(not cute::is_void_v, "SmemElementD is void"); + using NonVoidElementC = cute::conditional_t; // prevents void ref breakages + + using TmaElementD = cute::conditional_t>, uint64_t, NonVoidElementD>; + using TmaElementC = cute::conditional_t>, uint64_t, NonVoidElementC>; + + using SmemElementC = typename cutlass::detail::get_unpacked_element_type::type; + using SmemElementD = typename cutlass::detail::get_unpacked_element_type::type; + + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static bool ReuseSmemC = ReuseSmemC_ and is_destination_supported; + constexpr static bool DelayTmaStore = DelayTmaStore_; + + constexpr static bool is_m_major_C = detail::is_m_major(); + constexpr static bool is_m_major_D = detail::is_m_major(); + + constexpr static bool is_im2col_C = cute::is_same_v; + constexpr static bool is_im2col_D = cute::is_same_v; + + // Check if register transformation is needed before copying register to shared memory. + constexpr static bool IsUseR2R = !cute::is_void_v; + + using SmemLayoutC = decltype(tile_to_shape( + SmemLayoutAtomC{}, + make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); + using SmemLayoutD = decltype(tile_to_shape( + SmemLayoutAtomD{}, + make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); + + constexpr static bool support_smem_reuse = is_source_supported && is_destination_supported && StagesD <= StagesC + && cosize(take<0,2>(SmemLayoutC{})) == cosize(take<0,2>(SmemLayoutD{})); + static_assert(not (ReuseSmemC && not support_smem_reuse), "Smem reuse requirements not met"); + + constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); + constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{}); + constexpr static size_t MaxSmemAlignment = cute::max(SmemAlignmentC, SmemAlignmentD); + + using SmemArrayTypeC = cute::ArrayEngine>; + using SmemArrayTypeD = cute::ArrayEngine>; + + using EmptyType = cute::tuple<>; + using SmemCStorage = cute::conditional_t; + using SmemDStorage = cute::conditional_t; + + struct CollectiveStorageWithC { + alignas(SmemAlignmentC) ArrayEngine> smem_C; + alignas(SmemAlignmentD) ArrayEngine> smem_D; + }; + + union CollectiveStorageWithoutC { + cute::array smem_C; + alignas(SmemAlignmentD) ArrayEngine> smem_D; + }; + + union CollectiveStorageReuseC { + alignas(MaxSmemAlignment) ArrayEngine> smem_C; + alignas(MaxSmemAlignment) ArrayEngine> smem_D; + }; + +public: + // TMA pipeline for loading C + using LoadPipeline = cutlass::PipelineTransactionAsync; + using LoadPipelineState = cutlass::PipelineState; + constexpr static uint32_t TmaTransactionBytes = + (size(take<0,2>(SmemLayoutC{})) * static_cast(sizeof_bits::value)) / 8; + constexpr static bool RequiresTransactionBytes = true; + + // TMA pipeline for storing D + using StorePipeline = cute::conditional_t, + cutlass::PipelineTmaStore>; + using StorePipelineState = cutlass::PipelineState; + + struct SharedStorage { + struct TensorStorage { + using CollectiveStorage = cute::conditional_t>; + CollectiveStorage collective; + + using DualFusionStorage = typename DualFusionCallbacks::SharedStorage; + DualFusionStorage thread; + } tensors; + + using PipelineStorage = typename LoadPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side epilogue arguments + struct Arguments { + typename DualFusionCallbacks::Arguments thread{}; + ElementC const* ptr_C; + StrideC dC; + ElementD const* ptr_D; + StrideD dD; + }; + + // Device side epilogue params + struct Params { + using TMA_C = decltype(make_tma_copy( + CopyOpG2S{}, + make_tensor(make_gmem_ptr(nullptr), + repeat_like(StrideC{}, int32_t(0)), StrideC{}), + take<0,2>(SmemLayoutC{}), + EpilogueTile{}, + _1{})); + using TMA_D = decltype(make_tma_copy( + CopyOpS2G{}, + make_tensor(make_gmem_ptr(nullptr), + repeat_like(StrideD{}, int32_t(0)), StrideD{}), + take<0,2>(SmemLayoutD{}), + EpilogueTile{}, + _1{})); + + typename DualFusionCallbacks::Params thread{}; + TMA_C tma_load_c; + TMA_D tma_store_d; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + uint32_t transaction_bytes = TmaTransactionBytes; + typename Params::TMA_C tma_load_c{}; + if constexpr (is_source_supported) { + Tensor tensor_c = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M,N,L), args.dC)); + tma_load_c = make_tma_copy_C_sm90( + CopyOpG2S{}, + tensor_c, + take<0,2>(SmemLayoutC{}), + EpilogueTile{}); + } + + typename Params::TMA_D tma_store_d{}; + if constexpr (is_destination_supported) { + Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M,N,L), args.dD)); + tma_store_d = make_tma_copy_C_sm90( + CopyOpS2G{}, + tensor_d, + take<0,2>(SmemLayoutD{}), + EpilogueTile{}); + } + + return { + DualFusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), + tma_load_c, + tma_store_d, + transaction_bytes + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return DualFusionCallbacks::get_workspace_size(problem_shape, args.thread); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return DualFusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream, cuda_adapter); + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + auto shape = cute::make_shape(M,N,L); + + bool implementable = true; + if constexpr (is_destination_supported) { + constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits(); + constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits::value; + if constexpr (cute::is_same_v) { // ignore L stride for implicit gemm + implementable = cutlass::detail::check_alignment(take<0,2>(shape), take<0,2>(StrideD{})); + } + else { + implementable = cutlass::detail::check_alignment(shape, StrideD{}); + } + } + + if constexpr (not cute::is_void_v) { + constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits::value; + if constexpr (cute::is_same_v) { // ignore L stride for implicit gemm + implementable = implementable && cutlass::detail::check_alignment(take<0,2>(shape), take<0,2>(StrideC{})); + } + else { + implementable = implementable && cutlass::detail::check_alignment(shape, StrideC{}); + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + bool fusion_implementable = DualFusionCallbacks::can_implement(problem_shape, args.thread); + + if (!fusion_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for DualFusionCallbacks.\n"); + } + + bool beta_implementable = true; + + if constexpr (cute::is_void_v) { + if constexpr (detail::has_beta::value) { + beta_implementable = args.thread.beta == 0.0; + } + if constexpr (detail::has_beta_ptr::value) { + beta_implementable = beta_implementable && args.thread.beta_ptr == nullptr; + } + } + + if (!beta_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Beta/beta pointer was set, but epilogue is sourceless (void-C).\n"); + } + + return implementable && fusion_implementable && beta_implementable; + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_load_pipe_increment(TileShapeMNK tile_shape_MNK) { + // Compute number of epilogue subtiles + return size<1>(zipped_divide(make_layout(take<0,2>(tile_shape_MNK)), EpilogueTile{})); + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_store_pipe_increment(TileShapeMNK tile_shape_MNK) { + return get_load_pipe_increment(tile_shape_MNK); + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void + prefetch_tma_descriptors(Params const& epilogue_params) { + if constexpr (is_source_supported) { + cute::prefetch_tma_descriptor(epilogue_params.tma_load_c.get_tma_descriptor()); + } + if constexpr (is_destination_supported) { + cute::prefetch_tma_descriptor(epilogue_params.tma_store_d.get_tma_descriptor()); + } + } + + CUTLASS_HOST_DEVICE + DualCollectiveEpilogue(Params const& params_, TensorStorage& shared_tensors) + : params(params_), fusion_callbacks(params_.thread, shared_tensors.thread) {} + + CUTLASS_DEVICE + bool + is_producer_load_needed() const { + return fusion_callbacks.is_producer_load_needed(); + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class TiledMma + > + CUTLASS_DEVICE auto + load( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + TiledMma tiled_mma, + int thread_idx, + TensorStorage& shared_tensors, + int subtile_idx=-1) { + using namespace cute; + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + + // The tma tensor C under im2col mode only has two modes (M, N) which + // should be local tiled with only (m_coord, n_coord). + auto coord_shape = conditional_return( + make_coord(m_coord, n_coord), + make_coord(m_coord, n_coord, l_coord)); + + // Represent the full source tensor, slice to get the tile this CTA is currently responsible for + Tensor mC_mn = params.tma_load_c.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor mC = coalesce(mC_mn, take<0,2>(CtaTileMNK{})); + Tensor gC = local_tile(mC, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N) + + // Apply epilogue subtile, get matching smem tensor + auto ptr_sC = shared_tensors.collective.smem_C.begin(); + Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor sC_epi = make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + + // Prepare the thread(b)lock's (G)mem to (S)mem TMA tiled copy (bGS_) + ThrCopy thrblk_g2s = params.tma_load_c.get_slice(Int<0>{}); + Tensor bGS_gC = thrblk_g2s.partition_S(gC_epi); // (G2S,G2S_M,G2S_N,EPI_M,EPI_N) + Tensor bGS_sC = thrblk_g2s.partition_D(sC_epi); // (G2S,G2S_M,G2S_N,PIPE_C) + + // Get the fusion callbacks for the producer load warp + auto pld_args = cutlass::epilogue::fusion::detail::ProducerLoadArgs( + problem_shape_mnkl, + CtaTileMNK{}, + tile_coord_mnkl, + tiled_mma, + EpilogueTile{}, + thread_idx + ); + auto pld_callbacks = fusion_callbacks.get_producer_load_callbacks(pld_args); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // Predication for TMA load (one thread issues TMA load) + bool issue_tma_load = cute::elect_one_sync(); + + // Pre-loop fusion callback entry point + pld_callbacks.begin(); + + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < size<3>(gC_epi); ++epi_n) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < size<2>(gC_epi); ++epi_m) { + if (subtile_idx != -1 && (epi_n * static_cast(size<2>(gC_epi)) + epi_m) != subtile_idx) { + continue; + } + // Acquire the lock for this stage + constexpr uint16_t mcast_mask = 0; + uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); + load_pipeline.producer_acquire(load_pipe_producer_state); + + // Loop fusion callback entry point + pld_callbacks.step(tma_barrier, epi_m, epi_n, load_pipe_producer_state.count(), issue_tma_load); + + // Execute the TMA load for C if needed + if (issue_tma_load && is_C_load_needed) { + copy(params.tma_load_c.with(*tma_barrier, mcast_mask), + bGS_gC(_,_,_,epi_m,epi_n), bGS_sC(_,_,_,load_pipe_producer_state.index())); + load_pipeline.producer_expect_transaction(load_pipe_producer_state); + } + + // Commit TMA loads for this stage and release the lock + load_pipeline.producer_commit(load_pipe_producer_state); + ++load_pipe_producer_state; + } + } + + // Post-loop fusion callback entry point + pld_callbacks.end(); + + return load_pipe_producer_state; + } + + CUTLASS_DEVICE auto + load_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state) { + bool issue_tma_load = cute::elect_one_sync(); + if (issue_tma_load) { + load_pipeline.producer_tail(load_pipe_producer_state); + } + + return load_pipe_producer_state; + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class AccEngine0, class AccLayout0, + class AccEngine1, class AccLayout1, + class TiledMma + > + CUTLASS_DEVICE auto + store( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + cute::Tensor accumulators0, + cute::Tensor accumulators1, + TiledMma tiled_mma, + int thread_idx, + TensorStorage& shared_tensors, + int subtile_idx=-1) { + using namespace cute; + using ElementAccumulator0 = typename AccEngine0::value_type; + using ElementAccumulator1 = typename AccEngine1::value_type; + using ElementCompute_ = typename epilogue::fusion::DualFusionCallbacksTraits::ElementCompute; + using ElementCompute = cute::conditional_t,ElementAccumulator0,ElementCompute_>; + + using CallbacksOperation = + typename epilogue::fusion::DualFusionCallbacksTraits::Operation; + constexpr bool IsDualOpPair = + epilogue::fusion::is_dual_op_pair_v; + + static_assert(is_rmem::value, "Accumulator0 must be RF resident."); + static_assert(is_rmem::value, "Accumulator1 must be RF resident."); + static_assert(rank(AccLayout0{}) == 3, "Accumulator0 must be MMA-partitioned: (MMA,MMA_M,MMA_N)"); + static_assert(rank(AccLayout1{}) == 3, "Accumulator1 must be MMA-partitioned: (MMA,MMA_M,MMA_N)"); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "TileShapeMNK must be static"); + static_assert(rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3"); + static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + + // The tma tensor D under im2col mode only has two modes (M, N) which + // should be local tiled with only (m_coord, n_coord). + auto coord_shape = conditional_return( + make_coord(m_coord, n_coord), + make_coord(m_coord, n_coord, l_coord)); + + // Represent the full output tensor, slice to get the tile this CTA is responsible for + Tensor mD_mn = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor mD = coalesce(mD_mn, take<0,2>(CtaTileMNK{})); + Tensor gD = local_tile(mD, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N) + + // Apply epilogue subtiling + Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + // Construct the corresponding pipelined smem tensors + auto ptr_sC = shared_tensors.collective.smem_C.begin(); + auto ptr_sD = shared_tensors.collective.smem_D.begin(); + Tensor sC_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + Tensor sD_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sD), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) + + TiledCopy tiled_copy_C_atom = make_tiled_copy_C_atom(CopyAtomC{}, tiled_mma); + + // (t)hread-partition for (r)egister to (r)egister copy (tRR_) + TiledCopy tiled_r2r = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + if constexpr (IsUseR2R) { + return make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + } + else { + return make_tiled_copy_S(Copy_Atom, + ElementCompute>{}, tiled_copy_C_atom); + } + }(); + ThrCopy thread_r2r = tiled_r2r.get_slice(thread_idx); + + // (t)hread-partition for (r)egister to (s)mem copy (tRS_) + TiledCopy tiled_r2s = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + if constexpr (IsUseR2R) { + return make_tiled_copy_D(Copy_Atom{}, tiled_r2r); + } + else { + return make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + } + }(); + ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); + Tensor tRS_rAcc0 = thread_r2s.retile_S(accumulators0); // ((R2S,R2S_V),MMA_M,MMA_N) + Tensor tRS_rAcc1 = thread_r2s.retile_S(accumulators1); // ((R2S,R2S_V),MMA_M,MMA_N) + Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D) + + auto mma_tile_m = size<0>(TileShapeMNK{}) / size<1>(tRS_rAcc0); + auto mma_tile_n = size<1>(TileShapeMNK{}) / size<2>(tRS_rAcc0); + auto epi_tile_m = size<0>(EpilogueTile{}); + auto epi_tile_n = size<1>(EpilogueTile{}); + + // Allocate D registers + Layout tRS_rD_layout = make_layout(take<0,3>(shape(thread_r2s.partition_S(sD_epi)))); + Tensor tRS_rD = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + + // Vectorized fragment view + constexpr int FragmentSize = DispatchPolicy::FragmentSize; + Tensor tRS_rAcc_frg0 = recast>(tRS_rAcc0); + Tensor tRS_rAcc_frg1 = recast>(tRS_rAcc1); + Tensor tRS_rD_frg = recast>(tRS_rD); + CUTE_STATIC_ASSERT(size<0>(tRS_rAcc0) % FragmentSize == 0, "Fragment size does not vectorize properly"); + + // (t)hread-partition for (s)mem to (r)egister copy (tSR_) + TiledCopy tiled_s2r = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); + Tensor tSR_sC = thread_s2r.partition_S(sC_epi); // (S2R,S2R_M,S2R_N,PIPE_C) + Layout tSR_rC_layout = thread_s2r.retile_D(tRS_rD).layout(); // (S2R,S2R_M,S2R_N) + + // Allocate C registers + // If C smem load is a non-vectorized dst(i) = src(i) then we can allocate C registers directly in the compute type + // to eliminate some redundant pack+unpack instruction sequences for sub-word types + constexpr bool IsDirectS2R = cute::is_same_v> + && decltype(max_common_vector(tSR_rC_layout, tSR_sC.layout()))::value <= 1; + using RegisterElementC = cute::conditional_t; + Tensor tRS_rC = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + Tensor tSR_rC = thread_s2r.retile_D(tRS_rC); // (S2R,S2R_M,S2R_N) + + // thread(b)lock-partition for (s)mem to (g)mem copy (bSG_) + ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{}); + Tensor bSG_sD = thrblk_s2g.partition_S(sD_epi); // (S2G,S2G_M,S2G_N,PIPE_D) + Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + + // OOB predication for tile quantization "residue" + // Absolute coordinate tensors (dynamic) + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) + Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) + Tensor tRS_cD_mn = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + if constexpr (IsUseR2R) { + // (t)hread-partition for ConsumerStoreCallbacks. + TiledCopy tiled_cst = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + ThrCopy thread_cst = tiled_cst.get_slice(thread_idx); + return thread_cst.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) + } + else { + return thread_r2s.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) + } + }(); + // Relative coordinate tensors (static) + Tensor cD = make_coord_tensor(cD_mn.layout()); // (CTA_M,CTA_N) + Tensor tRS_cD = make_coord_tensor(tRS_cD_mn.layout()); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) + // Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate + auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n) + auto residue_tRS_cD = make_coord(M,N) - tRS_cD_mn(_0{}); + + CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M"); + + if constexpr (epi_tile_m * epi_tile_n > mma_tile_m * mma_tile_n) { + // When the epilogue subtile is larger than the MMA tiles, loop over multiple MMA tiles + CUTE_STATIC_ASSERT(epi_tile_n % mma_tile_n == 0, "MMA_TILE_N must divide EPI_TILE_N"); + } + else { + CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); + } + + // Get TiledCopy for partition reference when consumer store. + TiledCopy tiled_copy_partition_ref = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + // Get the fusion callbacks for the consumer store warps + constexpr bool RefSrc = true; // Register tensors reference tiled copy src layout + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs( + problem_shape_mnkl, + CtaTileMNK{}, + tile_coord_mnkl, + tiled_mma, + EpilogueTile{}, + tiled_copy_partition_ref, + cD, + residue_cD, + tRS_cD, + residue_tRS_cD, + tRS_rC, + thread_idx + ); + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + using FragmentVisit = decltype(cst_callbacks.visit(tRS_rAcc_frg0(0), 0, 0, 0)); + constexpr bool IsDirectR2S = cute::is_same_v>; + using RegisterElementD = cute::conditional_t; + Tensor tRS_rCompute = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + Tensor tRS_rCompute_frg = recast>(tRS_rCompute); + + // Thread synchronizer for previously issued waits or fences + // to ensure visibility of smem reads/writes to threads or TMA unit + auto synchronize = [&] () CUTLASS_LAMBDA_FUNC_INLINE { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + + // Predication for TMA store (one warp issues TMA store) + bool issue_tma_store = (thread_idx / NumThreadsPerWarp) == 0; + + // In the reuse smem configuration we have StagesC smem buffers and at most StagesD committed TMA stores in flight. + // The TMA store pipeline producer acquire returns when at most StagesD-1 committed stores are in-flight, so we can + // only guarantee store completion after StagesD iterations, then we can begin issuing releases on the smem buffer locks. + // store_pipe_producer_state tracks the acquire and load_pipe_consumer_state tracks the release, in circular buffer fashion. + LoadPipelineState load_wait_state = load_pipe_consumer_state; + if constexpr (ReuseSmemC) { + load_wait_state = store_pipe_producer_state; + load_wait_state.phase_ ^= 1; + } + + // We can delay issue of TMA store by one iteration to achieve better interleaving of non-TMA instructions + // Sync requirements of smem reuse may preclude this optimization + // Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD + [[maybe_unused]] int epi_m_prev = 0; + [[maybe_unused]] int epi_n_prev = 0; + static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); + + // The TMA store sequence for one subtile iteration + auto tma_store_fn = [&] (int epi_m, int epi_n) CUTLASS_LAMBDA_FUNC_INLINE { + // Write the tile from smem to gmem with TMA + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + synchronize(); // ensure all threads have issued their async fence + if constexpr (is_destination_supported) { + if (issue_tma_store) { + copy(params.tma_store_d, bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); + } + } + + // Post async fence, pre TMA commit callback entry point + cst_callbacks.tma_store(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store); + + // Commit the TMA stores for this stage + if (issue_tma_store) { + store_pipeline.producer_commit(store_pipe_producer_state); + } + ++store_pipe_producer_state; + ++issued_stores; + + // Wait for the next smem buffer to be available + if (issue_tma_store) { + store_pipeline.producer_acquire(store_pipe_producer_state); + } + synchronize(); + + if constexpr (ReuseSmemC) { + // producer_acquire returns when at most StagesD-1 committed stores are pending + bool store_finished = issued_stores > StorePipeline::UnacquiredStages; + // Let dma warp know earliest smem buffer is consumed and empty after StagesD producer commits + if (store_finished) { + if (is_producer_load_needed) { + load_pipeline.consumer_release(load_pipe_consumer_state); + } + ++load_pipe_consumer_state; + } + } + }; + + // + // BEGIN EPILOGUE + // + + // Pre-loop fusion callback entry point + cst_callbacks.begin(); + if (cst_callbacks.begin_sync_needed()) { + synchronize(); + } + + // For each output tile + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m) { + [[maybe_unused]] bool is_first_iteration = epi_m == 0 && epi_n == 0; + bool is_last_iteration = epi_m == size<2>(gD_epi)-1 && epi_n == size<3>(gD_epi)-1; + + if (subtile_idx != -1 && (epi_n * static_cast(size<2>(gD_epi)) + epi_m) != subtile_idx) { + continue; + } + + cst_callbacks.begin_loop(epi_m, epi_n); + + if (is_producer_load_needed) { + // Wait for the producer load to fill smem + load_pipeline.consumer_wait(load_wait_state); + + if (is_C_load_needed) { + // Copy source tile from smem to register + copy(tiled_s2r, tSR_sC(_,_,_,load_wait_state.index()), tSR_rC); + // Ensure smem loads are complete before reusing smem for mixed types/layouts + if constexpr (ReuseSmemC && not (SmemLayoutC{} == SmemLayoutD{})) { + synchronize(); + } + } + } + + // First loop fusion callback entry point + cst_callbacks.previsit(epi_m, epi_n, load_wait_state.count(), is_producer_load_needed); + + if (is_producer_load_needed) { + if constexpr (not ReuseSmemC) { + // Let producer load warp know smem buffers are consumed and empty + cutlass::arch::fence_view_async_shared(); + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + ++load_wait_state; + } + + if constexpr (epi_tile_m * epi_tile_n > mma_tile_m * mma_tile_n) { + // When the epilogue subtile is larger than the MMA tiles, loop over multiple + // MMA tiles + static constexpr int MmaMPerEpiM = epi_tile_m / mma_tile_m; + static constexpr int MmaNPerEpiN = epi_tile_n / mma_tile_n; + + CUTLASS_PRAGMA_UNROLL + for (int mma_n_in_epi = 0; mma_n_in_epi < MmaNPerEpiN; ++mma_n_in_epi) { + int mma_n = (epi_n * MmaNPerEpiN) + mma_n_in_epi; + + CUTLASS_PRAGMA_UNROLL + for (int mma_m_in_epi = 0; mma_m_in_epi < MmaMPerEpiM; ++mma_m_in_epi) { + int mma_m = (epi_m * MmaMPerEpiM) + mma_m_in_epi; + Tensor tRS_rAcc_frg_mn0 = tRS_rAcc_frg0(_,mma_m,mma_n); + Tensor tRS_rAcc_frg_mn1 = tRS_rAcc_frg1(_,mma_m,mma_n); + int idx_in_epi_subtile = (mma_n_in_epi * MmaMPerEpiM + mma_m_in_epi); + + if constexpr (IsDualOpPair) { + auto comp0 = cst_callbacks.visit0(tRS_rAcc_frg_mn0(0), idx_in_epi_subtile, epi_m, epi_n); + auto comp1 = cst_callbacks.visit1(tRS_rAcc_frg_mn1(0), idx_in_epi_subtile, epi_m, epi_n); + using LsmOp = cutlass::epilogue::thread::LeftSiLUAndMul; + typename LsmOp::Params lsm_params{}; + LsmOp lsm(lsm_params); + Array combined_comp = lsm(comp0, comp1); + tRS_rCompute_frg(idx_in_epi_subtile) = combined_comp; + } else { + auto comp0 = cst_callbacks.visit(tRS_rAcc_frg_mn0(0), idx_in_epi_subtile, epi_m, epi_n); + auto comp1 = cst_callbacks.visit(tRS_rAcc_frg_mn1(0), idx_in_epi_subtile, epi_m, epi_n); + using LsmOp = cutlass::epilogue::thread::LeftSiLUAndMul; + typename LsmOp::Params lsm_params{}; + LsmOp lsm(lsm_params); + Array combined_comp = lsm(comp0, comp1); + tRS_rCompute_frg(idx_in_epi_subtile) = combined_comp; + } + } + } + } + else { + int mma_m = epi_m; + int mma_n = (epi_n * size<1>(EpilogueTile{})) / mma_tile_n; + Tensor tRS_rAcc_frg_mn0 = tRS_rAcc_frg0(_,mma_m,mma_n); + Tensor tRS_rAcc_frg_mn1 = tRS_rAcc_frg1(_,mma_m,mma_n); + + // Vectorized fragment loop with visitor callback entry point + int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n); + int r2s_v = epi_n_in_mma * size(tRS_rCompute_frg); + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(tRS_rCompute_frg); ++epi_v) { + if constexpr (IsDualOpPair) { + auto comp0 = cst_callbacks.visit0(tRS_rAcc_frg_mn0(r2s_v + epi_v), epi_v, epi_m, epi_n); + auto comp1 = cst_callbacks.visit1(tRS_rAcc_frg_mn1(r2s_v + epi_v), epi_v, epi_m, epi_n); + using LsmOp = cutlass::epilogue::thread::LeftSiLUAndMul; + typename LsmOp::Params lsm_params{}; + LsmOp lsm(lsm_params); + Array combined_comp = lsm(comp0, comp1); + tRS_rCompute_frg(epi_v) = combined_comp; + } else { + auto comp0 = cst_callbacks.visit(tRS_rAcc_frg_mn0(r2s_v + epi_v), epi_v, epi_m, epi_n); + auto comp1 = cst_callbacks.visit(tRS_rAcc_frg_mn1(r2s_v + epi_v), epi_v, epi_m, epi_n); + using LsmOp = cutlass::epilogue::thread::LeftSiLUAndMul; + typename LsmOp::Params lsm_params{}; + LsmOp lsm(lsm_params); + Array combined_comp = lsm(comp0, comp1); + tRS_rCompute_frg(epi_v) = combined_comp; + } + } + } + + // The latest we can delay the TMA store is right before the smem store of the next iteration + // since the current TMA store needs to be committed before we can acquire the next smem buffer + if constexpr (DelayTmaStore) { + // Issue TMA stores for the previous subtile + if (not is_first_iteration and subtile_idx == -1) { + tma_store_fn(epi_m_prev, epi_n_prev); + } + epi_m_prev = epi_m; + epi_n_prev = epi_n; + } + + // Smem reduction callback entry point using current store buffer for workspace + cst_callbacks.reduce(sD_epi(_,_,store_pipe_producer_state.index()), + synchronize, epi_m, epi_n, is_last_iteration, tRS_rCompute_frg); + + // Copy tile from register to regiser if needed + if constexpr (IsUseR2R) { + // retile source and destination for tiled_r2r + Tensor tRR_rD_src = thread_r2r.retile_S(tRS_rCompute); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) + Tensor tRR_rD_dst = thread_r2r.retile_D(tRS_rCompute); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) + + // Output register transformation before copying to shared memory. + copy(tiled_r2r, tRR_rD_src, tRR_rD_dst); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tRS_rD_frg); ++i) { + tRS_rD_frg(i) = cutlass::NumericArrayConverter{}(tRS_rCompute_frg(i)); + } + + // Copy tile from register to smem + if constexpr (is_destination_supported) { + copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); + } + + // Post reduction, pre TMA store callback entry point + constexpr bool issue_smem_store = true; // No smem store predication + cst_callbacks.postreduce(epi_m, epi_n, store_pipe_producer_state.count(), issue_smem_store); + + if constexpr (not DelayTmaStore) { + // Issue TMA stores for this subtile + tma_store_fn(epi_m, epi_n); + } + + cst_callbacks.end_loop(epi_m, epi_n); + + } // for epi_m + } // for epi_n + + if constexpr (DelayTmaStore) { + // Issue TMA stores for the last subtile + tma_store_fn(epi_m_prev, epi_n_prev); + } + + // Post-loop fusion callback entry point + cst_callbacks.end(); + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + + CUTLASS_DEVICE auto + store_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state) { + // wait for all TMA stores to complete + store_pipeline.producer_tail(store_pipe_producer_state); + // reset store counter + issued_stores = 0; + + if constexpr (ReuseSmemC) { + if (fusion_callbacks.is_producer_load_needed()) { + // Issue releases on up to StagesD-1 previously issued TMA stores + constexpr int release_stages = cute::min(StorePipeline::UnacquiredStages, get_load_pipe_increment(CtaTileMNK{})); + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < release_stages; ++stage) { + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + } + } + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + +private: + Params const& params; + DualFusionCallbacks fusion_callbacks; + int issued_stores = 0; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/93_hopper_dual_gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_dual.hpp b/examples/93_hopper_dual_gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_dual.hpp new file mode 100644 index 0000000000..4f43e081d5 --- /dev/null +++ b/examples/93_hopper_dual_gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_dual.hpp @@ -0,0 +1,715 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +#include "dispatch_policy_extra.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Primary template declaration to enable partial/specialized definitions below. +template < + class DispatchPolicy, + class TileShape, + class ElementA, + class StrideA, + class ElementB, + class StrideB0, + class StrideB1, + class TiledMma, + class GmemTiledCopyA, + class SmemLayoutAtomsA, + class SmemCopyAtomsA, + class TransformA, + class GmemTiledCopyB0, + class SmemLayoutAtomsB0, + class SmemCopyAtomsB0, + class TransformB0, + class GmemTiledCopyB1, + class SmemLayoutAtomsB1, + class SmemCopyAtomsB1, + class TransformB1 +> +struct DualCollectiveMma { + static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); +}; + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape, + class KernelSchedule, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB0_, + class StrideB1_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB0_, + class SmemLayoutAtomB0_, + class SmemCopyAtomB0_, + class TransformB0_, + class GmemTiledCopyB1_, + class SmemLayoutAtomB1_, + class SmemCopyAtomB1_, + class TransformB1_> +struct DualCollectiveMma< + DualMainloopSm90TmaGmmaWarpSpecialized, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB0_, + StrideB1_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB0_, + SmemLayoutAtomB0_, + SmemCopyAtomB0_, + TransformB0_, + GmemTiledCopyB1_, + SmemLayoutAtomB1_, + SmemCopyAtomB1_, + TransformB1_> +{ + // + // Type Aliases + // + using DispatchPolicy = DualMainloopSm90TmaGmmaWarpSpecialized; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB0 = StrideB0_; + using StrideB1 = StrideB1_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB0 = GmemTiledCopyB0_; + using GmemTiledCopyB1 = GmemTiledCopyB1_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB0 = SmemLayoutAtomB0_; + using SmemLayoutAtomB1 = SmemLayoutAtomB1_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB0 = SmemCopyAtomB0_; + using SmemCopyAtomB1 = SmemCopyAtomB1_; + using TransformA = TransformA_; + using TransformB0 = TransformB0_; + using TransformB1 = TransformB1_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + // One threads per CTA are producers (1 for operand tile) + static constexpr int NumProducerThreadEvents = 1; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB0{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB0{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB0{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB1{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB1{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB1{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB0 = decltype(tile_to_shape( + SmemLayoutAtomB0{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB0>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB1 = decltype(tile_to_shape( + SmemLayoutAtomB1{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB1>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using InternalElementA = cute::conditional_t>>; + using InternalElementB = cute::conditional_t>>; + + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B0; + cute::array_aligned> smem_B1; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B0; + StrideB0 dB0; + ElementB const* ptr_B1; + StrideB1 dB1; + uint32_t mma_promotion_interval = 4; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy_A_sm90( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{})); + // Assumption: StrideB0 is congruent with Problem_NK + using TMA_B0 = decltype(make_tma_copy_B_sm90( + GmemTiledCopyB0{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB0{}, int32_t(0)), StrideB0{}), + SmemLayoutB0{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{})); + // Assumption: StrideB1 is congruent with Problem_NK + using TMA_B1 = decltype(make_tma_copy_B_sm90( + GmemTiledCopyB1{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB1{}, int32_t(0)), StrideB1{}), + SmemLayoutB1{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{})); + TMA_A tma_load_a; + TMA_B0 tma_load_b0; + TMA_B1 tma_load_b1; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk0 = TmaTransactionBytesNK0; + uint32_t tma_transaction_bytes_nk1 = TmaTransactionBytesNK1; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B0 = reinterpret_cast(args.ptr_B0); + auto ptr_B1 = reinterpret_cast(args.ptr_B1); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b0 = make_tensor(ptr_B0, make_layout(make_shape(N,K,L), args.dB0)); + Tensor tensor_b1 = make_tensor(ptr_B1, make_layout(make_shape(N,K,L), args.dB1)); + + typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + typename Params::TMA_B0 tma_load_b0 = make_tma_copy_B_sm90( + GmemTiledCopyB0{}, + tensor_b0, + SmemLayoutB0{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + typename Params::TMA_B1 tma_load_b1 = make_tma_copy_B_sm90( + GmemTiledCopyB1{}, + tensor_b1, + SmemLayoutB1{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + uint32_t transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t transaction_bytes_nk0 = TmaTransactionBytesNK0; + uint32_t transaction_bytes_nk1 = TmaTransactionBytesNK1; + uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk0 + transaction_bytes_nk1; + + return { + tma_load_a, + tma_load_b0, + tma_load_b1, + transaction_bytes, + transaction_bytes_mk, + transaction_bytes_nk0, + transaction_bytes_nk1 + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B0 = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB0{}); + constexpr int min_tma_aligned_elements_B1 = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB1{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytesMK = + cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytesNK0 = + cutlass::bits_to_bytes(size<0>(SmemLayoutB0{}) * size<1>(SmemLayoutB0{}) * static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytesNK1 = + cutlass::bits_to_bytes(size<0>(SmemLayoutB1{}) * size<1>(SmemLayoutB1{}) * static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK0 + TmaTransactionBytesNK1; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b0.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b1.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) + Tensor mB0_nkl = mainloop_params.tma_load_b0.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + Tensor mB1_nkl = mainloop_params.tma_load_b1.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB0_nkl = local_tile(mB0_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + Tensor gB1_nkl = local_tile(mB1_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + return cute::make_tuple(gA_mkl, gB0_nkl, gB1_nkl); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, class TensorB0, class TensorB1, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB0 = make_tensor(make_smem_ptr(shared_tensors.smem_B0.data()), SmemLayoutB0{}); // (BLK_N,BLK_K,PIPE) + Tensor sB1 = make_tensor(make_smem_ptr(shared_tensors.smem_B1.data()), SmemLayoutB1{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB0_nkl = get<1>(load_inputs); + Tensor gB1_nkl = get<2>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b0 = mainloop_params.tma_load_b0.get_slice(cluster_local_block_id.x); + auto block_tma_b1 = mainloop_params.tma_load_b1.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB0 = gB0_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + Tensor gB1 = gB1_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB0 = block_tma_b0.partition_S(gB0); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB0 = block_tma_b0.partition_D(sB0); // (TMA,TMA_N,TMA_K,PIPE) + Tensor tBgB1 = block_tma_b1.partition_S(gB1); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB1 = block_tma_b1.partition_D(sB1); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b0 = 0; + uint16_t mcast_mask_b1 = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b0 |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b1 |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b0.with(*tma_barrier, mcast_mask_b0), tBgB0(_,_,_,*k_tile_iter), tBsB0(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b1.with(*tma_barrier, mcast_mask_b1), tBgB1(_,_,_,*k_tile_iter), tBsB1(_,_,_,write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC0, class FrgTensorC1 + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC0& accum0, + FrgTensorC1& accum1, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB0{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB1{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB0 = make_tensor(make_smem_ptr(shared_tensors.smem_B0.data()), SmemLayoutB0{}); // (BLK_N,BLK_K,PIPE) + Tensor sB1 = make_tensor(make_smem_ptr(shared_tensors.smem_B1.data()), SmemLayoutB1{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB0 = thread_mma.partition_B(sB0); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCsB1 = thread_mma.partition_B(sB1); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB0 = thread_mma.make_fragment_B(tCsB0); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrB1 = thread_mma.make_fragment_B(tCsB1); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB0) == size<2>(accum0)); // N + CUTE_STATIC_ASSERT_V(size<1>(tCsB1) == size<2>(accum1)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB0)); // K + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB1)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB0)); // PIPE + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB1)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB0)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB1)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + assert(k_tile_count >= 1); + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB0(_,_,k_block,read_stage), accum0); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + // Do the same for B1 on the same stage + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB1(_,_,k_block,read_stage), accum1); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + + ++smem_pipe_read; + } + + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count - 1; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,_,read_stage), tCrB0(_,_,_,read_stage), accum0); + cute::gemm(tiled_mma, tCrA(_,_,_,read_stage), tCrB1(_,_,_,read_stage), accum1); + warpgroup_commit_batch(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + warpgroup_arrive(); + // (V,M,K) x (V,N,K) => (V,M,N) + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB0(_,_,k_block,read_stage), accum0); + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB1(_,_,k_block,read_stage), accum1); + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + + // UNLOCK smem_pipe_release, done _computing_ on it + pipeline.consumer_release(smem_pipe_release); + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/93_hopper_dual_gemm/device/gemm_universal_adapter.h b/examples/93_hopper_dual_gemm/device/gemm_universal_adapter.h new file mode 100644 index 0000000000..c72e7cfe4e --- /dev/null +++ b/examples/93_hopper_dual_gemm/device/gemm_universal_adapter.h @@ -0,0 +1,785 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! + \file + \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/layout.hpp" +#include "cutlass/detail/mma.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +#include "cutlass/kernel_launch.h" +#if !defined(__CUDACC_RTC__) +#include "cutlass/cluster_launch.hpp" +#include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) + +// 2.x +#include "cutlass/gemm/device/gemm_universal_base.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h" + +// 3.x +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::device { + +//////////////////////////////////////////////////////////////////////////////// + +/*! + GemmUniversalAdapter is a stateful, reusable GEMM handle built around a kernel + of type cutlass::gemm::kernel::Gemm or cutlass::gemm::kernel::GemmUniversal. + + It manages the lifetime of the underlying `kernel::Params` struct, and exposes APIs + to create it from the host facing arguments. For power users, new static methods + are exposed in 3.x APIs that bypass the stateful methods or args->params lowering. + + It supports kernel types that implement both the 2.x and 3.0 APIs, + however, this is done by specializing the implementation of GemmUniversalAdapter + on the two kernel API types, and thus, GemmUniversalAdapter's behaviour might + differ between the two specializations. +*/ +template +class GemmUniversalAdapter; + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// Work-around for some DispatchPolicy types not having a Stages member. +// In that case, the Stages value is 0. Most code should static_assert +// that the number of stages is valid. + +// Whether DispatchPolicy::Stages is valid. +// It should also be convertible to int, but if not, that will show up +// as a build error when GemmUniversalAdapter attempts to assign it to kStages. +template +struct has_Stages : cute::false_type {}; + +template +struct has_Stages> : cute::true_type {}; + +template +constexpr int stages_member(DispatchPolicy) { + if constexpr (has_Stages::value) { + return DispatchPolicy::Stages; + } + else { + return 0; + } +} + +} // namespace detail + +template +class GemmUniversalAdapter< + GemmKernel_, + cute::enable_if_t>::value>> +{ +public: + using GemmKernel = GetUnderlyingKernel_t; + using TileShape = typename GemmKernel::TileShape; + using ElementA = typename GemmKernel::ElementA; + using ElementB = typename GemmKernel::ElementB; + using ElementC = typename GemmKernel::ElementC; + using ElementD = typename GemmKernel::ElementD; + using ElementAccumulator = typename GemmKernel::ElementAccumulator; + using DispatchPolicy = typename GemmKernel::DispatchPolicy; + using CollectiveMainloop = typename GemmKernel::CollectiveMainloop; + using CollectiveEpilogue = typename GemmKernel::CollectiveEpilogue; + + // Map back to 2.x type as best as possible + using LayoutA = gemm::detail::StrideToLayoutTagA_t; + using LayoutB0 = gemm::detail::StrideToLayoutTagB_t; + using LayoutB1 = gemm::detail::StrideToLayoutTagB_t; + using LayoutC = gemm::detail::StrideToLayoutTagC_t; + using LayoutD = gemm::detail::StrideToLayoutTagC_t; + + static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; + + static ComplexTransform const kTransformA = cute::is_same_v ? + ComplexTransform::kConjugate : ComplexTransform::kNone; + static ComplexTransform const kTransformB0 = cute::is_same_v ? + ComplexTransform::kConjugate : ComplexTransform::kNone; + static ComplexTransform const kTransformB1 = cute::is_same_v ? + ComplexTransform::kConjugate : ComplexTransform::kNone; + // Legacy: Assume MultiplyAdd only since we do not use this tag type in 3.0 + using MathOperator = cutlass::arch::OpMultiplyAdd; + + using OperatorClass = cutlass::detail::get_operator_class_t; + + using ArchTag = typename GemmKernel::ArchTag; + + // NOTE: Assume identity swizzle for now + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + // Assume TiledMma's ShapeMNK is the same as 2.x's ThreadblockShape + using ThreadblockShape = cutlass::gemm::GemmShape< + cute::size<0>(TileShape{}), + cute::size<1>(TileShape{}), + cute::size<2>(TileShape{})>; + + using ClusterShape = cutlass::gemm::GemmShape< + cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})>; + + // Instruction shape is easy too, since we get that directly from our TiledMma's atom shape + using InstructionShape = cutlass::gemm::GemmShape< + cute::size<0>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), + cute::size<1>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), + cute::size<2>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{})>; + + // Legacy: provide a correct warp count, but no reliable warp shape + static int const kThreadCount = GemmKernel::MaxThreadsPerBlock; + + // Warp shape is not a primary API type in 3.x + // But we can best approximate it by inspecting the TiledMma + // For this, we make the assumption that we always have 4 warps along M, and rest along N, none along K + // We also always round up the warp count to 4 if the tiled mma is smaller than 128 threads + static constexpr int WarpsInMma = cute::max(4, CUTE_STATIC_V(cute::size(typename GemmKernel::TiledMma{})) / 32); + static constexpr int WarpsInMmaM = 4; + static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM); + using WarpCount = cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape< + CUTE_STATIC_V(cute::tile_size<0>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaM, + CUTE_STATIC_V(cute::tile_size<1>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaN, + CUTE_STATIC_V(cute::tile_size<2>(typename CollectiveMainloop::TiledMma{}))>; + + static int constexpr kStages = detail::stages_member(typename CollectiveMainloop::DispatchPolicy{}); + + // Inspect TiledCopy for A and B to compute the alignment size + static int constexpr kAlignmentA = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyA, ElementA, typename CollectiveMainloop::TiledMma::ValTypeA>(); + static int constexpr kAlignmentB = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyB, ElementB, typename CollectiveMainloop::TiledMma::ValTypeB>(); + static int constexpr kAlignmentC = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveEpilogue::GmemTiledCopyC, ElementC>(); + static int constexpr kAlignmentD = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveEpilogue::GmemTiledCopyD, ElementD>(); + + using EpilogueOutputOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + // Split-K preserves splits that are 128b aligned + static int constexpr kSplitKAlignment = cute::max( + 128 / sizeof_bits::value, 128 / sizeof_bits::value); + + /// Argument structure: User API + using Arguments = typename GemmKernel::Arguments; + /// Argument structure: Kernel API + using Params = typename GemmKernel::Params; + +private: + + /// Kernel API parameters object + Params params_; + +public: + + /// Access the Params structure + Params const& params() const { + return params_; + } + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + if (GemmKernel::can_implement(args)) { + return Status::kSuccess; + } + else { + return Status::kInvalid; + } + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { + workspace_bytes += sizeof(int) * size_t(cute::size<0>(TileShape{})) * size_t(cute::size<1>(TileShape{})); + } + + workspace_bytes += GemmKernel::get_workspace_size(args); + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Arguments const& args, void* workspace = nullptr) { + auto tmp_params = GemmKernel::to_underlying_arguments(args, workspace); + return GemmKernel::get_grid_shape(tmp_params); + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Params const& params) { + return GemmKernel::get_grid_shape(params); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("GemmUniversal::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = GemmKernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + device_kernel, + GemmKernel::MaxThreadsPerBlock, + smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes GEMM state from arguments. + Status + initialize( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + + CUTLASS_TRACE_HOST("GemmUniversal::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = GemmKernel::initialize_workspace(args, workspace, stream, cuda_adapter); + if (status != Status::kSuccess) { + return status; + } + // Initialize the Params structure + params_ = GemmKernel::to_underlying_arguments(args, workspace); + // Don't set the function attributes - require the CudaHostAdapter to set it. + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + return Status::kSuccess; + } + else { + // + // Account for dynamic smem capacity if needed + // + int smem_size = GemmKernel::SharedStorageSize; + + CUTLASS_ASSERT(cuda_adapter == nullptr); + + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + } + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status + update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversal()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + params_ = GemmKernel::to_underlying_arguments(args, workspace); + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling GemmKernel::to_underlying_arguments() + static Status + run(Params& params, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false) { + CUTLASS_TRACE_HOST("GemmUniversal::run()"); + dim3 const block = GemmKernel::get_block_shape(); + dim3 const grid = get_grid_shape(params); + + // configure smem size and carveout + int smem_size = GemmKernel::SharedStorageSize; + + Status launch_result{ Status::kSuccess }; + // Use extended launch API only for mainloops that use it + if constexpr (GemmKernel::ArchTag::kMinComputeCapability >= 90) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Use extended launch API"); +#endif + [[maybe_unused]] constexpr bool is_static_1x1x1 = + cute::is_static_v and + cute::size(typename GemmKernel::DispatchPolicy::ClusterShape{}) == 1; + [[maybe_unused]] dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})); + + // Dynamic cluster support + [[maybe_unused]] dim3 fallback_cluster = dim3{0,0,0}; + if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 100 + || GemmKernel::ArchTag::kMinComputeCapability == 101 + ) { + if constexpr (!cute::is_static_v) { + fallback_cluster = params.hw_info.cluster_shape_fallback; + cluster = params.hw_info.cluster_shape; + } + } + + [[maybe_unused]] void* kernel_params[] = {¶ms}; + + if constexpr (kEnableCudaHostAdapter) { + // + // Use the cuda host adapter + // + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + if (launch_with_pdl) { + CUTLASS_TRACE_HOST( + "GemmUniversal::run() does not support launching with PDL and a custom cuda adapter."); + return Status::kErrorInternal; + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with CUDA host adapter"); +#endif + if constexpr (is_static_1x1x1) { + launch_result = cuda_adapter->launch(grid, + block, + smem_size, + stream, + kernel_params, + 0); + } + else { + launch_result = cuda_adapter->launch(grid, + cluster, + fallback_cluster, + block, + smem_size, + stream, + kernel_params, + 0); + } + } + else { + CUTLASS_TRACE_HOST("GemmUniversal::run: kEnableCudaHostAdapter is true, but CUDA host adapter is null"); + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + [[maybe_unused]] void const* kernel = (void const*) device_kernel; + static constexpr bool kClusterLaunch = GemmKernel::ArchTag::kMinComputeCapability == 90; + if constexpr (kClusterLaunch) { + if constexpr (is_static_1x1x1) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching static 1x1x1 kernel"); +#endif + launch_result = cutlass::kernel_launch( + grid, block, smem_size, stream, params, launch_with_pdl); + if (launch_result != Status::kSuccess) { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports failure"); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports success"); + } +#endif + } + else { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching dynamic cluster kernel"); +#endif + launch_result = ClusterLauncher::launch( + grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl); + } + } + + else { + if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 100 + || GemmKernel::ArchTag::kMinComputeCapability == 101 + || GemmKernel::ArchTag::kMinComputeCapability == 120 + ) { + if constexpr (is_static_1x1x1) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching static 1x1x1 kernel"); +#endif + launch_result = cutlass::kernel_launch(grid, block, smem_size, stream, params, launch_with_pdl); + if (launch_result != Status::kSuccess) { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports failure"); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports success"); + } +#endif + } + else { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with fall-back cluster"); +#endif + launch_result = ClusterLauncher::launch_with_fallback_cluster( + grid, + cluster, + fallback_cluster, + block, + smem_size, + stream, + kernel, + kernel_params, + launch_with_pdl); + } + } + } + + } + } + else { + launch_result = Status::kSuccess; + cutlass::arch::synclog_setup(); + + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + void* kernel_params[] = {¶ms}; +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with CUDA host adapter"); +#endif + launch_result = cuda_adapter->launch( + grid, block, smem_size, stream, kernel_params, 0 + ); + + } + else { + CUTLASS_TRACE_HOST("GemmUniversal::run: CUDA host adapter is null"); + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with cutlass::kernel_launch"); +#endif + launch_result = cutlass::kernel_launch( + grid, block, smem_size, stream, params, launch_with_pdl); + if (launch_result != Status::kSuccess) { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports failure"); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports success"); + } +#endif + } + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result && Status::kSuccess == launch_result) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: cudaGetLastError reports success"); +#endif + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false + ) { + Status status = initialize(args, workspace, stream, cuda_adapter); + + if (Status::kSuccess == status) { + status = run(params_, stream, cuda_adapter, launch_with_pdl); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false) { + return run(args, workspace, stream, cuda_adapter, launch_with_pdl); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run( + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false) { + return run(params_, stream, cuda_adapter, launch_with_pdl); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, bool launch_with_pdl = false) { + return run(params_, stream, cuda_adapter, launch_with_pdl); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 2.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversalAdapter< + GemmKernel_, + cute::enable_if_t>::value>> +{ +public: + + using GemmKernel = GetUnderlyingKernel_t; + + static bool const kInternalTranspose = + !cutlass::epilogue::threadblock::detail::is_2x_evt_v && // 2.x EVT does not require internal transpose + cute::is_same::value; + + using ThreadblockShape = typename GemmKernel::Mma::Shape; + using WarpShape = typename GemmKernel::WarpShape; + using InstructionShape = typename GemmKernel::InstructionShape; + + // warp-level, arch-level (instruction), math operator + using WarpMmaOperator = typename GemmKernel::Mma::Policy::Operator; + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename WarpMmaOperator::MathOperator; + + // Operator class and arch tag extract bottom-up + // set it for top-level gemm device-level template + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + + // Type, layout, and complex transform deliberately exchanged with B + using MapArguments = kernel::detail::MapArguments< + typename GemmKernel::ElementA, + typename GemmKernel::LayoutA, + GemmKernel::kTransformA, + GemmKernel::kAlignmentA, + typename GemmKernel::ElementB, + typename GemmKernel::LayoutB, + GemmKernel::kTransformB, + GemmKernel::kAlignmentB, + typename GemmKernel::LayoutC, + kInternalTranspose + >; + + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static int const kAlignmentA = MapArguments::kAlignmentA; + + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + static int const kAlignmentB = MapArguments::kAlignmentB; + + using ElementC = typename GemmKernel::ElementC; + using LayoutC = typename MapArguments::LayoutC; + static int const kAlignmentC = GemmKernel::kAlignmentC; + + // C and D same type for 2.x kernel + using ElementD = ElementC; + using LayoutD = LayoutC; + + using TensorRefA = TensorRef; + using TensorRefB = TensorRef; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + + static int const kStages = GemmKernel::Mma::kStages; + + using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; + using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; + using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; + using UnderlyingOperator = GemmUniversalBase; + using Arguments = typename UnderlyingOperator::Arguments; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the GEMM. + GemmUniversalAdapter() { } + + /// Helper to construct a transposed equivalent for the underlying GEMM operator + static Arguments to_underlying_arguments(Arguments const &args) { + if (kInternalTranspose) { + return args.transposed_problem(); + } + else { + return args; + } + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args), cuda_adapter); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr) { + + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args), cuda_adapter); + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args) { + return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + return UnderlyingOperator::maximum_active_blocks(smem_capacity); + } + + /// Initializes GEMM state from arguments. + Status initialize( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr + ) { + + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream, cuda_adapter); + } + + /// Lightweight update given a subset of arguments. + Status update(Arguments const &args) { + + return underlying_operator_.update(to_underlying_arguments(args)); + } + + /// Runs the kernel using initialized state. + Status run( + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + + return underlying_operator_.run(stream, cuda_adapter); + } + + /// Runs the kernel using initialized state. + Status operator()( + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + + Status status = initialize(args, workspace, stream, cuda_adapter); + + if (status == Status::kSuccess) { + status = run(stream, cuda_adapter); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/93_hopper_dual_gemm/fusion/callbacks.hpp b/examples/93_hopper_dual_gemm/fusion/callbacks.hpp new file mode 100644 index 0000000000..4b6569c168 --- /dev/null +++ b/examples/93_hopper_dual_gemm/fusion/callbacks.hpp @@ -0,0 +1,91 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/detail/dependent_false.hpp" +#include "operations.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Dispatch interface for epilogue fusion callbacks +// For visitor fusions, this is just a convenience wrapper to provide metadata and non-nested args. +// It is also valid to just pass visitor callbacks directly to the collective, e.g. fusion::Sm90LinearCombination, +// provided the collective supports a visitor callbacks interface. This is useful for implementing custom fusions. +template < + class DispatchPolicy, // specialize on collective's dispatch policy since callbacks API will depend on collective's algorithm + class Operation, // the fusion operation being performed, e.g. fusion::LinearCombination + class CtaTile_MNK, // computed tile per CTA + class EpilogueTile_MN, // epilogue subtile size + class... Args // callbacks implementation dependent args (e.g. copy atoms, smem layouts) +> +struct DualFusionCallbacks { + static_assert(cutlass::detail::dependent_false, "Could not find a callbacks specialization."); +}; + +// Metadata helper to handle custom EVTs or other non-FusionCallbacks types +template +struct DualFusionCallbacksTraits { + using DispatchPolicy = void; + using Callbacks = T; + using Operation = DualFusionOperation; + using CtaTile_MNK = void; + using EpilogueTile_MN = void; + using ElementCompute = void; +}; + +template < + class DispatchPolicy_, + class Operation_, + class CtaTile_MNK_, + class EpilogueTile_MN_, + class... Args +> +struct DualFusionCallbacksTraits< + DualFusionCallbacks +> { + using DispatchPolicy = DispatchPolicy_; + using Callbacks = DualFusionCallbacks; + using Operation = Operation_; + using CtaTile_MNK = CtaTile_MNK_; + using EpilogueTile_MN = EpilogueTile_MN_; + using ElementCompute = typename Operation::ElementCompute; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/93_hopper_dual_gemm/fusion/operations.hpp b/examples/93_hopper_dual_gemm/fusion/operations.hpp new file mode 100644 index 0000000000..8eaf392994 --- /dev/null +++ b/examples/93_hopper_dual_gemm/fusion/operations.hpp @@ -0,0 +1,151 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include +#include // cute::false_type +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Fusion Operations +// Template args must not be implementation dependent +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct DualFusionOperation { + // metadata types/queries that can be overrided + using ElementOutput = void; + using ElementCompute = void; + FloatRoundStyle RoundStyle = FloatRoundStyle::round_indeterminate; + + using ElementSource = void; + static constexpr bool IsSourceSupported = false; + static constexpr bool IsResidualSupported = false; // Source is added after activation + + using ElementScalar = void; + static constexpr int AlignmentScalar = 0; + static constexpr bool IsScaleFactorSupported = false; + static constexpr bool IsPerRowScaleSupported = false; + static constexpr bool IsPerColScaleSupported = false; + + using ElementBias = void; + static constexpr int AlignmentBias = 0; + static constexpr bool IsPerRowBiasSupported = false; + static constexpr bool IsPerColBiasSupported = false; + static constexpr bool IsDePerRowBiasSupported = false; + + using ActivationFn = void; + static constexpr bool IsEltActSupported = false; + static constexpr bool IsDeEltActSupported = false; + + using ElementAux = void; + using GmemLayoutTagAux = void; + static constexpr int AlignmentAux = 0; + static constexpr bool IsAuxOutSupported = false; + static constexpr bool IsAuxInSupported = false; + + using ElementAmax = void; + static constexpr bool IsAbsMaxSupported = false; + + using ElementBlockScaleFactor = void; + static constexpr int SFVecSize = 0; + static constexpr bool IsBlockScaleSupported = false; // Umbrella variable to check BlockScaling support in the epilogues + using GmemLayoutTagScalefactor = void; +}; + +// D = alpha * acc +template< + class ElementOutput_, + class ElementCompute_, + class ElementScalar_ = ElementCompute_, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct DualScaledAcc : DualFusionOperation { + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementScalar = ElementScalar_; + static constexpr int AlignmentScalar = 1; + static constexpr auto RoundStyle = RoundStyle_; +}; + +// D = alpha * acc + beta * C +template< + class ElementOutput_, + class ElementCompute_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct DualLinearCombination + : DualScaledAcc { + using ElementSource = ElementSource_; + static constexpr bool IsSourceSupported = true; +}; + +template +struct DualOpPair : DualFusionOperation { + using LeftOp = Op0; + using RightOp = Op1; + using ElementCompute = + std::conditional_t< + !std::is_void_v, + typename Op0::ElementCompute, + typename Op1::ElementCompute>; + using ElementOutput = std::conditional_t< + !std::is_void_v, + typename Op0::ElementOutput, + typename Op1::ElementOutput>; + // Provide passthrough flags (OR) + static constexpr bool IsSourceSupported = + (bool)Op0::IsSourceSupported || (bool)Op1::IsSourceSupported; +}; + +// trait to detect DualOpPair +template struct is_dual_op_pair : std::false_type {}; +template struct is_dual_op_pair> : std::true_type {}; +template +inline constexpr bool is_dual_op_pair_v = is_dual_op_pair::value; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/93_hopper_dual_gemm/fusion/sm90_callbacks_tma_warpspecialized_dual.hpp b/examples/93_hopper_dual_gemm/fusion/sm90_callbacks_tma_warpspecialized_dual.hpp new file mode 100644 index 0000000000..ef70ca97bc --- /dev/null +++ b/examples/93_hopper_dual_gemm/fusion/sm90_callbacks_tma_warpspecialized_dual.hpp @@ -0,0 +1,435 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Fusion callbacks specializations for the sm90 TMA warp-specialized (ws) epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" + +#include "cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp" +#include "../collective/dispatch_policy_extra.hpp" +#include "callbacks.hpp" +#include "operations.hpp" +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +using Sm90EVT = Sm90TreeVisitor; + +// D = alpha * acc +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct DualFusionCallbacks< + epilogue::DualSm90TmaWarpSpecialized, + fusion::DualScaledAcc, + CtaTileShapeMNK, + EpilogueTile +> : Sm90EVT, + Sm90ScalarBroadcast>, + Sm90AccFetch + > { + using Impl = + Sm90EVT, + Sm90ScalarBroadcast>, + Sm90AccFetch + >; + using Operation = fusion::DualScaledAcc; + + struct Arguments { + // Give a name and flat ordering to the fusion callback args + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + + // Conversion to the args expected by the visitor implementation + // to_underlying_arguments will implicitly call this + operator typename Impl::Arguments() const { + return + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }; // end binary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = alpha * acc + beta * C +template< + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using DualSm90LinearCombination = + Sm90EVT, // beta * C + (alpha * acc) + Sm90ScalarBroadcast>, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + Sm90ScalarBroadcast>, // alpha + Sm90AccFetch // acc + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct DualFusionCallbacks< + epilogue::DualSm90TmaWarpSpecialized, + fusion::DualLinearCombination, + CtaTileShapeMNK, + EpilogueTile +> : DualSm90LinearCombination::type, ElementCompute, ElementSource, ElementScalar, RoundStyle> { + + using Impl = DualSm90LinearCombination::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::DualLinearCombination; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + operator typename Impl::Arguments() const { + return + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// DualOpPair specialization: allow distinct ops for accumulator0 and accumulator1 +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class Op0, + class Op1, + class CtaTileShapeMNK, + class EpilogueTile +> +struct DualFusionCallbacks< + epilogue::DualSm90TmaWarpSpecialized, + DualOpPair, + CtaTileShapeMNK, + EpilogueTile +> { + + using DispatchPolicy = epilogue::DualSm90TmaWarpSpecialized; + using Operation = DualOpPair; + + using CB0 = DualFusionCallbacks; + using CB1 = DualFusionCallbacks; + + struct Arguments { + typename CB0::Arguments op0{}; + typename CB1::Arguments op1{}; + }; + + struct Params { + typename CB0::Params op0; + typename CB1::Params op1; + }; + + struct SharedStorage { + typename CB0::SharedStorage left; + typename CB1::SharedStorage right; + }; + + template + static Params to_underlying_arguments(ProblemShape const& ps, Arguments const& a, void* ws) { + return { + CB0::to_underlying_arguments(ps, a.op0, ws), + CB1::to_underlying_arguments(ps, a.op1, ws) + }; + } + + template + static size_t get_workspace_size(ProblemShape const& ps, Arguments const& a) { + return CB0::get_workspace_size(ps, a.op0) + CB1::get_workspace_size(ps, a.op1); + } + + template + static cutlass::Status initialize_workspace( + ProblemShape const& ps, Arguments const& a, void* ws, cudaStream_t stream, + CudaHostAdapter* adapter = nullptr) { + size_t size0 = CB0::get_workspace_size(ps, a.op0); + void* ws1 = static_cast(ws) + size0; + CUTLASS_CHECK(CB0::initialize_workspace(ps, a.op0, ws, stream, adapter)); + CUTLASS_CHECK(CB1::initialize_workspace(ps, a.op1, ws1, stream, adapter)); + return cutlass::Status::kSuccess; + } + + template + static bool can_implement(ProblemShape const& ps, Arguments const& a) { + return CB0::can_implement(ps, a.op0) && CB1::can_implement(ps, a.op1); + } + + CUTLASS_DEVICE + DualFusionCallbacks(Params const& p, SharedStorage& st) + : cb0_(p.op0, st.left), cb1_(p.op1, st.right) {} + + CUTLASS_DEVICE bool is_producer_load_needed() const { + return cb0_.is_producer_load_needed() || cb1_.is_producer_load_needed(); + } + + CUTLASS_DEVICE bool is_C_load_needed() const { + return detail_is_C_load_needed(cb0_) || detail_is_C_load_needed(cb1_); + } + +private: + CB0 cb0_; + CB1 cb1_; + template + struct has_is_C_load_needed : std::false_type {}; + + template + struct has_is_C_load_needed().is_C_load_needed())>> : std::true_type {}; + + template + CUTLASS_DEVICE static bool detail_is_C_load_needed(CbT const& cb) { + if constexpr (has_is_C_load_needed::value) return cb.is_C_load_needed(); + else return false; + } + + template + CUTLASS_DEVICE static auto obtain_consumer_impl(int, Cb& cb, Args const& a) + -> decltype(cb.template get_consumer_store_callbacks(a)) { + return cb.template get_consumer_store_callbacks(a); + } + template + CUTLASS_DEVICE static auto obtain_consumer_impl(long, Cb& cb, Args const& a) + -> decltype(cb.get_consumer_store_callbacks(a)) { + return cb.get_consumer_store_callbacks(a); + } + template + CUTLASS_DEVICE static auto obtain_consumer(Cb& cb, Args const& a) { + return obtain_consumer_impl(0, cb, a); + } + + // Wrapper structs + template + struct ProducerPair { + P0 a; P1 b; + CUTLASS_DEVICE void begin(){ a.begin(); b.begin(); } + CUTLASS_DEVICE void step(uint64_t* barrier,int em,int en,int stage,bool issue){ + a.step(barrier,em,en,stage,issue); + b.step(barrier,em,en,stage,issue); + } + CUTLASS_DEVICE void end(){ a.end(); b.end(); } + }; + + template + struct ConsumerPair { + C0 a; C1 b; + + // Trait helpers + template struct has_begin_sync_needed : std::false_type {}; + template struct has_begin_sync_needed().begin_sync_needed())>> : std::true_type {}; + + template struct has_tma_store : std::false_type {}; + template struct has_tma_store().tma_store(0,0,0,false))>> : std::true_type {}; + + CUTLASS_DEVICE void begin(){ a.begin(); b.begin(); } + + CUTLASS_DEVICE bool begin_sync_needed(){ + bool r0=false,r1=false; + if constexpr (has_begin_sync_needed::value) r0 = a.begin_sync_needed(); + if constexpr (has_begin_sync_needed::value) r1 = b.begin_sync_needed(); + return r0 || r1; + } + + CUTLASS_DEVICE void begin_loop(int em,int en){ a.begin_loop(em,en); b.begin_loop(em,en); } + + CUTLASS_DEVICE void previsit(int em,int en,int stage,bool has_src){ + a.previsit(em,en,stage,has_src); b.previsit(em,en,stage,has_src); + } + + template + CUTLASS_DEVICE auto visit0(Frag const& f,int idx,int em,int en){ + return a.visit(f,idx,em,en); + } + template + CUTLASS_DEVICE auto visit1(Frag const& f,int idx,int em,int en){ + return b.visit(f,idx,em,en); + } + template + CUTLASS_DEVICE auto visit(Frag const& f,int idx,int em,int en){ + return a.visit(f,idx,em,en); + } + + template + CUTLASS_DEVICE void reduce(SmemTensor&& t, SyncFn&& sync,int em,int en,bool last, FragTensor& fr){ + a.reduce(t,sync,em,en,last,fr); + b.reduce(t,sync,em,en,last,fr); + } + + CUTLASS_DEVICE void postreduce(int em,int en,int stage,bool issue){ + a.postreduce(em,en,stage,issue); + b.postreduce(em,en,stage,issue); + } + + CUTLASS_DEVICE void tma_store(int em,int en,int stage,bool issue){ + if constexpr (has_tma_store::value) a.tma_store(em,en,stage,issue); + if constexpr (has_tma_store::value) b.tma_store(em,en,stage,issue); + } + + CUTLASS_DEVICE void end_loop(int em,int en){ a.end_loop(em,en); b.end_loop(em,en); } + CUTLASS_DEVICE void end(){ a.end(); b.end(); } + }; + +public: + template + CUTLASS_DEVICE auto get_producer_load_callbacks(ProducerLoadArgs const& args) { + auto p0 = cb0_.get_producer_load_callbacks(args); + auto p1 = cb1_.get_producer_load_callbacks(args); + return ProducerPair{p0,p1}; + } + + template + CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto c0 = obtain_consumer(cb0_, args); + auto c1 = obtain_consumer(cb1_, args); + return ConsumerPair{c0,c1}; + } + +}; + +namespace detail { +template > +struct dual_get_element_aux { + using type = void; +}; + +template +struct dual_get_element_aux> { + using type = typename DualFusionOpOrCallbacks::ElementAux; +}; + +template +struct dual_get_element_aux, cute::void_t<>> { + using type = typename dual_get_element_aux::type; +}; + +template +struct dual_get_element_aux, cute::void_t::Operation>> { + private: + using Operation = typename DualFusionCallbacks::Operation; + public: + using type = typename dual_get_element_aux::type; +}; + +template +struct dual_get_element_aux, cute::void_t<>> { + using type = typename dual_get_element_aux::type; +}; +} // namespace cutlass:epilogue::fusion::detail + +template +using dual_get_element_aux_t = typename detail::dual_get_element_aux::type; + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/93_hopper_dual_gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative_dual.hpp b/examples/93_hopper_dual_gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative_dual.hpp new file mode 100644 index 0000000000..84eac5c031 --- /dev/null +++ b/examples/93_hopper_dual_gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative_dual.hpp @@ -0,0 +1,870 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/workspace.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/arch/mma_sm90.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cute/tensor.hpp" +#include "cutlass/trace.h" +#include "cutlass/gemm/kernel/gemm_universal_decl.h" +#include "cutlass/arch/grid_dependency_control.h" + +#include "../collective/dispatch_policy_extra.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileSchedulerTag_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileSchedulerTag_, + cute::enable_if_t>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB0 = typename CollectiveMainloop::StrideB0; + using StrideB1 = typename CollectiveMainloop::StrideB1; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(ArchTag::kMinComputeCapability >= 90); + + static constexpr uint32_t TileSchedulerPipelineStageCount = DispatchPolicy::Schedule::SchedulerPipelineStageCount; + using TileSchedulerTag = TileSchedulerTag_; + + using TileScheduler = typename detail::TileSchedulerSelector< + TileSchedulerTag, + ArchTag, + TileShape, + ClusterShape + ,TileSchedulerPipelineStageCount + >::Scheduler; + + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + // Warp specialization thread count per threadblock + static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMMAThreads = size(TiledMma{}); // 8 warps + static constexpr uint32_t NumMainloopLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp for C + + static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent; + static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled; + + static constexpr uint32_t NumLoadWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = NumMMAThreads / NumThreadsPerWarpGroup; + static constexpr uint32_t MaxThreadsPerBlock = NumMMAThreads + (NumLoadWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + static constexpr uint32_t NumFixupBarriers = NumMmaWarpGroups; + static constexpr uint32_t NumProducerThreads = CollectiveMainloop::NumProducerThreadEvents; + static constexpr bool IsMainloopAuxiliaryLoadNeeded = detail::HasAuxiliaryLoad_v; + + /// Register requirement for Load and Math WGs + static constexpr int RegsPerThread = + size<0>(TileShape{}) * size<1>(TileShape{}) / NumMMAThreads * + sizeof(ElementAccumulator) / sizeof(uint32_t); + static constexpr bool HeavyRegisterPressure = RegsPerThread >= 208; + static constexpr uint32_t LoadRegisterRequirement = !HeavyRegisterPressure ? 40 : 24; + static constexpr uint32_t MmaRegisterRequirement = !HeavyRegisterPressure ? 232 : 240; + + // 1 stage ordered sequence between mainloop and epilogue producer load threads + using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; + + using TileSchedulerPipeline = typename TileScheduler::Pipeline; + using TileSchedulerPipelineState = typename TileSchedulerPipeline::PipelineState; + using TileSchedulerStorage = typename TileScheduler::SharedStorage; + using TileSchedulerThrottlePipeline = typename TileScheduler::ThrottlePipeline; + using TileSchedulerThrottlePipelineState = typename TileSchedulerThrottlePipeline::PipelineState; + + // Kernel level shared memory storage + struct SharedStorage { + struct PipelineStorage : cute::aligned_struct<16, _1> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; + } pipelines; + + alignas(16) TileSchedulerStorage scheduler; + + struct TensorStorage : cute::aligned_struct<128, _1> { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + EpilogueTensorStorage epilogue; + MainloopTensorStorage mainloop; + } tensors; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + void* workspace{nullptr}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + auto problem_shape = args.problem_shape; + if constexpr (detail::Has_SwapAB_v) { + // swap M/N + get<0>(problem_shape) = get<1>(args.problem_shape); + get<1>(problem_shape) = get<0>(args.problem_shape); + } + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + // Get maximum number of clusters that could co-exist on the target device + int max_active_clusters = args.hw_info.max_active_clusters; + if (max_active_clusters <= 0) { + max_active_clusters = 0; + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid max cluster count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the max_active_clusters."); + } + else { + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid cluster count to " << max_active_clusters); + } + + KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count, max_active_clusters}; + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* scheduler_workspace = workspace_ptr + workspace_offset; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = nullptr; + // Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used + // in separate reduction scheme for streamk case, NumEpilogueSubTiles default value is 1, which means + // subtile will not be used, therefore separate reduction will not be enabled. + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments( + problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles + ); + + return { + args.mode, + problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), + hw_info, + scheduler, + workspace + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + return implementable; + } + + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_size = 0; + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + return workspace_size; + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + static constexpr uint32_t NumAccumulatorMtxs = 1; + + status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, NumAccumulatorMtxs, cuda_adapter); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently + TileSchedulerArguments args{}; + if constexpr (!std::is_const_v) { + args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; + } + args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN ? TileScheduler::RasterOrderOptions::AlongN : TileScheduler::RasterOrderOptions::AlongM; + return TileScheduler::get_grid_shape(params.scheduler, params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + using namespace cute; + using X = Underscore; + +#if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM120_ALL) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1200)) +# define ENABLE_SM90_KERNEL_LEVEL 1 +#endif +// Any Tensor Op MMA Atom in the ISA is arch conditional. +#if ! defined(ENABLE_SM90_KERNEL_LEVEL) + printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); +#else + + // Preconditions + static_assert(NumMMAThreads == 256, "Cooperative kernel must have TiledMMA operating using 256 threads."); + static_assert(size<0>(TileShape{}) >= 128, + "Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension."); + + static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB0{}) == 3, "StrideB0 must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB1{}) == 3, "StrideB1 must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + /* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */ + enum class WarpGroupRole { + Producer = 0, + Consumer0 = 1, + Consumer1 = 2 + }; + enum class ProducerWarpRole { + Mainloop = 0, + Warp1 = 1, + Epilogue = 2, + MainloopAux = 3 + }; + + + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int lane_idx = canonical_lane_idx(); + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + int mma_thread_idx = thread_idx % NumMMAThreads; + auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); + auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); + int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); + // TileScheduler pipeline + typename TileSchedulerPipeline::Params scheduler_pipeline_params; + typename TileSchedulerThrottlePipeline::Params scheduler_throttle_pipeline_params; + if constexpr (IsSchedDynamicPersistent) { + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Warp1) { + scheduler_pipeline_params.role = TileSchedulerPipeline::ThreadCategory::ProducerConsumer; + } + else { + scheduler_pipeline_params.role = TileSchedulerPipeline::ThreadCategory::Consumer; + } + scheduler_pipeline_params.producer_blockid = 0; + scheduler_pipeline_params.producer_arv_count = 1; + scheduler_pipeline_params.consumer_arv_count = NumSchedThreads + NumMainloopLoadThreads + NumMMAThreads; + + if (is_epi_load_needed) { + scheduler_pipeline_params.consumer_arv_count += NumEpilogueLoadThreads; + } + scheduler_pipeline_params.transaction_bytes = sizeof(typename TileScheduler::CLCResponse); + + scheduler_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; + scheduler_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; + scheduler_throttle_pipeline_params.dst_blockid = 0; + scheduler_throttle_pipeline_params.initializing_warp = 3; + if (warp_group_role == WarpGroupRole::Producer && + producer_warp_role == ProducerWarpRole::Warp1) { + scheduler_throttle_pipeline_params.role = + TileSchedulerThrottlePipeline::ThreadCategory::Consumer; + } + // set role when it is for DMA warp in Mainloop + else if (warp_group_role == WarpGroupRole::Producer && + producer_warp_role == ProducerWarpRole::Mainloop) { + scheduler_throttle_pipeline_params.role = + TileSchedulerThrottlePipeline::ThreadCategory::Producer; + } + } + TileSchedulerPipeline scheduler_pipeline(shared_storage.scheduler.pipeline(), scheduler_pipeline_params); + TileSchedulerPipelineState scheduler_pipe_consumer_state; + + TileSchedulerThrottlePipeline scheduler_throttle_pipeline(shared_storage.scheduler.throttle_pipeline(), scheduler_throttle_pipeline_params); + TileSchedulerThrottlePipelineState scheduler_pipe_throttle_consumer_state; + TileSchedulerThrottlePipelineState scheduler_pipe_throttle_producer_state = cutlass::make_producer_start_state(); + + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && (producer_warp_role == ProducerWarpRole::Mainloop || + producer_warp_role == ProducerWarpRole::MainloopAux)) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; + mainloop_pipeline_params.num_consumers = NumMMAThreads; + mainloop_pipeline_params.num_producers = NumProducerThreads; + mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = NumEpilogueLoadThreads; + epi_load_pipeline_params.consumer_arv_count = NumMMAThreads; + if constexpr (CollectiveEpilogue::RequiresTransactionBytes) { + epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes; + } + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + typename LoadWarpOrderBarrier::Params params_load_order_barrier; + params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1; + params_load_order_barrier.group_size = NumThreadsPerWarp; + LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier); + + // Initialize starting pipeline states for the collectives + // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + + auto cluster_wait_fn = [] () { + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer thread blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + return [] () { cute::cluster_wait(); }; + } + else { + __syncthreads(); + return [] () {}; // do nothing + } + } (); + + // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + + // Get the appropriate blocks for this thread block -- potential for thread block locality + TiledMma tiled_mma; + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + + TileScheduler scheduler{params.scheduler}; + if constexpr (IsSchedDynamicPersistent) { + scheduler.set_data_ptr(shared_storage.scheduler.data()); + } + // Declare work_tile_info, then define it in each of warps that use it. + typename TileScheduler::WorkTileInfo work_tile_info; + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + + // Prepare and partition the input tensors. Expects a tuple of tensors where: + // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) + // get<1>(load_inputs) is the tma tensor B0 after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) + // get<2>(load_inputs) is the tma tensor B1 after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) + auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); + static_assert(cute::tuple_size_v >= 3, "Output of load_init must have at least two elements (A, B0, B1)"); + + // Extract out partitioned A and B. + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB0_nkl = get<1>(load_inputs); + Tensor gB1_nkl = get<2>(load_inputs); + // Wait for all thread blocks in the Cluster + cluster_wait_fn(); + + if (warp_group_role == WarpGroupRole::Producer) { + work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); + cutlass::arch::warpgroup_reg_dealloc(); + + // Scheduler Producer Warp + if (producer_warp_role == ProducerWarpRole::Warp1) { + if constexpr (IsSchedDynamicPersistent) { + bool requires_clc_query = true; + TileSchedulerPipelineState scheduler_pipe_producer_state = cutlass::make_producer_start_state(); + + cutlass::arch::wait_on_dependent_grids(); + while (work_tile_info.is_valid()) { + + if (requires_clc_query) { + // Throttle CLC query to mitigate workload imbalance caused by skews among persistent workers. + scheduler_throttle_pipeline.consumer_wait(scheduler_pipe_throttle_consumer_state); + scheduler_throttle_pipeline.consumer_release(scheduler_pipe_throttle_consumer_state); + ++scheduler_pipe_throttle_consumer_state; + + // Query next work tile + scheduler_pipe_producer_state = scheduler.advance_to_next_work(scheduler_pipeline, scheduler_pipe_producer_state); + } + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + scheduler_pipeline, + scheduler_pipe_consumer_state + ); + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++scheduler_pipe_consumer_state; + } + + work_tile_info = next_work_tile_info; + } + scheduler_pipeline.producer_tail(scheduler_pipe_producer_state); + } + } // Scheduler Producer Warp End + else + + // Mainloop Producer Warp + if (producer_warp_role == ProducerWarpRole::Mainloop) { + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + bool do_load_order_arrive = true; + bool requires_clc_query = true; + while (work_tile_info.is_valid()) { + if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info); + work_tile_info = next_work_tile_info; + continue; + } + + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB0_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB0_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); + auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); + auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); + + if (requires_clc_query) { + scheduler_throttle_pipeline.producer_acquire(scheduler_pipe_throttle_producer_state); + scheduler_throttle_pipeline.producer_commit(scheduler_pipe_throttle_producer_state); + ++scheduler_pipe_throttle_producer_state; + } + + collective_mainloop.load( + params.mainloop, + mainloop_pipeline, + mainloop_pipe_producer_state, + load_inputs, + blk_coord, + k_tile_iter, work_k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop + ); + // Update starting pipeline state for the next tile + mainloop_pipe_producer_state.advance(work_k_tile_count); + + // Signal for the epilogue load warp to begin + if (do_load_order_arrive) { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + // Get next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, + scheduler_pipeline, + scheduler_pipe_consumer_state + ); + + work_tile_info = next_work_tile_info; + if constexpr (IsSchedDynamicPersistent) { + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++scheduler_pipe_consumer_state; + } + } + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + + } + else if (producer_warp_role == ProducerWarpRole::MainloopAux) { + if constexpr (IsMainloopAuxiliaryLoadNeeded) { + while (work_tile_info.is_valid()) { + if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info); + work_tile_info = next_work_tile_info; + continue; + } + + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB0_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB0_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); + auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); + auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); + + collective_mainloop.load_auxiliary( + params.mainloop, + mainloop_pipeline, + mainloop_pipe_producer_state, + load_inputs, + blk_coord, + k_tile_iter, work_k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop + ); + // Update starting pipeline state for the next tile + mainloop_pipe_producer_state.advance(work_k_tile_count); + + // Get next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + scheduler_pipeline, + scheduler_pipe_consumer_state + ); + + work_tile_info = next_work_tile_info; + } // Scheduler work fetch loop + + } + } + + // Epilogue Producer Warp + else if (producer_warp_role == ProducerWarpRole::Epilogue && is_epi_load_needed) { + + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + if (!TileScheduler::requires_separate_reduction(params.scheduler) && work_tile_info.is_valid()) { + load_order_barrier.wait(); + } + + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + while (work_tile_info.is_valid()) { + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB0_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB0_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + epi_load_pipe_producer_state = + collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + tiled_mma, + lane_idx, + shared_storage.tensors.epilogue, + work_tile_info.reduction_subtile_idx() + ); + } + + // Get next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, + scheduler_pipeline, + scheduler_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + if constexpr (IsSchedDynamicPersistent) { + if (increment_pipe) { + ++scheduler_pipe_consumer_state; + } + } + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } // Epilogue Producer Warp End + } // Producer Warp Group End + + else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { + work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); + cutlass::arch::warpgroup_reg_alloc(); + + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it + bool do_store_tail = false; + while (work_tile_info.is_valid()) { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB0_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB0_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); + // Allocate the accumulators for the (M,N) blk_shape + // + // MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead. + auto accumulators0 = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + auto accumulators1 = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { + collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators0, + accumulators1, + work_k_tile_count, + mma_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop + ); + + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail( + mainloop_pipeline, + mainloop_pipe_consumer_state, + work_k_tile_count + ); + + // Update starting mainloop pipeline state for the next tile + mainloop_pipe_consumer_state.advance(work_k_tile_count); + } + #ifdef CUTLASS_ENABLE_GDC_FOR_SM90 + if (scheduler.is_last_tile(work_tile_info)) { + // Hint on an early release of global memory resources. + // The timing of calling this function only influences performance, + // not functional correctness. + cutlass::arch::launch_dependent_grids(); + + } + #endif + + // Index of warp group within consumer warp groups + int consumer_warp_group_idx = canonical_warp_group_idx() - NumLoadWarpGroups; + + // Perform reduction across splits, if needed + TileScheduler::fixup( + params.scheduler, work_tile_info, accumulators0, NumMmaWarpGroups, consumer_warp_group_idx); + TileScheduler::fixup( + params.scheduler, work_tile_info, accumulators1, NumMmaWarpGroups, consumer_warp_group_idx); + + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { + // Epilogue and write to gD + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = + collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + accumulators0, + accumulators1, + tiled_mma, + mma_thread_idx, + shared_storage.tensors.epilogue, + work_tile_info.reduction_subtile_idx() + ); + epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next; + epi_store_pipe_producer_state = epi_store_pipe_producer_state_next; + do_store_tail = true; + } + + // Get next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, + scheduler_pipeline, + scheduler_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + if constexpr (IsSchedDynamicPersistent) { + if (increment_pipe) { + ++scheduler_pipe_consumer_state; + } + } + } // Scheduler work fetch loop + + if (do_store_tail) { + collective_epilogue.store_tail( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state + ); + } + } // Consumer Warp Groups End +#endif + } + +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/examples/93_hopper_dual_gemm/thread/left_silu_and_mul.h b/examples/93_hopper_dual_gemm/thread/left_silu_and_mul.h new file mode 100644 index 0000000000..f51677062f --- /dev/null +++ b/examples/93_hopper_dual_gemm/thread/left_silu_and_mul.h @@ -0,0 +1,152 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Functor performing linear combination operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/epilogue/thread/linear_combination_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements. +/// +/// D = alpha * accumulator + beta * source + uniform +/// +template < + typename ElementOutput_, ///< Data type used to load and store tensors + int Count, ///< Number of elements computed per operation. + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data to store + typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type + typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest +> +class LeftSiLUAndMul { +public: + + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + + static FloatRoundStyle const kRound = Round; + + struct Params{}; + +private: + + // + // Data members + // + + ElementCompute alpha_; + ElementCompute beta_; + +public: + + /// Constructs the function object, possibly loading from pointers in host memory + CUTLASS_HOST_DEVICE + LeftSiLUAndMul(Params const &/*params*/) {} + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return true; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { + assert(false); + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &lhs, + FragmentAccumulator const &rhs) const { + + // Convert source to interal compute numeric type + NumericArrayConverter accumulator_to_compute; + + // Convert to destination numeric type + NumericArrayConverter compute_to_output; + + ComputeFragment converted_lhs = accumulator_to_compute(lhs); + ComputeFragment converted_rhs = accumulator_to_compute(rhs); + + cutlass::epilogue::thread::SiLu silu; + cutlass::multiplies mul; + auto silu_lhs = silu(converted_lhs); + return compute_to_output(mul(silu_lhs, converted_rhs)); + } + + CUTLASS_HOST_DEVICE + ElementOutput operator()( + ElementAccumulator const& lhs, + ElementAccumulator const& rhs + ) const { + ElementCompute convert_lhs(lhs); + ElementCompute convert_rhs(rhs); + cutlass::epilogue::thread::SiLu silu; + cutlass::multiplies mul; + auto silu_lhs = silu(convert_lhs); + return ElementOutput(mul(silu_lhs, convert_rhs)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/94_blackwell_geforce_dual_gemm/94_blackwell_geforce_dual_gemm.cu b/examples/94_blackwell_geforce_dual_gemm/94_blackwell_geforce_dual_gemm.cu new file mode 100644 index 0000000000..685316255b --- /dev/null +++ b/examples/94_blackwell_geforce_dual_gemm/94_blackwell_geforce_dual_gemm.cu @@ -0,0 +1,620 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Blackwell GeForce Dual-GEMM example using CUTLASS 3.0 APIs for the NVIDIA Blackwell SM120 architecture. + + This example is based on example 79a_blackwell_geforce_nvfp4_bf16_gemm.cu but for Dual-GEMM + +``` +D0 = epilogue0(X @ B0, C0) +D1 = epilogue1(X @ B1, C1) +D2 = element_wise(D0, D1) +``` + Usage: + + $ ./examples/94_blackwell_geforce_dual_gemm/94_blackwell_geforce_dual_gemm --m=2048 --n=2048 --k=2048 +*/ + +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include "helper.h" + +#include "collective/dispatch_policy_extra.hpp" +#include "collective/builder.hpp" +#include "kernel/sm90_gemm_tma_warpspecialized_cooperative_dual.hpp" +#include "device/gemm_universal_adapter.h" +#include "collective/builder_epilogue.hpp" +#include "thread/left_silu_and_mul.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::nv_float4_t; // Element type for A matrix operand +using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 32; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::nv_float4_t; // Element type for B matrix operand +using LayoutB0Tag = cutlass::layout::ColumnMajor; // Layout type for B0 matrix operand +using LayoutB1Tag = cutlass::layout::ColumnMajor; // Layout type for B1 matrix operand +constexpr int AlignmentB = 32; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand +using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand +using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand +using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// Kernel functional config +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag + +// Kernel Perf config +using ThreadBlockShape = Shape<_128,_128,_128>; // Threadblock's tile size +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster +using KernelSchedule = cutlass::gemm::DualKernelTmaWarpSpecializedCooperativeBlockScaledSm120<3>; +using EpilogueSchedule = cutlass::epilogue::DualTmaWarpSpecialized; +using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; +using ElementCompute = float; + +using OpLeft = cutlass::epilogue::fusion::DualLinearCombination; +using OpRight = cutlass::epilogue::fusion::DualLinearCombination; +using DualPairOp = cutlass::epilogue::fusion::DualOpPair; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::DualCollectiveBuilder< + ArchTag, OperatorClass, + ThreadBlockShape, ClusterShape, + EpilogueTileType, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutCTag, AlignmentC, + ElementD, LayoutDTag, AlignmentD, + EpilogueSchedule, + DualPairOp + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::DualCollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutATag, AlignmentA, + ElementB, LayoutB0Tag, LayoutB1Tag, AlignmentB, + ElementAccumulator, + ThreadBlockShape, ClusterShape, + cutlass::gemm::collective::StageCount<3>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + void>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Reference device GEMM implementation type +using StrideA = typename Gemm::GemmKernel::StrideA; +using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{})); +using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. +using StrideB0 = typename Gemm::GemmKernel::StrideB0; +using LayoutB0 = decltype(cute::make_layout(make_shape(0,0,0), StrideB0{})); +using LayoutSFB0 = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB0; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. +using StrideB1 = typename Gemm::GemmKernel::StrideB1; +using LayoutB1 = decltype(cute::make_layout(make_shape(0,0,0), StrideB1{})); +using LayoutSFB1 = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB1; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. +using StrideC = typename Gemm::GemmKernel::StrideC; +using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{})); +using StrideD = typename Gemm::GemmKernel::StrideD; +using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{})); + +// +// Data members +// + +/// Initialization +StrideA stride_A; +LayoutA layout_A; +LayoutSFA layout_SFA; +StrideB0 stride_B0; +LayoutB0 layout_B0; +LayoutSFB0 layout_SFB0; +StrideB1 stride_B1; +LayoutB1 layout_B1; +LayoutSFB1 layout_SFB1; +StrideC stride_C; +LayoutC layout_C; +StrideD stride_D; +LayoutD layout_D; +uint64_t seed; + +// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device +// Use cute::Tensor and cute::Layout for iterating thru the matrix elements +cutlass::HostTensor block_A; +cutlass::HostTensor block_SFA; +cutlass::HostTensor block_B0; +cutlass::HostTensor block_B1; +cutlass::HostTensor block_SFB0; +cutlass::HostTensor block_SFB1; +cutlass::HostTensor block_C; +// Output Tensor +cutlass::HostTensor block_D; +// Reference Output Tensor +cutlass::HostTensor block_reference_D; +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + +template +auto make_iterator(T* ptr) { + return cute::recast_ptr(ptr); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + float alpha, beta; + int iterations; + int m, n, k; + + Options(): + help(false), + m(1024), n(1024), k(1024), + alpha(1.f), beta(0.f), + iterations(10) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "94_blackwell_geforce_dual_gemm\n\n" + << " Blackwell NVFP4 Dual-GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out << "\n\nExamples:\n\n" + << "$ " << "./examples/94_blackwell_geforce_dual_gemm/94_blackwell_geforce_dual_gemm.cu" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add x2 for dual gemm + uint64_t flop = uint64_t(4) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::TensorView view, + uint64_t seed) { + + double scope_max, scope_min; + constexpr int bits_input = cutlass::sizeof_bits::value; + + if constexpr (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if constexpr (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + else if constexpr (bits_input <= 8) { + if constexpr (cute::is_same_v) { + scope_max = 4; + scope_min = 1; + } + else { + scope_max = 1; + scope_min = -1; + } + } + else{ + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + using namespace cute; + // For SFA and SFB tensors layouts + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1}); + stride_B0 = cutlass::make_cute_packed_stride(StrideB0{}, {options.n, options.k, 1}); + stride_B1 = cutlass::make_cute_packed_stride(StrideB1{}, {options.n, options.k, 1}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1}); + + layout_A = make_layout(make_shape(options.m, options.k, 1), stride_A); + layout_B0 = make_layout(make_shape(options.n, options.k, 1), stride_B0); + layout_B1 = make_layout(make_shape(options.n, options.k, 1), stride_B1); + layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C); + layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D); + layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1)); + layout_SFB0 = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1)); + layout_SFB1 = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1)); + + block_A.reset(cutlass::make_Coord(size(layout_A))); + block_B0.reset(cutlass::make_Coord(size(layout_B0))); + block_B1.reset(cutlass::make_Coord(size(layout_B1))); + block_C.reset(cutlass::make_Coord(size(layout_C))); + block_D.reset(cutlass::make_Coord(size(layout_D))); + block_reference_D.reset(cutlass::make_Coord(size(layout_D))); + block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA)))); + block_SFB0.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB0)))); + block_SFB1.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB1)))); + + initialize_block(block_A.host_view(), seed + 2021); + initialize_block(block_B0.host_view(), seed + 2022); + initialize_block(block_B1.host_view(), seed + 3022); + initialize_block(block_C.host_view(), seed + 2023); + initialize_block(block_SFA.host_view(), seed + 2024); + initialize_block(block_SFB0.host_view(), seed + 2025); + initialize_block(block_SFB1.host_view(), seed + 3025); + + block_A.sync_device(); + block_B0.sync_device(); + block_B1.sync_device(); + block_C.sync_device(); + block_SFA.sync_device(); + block_SFB0.sync_device(); + block_SFB1.sync_device(); +} + +// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + typename Gemm::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, 1}, + { // Mainloop arguments + block_A.device_data(), stride_A, + block_B0.device_data(), stride_B0, + block_B1.device_data(), stride_B1, + block_SFA.device_data(), layout_SFA, + block_SFB0.device_data(), layout_SFB0, + block_SFB1.device_data(), layout_SFB1 + }, + { // Epilogue arguments + { + { options.alpha, options.beta }, // op0 + { options.alpha, options.beta } // op1 + }, + block_C.device_data(), stride_C, + block_D.device_data(), stride_D + } + }; + + return arguments; +} + +bool verify(const Options &options) { + using namespace cute; + // Create the arguments for host reference implementation + Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A); + Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA); + Tensor tensor_B0 = make_tensor(make_iterator(block_B0.host_data()), layout_B0); + Tensor tensor_SFB0 = make_tensor(block_SFB0.host_data(), layout_SFB0); + Tensor tensor_B1 = make_tensor(make_iterator(block_B1.host_data()), layout_B1); + Tensor tensor_SFB1 = make_tensor(block_SFB1.host_data(), layout_SFB1); + + // Prepare host buffers for two separate GEMM outputs: + cutlass::HostTensor block_reference_D0; + cutlass::HostTensor block_reference_D1; + block_reference_D0.reset(cutlass::make_Coord(size(layout_D))); + block_reference_D1.reset(cutlass::make_Coord(size(layout_D))); + + auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C); + auto tensor_D0 = cute::make_tensor(make_iterator(block_reference_D0.host_data()), layout_D); + auto tensor_D1 = cute::make_tensor(make_iterator(block_reference_D1.host_data()), layout_D); + + // First GEMM: D0 = alpha * (A @ B0) + beta * C (linear combination epilogue) + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, // ElementAccumulator + decltype(tensor_A), // TensorA + decltype(tensor_SFA), // TensorSfA + decltype(tensor_B0), // TensorB0 + decltype(tensor_SFB0) // TensorSfB0 + > mainloop_params0{tensor_A, tensor_SFA, tensor_B0, tensor_SFB0}; + + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementAccumulator, // ElementScalar + ElementAccumulator, // ElementAccumulator + ElementAccumulator, // ElementCompute + decltype(tensor_C), // TensorC + decltype(tensor_D0) // TensorD0 + > epilogue_params0{options.alpha, options.beta, tensor_C, tensor_D0}; + + cutlass::reference::host::Gemm3x(mainloop_params0, epilogue_params0); + + // Second GEMM: D1 = alpha * (A @ B1) + beta * C (linear combination epilogue) + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, + decltype(tensor_A), + decltype(tensor_SFA), + decltype(tensor_B1), + decltype(tensor_SFB1) + > mainloop_params1{tensor_A, tensor_SFA, tensor_B1, tensor_SFB1}; + + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementAccumulator, + ElementAccumulator, + ElementAccumulator, + decltype(tensor_C), + decltype(tensor_D1) + > epilogue_params1{options.alpha, options.beta, tensor_C, tensor_D1}; + + cutlass::reference::host::Gemm3x(mainloop_params1, epilogue_params1); + + // Compute final D2_ref = LeftSiLUAndMul(D0_ref, D1_ref) + int64_t total_elems = static_cast(options.m) * static_cast(options.n); + + using LeftOp = cutlass::epilogue::thread::LeftSiLUAndMul< + ElementD, // ElementOutput (store as ElementD) + 1, // Count: scalar path + ElementAccumulator, // ElementAccumulator (accum type) + ElementAccumulator, // ElementCompute (compute type) + cutlass::FloatRoundStyle::round_to_nearest + >; + + LeftOp leftop(typename LeftOp::Params{}); // Params is empty struct + + // elementwise loop (host) + ElementD *d0_ptr = block_reference_D0.host_data(); + ElementD *d1_ptr = block_reference_D1.host_data(); + ElementD *d2_ptr = block_reference_D.host_data(); + + for (int64_t i = 0; i < total_elems; ++i) { + ElementAccumulator lhs_acc = ElementAccumulator(d0_ptr[i]); + ElementAccumulator rhs_acc = ElementAccumulator(d1_ptr[i]); + + d2_ptr[i] = leftop(lhs_acc, rhs_acc); + } + + // Comparison + block_D.sync_host(); + bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_D.host_view()); + passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0); + passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + cudaDeviceSynchronize(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example + // and must have compute capability at least 100. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + + if (!(props.major == 12 && props.minor == 0)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/94_blackwell_geforce_dual_gemm/CMakeLists.txt b/examples/94_blackwell_geforce_dual_gemm/CMakeLists.txt new file mode 100644 index 0000000000..768a8fe20f --- /dev/null +++ b/examples/94_blackwell_geforce_dual_gemm/CMakeLists.txt @@ -0,0 +1,32 @@ +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_example_add_executable( + 94_blackwell_geforce_dual_gemm + 94_blackwell_geforce_dual_gemm.cu +) diff --git a/examples/94_blackwell_geforce_dual_gemm/collective/builder.hpp b/examples/94_blackwell_geforce_dual_gemm/collective/builder.hpp new file mode 100644 index 0000000000..e61a8eb4be --- /dev/null +++ b/examples/94_blackwell_geforce_dual_gemm/collective/builder.hpp @@ -0,0 +1,282 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/builders/sm120_common.inl" +#include "dispatch_policy_extra.hpp" +#include "sm120_blockscaled_mma_tma_dual.hpp" + +namespace cutlass::gemm::collective { + +template < + class ArchTag, + class OpClass, + class ElementA, + class GmemLayoutA, + int AlignmentA, + class ElementB, + class GmemLayoutB0, + class GmemLayoutB1, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType, + class Enable = void +> +struct DualCollectiveBuilder { + static_assert(sizeof(ElementA) == 0, + "DualCollectiveBuilder: unsupported configuration."); +}; + +template < + class ElementPairA, + class GmemLayoutATag, + int AlignmentA, + class ElementPairB, + class GmemLayoutB0Tag, + class GmemLayoutB1Tag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct DualCollectiveBuilder< + arch::Sm120, + arch::OpClassBlockScaledTensorOp, + ElementPairA, + GmemLayoutATag, + AlignmentA, + ElementPairB, + GmemLayoutB0Tag, + GmemLayoutB1Tag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + cute::is_same_v>>> +{ + + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(detail::is_aligned(), + "Not meet TMA alignment requirement yet\n"); + // Let the blockscaled helpers auto-deduce instruction and SF vector size from ElementPair types. + using KernelScheduleTag = KernelScheduleAuto; + + // Deduce data/scalefactor types from the pair wrappers + using ElementSFA = typename detail::blockscaled::blockscaled_type::sf_type; + using ElementSFB = typename detail::blockscaled::blockscaled_type::sf_type; + static_assert(cute::is_same_v, "Scale factor types for A and B must be the same."); + using ElementA = typename detail::blockscaled::blockscaled_type::data_type; + using ElementB = typename detail::blockscaled::blockscaled_type::data_type; + using ElementSF = ElementSFA; + static constexpr auto SFVectorSizeA = detail::blockscaled::blockscaled_type::SfVectorSize; + static constexpr auto SFVectorSizeB = detail::blockscaled::blockscaled_type::SfVectorSize; + static_assert(SFVectorSizeA == SFVectorSizeB, "Scale factor vector size for A and B must be the same."); + static constexpr int SFVectorSize = SFVectorSizeA; + + // Supported layout + static constexpr cute::UMMA::Major UmmaMajorA = cutlass::gemm::collective::detail::tag_to_umma_major_A(); + static constexpr cute::UMMA::Major UmmaMajorB0 = cutlass::gemm::collective::detail::tag_to_umma_major_B(); + static constexpr cute::UMMA::Major UmmaMajorB1 = cutlass::gemm::collective::detail::tag_to_umma_major_B(); + static_assert((UmmaMajorA == UMMA::Major::K && UmmaMajorB0 == UMMA::Major::K && UmmaMajorB1 == UMMA::Major::K), "Only TN layout is supported."); + + static_assert(cute::is_static_v, "TileShape has to be static"); + static_assert(cute::is_static_v, "Cluster has to be static"); + static_assert(detail::blockscaled::check_input_datatypes(), "Incorrect input types"); + static_assert(detail::blockscaled::check_input_datatypes(), "Incorrect input types"); + static_assert(cute::size(ClusterShape_MNK{}) == Int<1>{}, "No programmatic multicast on this arch"); + static_assert(size<1>(TileShape_MNK{}) >= 32, "Invalid tile shape N."); + + static constexpr auto Instr0 = detail::blockscaled::select_instr(); + static constexpr auto Instr1 = detail::blockscaled::select_instr(); + static constexpr bool UseMxf8f6f4 = (Instr0 == detail::blockscaled::BlockScaledInstr::MXF4F6F8); + + using PermTileM = decltype(cute::min(size<0>(TileShape_MNK{}), _128{})); + using PermTileN = decltype(detail::sm120_tile_n_permute_selector()); + using PermTileK = cute::conditional_t<(UseMxf8f6f4), _32, _64>; + + // Cooperative schedule only + using AtomLayoutMNK = Layout>; + + // MMA + using TiledMma = decltype(cute::make_tiled_mma( + cute::rr_blockscaled_op_selector_sm120(), + AtomLayoutMNK{}, + Tile{} + )); + + static constexpr int MMA_NSF = size<2>(typename TiledMma::AtomShape_MNK{}) / SFVectorSize; + + // SMEM allocation types + using SmemAllocTypeA = cute::conditional_t; + using SmemAllocTypeB = cute::conditional_t; + using SmemAllocTypeSF = ElementSF; + + // GMEM TMA copies + using GmemTiledCopyA = SM90_TMA_LOAD; + using GmemTiledCopyB0 = SM90_TMA_LOAD; + using GmemTiledCopyB1 = SM90_TMA_LOAD; + using GmemTiledCopySFA = SM90_TMA_LOAD; + using GmemTiledCopySFB0 = SM90_TMA_LOAD; + using GmemTiledCopySFB1 = SM90_TMA_LOAD; + + using GmemTiledCopyPairA = decltype(cute::make_tuple(GmemTiledCopyA{}, GmemTiledCopySFA{})); + using GmemTiledCopyPairB0 = decltype(cute::make_tuple(GmemTiledCopyB0{}, GmemTiledCopySFB0{})); + using GmemTiledCopyPairB1 = decltype(cute::make_tuple(GmemTiledCopyB1{}, GmemTiledCopySFB1{})); + + // Config and SMEM layouts + using Sm1xxBlkScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig; + + using SmemLayoutAtomA = decltype(detail::sm120_rr_smem_selector(TileShape_MNK{}))>()); + using SmemLayoutAtomB0 = decltype(detail::sm120_rr_smem_selector(TileShape_MNK{}))>()); + using SmemLayoutAtomB1 = decltype(detail::sm120_rr_smem_selector(TileShape_MNK{}))>()); + + using SmemCopyAtomA = Copy_Atom()), SmemAllocTypeA>; + using SmemCopyAtomB0 = Copy_Atom()), SmemAllocTypeB>; + using SmemCopyAtomB1 = Copy_Atom()), SmemAllocTypeB>; + + using SmemCopyAtomSF = Copy_Atom, SmemAllocTypeSF>; + using SmemCopyAtomSFA = SmemCopyAtomSF; + using SmemCopyAtomSFB0 = SmemCopyAtomSF; + using SmemCopyAtomSFB1 = SmemCopyAtomSF; + + using SmemCopyAtomsA = decltype(cute::make_tuple(SmemCopyAtomA{}, SmemCopyAtomSFA{})); + using SmemCopyAtomsB0 = decltype(cute::make_tuple(SmemCopyAtomB0{}, SmemCopyAtomSFB0{})); + using SmemCopyAtomsB1 = decltype(cute::make_tuple(SmemCopyAtomB1{}, SmemCopyAtomSFB1{})); + + // Construct SMEM layout for SF + // A single indivisible block will hold 4 scale factors of 128 rows/columns (A/B matrix). + // 4 is chosen to make consecutive 32bits of data to have scale factors for only a single row (col). 32bits corresponds to the TMEM word size + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + using Blk_SF = typename Sm1xxBlkScaledConfig::Blk_SF; + using Blk_Elems = decltype(Blk_MN{} * Blk_SF{}); + + // Basic storage block for new Scaling Factor Layouts + using mnBasicBlockShape = Shape<_32,_4>; + using mnBasicBlockStride = Stride<_16,_4>; + using kBasicBlockShape = Shape, Int>; + using kBasicBlockStride = Stride<_0,_1>; + + using sSFA_shapeM = decltype(prepend(size<0>(TileShape_MNK{}) / Blk_MN{}, mnBasicBlockShape{})); + using sSF_strideMN = decltype(prepend(Blk_Elems{}, mnBasicBlockStride{})); + using sSFA_strideM = sSF_strideMN; + using sSF_shapeK = decltype(prepend(make_shape(Blk_SF{} / Int{}, size<2>(TileShape_MNK{}) / Int{} / Blk_SF{}), kBasicBlockShape{})); + + using sSFA_strideK = decltype(prepend(make_stride(Int{}, size<0>(TileShape_MNK{}) / Blk_MN{} * Blk_Elems{}), kBasicBlockStride{})); + using sSFA_shape = decltype(make_shape(sSFA_shapeM{}, sSF_shapeK{})); + using sSFA_stride = decltype(make_stride(sSFA_strideM{}, sSFA_strideK{})); + using SmemLayoutAtomSFA = decltype(make_layout(sSFA_shape{}, sSFA_stride{})); + + using sSFB0_shapeN = decltype(prepend(size<1>(TileShape_MNK{}) / Blk_MN{}, mnBasicBlockShape{})); + using sSFB0_strideN = sSF_strideMN; + using sSFB0_strideK = decltype(prepend(make_stride(Int{}, size<1>(TileShape_MNK{}) / Blk_MN{} * Blk_Elems{}), kBasicBlockStride{})); + using sSFB1_shapeN = decltype(prepend(size<1>(TileShape_MNK{}) / Blk_MN{}, mnBasicBlockShape{})); + using sSFB1_strideN = sSF_strideMN; + using sSFB1_strideK = decltype(prepend(make_stride(Int{}, size<1>(TileShape_MNK{}) / Blk_MN{} * Blk_Elems{}), kBasicBlockStride{})); + using sSFB0_shape = decltype(make_shape(sSFB0_shapeN{}, sSF_shapeK{})); + using sSFB0_stride = decltype(make_stride(sSFB0_strideN{}, sSFB0_strideK{})); + using sSFB1_shape = decltype(make_shape(sSFB1_shapeN{}, sSF_shapeK{})); + using sSFB1_stride = decltype(make_stride(sSFB1_strideN{}, sSFB1_strideK{})); + using SmemLayoutAtomSFB0 = decltype(make_layout(sSFB0_shape{}, sSFB0_stride{})); + using SmemLayoutAtomSFB1 = decltype(make_layout(sSFB1_shape{}, sSFB1_stride{})); + + using SmemLayoutAtomsA = decltype(cute::make_tuple(SmemLayoutAtomA{}, SmemLayoutAtomSFA{})); + using SmemLayoutAtomsB0 = decltype(cute::make_tuple(SmemLayoutAtomB0{}, SmemLayoutAtomSFB0{})); + using SmemLayoutAtomsB1 = decltype(cute::make_tuple(SmemLayoutAtomB1{}, SmemLayoutAtomSFB1{})); + + static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_blockscaled< + detail::sm120_smem_capacity_bytes, SmemAllocTypeA, SmemAllocTypeB, TileShape_MNK, SmemLayoutAtomSFA, SmemLayoutAtomSFB0>(StageCountType{}); + + static constexpr uint32_t SchedulerPipelineStageCount = 3; + + // Strides and interleaved SF layouts + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideB0 = cutlass::gemm::TagToStrideB_t; + using StrideB1 = cutlass::gemm::TagToStrideB_t; + using InternalStrideA = cute::remove_pointer_t; + using InternalStrideB0 = cute::remove_pointer_t; + using InternalStrideB1 = cute::remove_pointer_t; + using InternalLayoutSFA = decltype(Sm1xxBlkScaledConfig::deduce_layoutSFA()); + using InternalLayoutSFB0 = decltype(Sm1xxBlkScaledConfig::deduce_layoutSFB()); + using InternalLayoutSFB1 = decltype(Sm1xxBlkScaledConfig::deduce_layoutSFB()); + using LayoutSFA = cute::conditional_t, InternalLayoutSFA, InternalLayoutSFA*>; + using LayoutSFB0 = cute::conditional_t, InternalLayoutSFB0, InternalLayoutSFB0*>; + using LayoutSFB1 = cute::conditional_t, InternalLayoutSFB1, InternalLayoutSFB1*>; + using StridePairA = decltype(cute::make_tuple(StrideA{}, LayoutSFA{})); + using StridePairB0 = decltype(cute::make_tuple(StrideB0{}, LayoutSFB0{})); + using StridePairB1 = decltype(cute::make_tuple(StrideB1{}, LayoutSFB1{})); + + using DispatchPolicy = DualMainloopSm120TmaWarpSpecializedBlockScaled< + PipelineStages, + SchedulerPipelineStageCount, + ClusterShape_MNK, + KernelScheduleType>; + + using CollectiveOp = DualCollectiveMma< + DispatchPolicy, + TileShape_MNK, + cute::tuple, + StridePairA, + cute::tuple, + StridePairB0, + StridePairB1, + TiledMma, + GmemTiledCopyPairA, + SmemLayoutAtomsA, + SmemCopyAtomsA, + cute::identity, + GmemTiledCopyPairB0, + SmemLayoutAtomsB0, + SmemCopyAtomsB0, + cute::identity, + GmemTiledCopyPairB1, + SmemLayoutAtomsB1, + SmemCopyAtomsB1, + cute::identity + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/94_blackwell_geforce_dual_gemm/collective/builder_epilogue.hpp b/examples/94_blackwell_geforce_dual_gemm/collective/builder_epilogue.hpp new file mode 100644 index 0000000000..9de9c0ef63 --- /dev/null +++ b/examples/94_blackwell_geforce_dual_gemm/collective/builder_epilogue.hpp @@ -0,0 +1,485 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/detail/collective.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/collective/builders/sm90_common.inl" +#include "cutlass/epilogue/collective/builders/sm120_common.inl" + +#include "dispatch_policy_extra.hpp" +#include "sm90_epilogue_tma_warpspecialized_dual.hpp" +#include "../fusion/callbacks.hpp" +#include "../fusion/sm90_callbacks_tma_warpspecialized_dual.hpp" +#include "../fusion/sm120_callbacks_tma_warpspecialized_dual.hpp" + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective { + +/////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// Returns the parameterized dispatch policy for the TMA epilogue +template +constexpr auto +dual_sm120_get_tma_dispatch_policy() { + using namespace cute; + + constexpr int EpiTiles = size(shape_div(take<0,2>(TileShapeMNK{}), EpilogueTileMN{})); + using StrideD = cutlass::detail::TagToStrideC_t; + using InternalStrideD = cute::remove_pointer_t; + constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + // For 120, a FragmentSize of 4 is used to match the + // output per thread from each MMA. Epilogue subtiles iterate over multiple of these + // fragments before storing the subtile's outputs to shared memory. + constexpr int FragmentSize = 4; + + // 8b residuals load fast and consume little smem, so the perf cost of waiting on stores to finish outweighs the cost of extra allocation + constexpr bool ReuseSmem = (sizeof_bits_v == sizeof_bits_v) && (sizeof_bits_v > 8); + constexpr bool DelayTmaStore = is_void_v; // TMA store delay performs worse with residual loads + + constexpr bool IsFP6 = cute::is_same_v || cute::is_same_v; + constexpr bool IsRowMajorD = cutlass::gemm::detail::is_major<1, StrideD>(); + constexpr int StagesD = (IsFP6 && IsRowMajorD) ? 1 : cute::min(EpiTiles, 2); + + // SM120 epilogues use smaller stage counts in order to fit within the limited shared memory capacity. + constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 2), StagesD+1) + : StagesD; + + constexpr int NumEpilogueWarpGroups = GetNumEpilogueWarpGroups::value; + + if constexpr (IsGroupedGemmKernel) { + return Sm120PtrArrayTmaWarpSpecialized{}; + } + else { + return DualSm120TmaWarpSpecialized{}; + } +} + +// Returns the smem layout atom to be used for C or D matrix +template +constexpr auto +dual_sm120_get_epilogue_smem_swizzle_layout_atom() { + using namespace cute; + + // FP6 data is always stored in 8-bit containers in the epilogue + using Element = cute::conditional_t< + cute::is_same_v || cute::is_same_v, + uint8_t, Element_ + >; + + // ColMajor C/D (M-major) + if constexpr (cutlass::gemm::detail::is_major<0>(GmemStrideType{})) { + return cutlass::gemm::collective::detail::ss_smem_selector< + cute::GMMA::Major::MN, Element, decltype(get<0>(EpilogueTile_MN{})), decltype(get<1>(EpilogueTile_MN{})) + >(); + } + // RowMajor C/D (N-major) + else if constexpr (cutlass::gemm::detail::is_major<1>(GmemStrideType{})) { + return cutlass::gemm::collective::detail::ss_smem_selector< + cute::GMMA::Major::K , Element, decltype(get<0>(EpilogueTile_MN{})), decltype(get<1>(EpilogueTile_MN{})) + >(); + } + else { + static_assert(cutlass::detail::dependent_false, "Unsupported gmem layout."); + } +} + +template +constexpr auto +dual_sm120_compute_tile_shape_or_override() { + + constexpr int CTA_M = size<0>(TileShape_MNK{}); + constexpr int CTA_N = size<1>(TileShape_MNK{}); + + constexpr bool IsFP6 = cute::is_same_v || cute::is_same_v; + + constexpr bool IsColMajorD = cutlass::gemm::detail::is_major<0, StrideD>(); + constexpr bool IsRowMajorD = cutlass::gemm::detail::is_major<1, StrideD>(); + static_assert(IsColMajorD || IsRowMajorD, "SM120 LayoutD must be either row or column major."); + + static_assert(!IsFP6 || + (CTA_M % 128 == 0 && IsColMajorD) || + (CTA_N % 128 == 0 && IsRowMajorD), + "CTA tile for FP6 ElementD must have a contiguous extent that is a multiple of 128."); + + if constexpr (cute::is_same_v) { + // If ElementD is FP6, use an epilogue subtile with an extent + // of 128 along the continuous dimension to meet TMA requirements. + if constexpr (IsFP6) { + if constexpr (IsRowMajorD) { + return Shape<_64, _128>{}; + } + else { + return Shape<_128, _32>{}; + } + } + else { + if constexpr (cute::is_same_v) { + // sm120 sparse kernels require more shared memory budget than dense kernels in the mainloop + // so selecting a smaller EpilogueTileN (16) for some cases. + if constexpr (FusionOp::SFVecSize == 64 && IsRowMajorD) { + return Shape<_32, _64>{}; + } + else { + constexpr int M = 64; + constexpr int N = cute::is_void_v + // When C is void, let N = 16 when D is fp32 for lesser SMEM consumption, otherwise 32. + ? cute::sizeof_bits_v == 32 ? 16 : 32 + // When C is not void + : cute::sizeof_bits_v <= 16 + ? 32 // 16-bit or smaller C needs lesser SMEM for epilogue so we keep N = 32 + : 16; // 32-bit needs to let N = 16 + return Shape, Int>{}; + } + } + else { + return Shape<_64, _32>{}; + } + } + } // EpilogueTileAuto + else if constexpr (cute::is_tuple::value) { + static_assert(!is_layout::value, "EpilogueTile must be a cute::Tile or cute::Shape"); + + EpilogueTileType epi_tile; + constexpr int M = size<0>(shape(epi_tile)); + constexpr int N = size<1>(shape(epi_tile)); + + static_assert(!IsFP6 || + (M % 128 == 0 && IsColMajorD) || + (N % 128 == 0 && IsRowMajorD), + "EpilogueTile for narrow ElementD must have a contiguous extent that is a multiple of 128."); + + static_assert(CTA_M % M == 0 && CTA_N % N == 0, "EpilogueTile must evenly divide CTA tile"); + + return epi_tile; + } + else { + static_assert(cutlass::detail::dependent_false, "Invalid type for EpilogueTileType."); + } +} + +template +constexpr auto +dual_sm120_get_register_transform_op() { + using namespace cute; + + [[maybe_unused]] constexpr bool is_m_major = cutlass::detail::is_major<0>(GmemStrideTypeD{}); + [[maybe_unused]] constexpr bool is_n_major = cutlass::detail::is_major<1>(GmemStrideTypeD{}); + static_assert(is_m_major || is_n_major, "Unsupported gmem layout"); + + if constexpr (sizeof_bits_v == 4 && is_m_major) { + // Before store fp4 along M major, row0 column{0,1} is kept in one thread, and row1 column{0,1} + // is kept in another thread. It is expected to have row{0,1} column0 in one thread, + // while row{0,1} column1 in another thread, so that the store could keep granularity + // 8bits at least. The shuffle is a 2x2 transpose, like below diagram, switching N major to + // M major from a register view. + // + // Before After + // Column0 Column1 Column0 Column1 + // Row0 d0(t0) d1(t0) Row0 d0(t0) d0(t4) + // Row1 d0(t4) d1(t4) Row1 d1(t0) d1(t4) + // + return SM50_Shuffle_U32_2x2Trans_XOR4{}; + } + else { + return; // void + } +} + +// callbacks builder with operation tag +template< + class DispatchPolicy, + class FusionOp, + class TileShape_MNK, + class EpilogueTile_MN, + class ElementAccumulator, + class AccLoadOp = cute::DefaultCopy, + class = void +> +struct DualCallbacksBuilder { + using Callbacks = fusion::DualFusionCallbacks; +}; + +// Overload CallbacksBuilder to pick the correct copy atoms +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class FusionOp, + class TileShape_MNK, + class EpilogueTile_MN, + class AccLoadOp, + class ElementAccumulator +> +struct DualCallbacksBuilder< + DualSm120TmaWarpSpecialized, + FusionOp, + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + AccLoadOp, + cute::enable_if_t<(FusionOp::IsAuxOutSupported ^ FusionOp::IsAuxInSupported) // only one aux tensor + && not cute::is_subbyte_v> +> { + using GmemStrideTypeAux = gemm::TagToStrideC_t; + using SmemLayoutAtomAux = decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom< + GmemStrideTypeAux, typename FusionOp::ElementAux, EpilogueTile_MN>()); + + using CopyOpR2S = decltype(detail::sm120_get_smem_store_op_for_accumulator()); + + using CopyOpS2R = decltype(detail::sm120_get_smem_load_op_for_source()); + + using SmemCopyOpAux = cute::conditional_t; + + using Callbacks = fusion::DualFusionCallbacks< + DualSm120TmaWarpSpecialized, + FusionOp, TileShape_MNK, EpilogueTile_MN, + SmemLayoutAtomAux, SmemCopyOpAux + >; +}; + +// Helper for building TMA warp-specialized collective epilogues, specialized by +// the fusion operation performed and the dispatch policy to use. +template < + class TileShape_MNK, + class EpilogueTile_MN, + class ElementAccumulator, + class ElementCompute, + class ElementC_, + class GmemLayoutTagC_, + int AlignmentC, + class ElementD_, + class GmemLayoutTagD, + int AlignmentD, + class FusionOpOrCallbacks, + class DispatchPolicy +> +struct DualSm120TmaBuilderImpl { + // Passing void D disables destination store + smem allocation + using ElementD = cute::conditional_t, + fusion::dual_get_element_aux_t, ElementD_>; + + // Passing void C disables source load + smem allocation + using ElementC = cute::conditional_t,ElementD,ElementC_>; // prevents void ref breakages + using GmemLayoutTagC = cute::conditional_t,GmemLayoutTagD,GmemLayoutTagC_>; + + using GmemStrideTypeC = cutlass::detail::TagToStrideC_t; + using GmemStrideTypeD = cutlass::detail::TagToStrideC_t; + + using UnderlyingGmemStrideTypeC = cute::remove_pointer_t; + using UnderlyingGmemStrideTypeD = cute::remove_pointer_t; + + using CopyOpS2G = + cute::conditional_t, + SM90_TMA_STORE_IM2COL, + SM90_TMA_STORE + >; + + using CopyOpG2S = + cute::conditional_t, + SM90_TMA_LOAD_IM2COL, + SM90_TMA_LOAD + >; + + // Get the smallest tiled copy we can use to retile the accumulators + using CopyAtomC = Copy_Atom; + + using SmemLayoutAtomC = decltype(detail::dual_sm120_get_epilogue_smem_swizzle_layout_atom()); + using SmemLayoutAtomD = decltype(detail::dual_sm120_get_epilogue_smem_swizzle_layout_atom()); + + using CopyOpS2R = decltype(detail::sm120_get_smem_load_op_for_source()); + + using CopyOpR2S = decltype(detail::sm120_get_smem_store_op_for_accumulator()); + + // Get register to register tiled copy that happen before shared memory store. + using CopyOpR2R = decltype(detail::dual_sm120_get_register_transform_op()); + + // TMA builder allows for passing callbacks directly, which is either a fusion::FusionCallbacks + // instance or a direct visitor implementation, e.g. fusion::Sm90LinearCombination + using DualFusionCallbacks = + typename DualCallbacksBuilder< + DispatchPolicy, + FusionOpOrCallbacks, + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator + >::Callbacks; + + // Re-use Sm90 collective epilogue implementation + constexpr static int StagesC = DispatchPolicy::StagesC; + constexpr static int StagesD = DispatchPolicy::StagesD; + constexpr static int FragmentSize = DispatchPolicy::FragmentSize; + constexpr static bool ReuseSmemC = DispatchPolicy::ReuseSmemC; + constexpr static bool DelayTmaStore = DispatchPolicy::DelayTmaStore; + + //Helper to deduce BaseDispatchPolicy based on DispatchPolicy + template + struct GetBaseDispatchPolicy { + using Type = T; + }; + + template + struct GetBaseDispatchPolicy> { + using Type = typename cutlass::epilogue::DualSm90TmaWarpSpecialized; + }; + + using BaseDispatchPolicy = typename GetBaseDispatchPolicy::Type; + + using CollectiveOp = DualCollectiveEpilogue< + BaseDispatchPolicy, + TileShape_MNK, + EpilogueTile_MN, + ElementC_, // Need to pass void through to expose via GemmUniversal + GmemStrideTypeC, + ElementD_, + GmemStrideTypeD, + DualFusionCallbacks, + CopyOpG2S, + SmemLayoutAtomC, + CopyOpS2R, + CopyOpS2G, + SmemLayoutAtomD, + CopyOpR2S, + CopyAtomC, + CopyOpR2R + >; +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ArchTag, + class OpClass, + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC, + class GmemLayoutTagC, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class EpilogueScheduleType, + class FusionOpOrCallbacks = cutlass::epilogue::fusion::DualLinearCombination, + class Enable = void +> +struct DualCollectiveBuilder { + static_assert(cutlass::detail::dependent_false, + "Could not build a collective epilogue for given parameters."); +}; + +// Tma warp-specialized builder +template < + class OpClass, + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC, + class GmemLayoutTagC, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class Schedule, + class FusionOperation +> +struct DualCollectiveBuilder< + arch::Sm120, + OpClass, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + Schedule, + FusionOperation, + cute::enable_if_t< + cute::is_same_v + >> { +private: + using EpilogueTile_MN = + decltype(detail::dual_sm120_compute_tile_shape_or_override, FusionOperation>()); + using DispatchPolicy = + decltype(detail::dual_sm120_get_tma_dispatch_policy()); + + +public: + using CollectiveOp = + typename detail::DualSm120TmaBuilderImpl< + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + FusionOperation, + DispatchPolicy + >::CollectiveOp; +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::collective diff --git a/examples/94_blackwell_geforce_dual_gemm/collective/dispatch_policy_extra.hpp b/examples/94_blackwell_geforce_dual_gemm/collective/dispatch_policy_extra.hpp new file mode 100644 index 0000000000..682580cd82 --- /dev/null +++ b/examples/94_blackwell_geforce_dual_gemm/collective/dispatch_policy_extra.hpp @@ -0,0 +1,98 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/arch/arch.h" + +namespace cutlass::gemm { + +struct DualKernelTmaWarpSpecializedCooperative { + static constexpr int SchedulerPipelineStageCount = 0; +}; + +template< + int Stages_, + int SchedulerPipelineStageCount_, + class ClusterShape_, + class KernelSchedule_ +> +struct DualMainloopSm120TmaWarpSpecializedBlockScaled { + constexpr static int Stages = Stages_; + constexpr static int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + using ClusterShape = ClusterShape_; + using Schedule = KernelSchedule_; + constexpr static int PipelineAsyncMmaStages = 0; + using ArchTag = arch::Sm120; +}; + +template +struct DualKernelTmaWarpSpecializedCooperativeBlockScaledSm120 : DualKernelTmaWarpSpecializedCooperative { + static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; +}; + +} // namespace cutlass::gemm + +namespace cutlass::epilogue { + +struct DualTmaWarpSpecialized {}; + +template< + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_ +> +struct DualSm90TmaWarpSpecialized { + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static int FragmentSize = FragmentSize_; + constexpr static bool ReuseSmemC = ReuseSmemC_; + constexpr static bool DelayTmaStore = DelayTmaStore_; +}; + +template< + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_ +> +struct DualSm120TmaWarpSpecialized { + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static int FragmentSize = FragmentSize_; + constexpr static bool ReuseSmemC = ReuseSmemC_; + constexpr static bool DelayTmaStore = DelayTmaStore_; +}; + +} // namespace cutlass::epilogue \ No newline at end of file diff --git a/examples/94_blackwell_geforce_dual_gemm/collective/sm120_blockscaled_mma_tma_dual.hpp b/examples/94_blackwell_geforce_dual_gemm/collective/sm120_blockscaled_mma_tma_dual.hpp new file mode 100755 index 0000000000..4ddaca82e6 --- /dev/null +++ b/examples/94_blackwell_geforce_dual_gemm/collective/sm120_blockscaled_mma_tma_dual.hpp @@ -0,0 +1,1140 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/trace.h" +#include "cutlass/numeric_types.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +#include "dispatch_policy_extra.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class DispatchPolicy, + class TileShape, + class ElementPairA, + class StridePairA, + class ElementPairB, + class StridePairB0, + class StridePairB1, + class TiledMma, + class GmemTiledCopyPairA, + class SmemLayoutAtomsA, + class SmemCopyAtomsA, + class TransformA, + class GmemTiledCopyPairB0, + class SmemLayoutAtomsB0, + class SmemCopyAtomsB0, + class TransformB0, + class GmemTiledCopyPairB1, + class SmemLayoutAtomsB1, + class SmemCopyAtomsB1, + class TransformB1 +> +struct DualCollectiveMma { + static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); +}; + +template < + int Stages, + int SchedulerPipelineStageCount, + class ClusterShape, + class KernelScheduleType, + class TileShape_, + class ElementPairA_, + class StridePairA_, + class ElementPairB_, + class StridePairB0_, + class StridePairB1_, + class TiledMma_, + class GmemTiledCopyPairA_, + class SmemLayoutAtomsA_, + class SmemCopyAtomsA_, + class TransformA_, + class GmemTiledCopyPairB0_, + class SmemLayoutAtomsB0_, + class SmemCopyAtomsB0_, + class TransformB0_, + class GmemTiledCopyPairB1_, + class SmemLayoutAtomsB1_, + class SmemCopyAtomsB1_, + class TransformB1_> +struct DualCollectiveMma< + DualMainloopSm120TmaWarpSpecializedBlockScaled, + TileShape_, + ElementPairA_, + StridePairA_, + ElementPairB_, + StridePairB0_, + StridePairB1_, + TiledMma_, + GmemTiledCopyPairA_, + SmemLayoutAtomsA_, + SmemCopyAtomsA_, + TransformA_, + GmemTiledCopyPairB0_, + SmemLayoutAtomsB0_, + SmemCopyAtomsB0_, + TransformB0_, + GmemTiledCopyPairB1_, + SmemLayoutAtomsB1_, + SmemCopyAtomsB1_, + TransformB1_> { + // + // Type Aliases + // + using DispatchPolicy = DualMainloopSm120TmaWarpSpecializedBlockScaled; + using TileShape = TileShape_; + using ElementPairA = ElementPairA_; + using ElementPairB = ElementPairB_; + using StridePairA = StridePairA_; + using StridePairB0 = StridePairB0_; + using StridePairB1 = StridePairB1_; + + static_assert(cute::is_same_v(ElementPairA{}))>, + remove_cvref_t(ElementPairB{}))>>, "SFA and SFB data types should be the same"); + + using RuntimeDataTypeA = void*; + using RuntimeDataTypeB = void*; + + // A and B matrices + using ElementA = remove_cvref_t(ElementPairA{}))>; + using StrideA = remove_cvref_t(StridePairA{}))>; + + using ElementB = remove_cvref_t(ElementPairB{}))>; + using StrideB0 = remove_cvref_t(StridePairB0{}))>; + using StrideB1 = remove_cvref_t(StridePairB1{}))>; + + // SFA and SFB + using ElementSF = remove_cvref_t(ElementPairA{}))>; + using LayoutSFA = remove_cvref_t(StridePairA{}))>; + using LayoutSFB0 = remove_cvref_t(StridePairB0{}))>; + using LayoutSFB1 = remove_cvref_t(StridePairB1{}))>; + + using ArrayElementA = ElementA; + using ArrayElementB = ElementB; + + using TiledMma = TiledMma_; + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + using ElementAccumulator = typename TiledMma::ValTypeC; + + static constexpr int SFVecSize = TiledMma::Traits::SFVecSize; + using Sm1xxBlkScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig; + + // Gmem copies + using GmemTiledCopyPairA = GmemTiledCopyPairA_; + using GmemTiledCopyPairB0 = GmemTiledCopyPairB0_; + using GmemTiledCopyPairB1 = GmemTiledCopyPairB1_; + using GmemTiledCopyA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopySFA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopyB0 = remove_cvref_t(GmemTiledCopyPairB0{}))>; + using GmemTiledCopySFB0 = remove_cvref_t(GmemTiledCopyPairB0{}))>; + using GmemTiledCopyB1 = remove_cvref_t(GmemTiledCopyPairB1{}))>; + using GmemTiledCopySFB1 = remove_cvref_t(GmemTiledCopyPairB1{}))>; + + // Smem copies + using SmemLayoutAtomsA = SmemLayoutAtomsA_; + using SmemLayoutAtomsB0 = SmemLayoutAtomsB0_; + using SmemLayoutAtomsB1 = SmemLayoutAtomsB1_; + + using SmemLayoutAtomA = remove_cvref_t(SmemLayoutAtomsA{}))>; + using SmemLayoutAtomSFA = remove_cvref_t(SmemLayoutAtomsA{}))>; + using SmemLayoutAtomB0 = remove_cvref_t(SmemLayoutAtomsB0{}))>; + using SmemLayoutAtomSFB0 = remove_cvref_t(SmemLayoutAtomsB0{}))>; + using SmemLayoutAtomB1 = remove_cvref_t(SmemLayoutAtomsB1{}))>; + using SmemLayoutAtomSFB1 = remove_cvref_t(SmemLayoutAtomsB1{}))>; + + using SmemCopyAtomsA = SmemCopyAtomsA_; + using SmemCopyAtomsB0 = SmemCopyAtomsB0_; + using SmemCopyAtomsB1 = SmemCopyAtomsB1_; + + using SmemCopyAtomA = remove_cvref_t(SmemCopyAtomsA{}))>; + using SmemCopyAtomSFA = remove_cvref_t(SmemCopyAtomsA{}))>; + + using SmemCopyAtomB0 = remove_cvref_t(SmemCopyAtomsB0{}))>; + using SmemCopyAtomSFB0 = remove_cvref_t(SmemCopyAtomsB0{}))>; + using SmemCopyAtomB1 = remove_cvref_t(SmemCopyAtomsB1{}))>; + using SmemCopyAtomSFB1 = remove_cvref_t(SmemCopyAtomsB1{}))>; + + using TransformA = TransformA_; + using TransformB0 = TransformB0_; + using TransformB1 = TransformB1_; + + using ArchTag = typename DispatchPolicy::ArchTag; + + static constexpr int ThreadCount = size(TiledMma{}); + + using MainloopPipeline = cutlass::PipelineTmaAsync; + + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename cutlass::PipelineState; + + // One threads per CTA are producers (1 for operand tile) + static constexpr int NumProducerThreadEvents = 1; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB0{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB0{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB0{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB1{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB1{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB1{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(not cute::is_void_v, + "SM120 mainloop must specify a copy atom for A operand smem->rmem reads."); + static_assert(not cute::is_void_v, + "SM120 mainloop must specify a copy atom for B0 operand smem->rmem reads."); + static_assert(not cute::is_void_v, + "SM120 mainloop must specify a copy atom for B1 operand smem->rmem reads."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB0 = decltype(tile_to_shape( + SmemLayoutAtomB0{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB0>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB1 = decltype(tile_to_shape( + SmemLayoutAtomB1{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB1>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + // SmemLayoutAtomSFA and SmemLayoutAtomSFB are for whole CTA tiles. We add the number of pipeline stages here. + // The number of pipeline stages is the same as the number of pipeline stages from AB Load <-> MainLoop + using SmemLayoutSFA = decltype(make_layout( + append(shape(SmemLayoutAtomSFA{}), Int{}), + append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{}))) + )); + + using SmemLayoutSFB0 = decltype(make_layout( + append(shape(SmemLayoutAtomSFB0{}), Int{}), + append(stride(SmemLayoutAtomSFB0{}), size(filter_zeros(SmemLayoutAtomSFB0{}))) + )); + + using SmemLayoutSFB1 = decltype(make_layout( + append(shape(SmemLayoutAtomSFB1{}), Int{}), + append(stride(SmemLayoutAtomSFB1{}), size(filter_zeros(SmemLayoutAtomSFB1{}))) + )); + + static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(SmemLayoutB0{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(SmemLayoutB1{}) == 3, "Smem layout must be rank 3."); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(not cute::is_base_of::value && + not cute::is_base_of::value, + "MMA atom must source both A and B operands from rmem for this mainloop."); + static_assert(cute::is_same_v, "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v, "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v, "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + static constexpr bool IsF8F6F4 = detail::is_sm120_f8f6f4(); + + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + using TmaInternalElementA = cute::conditional_t, + cutlass::detail::float_e2m1_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e2m3_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e3m2_unpacksmem_t, + uint_bit_t>>>>>; + + using TmaInternalElementB = cute::conditional_t, + cutlass::detail::float_e2m1_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e2m3_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e3m2_unpacksmem_t, + uint_bit_t>>>>>; + + using TmaInternalElementSF = ElementSF; + + using SmemAllocTypeA = cute::conditional_t; + using SmemAllocTypeB = cute::conditional_t; + + + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + static constexpr uint32_t TmaTransactionBytesMK = static_cast( + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutSFA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(take<0,2>(SmemLayoutA{})) * sizeof_bits::value)); + + static constexpr uint32_t TmaTransactionBytesNK0 = static_cast( + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutSFB0{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(take<0,2>(SmemLayoutB0{})) * sizeof_bits::value)); + + static constexpr uint32_t TmaTransactionBytesNK1 = static_cast( + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutSFB1{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(take<0,2>(SmemLayoutB1{})) * sizeof_bits::value)); + + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK0 + TmaTransactionBytesNK1; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + alignas(1024) cute::ArrayEngine> smem_A; + alignas(1024) cute::ArrayEngine> smem_B0; + alignas(1024) cute::ArrayEngine> smem_B1; + cute::ArrayEngine> smem_SFA; + cute::ArrayEngine> smem_SFB0; + cute::ArrayEngine> smem_SFB1; + } tensors; + using PipelineStorage = typename MainloopPipeline::SharedStorage; + alignas(16) PipelineStorage pipeline_storage; + }; + + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A{nullptr}; + StrideA dA{}; + ElementB const* ptr_B0{nullptr}; + StrideB0 dB0{}; + ElementB const* ptr_B1{nullptr}; + StrideB1 dB1{}; + ElementSF const* ptr_SFA{nullptr}; + LayoutSFA layout_SFA{}; + ElementSF const* ptr_SFB0{nullptr}; + LayoutSFB0 layout_SFB0{}; + ElementSF const* ptr_SFB1{nullptr}; + LayoutSFB1 layout_SFB1{}; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + _1{})); // No programmatic multicast + + // Assumption: StrideB is congruent with Problem_NK + using TMA_B0 = decltype(make_tma_copy( + GmemTiledCopyB0{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideB0{}, int32_t(0)), StrideB0{}), + SmemLayoutB0{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + _1{})); // No programmatic multicast + + using TMA_B1 = decltype(make_tma_copy( + GmemTiledCopyB1{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideB1{}, int32_t(0)), StrideB1{}), + SmemLayoutB1{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + _1{})); // No programmatic multicast + + using TMA_SFA = decltype(make_tma_copy( + GmemTiledCopySFA{}, + make_tensor(static_cast(nullptr), LayoutSFA{}), + SmemLayoutSFA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + _1{})); // No programmatic multicast + + using TMA_SFB0 = decltype(make_tma_copy( + GmemTiledCopySFB0{}, + make_tensor(static_cast(nullptr), LayoutSFB0{}), + SmemLayoutSFB0{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + _1{})); // No programmatic multicast + + using TMA_SFB1 = decltype(make_tma_copy( + GmemTiledCopySFB1{}, + make_tensor(static_cast(nullptr), LayoutSFB1{}), + SmemLayoutSFB1{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + _1{})); // No programmatic multicast + + TMA_A tma_load_a; + TMA_B0 tma_load_b0; + TMA_B1 tma_load_b1; + TMA_SFA tma_load_sfa; + TMA_SFB0 tma_load_sfb0; + TMA_SFB1 tma_load_sfb1; + LayoutSFA layout_SFA; + LayoutSFB0 layout_SFB0; + LayoutSFB1 layout_SFB1; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk0 = TmaTransactionBytesNK0; + uint32_t tma_transaction_bytes_nk1 = TmaTransactionBytesNK1; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B0 = recast_ptr(args.ptr_B0); + auto ptr_B1 = recast_ptr(args.ptr_B1); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b0 = make_tensor(ptr_B0, make_layout(make_shape(N,K,L), args.dB0)); + Tensor tensor_b1 = make_tensor(ptr_B1, make_layout(make_shape(N,K,L), args.dB1)); + + Tensor tensor_sfa = make_tensor(args.ptr_SFA, args.layout_SFA); + Tensor tensor_sfb0 = make_tensor(args.ptr_SFB0, args.layout_SFB0); + Tensor tensor_sfb1 = make_tensor(args.ptr_SFB1, args.layout_SFB1); + + typename Params::TMA_A tma_load_a = make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + _1{}); // No programmatic multicast + typename Params::TMA_B0 tma_load_b0 = make_tma_copy( + GmemTiledCopyB0{}, + tensor_b0, + SmemLayoutB0{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + _1{}); // No programmatic multicast + typename Params::TMA_B1 tma_load_b1 = make_tma_copy( + GmemTiledCopyB1{}, + tensor_b1, + SmemLayoutB1{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + _1{}); // No programmatic multicast + + typename Params::TMA_SFA tma_load_sfa = make_tma_copy( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + _1{}); // No programmatic multicast + + typename Params::TMA_SFB0 tma_load_sfb0 = make_tma_copy( + GmemTiledCopySFB0{}, + tensor_sfb0, + SmemLayoutSFB0{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + _1{}); // No programmatic multicast + + typename Params::TMA_SFB1 tma_load_sfb1 = make_tma_copy( + GmemTiledCopySFB1{}, + tensor_sfb1, + SmemLayoutSFB1{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + _1{}); // No programmatic multicast + + return { + tma_load_a, + tma_load_b0, + tma_load_b1, + tma_load_sfa, + tma_load_sfb0, + tma_load_sfb1, + args.layout_SFA, + args.layout_SFB0, + args.layout_SFB1, + TmaTransactionBytes, + TmaTransactionBytesMK, + TmaTransactionBytesNK0, + TmaTransactionBytesNK1 + }; + } + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B0 = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B1 = cutlass::detail::get_input_alignment_bits(); + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B0 = tma_alignment_bits_B0 / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB0{}); + constexpr int min_tma_aligned_elements_B1 = tma_alignment_bits_B1 / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB1{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b0.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b1.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_sfa.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_sfb0.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_sfb1.get_tma_descriptor()); + } + + // Temporary adhoc partitioning for scaling factors. + template + CUTE_HOST_DEVICE constexpr + auto + thrfrg_SFA(SFATensor&& sfatensor, TiledMMA& mma) + { + CUTE_STATIC_ASSERT_V(rank(sfatensor) >= Int<2>{}); + + using AtomShape_MNK = typename Atom::Shape_MNK; + using AtomLayoutSFA_TV = typename Atom::Traits::SFALayout; + + auto permutation_mnk = TiledPerm{}; + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(get<0>(permutation_mnk), + get<2>(permutation_mnk)); + auto t_tensor = logical_divide(sfatensor, t_tile); // (PermM,PermK) + + // Tile the tensor for the Atom + auto a_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})), + make_layout(size<2>(AtomShape_MNK{}))); + auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomM,AtomK),(RestM,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + auto tv_tensor = a_tensor.compose(AtomLayoutSFA_TV{},_); // ((ThrV,FrgV),(RestM,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = make_tile(_, + make_tile(make_layout(size<1>(thr_layout_vmnk)), + make_layout(size<3>(thr_layout_vmnk)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK))) + + return thr_tensor; + } + + template + CUTE_HOST_DEVICE constexpr + auto + thrfrg_SFB0(SFB0Tensor&& sfb0tensor, TiledMMA& mma) + { + CUTE_STATIC_ASSERT_V(rank(sfb0tensor) >= Int<2>{}); + + using AtomShape_MNK = typename Atom::Shape_MNK; + using AtomLayoutSFB0_TV = typename Atom::Traits::SFBLayout; + + auto permutation_mnk = TiledPerm{}; + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(get<1>(permutation_mnk), + get<2>(permutation_mnk)); + auto t_tensor = logical_divide(sfb0tensor, t_tile); // (PermN,PermK) + + // Tile the tensor for the Atom + auto a_tile = make_tile(make_layout(size<1>(AtomShape_MNK{})), + make_layout(size<2>(AtomShape_MNK{}))); + auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomN,AtomK),(RestN,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + auto tv_tensor = a_tensor.compose(AtomLayoutSFB0_TV{},_); // ((ThrV,FrgV),(RestN,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = make_tile(_, + make_tile(make_layout(size<2>(thr_layout_vmnk)), + make_layout(size<3>(thr_layout_vmnk)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK))) + return thr_tensor; + } + + template + CUTE_HOST_DEVICE constexpr + auto + thrfrg_SFB1(SFB1Tensor&& sfb1tensor, TiledMMA& mma) + { + CUTE_STATIC_ASSERT_V(rank(sfb1tensor) >= Int<2>{}); + + using AtomShape_MNK = typename Atom::Shape_MNK; + using AtomLayoutSFB1_TV = typename Atom::Traits::SFBLayout; + + auto permutation_mnk = TiledPerm{}; + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(get<1>(permutation_mnk), + get<2>(permutation_mnk)); + auto t_tensor = logical_divide(sfb1tensor, t_tile); // (PermN,PermK) + + // Tile the tensor for the Atom + auto a_tile = make_tile(make_layout(size<1>(AtomShape_MNK{})), + make_layout(size<2>(AtomShape_MNK{}))); + auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomN,AtomK),(RestN,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + auto tv_tensor = a_tensor.compose(AtomLayoutSFB1_TV{},_); // ((ThrV,FrgV),(RestN,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = make_tile(_, + make_tile(make_layout(size<2>(thr_layout_vmnk)), + make_layout(size<3>(thr_layout_vmnk)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK))) + return thr_tensor; + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_fragment_SFA(SFATensor&& sfatensor, ThrMma& thread_mma) + { + using ValTypeSF = typename ThrMma::Atom::Traits::ValTypeSF; + auto thr_tensor = make_tensor(static_cast(sfatensor).data(), thrfrg_SFA(sfatensor.layout(),thread_mma)); + auto thr_vmnk = thread_mma.thr_vmnk_; + auto thr_vmk = make_coord(get<0>(thr_vmnk), make_coord(get<1>(thr_vmnk), get<3>(thr_vmnk))); + auto partition_SFA = thr_tensor(thr_vmk, make_coord(_, repeat(thr_tensor)>(_))); + return make_fragment_like(partition_SFA); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_fragment_SFB0(SFB0Tensor&& sfb0tensor, ThrMma& thread_mma) + { + using ValTypeSF = typename ThrMma::Atom::Traits::ValTypeSF; + auto thr_tensor = make_tensor(static_cast(sfb0tensor).data(), thrfrg_SFB0(sfb0tensor.layout(),thread_mma)); + auto thr_vmnk = thread_mma.thr_vmnk_; + auto thr_vnk = make_coord(get<0>(thr_vmnk), make_coord(get<2>(thr_vmnk), get<3>(thr_vmnk))); + auto partition_SFB0 = thr_tensor(thr_vnk, make_coord(_, repeat(thr_tensor)>(_))); + return make_fragment_like(partition_SFB0); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_fragment_SFB1(SFB1Tensor&& sfb1tensor, ThrMma& thread_mma) + { + using ValTypeSF = typename ThrMma::Atom::Traits::ValTypeSF; + auto thr_tensor = make_tensor(static_cast(sfb1tensor).data(), thrfrg_SFB1(sfb1tensor.layout(),thread_mma)); + auto thr_vmnk = thread_mma.thr_vmnk_; + auto thr_vnk = make_coord(get<0>(thr_vmnk), make_coord(get<2>(thr_vmnk), get<3>(thr_vmnk))); + auto partition_SFB1 = thr_tensor(thr_vnk, make_coord(_, repeat(thr_tensor)>(_))); + return make_fragment_like(partition_SFB1); + } + + template + CUTE_HOST_DEVICE constexpr + auto + get_layoutSFA_TV(TiledMma& mma) + { + // (M,K) -> (M,K) + auto tile_shape_mnk = tile_shape(mma); + auto ref_A = make_layout(make_shape(size<0>(tile_shape_mnk), size<2>(tile_shape_mnk))); + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto atile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk), size<2>(thr_layout_vmnk)), + make_stride( Int<1>{} , Int<0>{} )), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(thr_layout_vmnk); + // (thr_idx,val) -> (M,K) + return thrfrg_SFA(ref_A, mma).compose(atile, _).compose(thridx_2_thrid, _); + } + + template + CUTE_HOST_DEVICE constexpr + auto + get_layoutSFB0_TV(TiledMma& mma) + { + // (N,K) -> (N,K) + auto tile_shape_mnk = tile_shape(mma); + auto ref_B0 = make_layout(make_shape(size<1>(tile_shape_mnk), size<2>(tile_shape_mnk))); + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto btile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk), size<2>(thr_layout_vmnk)), + make_stride( Int<0>{} , Int<1>{} )), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(thr_layout_vmnk); + // (thr_idx,val) -> (M,K) + return thrfrg_SFB0(ref_B0, mma).compose(btile, _).compose(thridx_2_thrid, _); + } + + template + CUTE_HOST_DEVICE constexpr + auto + get_layoutSFB1_TV(TiledMma& mma) + { + // (N,K) -> (N,K) + auto tile_shape_mnk = tile_shape(mma); + auto ref_B1 = make_layout(make_shape(size<1>(tile_shape_mnk), size<2>(tile_shape_mnk))); + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto btile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk), size<2>(thr_layout_vmnk)), + make_stride( Int<0>{} , Int<1>{} )), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(thr_layout_vmnk); + // (thr_idx,val) -> (M,K) + return thrfrg_SFB1(ref_B1, mma).compose(btile, _).compose(thridx_2_thrid, _); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) + Tensor mB0_nkl = params.tma_load_b0.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + Tensor mB1_nkl = params.tma_load_b1.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + Tensor mSFA_mkl = params.tma_load_sfa.get_tma_tensor(shape(params.layout_SFA)); + Tensor mSFB0_nkl = params.tma_load_sfb0.get_tma_tensor(shape(params.layout_SFB0)); + Tensor mSFB1_nkl = params.tma_load_sfb1.get_tma_tensor(shape(params.layout_SFB1)); + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB0_nkl = local_tile(mB0_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + Tensor gB1_nkl = local_tile(mB1_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + Tensor gSFA_mkl = local_tile(mSFA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (TILE_M,TILE_K,m,k,l) + Tensor gSFB0_nkl = local_tile(mSFB0_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (TILE_N,TILE_K,n,k,l) + Tensor gSFB1_nkl = local_tile(mSFB1_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (TILE_N,TILE_K,n,k,l) + + return cute::make_tuple(gA_mkl, gB0_nkl, gB1_nkl, gSFA_mkl, gSFB0_nkl, gSFB1_nkl); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, class TensorB0, class TensorB1, + class TensorSFA, class TensorSFB0, class TensorSFB1, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB0 = make_tensor(make_smem_ptr(shared_tensors.smem_B0.begin()), SmemLayoutB0{}); // (BLK_N,BLK_K,PIPE) + Tensor sB1 = make_tensor(make_smem_ptr(shared_tensors.smem_B1.begin()), SmemLayoutB1{}); // (BLK_N,BLK_K,PIPE) + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); // (BLK_M,BLK_K,PIPE) + Tensor sSFB0 = make_tensor(make_smem_ptr(shared_tensors.smem_SFB0.begin()), SmemLayoutSFB0{}); // (BLK_N,BLK_K,PIPE) + Tensor sSFB1 = make_tensor(make_smem_ptr(shared_tensors.smem_SFB1.begin()), SmemLayoutSFB1{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A, B, SFA and SFB + // + + auto [gA_mkl, gB0_nkl, gB1_nkl, gSFA_mkl, gSFB0_nkl, gSFB1_nkl] = load_inputs; + + auto block_tma_a = params.tma_load_a.get_slice(0); + auto block_tma_b0 = params.tma_load_b0.get_slice(0); + auto block_tma_b1 = params.tma_load_b1.get_slice(0); + + auto block_tma_sfa = params.tma_load_sfa.get_slice(0); + auto block_tma_sfb0 = params.tma_load_sfb0.get_slice(0); + auto block_tma_sfb1 = params.tma_load_sfb1.get_slice(0); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB0 = gB0_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + Tensor gB1 = gB1_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + Tensor gSFA = gSFA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gSFB0 = gSFB0_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + Tensor gSFB1 = gSFB1_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Partition source and destination tensors for tma copies + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB0 = block_tma_b0.partition_S(gB0); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB0 = block_tma_b0.partition_D(sB0); // (TMA,TMA_N,TMA_K,PIPE) + Tensor tBgB1 = block_tma_b1.partition_S(gB1); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB1 = block_tma_b1.partition_D(sB1); // (TMA,TMA_N,TMA_K,PIPE) + + Tensor tAgSFA = block_tma_sfa.partition_S(gSFA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsSFA = block_tma_sfa.partition_D(sSFA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgSFB0 = block_tma_sfb0.partition_S(gSFB0); // (TMA,TMA_N,TMA_K,k) + Tensor tBsSFB0 = block_tma_sfb0.partition_D(sSFB0); // (TMA,TMA_N,TMA_K,PIPE) + Tensor tBgSFB1 = block_tma_sfb1.partition_S(gSFB1); // (TMA,TMA_N,TMA_K,k) + Tensor tBsSFB1 = block_tma_sfb1.partition_D(sSFB1); // (TMA,TMA_N,TMA_K,PIPE) + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(params.tma_load_a.with(*tma_barrier), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(params.tma_load_b0.with(*tma_barrier), tBgB0(_,_,_,*k_tile_iter), tBsB0(_,_,_,write_stage)); + copy(params.tma_load_b1.with(*tma_barrier), tBgB1(_,_,_,*k_tile_iter), tBsB1(_,_,_,write_stage)); + + copy(params.tma_load_sfa.with(*tma_barrier), tAgSFA(_,_,_,*k_tile_iter), tAsSFA(_,_,_,write_stage)); + copy(params.tma_load_sfb0.with(*tma_barrier), tBgSFB0(_,_,_,*k_tile_iter), tBsSFB0(_,_,_,write_stage)); + copy(params.tma_load_sfb1.with(*tma_barrier), tBgSFB1(_,_,_,*k_tile_iter), tBsSFB1(_,_,_,write_stage)); + + // Advance k tile + ++k_tile_iter; + ++smem_pipe_write; + } + } + __syncwarp(); + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC0, class FrgTensorC1 + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC0& accum0, + FrgTensorC1& accum1, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + [[maybe_unused]] Params const& params) { + using namespace cute; + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(is_rmem::value, "C tensor must be rmem resident."); + + clear(accum0); + clear(accum1); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB0 = make_tensor(make_smem_ptr(shared_tensors.smem_B0.begin()), SmemLayoutB0{}); // (BLK_N,BLK_K,PIPE) + Tensor sB1 = make_tensor(make_smem_ptr(shared_tensors.smem_B1.begin()), SmemLayoutB1{}); // (BLK_N,BLK_K,PIPE) + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); // (BLK_M,BLK_K,PIPE) + Tensor sSFB0 = make_tensor(make_smem_ptr(shared_tensors.smem_SFB0.begin()), SmemLayoutSFB0{}); // (BLK_N,BLK_K,PIPE) + Tensor sSFB1 = make_tensor(make_smem_ptr(shared_tensors.smem_SFB1.begin()), SmemLayoutSFB1{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + // Allocate fragments and descriptors + Tensor tCrA = thread_mma.partition_fragment_A(sA(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K) + Tensor tCrB0 = thread_mma.partition_fragment_B(sB0(_,_,Int<0>{})); // (MMA,MMA_N,MMA_K) + Tensor tCrB1 = thread_mma.partition_fragment_B(sB1(_,_,Int<0>{})); // (MMA,MMA_N,MMA_K) + + Tensor tCrSFA = partition_fragment_SFA(sSFA(_,_,Int<0>{}), thread_mma); // (MMA,MMA_M,MMA_K) + Tensor tCrSFB0 = partition_fragment_SFB0(sSFB0(_,_,Int<0>{}), thread_mma); // (MMA,MMA_N,MMA_K) + Tensor tCrSFB1 = partition_fragment_SFB1(sSFB1(_,_,Int<0>{}), thread_mma); // (MMA,MMA_N,MMA_K) + + // + // Copy from smem to registers + // + + // A + auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + Tensor tCsA = smem_thr_copy_A.partition_S( + as_position_independent_swizzle_tensor(sA)); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K) + + // B0 + auto smem_tiled_copy_B0 = make_tiled_copy_B(SmemCopyAtomB0{}, tiled_mma); + auto smem_thr_copy_B0 = smem_tiled_copy_B0.get_thread_slice(thread_idx); + Tensor tCsB0 = smem_thr_copy_B0.partition_S( + as_position_independent_swizzle_tensor(sB0)); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrB0_copy_view = smem_thr_copy_B0.retile_D(tCrB0); // (CPY,CPY_M,CPY_K) + + // B1 + auto smem_tiled_copy_B1 = make_tiled_copy_B(SmemCopyAtomB1{}, tiled_mma); + auto smem_thr_copy_B1 = smem_tiled_copy_B1.get_thread_slice(thread_idx); + Tensor tCsB1 = smem_thr_copy_B1.partition_S( + as_position_independent_swizzle_tensor(sB1)); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrB1_copy_view = smem_thr_copy_B1.retile_D(tCrB1); // (CPY,CPY_M,CPY_K) + + // SFA + auto tile_shape_mnk = tile_shape(tiled_mma); + auto smem_tiled_copy_SFA = make_tiled_copy_impl(SmemCopyAtomSFA{}, + get_layoutSFA_TV(tiled_mma), + make_shape(size<0>(tile_shape_mnk), size<2>(tile_shape_mnk)) + ); + auto smem_thr_copy_SFA = smem_tiled_copy_SFA.get_thread_slice(thread_idx); + Tensor tCsSFA = smem_thr_copy_SFA.partition_S( + as_position_independent_swizzle_tensor(sSFA)); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrSFA_copy_view = smem_thr_copy_SFA.retile_D(tCrSFA); // (CPY,CPY_M,CPY_K) + + // SFB0 + auto smem_tiled_copy_SFB0 = make_tiled_copy_impl(SmemCopyAtomSFB0{}, + get_layoutSFB0_TV(tiled_mma), + make_shape(size<1>(tile_shape_mnk), size<2>(tile_shape_mnk)) + ); + auto smem_thr_copy_SFB0 = smem_tiled_copy_SFB0.get_thread_slice(thread_idx); + Tensor tCsSFB0 = smem_thr_copy_SFB0.partition_S( + as_position_independent_swizzle_tensor(sSFB0)); // (CPY,CPY_N,CPY_K,PIPE) + Tensor tCrSFB0_copy_view = smem_thr_copy_SFB0.retile_D(tCrSFB0); // (CPY,CPY_N,CPY_K) + + // SFB1 + auto smem_tiled_copy_SFB1 = make_tiled_copy_impl(SmemCopyAtomSFB1{}, + get_layoutSFB1_TV(tiled_mma), + make_shape(size<1>(tile_shape_mnk), size<2>(tile_shape_mnk)) + ); + auto smem_thr_copy_SFB1 = smem_tiled_copy_SFB1.get_thread_slice(thread_idx); + Tensor tCsSFB1 = smem_thr_copy_SFB1.partition_S( + as_position_independent_swizzle_tensor(sSFB1)); // (CPY,CPY_N,CPY_K,PIPE) + Tensor tCrSFB1_copy_view = smem_thr_copy_SFB1.retile_D(tCrSFB1); // (CPY,CPY_N,CPY_K) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum0)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum1)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB0) == size<2>(accum0)); // MMA_N + CUTE_STATIC_ASSERT_V(size<1>(tCrB1) == size<2>(accum1)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB0)); // CPY_K + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB1)); // CPY_K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB0)); // PIPE + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB1)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB0)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB1)); // PIPE + + CUTE_STATIC_ASSERT_V(size<1>(tCsSFA) == size<1>(tCrSFA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsSFA) == size<2>(tCrSFA_copy_view)); // CPY_K + CUTE_STATIC_ASSERT_V(size<1>(tCrSFA) == size<1>(accum0)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrSFA) == size<1>(accum1)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrSFB0) == size<2>(accum0)); // MMA_N + CUTE_STATIC_ASSERT_V(size<1>(tCrSFB1) == size<2>(accum1)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCsSFA) == size<2>(tCsSFB0)); // CPY_K + CUTE_STATIC_ASSERT_V(size<2>(tCsSFA) == size<2>(tCsSFB1)); // CPY_K + CUTE_STATIC_ASSERT_V(size<3>(tCsSFA) == size<3>(tCsSFB0)); // PIPE + CUTE_STATIC_ASSERT_V(size<3>(tCsSFA) == size<3>(tCsSFB1)); // PIPE + CUTE_STATIC_ASSERT_V(size<2>(sA) == size<2>(sSFA)); // PIPE + CUTE_STATIC_ASSERT_V(size<2>(sB0) == size<2>(sSFA)); // PIPE + CUTE_STATIC_ASSERT_V(size<2>(sB1) == size<2>(sSFA)); // PIPE + + // + // PIPELINED MAIN LOOP + // + + // Size of the register pipeline + auto K_BLOCK_MAX = size<2>(tCrA); + + int read_stage = smem_pipe_read.index(); + auto tCsA_stage = tCsA(_,_,_,read_stage); + auto tCsB0_stage = tCsB0(_,_,_,read_stage); + auto tCsB1_stage = tCsB1(_,_,_,read_stage); + auto tCsSFA_stage = tCsSFA(_,_,_,read_stage); + auto tCsSFB0_stage = tCsSFB0(_,_,_,read_stage); + auto tCsSFB1_stage = tCsSFB1(_,_,_,read_stage); + + auto copy_kblock = [&](auto k_block) { + // copy smem->rmem for A/B operand + copy(smem_tiled_copy_A, tCsA_stage(_,_,k_block), tCrA_copy_view(_,_,k_block)); + copy(smem_tiled_copy_B0, tCsB0_stage(_,_,k_block), tCrB0_copy_view(_,_,k_block)); + copy(smem_tiled_copy_B1, tCsB1_stage(_,_,k_block), tCrB1_copy_view(_,_,k_block)); + + // Left shift A,B for FP4 + using MMAOp = typename TiledMma::MMA_Op; + fp4_shift_A(MMAOp{}, tCrA_copy_view(_,_,k_block)); + fp4_shift_B(MMAOp{}, tCrB0_copy_view(_,_,k_block)); + fp4_shift_B(MMAOp{}, tCrB1_copy_view(_,_,k_block)); + + + // Copy smem->rmem for SFA/SFB operand + copy(tCsSFA_stage(_,_,k_block), tCrSFA_copy_view(_,_,k_block)); + copy(tCsSFB0_stage(_,_,k_block), tCrSFB0_copy_view(_,_,k_block)); + copy(tCsSFB1_stage(_,_,k_block), tCrSFB1_copy_view(_,_,k_block)); + }; + + auto gemm_kblock = [&](auto k_block) { + cute::gemm(tiled_mma, + make_zip_tensor(tCrA (_,_,k_block), tCrSFA (_,_,k_block)), + make_zip_tensor(tCrB0 (_,_,k_block), tCrSFB0 (_,_,k_block)), + accum0); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA (_,_,k_block), tCrSFA (_,_,k_block)), + make_zip_tensor(tCrB1 (_,_,k_block), tCrSFB1 (_,_,k_block)), + accum1); + }; + + pipeline.consumer_wait(smem_pipe_read); + + copy_kblock(_0{}); + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 1; --k_tile_count) { + // + // Compute on k_tile + // + for_each(make_int_sequence{}, [&] (auto k_block) { + + auto k_block_next = ((k_block + 1) == K_BLOCK_MAX) ? 0 : (k_block + 1); + + if (k_block == K_BLOCK_MAX - 1) { + cutlass::arch::NamedBarrier::sync( + thr_size(tiled_mma), cutlass::arch::ReservedNamedBarriers::Sm120MainloopBarrier); + // UNLOCK smem_pipe_read, done _computing_ on it + pipeline.consumer_release(smem_pipe_read); + ++smem_pipe_read; + read_stage = smem_pipe_read.index(); + tCsA_stage = tCsA(_,_,_,read_stage); + tCsB0_stage = tCsB0(_,_,_,read_stage); + tCsB1_stage = tCsB1(_,_,_,read_stage); + tCsSFA_stage = tCsSFA(_,_,_,read_stage); + tCsSFB0_stage = tCsSFB0(_,_,_,read_stage); + tCsSFB1_stage = tCsSFB1(_,_,_,read_stage); + pipeline.consumer_wait(smem_pipe_read); + } + + copy_kblock(k_block_next); + gemm_kblock(k_block); + + }); + } // k_tile_count + + // + // Hoist out last k_tile + // + for_each(make_int_sequence{}, [&] (auto k_block) { + + auto k_block_next = ((k_block + 1) == K_BLOCK_MAX) ? 0 : (k_block + 1); + + if (k_block == K_BLOCK_MAX - 1) { + cutlass::arch::NamedBarrier::sync( + thr_size(tiled_mma), cutlass::arch::ReservedNamedBarriers::Sm120MainloopBarrier); + // UNLOCK smem_pipe_read, done _computing_ on it + pipeline.consumer_release(smem_pipe_read); + ++smem_pipe_read; + } + + if (k_block_next > 0) { + copy_kblock(k_block_next); + } + gemm_kblock(k_block); + + }); +} + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline, PipelineState, int) { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/94_blackwell_geforce_dual_gemm/collective/sm90_epilogue_tma_warpspecialized_dual.hpp b/examples/94_blackwell_geforce_dual_gemm/collective/sm90_epilogue_tma_warpspecialized_dual.hpp new file mode 100644 index 0000000000..10305ce9d7 --- /dev/null +++ b/examples/94_blackwell_geforce_dual_gemm/collective/sm90_epilogue_tma_warpspecialized_dual.hpp @@ -0,0 +1,1015 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/thread/scale_type.h" + +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/layout.hpp" +#include "cutlass/detail/helper_macros.hpp" +#include "cutlass/trace.h" + +#include "cute/tensor.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +#include "../thread/left_silu_and_mul.h" +#include "../fusion/callbacks.hpp" +#include "../fusion/sm90_callbacks_tma_warpspecialized_dual.hpp" +#include "dispatch_policy_extra.hpp" + +#include "../fusion/operations.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class DispatchPolicy, + class... Args +> +class DualCollectiveEpilogue { + static_assert(cutlass::detail::dependent_false, "Could not find an epilogue specialization."); +}; + +template < + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_, + class CtaTileMNK_, // (CTA_M,CTA_N,CTA_K) + class EpilogueTile_, // (EPI_TILE_M,EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class DualFusionCallbacks_, + class CopyOpG2S_, + class SmemLayoutAtomC_, + class CopyOpS2R_, + class CopyOpS2G_, + class SmemLayoutAtomD_, + class CopyOpR2S_, + class CopyAtomC_, + class CopyOpR2R_ +> +class DualCollectiveEpilogue< + DualSm90TmaWarpSpecialized, + CtaTileMNK_, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + DualFusionCallbacks_, + CopyOpG2S_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpS2G_, + SmemLayoutAtomD_, + CopyOpR2S_, + CopyAtomC_, + CopyOpR2R_ +> { +public: + // + // Type Aliases + // + using DispatchPolicy = DualSm90TmaWarpSpecialized; + using CtaTileMNK = CtaTileMNK_; + using EpilogueTile = EpilogueTile_; + using DualFusionCallbacks = DualFusionCallbacks_; + using ElementC = ElementC_; + using StrideC = StrideC_; + using ElementD = ElementD_; + using StrideD = StrideD_; + using CopyOpG2S = CopyOpG2S_; + using SmemLayoutAtomC = SmemLayoutAtomC_; + using CopyOpS2R = CopyOpS2R_; + using CopyOpS2G = CopyOpS2G_; + using SmemLayoutAtomD = SmemLayoutAtomD_; + using CopyOpR2S = CopyOpR2S_; + using CopyAtomC = CopyAtomC_; + using CopyOpR2R = CopyOpR2R_; + + using ThreadEpilogueOp = typename epilogue::fusion::DualFusionCallbacksTraits::Operation; + using GmemTiledCopyC = CopyOpG2S; + using GmemTiledCopyD = CopyOpS2G; + + static_assert(!is_layout::value && is_tuple::value, "EpilogueTile must be a cute::Tile or cute::Shape"); + static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); + static_assert(size<0>(CtaTileMNK{}) % size<0>(shape(EpilogueTile{})) == 0, "EPI_TILE_M must divide CTA_M"); + static_assert(size<1>(CtaTileMNK{}) % size<1>(shape(EpilogueTile{})) == 0, "EPI_TILE_N must divide CTA_N"); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); + +private: + constexpr static bool is_source_supported = not cute::is_void_v; + constexpr static bool is_destination_supported = not cute::is_void_v; + using NonVoidElementD = cute::conditional_t, ElementD>; + static_assert(not cute::is_void_v, "SmemElementD is void"); + using NonVoidElementC = cute::conditional_t; // prevents void ref breakages + + using TmaElementD = cute::conditional_t>, uint64_t, NonVoidElementD>; + using TmaElementC = cute::conditional_t>, uint64_t, NonVoidElementC>; + + using SmemElementC = typename cutlass::detail::get_unpacked_element_type::type; + using SmemElementD = typename cutlass::detail::get_unpacked_element_type::type; + + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static bool ReuseSmemC = ReuseSmemC_ and is_destination_supported; + constexpr static bool DelayTmaStore = DelayTmaStore_; + + constexpr static bool is_m_major_C = detail::is_m_major(); + constexpr static bool is_m_major_D = detail::is_m_major(); + + constexpr static bool is_im2col_C = cute::is_same_v; + constexpr static bool is_im2col_D = cute::is_same_v; + + // Check if register transformation is needed before copying register to shared memory. + constexpr static bool IsUseR2R = !cute::is_void_v; + + using SmemLayoutC = decltype(tile_to_shape( + SmemLayoutAtomC{}, + make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); + using SmemLayoutD = decltype(tile_to_shape( + SmemLayoutAtomD{}, + make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); + + constexpr static bool support_smem_reuse = is_source_supported && is_destination_supported && StagesD <= StagesC + && cosize(take<0,2>(SmemLayoutC{})) == cosize(take<0,2>(SmemLayoutD{})); + static_assert(not (ReuseSmemC && not support_smem_reuse), "Smem reuse requirements not met"); + + constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); + constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{}); + constexpr static size_t MaxSmemAlignment = cute::max(SmemAlignmentC, SmemAlignmentD); + + using SmemArrayTypeC = cute::ArrayEngine>; + using SmemArrayTypeD = cute::ArrayEngine>; + + using EmptyType = cute::tuple<>; + using SmemCStorage = cute::conditional_t; + using SmemDStorage = cute::conditional_t; + + struct CollectiveStorageWithC { + alignas(SmemAlignmentC) ArrayEngine> smem_C; + alignas(SmemAlignmentD) ArrayEngine> smem_D; + }; + + union CollectiveStorageWithoutC { + cute::array smem_C; + alignas(SmemAlignmentD) ArrayEngine> smem_D; + }; + + union CollectiveStorageReuseC { + alignas(MaxSmemAlignment) ArrayEngine> smem_C; + alignas(MaxSmemAlignment) ArrayEngine> smem_D; + }; + +public: + // TMA pipeline for loading C + using LoadPipeline = cutlass::PipelineTransactionAsync; + using LoadPipelineState = cutlass::PipelineState; + constexpr static uint32_t TmaTransactionBytes = + (size(take<0,2>(SmemLayoutC{})) * static_cast(sizeof_bits::value)) / 8; + constexpr static bool RequiresTransactionBytes = true; + + // TMA pipeline for storing D + using StorePipeline = cute::conditional_t, + cutlass::PipelineTmaStore>; + using StorePipelineState = cutlass::PipelineState; + + struct SharedStorage { + struct TensorStorage { + using CollectiveStorage = cute::conditional_t>; + CollectiveStorage collective; + + using DualFusionStorage = typename DualFusionCallbacks::SharedStorage; + DualFusionStorage thread; + } tensors; + + using PipelineStorage = typename LoadPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side epilogue arguments + struct Arguments { + typename DualFusionCallbacks::Arguments thread{}; + ElementC const* ptr_C; + StrideC dC; + ElementD const* ptr_D; + StrideD dD; + }; + + // Device side epilogue params + struct Params { + using TMA_C = decltype(make_tma_copy( + CopyOpG2S{}, + make_tensor(make_gmem_ptr(nullptr), + repeat_like(StrideC{}, int32_t(0)), StrideC{}), + take<0,2>(SmemLayoutC{}), + EpilogueTile{}, + _1{})); + using TMA_D = decltype(make_tma_copy( + CopyOpS2G{}, + make_tensor(make_gmem_ptr(nullptr), + repeat_like(StrideD{}, int32_t(0)), StrideD{}), + take<0,2>(SmemLayoutD{}), + EpilogueTile{}, + _1{})); + + typename DualFusionCallbacks::Params thread{}; + TMA_C tma_load_c; + TMA_D tma_store_d; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + uint32_t transaction_bytes = TmaTransactionBytes; + typename Params::TMA_C tma_load_c{}; + if constexpr (is_source_supported) { + Tensor tensor_c = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M,N,L), args.dC)); + tma_load_c = make_tma_copy_C_sm90( + CopyOpG2S{}, + tensor_c, + take<0,2>(SmemLayoutC{}), + EpilogueTile{}); + } + + typename Params::TMA_D tma_store_d{}; + if constexpr (is_destination_supported) { + Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M,N,L), args.dD)); + tma_store_d = make_tma_copy_C_sm90( + CopyOpS2G{}, + tensor_d, + take<0,2>(SmemLayoutD{}), + EpilogueTile{}); + } + + return { + DualFusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), + tma_load_c, + tma_store_d, + transaction_bytes + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return DualFusionCallbacks::get_workspace_size(problem_shape, args.thread); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return DualFusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream, cuda_adapter); + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + auto shape = cute::make_shape(M,N,L); + + bool implementable = true; + if constexpr (is_destination_supported) { + constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits(); + constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits::value; + if constexpr (cute::is_same_v) { // ignore L stride for implicit gemm + implementable = cutlass::detail::check_alignment(take<0,2>(shape), take<0,2>(StrideD{})); + } + else { + implementable = cutlass::detail::check_alignment(shape, StrideD{}); + } + } + + if constexpr (not cute::is_void_v) { + constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits::value; + if constexpr (cute::is_same_v) { // ignore L stride for implicit gemm + implementable = implementable && cutlass::detail::check_alignment(take<0,2>(shape), take<0,2>(StrideC{})); + } + else { + implementable = implementable && cutlass::detail::check_alignment(shape, StrideC{}); + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + bool fusion_implementable = DualFusionCallbacks::can_implement(problem_shape, args.thread); + + if (!fusion_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for DualFusionCallbacks.\n"); + } + + bool beta_implementable = true; + + if constexpr (cute::is_void_v) { + if constexpr (detail::has_beta::value) { + beta_implementable = args.thread.beta == 0.0; + } + if constexpr (detail::has_beta_ptr::value) { + beta_implementable = beta_implementable && args.thread.beta_ptr == nullptr; + } + } + + if (!beta_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Beta/beta pointer was set, but epilogue is sourceless (void-C).\n"); + } + + return implementable && fusion_implementable && beta_implementable; + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_load_pipe_increment(TileShapeMNK tile_shape_MNK) { + // Compute number of epilogue subtiles + return size<1>(zipped_divide(make_layout(take<0,2>(tile_shape_MNK)), EpilogueTile{})); + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_store_pipe_increment(TileShapeMNK tile_shape_MNK) { + return get_load_pipe_increment(tile_shape_MNK); + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void + prefetch_tma_descriptors(Params const& epilogue_params) { + if constexpr (is_source_supported) { + cute::prefetch_tma_descriptor(epilogue_params.tma_load_c.get_tma_descriptor()); + } + if constexpr (is_destination_supported) { + cute::prefetch_tma_descriptor(epilogue_params.tma_store_d.get_tma_descriptor()); + } + } + + CUTLASS_HOST_DEVICE + DualCollectiveEpilogue(Params const& params_, TensorStorage& shared_tensors) + : params(params_), fusion_callbacks(params_.thread, shared_tensors.thread) {} + + CUTLASS_DEVICE + bool + is_producer_load_needed() const { + return fusion_callbacks.is_producer_load_needed(); + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class TiledMma + > + CUTLASS_DEVICE auto + load( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + TiledMma tiled_mma, + int thread_idx, + TensorStorage& shared_tensors, + int subtile_idx=-1) { + using namespace cute; + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + + // The tma tensor C under im2col mode only has two modes (M, N) which + // should be local tiled with only (m_coord, n_coord). + auto coord_shape = conditional_return( + make_coord(m_coord, n_coord), + make_coord(m_coord, n_coord, l_coord)); + + // Represent the full source tensor, slice to get the tile this CTA is currently responsible for + Tensor mC_mn = params.tma_load_c.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor mC = coalesce(mC_mn, take<0,2>(CtaTileMNK{})); + Tensor gC = local_tile(mC, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N) + + // Apply epilogue subtile, get matching smem tensor + auto ptr_sC = shared_tensors.collective.smem_C.begin(); + Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor sC_epi = make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + + // Prepare the thread(b)lock's (G)mem to (S)mem TMA tiled copy (bGS_) + ThrCopy thrblk_g2s = params.tma_load_c.get_slice(Int<0>{}); + Tensor bGS_gC = thrblk_g2s.partition_S(gC_epi); // (G2S,G2S_M,G2S_N,EPI_M,EPI_N) + Tensor bGS_sC = thrblk_g2s.partition_D(sC_epi); // (G2S,G2S_M,G2S_N,PIPE_C) + + // Get the fusion callbacks for the producer load warp + auto pld_args = cutlass::epilogue::fusion::detail::ProducerLoadArgs( + problem_shape_mnkl, + CtaTileMNK{}, + tile_coord_mnkl, + tiled_mma, + EpilogueTile{}, + thread_idx + ); + auto pld_callbacks = fusion_callbacks.get_producer_load_callbacks(pld_args); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // Predication for TMA load (one thread issues TMA load) + bool issue_tma_load = cute::elect_one_sync(); + + // Pre-loop fusion callback entry point + pld_callbacks.begin(); + + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < size<3>(gC_epi); ++epi_n) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < size<2>(gC_epi); ++epi_m) { + if (subtile_idx != -1 && (epi_n * static_cast(size<2>(gC_epi)) + epi_m) != subtile_idx) { + continue; + } + // Acquire the lock for this stage + constexpr uint16_t mcast_mask = 0; + uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); + load_pipeline.producer_acquire(load_pipe_producer_state); + + // Loop fusion callback entry point + pld_callbacks.step(tma_barrier, epi_m, epi_n, load_pipe_producer_state.count(), issue_tma_load); + + // Execute the TMA load for C if needed + if (issue_tma_load && is_C_load_needed) { + copy(params.tma_load_c.with(*tma_barrier, mcast_mask), + bGS_gC(_,_,_,epi_m,epi_n), bGS_sC(_,_,_,load_pipe_producer_state.index())); + load_pipeline.producer_expect_transaction(load_pipe_producer_state); + } + + // Commit TMA loads for this stage and release the lock + load_pipeline.producer_commit(load_pipe_producer_state); + ++load_pipe_producer_state; + } + } + + // Post-loop fusion callback entry point + pld_callbacks.end(); + + return load_pipe_producer_state; + } + + CUTLASS_DEVICE auto + load_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state) { + bool issue_tma_load = cute::elect_one_sync(); + if (issue_tma_load) { + load_pipeline.producer_tail(load_pipe_producer_state); + } + + return load_pipe_producer_state; + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class AccEngine0, class AccLayout0, + class AccEngine1, class AccLayout1, + class TiledMma + > + CUTLASS_DEVICE auto + store( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + cute::Tensor accumulators0, + cute::Tensor accumulators1, + TiledMma tiled_mma, + int thread_idx, + TensorStorage& shared_tensors, + int subtile_idx=-1) { + using namespace cute; + using ElementAccumulator0 = typename AccEngine0::value_type; + using ElementAccumulator1 = typename AccEngine1::value_type; + using ElementCompute_ = typename epilogue::fusion::DualFusionCallbacksTraits::ElementCompute; + using ElementCompute = cute::conditional_t,ElementAccumulator0,ElementCompute_>; + + using CallbacksOperation = + typename epilogue::fusion::DualFusionCallbacksTraits::Operation; + constexpr bool IsDualOpPair = + epilogue::fusion::is_dual_op_pair_v; + + static_assert(is_rmem::value, "Accumulator0 must be RF resident."); + static_assert(is_rmem::value, "Accumulator1 must be RF resident."); + static_assert(rank(AccLayout0{}) == 3, "Accumulator0 must be MMA-partitioned: (MMA,MMA_M,MMA_N)"); + static_assert(rank(AccLayout1{}) == 3, "Accumulator1 must be MMA-partitioned: (MMA,MMA_M,MMA_N)"); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "TileShapeMNK must be static"); + static_assert(rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3"); + static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + + // The tma tensor D under im2col mode only has two modes (M, N) which + // should be local tiled with only (m_coord, n_coord). + auto coord_shape = conditional_return( + make_coord(m_coord, n_coord), + make_coord(m_coord, n_coord, l_coord)); + + // Represent the full output tensor, slice to get the tile this CTA is responsible for + Tensor mD_mn = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor mD = coalesce(mD_mn, take<0,2>(CtaTileMNK{})); + Tensor gD = local_tile(mD, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N) + + // Apply epilogue subtiling + Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + // Construct the corresponding pipelined smem tensors + auto ptr_sC = shared_tensors.collective.smem_C.begin(); + auto ptr_sD = shared_tensors.collective.smem_D.begin(); + Tensor sC_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + Tensor sD_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sD), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) + + TiledCopy tiled_copy_C_atom = make_tiled_copy_C_atom(CopyAtomC{}, tiled_mma); + + // (t)hread-partition for (r)egister to (r)egister copy (tRR_) + TiledCopy tiled_r2r = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + if constexpr (IsUseR2R) { + return make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + } + else { + return make_tiled_copy_S(Copy_Atom, + ElementCompute>{}, tiled_copy_C_atom); + } + }(); + ThrCopy thread_r2r = tiled_r2r.get_slice(thread_idx); + + // (t)hread-partition for (r)egister to (s)mem copy (tRS_) + TiledCopy tiled_r2s = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + if constexpr (IsUseR2R) { + return make_tiled_copy_D(Copy_Atom{}, tiled_r2r); + } + else { + return make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + } + }(); + ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); + Tensor tRS_rAcc0 = thread_r2s.retile_S(accumulators0); // ((R2S,R2S_V),MMA_M,MMA_N) + Tensor tRS_rAcc1 = thread_r2s.retile_S(accumulators1); // ((R2S,R2S_V),MMA_M,MMA_N) + Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D) + + auto mma_tile_m = size<0>(TileShapeMNK{}) / size<1>(tRS_rAcc0); + auto mma_tile_n = size<1>(TileShapeMNK{}) / size<2>(tRS_rAcc0); + auto epi_tile_m = size<0>(EpilogueTile{}); + auto epi_tile_n = size<1>(EpilogueTile{}); + + // Allocate D registers + Layout tRS_rD_layout = make_layout(take<0,3>(shape(thread_r2s.partition_S(sD_epi)))); + Tensor tRS_rD = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + + // Vectorized fragment view + constexpr int FragmentSize = DispatchPolicy::FragmentSize; + Tensor tRS_rAcc_frg0 = recast>(tRS_rAcc0); + Tensor tRS_rAcc_frg1 = recast>(tRS_rAcc1); + Tensor tRS_rD_frg = recast>(tRS_rD); + CUTE_STATIC_ASSERT(size<0>(tRS_rAcc0) % FragmentSize == 0, "Fragment size does not vectorize properly"); + + // (t)hread-partition for (s)mem to (r)egister copy (tSR_) + TiledCopy tiled_s2r = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); + Tensor tSR_sC = thread_s2r.partition_S(sC_epi); // (S2R,S2R_M,S2R_N,PIPE_C) + Layout tSR_rC_layout = thread_s2r.retile_D(tRS_rD).layout(); // (S2R,S2R_M,S2R_N) + + // Allocate C registers + // If C smem load is a non-vectorized dst(i) = src(i) then we can allocate C registers directly in the compute type + // to eliminate some redundant pack+unpack instruction sequences for sub-word types + constexpr bool IsDirectS2R = cute::is_same_v> + && decltype(max_common_vector(tSR_rC_layout, tSR_sC.layout()))::value <= 1; + using RegisterElementC = cute::conditional_t; + Tensor tRS_rC = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + Tensor tSR_rC = thread_s2r.retile_D(tRS_rC); // (S2R,S2R_M,S2R_N) + + // thread(b)lock-partition for (s)mem to (g)mem copy (bSG_) + ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{}); + Tensor bSG_sD = thrblk_s2g.partition_S(sD_epi); // (S2G,S2G_M,S2G_N,PIPE_D) + Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + + // OOB predication for tile quantization "residue" + // Absolute coordinate tensors (dynamic) + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) + Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) + Tensor tRS_cD_mn = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + if constexpr (IsUseR2R) { + // (t)hread-partition for ConsumerStoreCallbacks. + TiledCopy tiled_cst = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + ThrCopy thread_cst = tiled_cst.get_slice(thread_idx); + return thread_cst.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) + } + else { + return thread_r2s.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) + } + }(); + // Relative coordinate tensors (static) + Tensor cD = make_coord_tensor(cD_mn.layout()); // (CTA_M,CTA_N) + Tensor tRS_cD = make_coord_tensor(tRS_cD_mn.layout()); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) + // Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate + auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n) + auto residue_tRS_cD = make_coord(M,N) - tRS_cD_mn(_0{}); + + CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M"); + + if constexpr (epi_tile_m * epi_tile_n > mma_tile_m * mma_tile_n) { + // When the epilogue subtile is larger than the MMA tiles, loop over multiple MMA tiles + CUTE_STATIC_ASSERT(epi_tile_n % mma_tile_n == 0, "MMA_TILE_N must divide EPI_TILE_N"); + } + else { + CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); + } + + // Get TiledCopy for partition reference when consumer store. + TiledCopy tiled_copy_partition_ref = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + // Get the fusion callbacks for the consumer store warps + constexpr bool RefSrc = true; // Register tensors reference tiled copy src layout + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs( + problem_shape_mnkl, + CtaTileMNK{}, + tile_coord_mnkl, + tiled_mma, + EpilogueTile{}, + tiled_copy_partition_ref, + cD, + residue_cD, + tRS_cD, + residue_tRS_cD, + tRS_rC, + thread_idx + ); + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + using FragmentVisit = decltype(cst_callbacks.visit(tRS_rAcc_frg0(0), 0, 0, 0)); + constexpr bool IsDirectR2S = cute::is_same_v>; + using RegisterElementD = cute::conditional_t; + Tensor tRS_rCompute = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + Tensor tRS_rCompute_frg = recast>(tRS_rCompute); + + // Thread synchronizer for previously issued waits or fences + // to ensure visibility of smem reads/writes to threads or TMA unit + auto synchronize = [&] () CUTLASS_LAMBDA_FUNC_INLINE { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + + // Predication for TMA store (one warp issues TMA store) + bool issue_tma_store = (thread_idx / NumThreadsPerWarp) == 0; + + // In the reuse smem configuration we have StagesC smem buffers and at most StagesD committed TMA stores in flight. + // The TMA store pipeline producer acquire returns when at most StagesD-1 committed stores are in-flight, so we can + // only guarantee store completion after StagesD iterations, then we can begin issuing releases on the smem buffer locks. + // store_pipe_producer_state tracks the acquire and load_pipe_consumer_state tracks the release, in circular buffer fashion. + LoadPipelineState load_wait_state = load_pipe_consumer_state; + if constexpr (ReuseSmemC) { + load_wait_state = store_pipe_producer_state; + load_wait_state.phase_ ^= 1; + } + + // We can delay issue of TMA store by one iteration to achieve better interleaving of non-TMA instructions + // Sync requirements of smem reuse may preclude this optimization + // Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD + [[maybe_unused]] int epi_m_prev = 0; + [[maybe_unused]] int epi_n_prev = 0; + static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); + + // The TMA store sequence for one subtile iteration + auto tma_store_fn = [&] (int epi_m, int epi_n) CUTLASS_LAMBDA_FUNC_INLINE { + // Write the tile from smem to gmem with TMA + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + synchronize(); // ensure all threads have issued their async fence + if constexpr (is_destination_supported) { + if (issue_tma_store) { + copy(params.tma_store_d, bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); + } + } + + // Post async fence, pre TMA commit callback entry point + cst_callbacks.tma_store(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store); + + // Commit the TMA stores for this stage + if (issue_tma_store) { + store_pipeline.producer_commit(store_pipe_producer_state); + } + ++store_pipe_producer_state; + ++issued_stores; + + // Wait for the next smem buffer to be available + if (issue_tma_store) { + store_pipeline.producer_acquire(store_pipe_producer_state); + } + synchronize(); + + if constexpr (ReuseSmemC) { + // producer_acquire returns when at most StagesD-1 committed stores are pending + bool store_finished = issued_stores > StorePipeline::UnacquiredStages; + // Let dma warp know earliest smem buffer is consumed and empty after StagesD producer commits + if (store_finished) { + if (is_producer_load_needed) { + load_pipeline.consumer_release(load_pipe_consumer_state); + } + ++load_pipe_consumer_state; + } + } + }; + + // + // BEGIN EPILOGUE + // + + // Pre-loop fusion callback entry point + cst_callbacks.begin(); + if (cst_callbacks.begin_sync_needed()) { + synchronize(); + } + + // For each output tile + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m) { + [[maybe_unused]] bool is_first_iteration = epi_m == 0 && epi_n == 0; + bool is_last_iteration = epi_m == size<2>(gD_epi)-1 && epi_n == size<3>(gD_epi)-1; + + if (subtile_idx != -1 && (epi_n * static_cast(size<2>(gD_epi)) + epi_m) != subtile_idx) { + continue; + } + + cst_callbacks.begin_loop(epi_m, epi_n); + + if (is_producer_load_needed) { + // Wait for the producer load to fill smem + load_pipeline.consumer_wait(load_wait_state); + + if (is_C_load_needed) { + // Copy source tile from smem to register + copy(tiled_s2r, tSR_sC(_,_,_,load_wait_state.index()), tSR_rC); + // Ensure smem loads are complete before reusing smem for mixed types/layouts + if constexpr (ReuseSmemC && not (SmemLayoutC{} == SmemLayoutD{})) { + synchronize(); + } + } + } + + // First loop fusion callback entry point + cst_callbacks.previsit(epi_m, epi_n, load_wait_state.count(), is_producer_load_needed); + + if (is_producer_load_needed) { + if constexpr (not ReuseSmemC) { + // Let producer load warp know smem buffers are consumed and empty + cutlass::arch::fence_view_async_shared(); + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + ++load_wait_state; + } + + if constexpr (epi_tile_m * epi_tile_n > mma_tile_m * mma_tile_n) { + // When the epilogue subtile is larger than the MMA tiles, loop over multiple + // MMA tiles + static constexpr int MmaMPerEpiM = epi_tile_m / mma_tile_m; + static constexpr int MmaNPerEpiN = epi_tile_n / mma_tile_n; + + CUTLASS_PRAGMA_UNROLL + for (int mma_n_in_epi = 0; mma_n_in_epi < MmaNPerEpiN; ++mma_n_in_epi) { + int mma_n = (epi_n * MmaNPerEpiN) + mma_n_in_epi; + + CUTLASS_PRAGMA_UNROLL + for (int mma_m_in_epi = 0; mma_m_in_epi < MmaMPerEpiM; ++mma_m_in_epi) { + int mma_m = (epi_m * MmaMPerEpiM) + mma_m_in_epi; + Tensor tRS_rAcc_frg_mn0 = tRS_rAcc_frg0(_,mma_m,mma_n); + Tensor tRS_rAcc_frg_mn1 = tRS_rAcc_frg1(_,mma_m,mma_n); + int idx_in_epi_subtile = (mma_n_in_epi * MmaMPerEpiM + mma_m_in_epi); + + if constexpr (IsDualOpPair) { + auto comp0 = cst_callbacks.visit0(tRS_rAcc_frg_mn0(0), idx_in_epi_subtile, epi_m, epi_n); + auto comp1 = cst_callbacks.visit1(tRS_rAcc_frg_mn1(0), idx_in_epi_subtile, epi_m, epi_n); + using LsmOp = cutlass::epilogue::thread::LeftSiLUAndMul; + typename LsmOp::Params lsm_params{}; + LsmOp lsm(lsm_params); + Array combined_comp = lsm(comp0, comp1); + tRS_rCompute_frg(idx_in_epi_subtile) = combined_comp; + } else { + auto comp0 = cst_callbacks.visit(tRS_rAcc_frg_mn0(0), idx_in_epi_subtile, epi_m, epi_n); + auto comp1 = cst_callbacks.visit(tRS_rAcc_frg_mn1(0), idx_in_epi_subtile, epi_m, epi_n); + using LsmOp = cutlass::epilogue::thread::LeftSiLUAndMul; + typename LsmOp::Params lsm_params{}; + LsmOp lsm(lsm_params); + Array combined_comp = lsm(comp0, comp1); + tRS_rCompute_frg(idx_in_epi_subtile) = combined_comp; + } + } + } + } + else { + int mma_m = epi_m; + int mma_n = (epi_n * size<1>(EpilogueTile{})) / mma_tile_n; + Tensor tRS_rAcc_frg_mn0 = tRS_rAcc_frg0(_,mma_m,mma_n); + Tensor tRS_rAcc_frg_mn1 = tRS_rAcc_frg1(_,mma_m,mma_n); + + // Vectorized fragment loop with visitor callback entry point + int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n); + int r2s_v = epi_n_in_mma * size(tRS_rCompute_frg); + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(tRS_rCompute_frg); ++epi_v) { + if constexpr (IsDualOpPair) { + auto comp0 = cst_callbacks.visit0(tRS_rAcc_frg_mn0(r2s_v + epi_v), epi_v, epi_m, epi_n); + auto comp1 = cst_callbacks.visit1(tRS_rAcc_frg_mn1(r2s_v + epi_v), epi_v, epi_m, epi_n); + using LsmOp = cutlass::epilogue::thread::LeftSiLUAndMul; + typename LsmOp::Params lsm_params{}; + LsmOp lsm(lsm_params); + Array combined_comp = lsm(comp0, comp1); + tRS_rCompute_frg(epi_v) = combined_comp; + } else { + auto comp0 = cst_callbacks.visit(tRS_rAcc_frg_mn0(r2s_v + epi_v), epi_v, epi_m, epi_n); + auto comp1 = cst_callbacks.visit(tRS_rAcc_frg_mn1(r2s_v + epi_v), epi_v, epi_m, epi_n); + using LsmOp = cutlass::epilogue::thread::LeftSiLUAndMul; + typename LsmOp::Params lsm_params{}; + LsmOp lsm(lsm_params); + Array combined_comp = lsm(comp0, comp1); + tRS_rCompute_frg(epi_v) = combined_comp; + } + } + } + + // The latest we can delay the TMA store is right before the smem store of the next iteration + // since the current TMA store needs to be committed before we can acquire the next smem buffer + if constexpr (DelayTmaStore) { + // Issue TMA stores for the previous subtile + if (not is_first_iteration and subtile_idx == -1) { + tma_store_fn(epi_m_prev, epi_n_prev); + } + epi_m_prev = epi_m; + epi_n_prev = epi_n; + } + + // Smem reduction callback entry point using current store buffer for workspace + cst_callbacks.reduce(sD_epi(_,_,store_pipe_producer_state.index()), + synchronize, epi_m, epi_n, is_last_iteration, tRS_rCompute_frg); + + // Copy tile from register to regiser if needed + if constexpr (IsUseR2R) { + // retile source and destination for tiled_r2r + Tensor tRR_rD_src = thread_r2r.retile_S(tRS_rCompute); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) + Tensor tRR_rD_dst = thread_r2r.retile_D(tRS_rCompute); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) + + // Output register transformation before copying to shared memory. + copy(tiled_r2r, tRR_rD_src, tRR_rD_dst); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tRS_rD_frg); ++i) { + tRS_rD_frg(i) = cutlass::NumericArrayConverter{}(tRS_rCompute_frg(i)); + } + + // Copy tile from register to smem + if constexpr (is_destination_supported) { + copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); + } + + // Post reduction, pre TMA store callback entry point + constexpr bool issue_smem_store = true; // No smem store predication + cst_callbacks.postreduce(epi_m, epi_n, store_pipe_producer_state.count(), issue_smem_store); + + if constexpr (not DelayTmaStore) { + // Issue TMA stores for this subtile + tma_store_fn(epi_m, epi_n); + } + + cst_callbacks.end_loop(epi_m, epi_n); + + } // for epi_m + } // for epi_n + + if constexpr (DelayTmaStore) { + // Issue TMA stores for the last subtile + tma_store_fn(epi_m_prev, epi_n_prev); + } + + // Post-loop fusion callback entry point + cst_callbacks.end(); + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + + CUTLASS_DEVICE auto + store_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state) { + // wait for all TMA stores to complete + store_pipeline.producer_tail(store_pipe_producer_state); + // reset store counter + issued_stores = 0; + + if constexpr (ReuseSmemC) { + if (fusion_callbacks.is_producer_load_needed()) { + // Issue releases on up to StagesD-1 previously issued TMA stores + constexpr int release_stages = cute::min(StorePipeline::UnacquiredStages, get_load_pipe_increment(CtaTileMNK{})); + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < release_stages; ++stage) { + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + } + } + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + +private: + Params const& params; + DualFusionCallbacks fusion_callbacks; + int issued_stores = 0; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/94_blackwell_geforce_dual_gemm/device/gemm_universal_adapter.h b/examples/94_blackwell_geforce_dual_gemm/device/gemm_universal_adapter.h new file mode 100644 index 0000000000..c72e7cfe4e --- /dev/null +++ b/examples/94_blackwell_geforce_dual_gemm/device/gemm_universal_adapter.h @@ -0,0 +1,785 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! + \file + \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/layout.hpp" +#include "cutlass/detail/mma.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +#include "cutlass/kernel_launch.h" +#if !defined(__CUDACC_RTC__) +#include "cutlass/cluster_launch.hpp" +#include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) + +// 2.x +#include "cutlass/gemm/device/gemm_universal_base.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h" + +// 3.x +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::device { + +//////////////////////////////////////////////////////////////////////////////// + +/*! + GemmUniversalAdapter is a stateful, reusable GEMM handle built around a kernel + of type cutlass::gemm::kernel::Gemm or cutlass::gemm::kernel::GemmUniversal. + + It manages the lifetime of the underlying `kernel::Params` struct, and exposes APIs + to create it from the host facing arguments. For power users, new static methods + are exposed in 3.x APIs that bypass the stateful methods or args->params lowering. + + It supports kernel types that implement both the 2.x and 3.0 APIs, + however, this is done by specializing the implementation of GemmUniversalAdapter + on the two kernel API types, and thus, GemmUniversalAdapter's behaviour might + differ between the two specializations. +*/ +template +class GemmUniversalAdapter; + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// Work-around for some DispatchPolicy types not having a Stages member. +// In that case, the Stages value is 0. Most code should static_assert +// that the number of stages is valid. + +// Whether DispatchPolicy::Stages is valid. +// It should also be convertible to int, but if not, that will show up +// as a build error when GemmUniversalAdapter attempts to assign it to kStages. +template +struct has_Stages : cute::false_type {}; + +template +struct has_Stages> : cute::true_type {}; + +template +constexpr int stages_member(DispatchPolicy) { + if constexpr (has_Stages::value) { + return DispatchPolicy::Stages; + } + else { + return 0; + } +} + +} // namespace detail + +template +class GemmUniversalAdapter< + GemmKernel_, + cute::enable_if_t>::value>> +{ +public: + using GemmKernel = GetUnderlyingKernel_t; + using TileShape = typename GemmKernel::TileShape; + using ElementA = typename GemmKernel::ElementA; + using ElementB = typename GemmKernel::ElementB; + using ElementC = typename GemmKernel::ElementC; + using ElementD = typename GemmKernel::ElementD; + using ElementAccumulator = typename GemmKernel::ElementAccumulator; + using DispatchPolicy = typename GemmKernel::DispatchPolicy; + using CollectiveMainloop = typename GemmKernel::CollectiveMainloop; + using CollectiveEpilogue = typename GemmKernel::CollectiveEpilogue; + + // Map back to 2.x type as best as possible + using LayoutA = gemm::detail::StrideToLayoutTagA_t; + using LayoutB0 = gemm::detail::StrideToLayoutTagB_t; + using LayoutB1 = gemm::detail::StrideToLayoutTagB_t; + using LayoutC = gemm::detail::StrideToLayoutTagC_t; + using LayoutD = gemm::detail::StrideToLayoutTagC_t; + + static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; + + static ComplexTransform const kTransformA = cute::is_same_v ? + ComplexTransform::kConjugate : ComplexTransform::kNone; + static ComplexTransform const kTransformB0 = cute::is_same_v ? + ComplexTransform::kConjugate : ComplexTransform::kNone; + static ComplexTransform const kTransformB1 = cute::is_same_v ? + ComplexTransform::kConjugate : ComplexTransform::kNone; + // Legacy: Assume MultiplyAdd only since we do not use this tag type in 3.0 + using MathOperator = cutlass::arch::OpMultiplyAdd; + + using OperatorClass = cutlass::detail::get_operator_class_t; + + using ArchTag = typename GemmKernel::ArchTag; + + // NOTE: Assume identity swizzle for now + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + // Assume TiledMma's ShapeMNK is the same as 2.x's ThreadblockShape + using ThreadblockShape = cutlass::gemm::GemmShape< + cute::size<0>(TileShape{}), + cute::size<1>(TileShape{}), + cute::size<2>(TileShape{})>; + + using ClusterShape = cutlass::gemm::GemmShape< + cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})>; + + // Instruction shape is easy too, since we get that directly from our TiledMma's atom shape + using InstructionShape = cutlass::gemm::GemmShape< + cute::size<0>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), + cute::size<1>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), + cute::size<2>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{})>; + + // Legacy: provide a correct warp count, but no reliable warp shape + static int const kThreadCount = GemmKernel::MaxThreadsPerBlock; + + // Warp shape is not a primary API type in 3.x + // But we can best approximate it by inspecting the TiledMma + // For this, we make the assumption that we always have 4 warps along M, and rest along N, none along K + // We also always round up the warp count to 4 if the tiled mma is smaller than 128 threads + static constexpr int WarpsInMma = cute::max(4, CUTE_STATIC_V(cute::size(typename GemmKernel::TiledMma{})) / 32); + static constexpr int WarpsInMmaM = 4; + static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM); + using WarpCount = cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape< + CUTE_STATIC_V(cute::tile_size<0>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaM, + CUTE_STATIC_V(cute::tile_size<1>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaN, + CUTE_STATIC_V(cute::tile_size<2>(typename CollectiveMainloop::TiledMma{}))>; + + static int constexpr kStages = detail::stages_member(typename CollectiveMainloop::DispatchPolicy{}); + + // Inspect TiledCopy for A and B to compute the alignment size + static int constexpr kAlignmentA = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyA, ElementA, typename CollectiveMainloop::TiledMma::ValTypeA>(); + static int constexpr kAlignmentB = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyB, ElementB, typename CollectiveMainloop::TiledMma::ValTypeB>(); + static int constexpr kAlignmentC = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveEpilogue::GmemTiledCopyC, ElementC>(); + static int constexpr kAlignmentD = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveEpilogue::GmemTiledCopyD, ElementD>(); + + using EpilogueOutputOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + // Split-K preserves splits that are 128b aligned + static int constexpr kSplitKAlignment = cute::max( + 128 / sizeof_bits::value, 128 / sizeof_bits::value); + + /// Argument structure: User API + using Arguments = typename GemmKernel::Arguments; + /// Argument structure: Kernel API + using Params = typename GemmKernel::Params; + +private: + + /// Kernel API parameters object + Params params_; + +public: + + /// Access the Params structure + Params const& params() const { + return params_; + } + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + if (GemmKernel::can_implement(args)) { + return Status::kSuccess; + } + else { + return Status::kInvalid; + } + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { + workspace_bytes += sizeof(int) * size_t(cute::size<0>(TileShape{})) * size_t(cute::size<1>(TileShape{})); + } + + workspace_bytes += GemmKernel::get_workspace_size(args); + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Arguments const& args, void* workspace = nullptr) { + auto tmp_params = GemmKernel::to_underlying_arguments(args, workspace); + return GemmKernel::get_grid_shape(tmp_params); + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Params const& params) { + return GemmKernel::get_grid_shape(params); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("GemmUniversal::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = GemmKernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + device_kernel, + GemmKernel::MaxThreadsPerBlock, + smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes GEMM state from arguments. + Status + initialize( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + + CUTLASS_TRACE_HOST("GemmUniversal::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = GemmKernel::initialize_workspace(args, workspace, stream, cuda_adapter); + if (status != Status::kSuccess) { + return status; + } + // Initialize the Params structure + params_ = GemmKernel::to_underlying_arguments(args, workspace); + // Don't set the function attributes - require the CudaHostAdapter to set it. + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + return Status::kSuccess; + } + else { + // + // Account for dynamic smem capacity if needed + // + int smem_size = GemmKernel::SharedStorageSize; + + CUTLASS_ASSERT(cuda_adapter == nullptr); + + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + } + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status + update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversal()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + params_ = GemmKernel::to_underlying_arguments(args, workspace); + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling GemmKernel::to_underlying_arguments() + static Status + run(Params& params, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false) { + CUTLASS_TRACE_HOST("GemmUniversal::run()"); + dim3 const block = GemmKernel::get_block_shape(); + dim3 const grid = get_grid_shape(params); + + // configure smem size and carveout + int smem_size = GemmKernel::SharedStorageSize; + + Status launch_result{ Status::kSuccess }; + // Use extended launch API only for mainloops that use it + if constexpr (GemmKernel::ArchTag::kMinComputeCapability >= 90) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Use extended launch API"); +#endif + [[maybe_unused]] constexpr bool is_static_1x1x1 = + cute::is_static_v and + cute::size(typename GemmKernel::DispatchPolicy::ClusterShape{}) == 1; + [[maybe_unused]] dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})); + + // Dynamic cluster support + [[maybe_unused]] dim3 fallback_cluster = dim3{0,0,0}; + if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 100 + || GemmKernel::ArchTag::kMinComputeCapability == 101 + ) { + if constexpr (!cute::is_static_v) { + fallback_cluster = params.hw_info.cluster_shape_fallback; + cluster = params.hw_info.cluster_shape; + } + } + + [[maybe_unused]] void* kernel_params[] = {¶ms}; + + if constexpr (kEnableCudaHostAdapter) { + // + // Use the cuda host adapter + // + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + if (launch_with_pdl) { + CUTLASS_TRACE_HOST( + "GemmUniversal::run() does not support launching with PDL and a custom cuda adapter."); + return Status::kErrorInternal; + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with CUDA host adapter"); +#endif + if constexpr (is_static_1x1x1) { + launch_result = cuda_adapter->launch(grid, + block, + smem_size, + stream, + kernel_params, + 0); + } + else { + launch_result = cuda_adapter->launch(grid, + cluster, + fallback_cluster, + block, + smem_size, + stream, + kernel_params, + 0); + } + } + else { + CUTLASS_TRACE_HOST("GemmUniversal::run: kEnableCudaHostAdapter is true, but CUDA host adapter is null"); + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + [[maybe_unused]] void const* kernel = (void const*) device_kernel; + static constexpr bool kClusterLaunch = GemmKernel::ArchTag::kMinComputeCapability == 90; + if constexpr (kClusterLaunch) { + if constexpr (is_static_1x1x1) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching static 1x1x1 kernel"); +#endif + launch_result = cutlass::kernel_launch( + grid, block, smem_size, stream, params, launch_with_pdl); + if (launch_result != Status::kSuccess) { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports failure"); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports success"); + } +#endif + } + else { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching dynamic cluster kernel"); +#endif + launch_result = ClusterLauncher::launch( + grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl); + } + } + + else { + if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 100 + || GemmKernel::ArchTag::kMinComputeCapability == 101 + || GemmKernel::ArchTag::kMinComputeCapability == 120 + ) { + if constexpr (is_static_1x1x1) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching static 1x1x1 kernel"); +#endif + launch_result = cutlass::kernel_launch(grid, block, smem_size, stream, params, launch_with_pdl); + if (launch_result != Status::kSuccess) { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports failure"); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports success"); + } +#endif + } + else { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with fall-back cluster"); +#endif + launch_result = ClusterLauncher::launch_with_fallback_cluster( + grid, + cluster, + fallback_cluster, + block, + smem_size, + stream, + kernel, + kernel_params, + launch_with_pdl); + } + } + } + + } + } + else { + launch_result = Status::kSuccess; + cutlass::arch::synclog_setup(); + + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + void* kernel_params[] = {¶ms}; +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with CUDA host adapter"); +#endif + launch_result = cuda_adapter->launch( + grid, block, smem_size, stream, kernel_params, 0 + ); + + } + else { + CUTLASS_TRACE_HOST("GemmUniversal::run: CUDA host adapter is null"); + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with cutlass::kernel_launch"); +#endif + launch_result = cutlass::kernel_launch( + grid, block, smem_size, stream, params, launch_with_pdl); + if (launch_result != Status::kSuccess) { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports failure"); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + CUTLASS_TRACE_HOST("GemmUniversal::run: cutlass::kernel_launch reports success"); + } +#endif + } + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result && Status::kSuccess == launch_result) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: cudaGetLastError reports success"); +#endif + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false + ) { + Status status = initialize(args, workspace, stream, cuda_adapter); + + if (Status::kSuccess == status) { + status = run(params_, stream, cuda_adapter, launch_with_pdl); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false) { + return run(args, workspace, stream, cuda_adapter, launch_with_pdl); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run( + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false) { + return run(params_, stream, cuda_adapter, launch_with_pdl); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, bool launch_with_pdl = false) { + return run(params_, stream, cuda_adapter, launch_with_pdl); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 2.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversalAdapter< + GemmKernel_, + cute::enable_if_t>::value>> +{ +public: + + using GemmKernel = GetUnderlyingKernel_t; + + static bool const kInternalTranspose = + !cutlass::epilogue::threadblock::detail::is_2x_evt_v && // 2.x EVT does not require internal transpose + cute::is_same::value; + + using ThreadblockShape = typename GemmKernel::Mma::Shape; + using WarpShape = typename GemmKernel::WarpShape; + using InstructionShape = typename GemmKernel::InstructionShape; + + // warp-level, arch-level (instruction), math operator + using WarpMmaOperator = typename GemmKernel::Mma::Policy::Operator; + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename WarpMmaOperator::MathOperator; + + // Operator class and arch tag extract bottom-up + // set it for top-level gemm device-level template + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + + // Type, layout, and complex transform deliberately exchanged with B + using MapArguments = kernel::detail::MapArguments< + typename GemmKernel::ElementA, + typename GemmKernel::LayoutA, + GemmKernel::kTransformA, + GemmKernel::kAlignmentA, + typename GemmKernel::ElementB, + typename GemmKernel::LayoutB, + GemmKernel::kTransformB, + GemmKernel::kAlignmentB, + typename GemmKernel::LayoutC, + kInternalTranspose + >; + + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static int const kAlignmentA = MapArguments::kAlignmentA; + + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + static int const kAlignmentB = MapArguments::kAlignmentB; + + using ElementC = typename GemmKernel::ElementC; + using LayoutC = typename MapArguments::LayoutC; + static int const kAlignmentC = GemmKernel::kAlignmentC; + + // C and D same type for 2.x kernel + using ElementD = ElementC; + using LayoutD = LayoutC; + + using TensorRefA = TensorRef; + using TensorRefB = TensorRef; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + + static int const kStages = GemmKernel::Mma::kStages; + + using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; + using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; + using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; + using UnderlyingOperator = GemmUniversalBase; + using Arguments = typename UnderlyingOperator::Arguments; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the GEMM. + GemmUniversalAdapter() { } + + /// Helper to construct a transposed equivalent for the underlying GEMM operator + static Arguments to_underlying_arguments(Arguments const &args) { + if (kInternalTranspose) { + return args.transposed_problem(); + } + else { + return args; + } + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args), cuda_adapter); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args, CudaHostAdapter *cuda_adapter = nullptr) { + + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args), cuda_adapter); + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args) { + return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + return UnderlyingOperator::maximum_active_blocks(smem_capacity); + } + + /// Initializes GEMM state from arguments. + Status initialize( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr + ) { + + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream, cuda_adapter); + } + + /// Lightweight update given a subset of arguments. + Status update(Arguments const &args) { + + return underlying_operator_.update(to_underlying_arguments(args)); + } + + /// Runs the kernel using initialized state. + Status run( + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + + return underlying_operator_.run(stream, cuda_adapter); + } + + /// Runs the kernel using initialized state. + Status operator()( + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + + Status status = initialize(args, workspace, stream, cuda_adapter); + + if (status == Status::kSuccess) { + status = run(stream, cuda_adapter); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/94_blackwell_geforce_dual_gemm/fusion/callbacks.hpp b/examples/94_blackwell_geforce_dual_gemm/fusion/callbacks.hpp new file mode 100644 index 0000000000..4b6569c168 --- /dev/null +++ b/examples/94_blackwell_geforce_dual_gemm/fusion/callbacks.hpp @@ -0,0 +1,91 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/detail/dependent_false.hpp" +#include "operations.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Dispatch interface for epilogue fusion callbacks +// For visitor fusions, this is just a convenience wrapper to provide metadata and non-nested args. +// It is also valid to just pass visitor callbacks directly to the collective, e.g. fusion::Sm90LinearCombination, +// provided the collective supports a visitor callbacks interface. This is useful for implementing custom fusions. +template < + class DispatchPolicy, // specialize on collective's dispatch policy since callbacks API will depend on collective's algorithm + class Operation, // the fusion operation being performed, e.g. fusion::LinearCombination + class CtaTile_MNK, // computed tile per CTA + class EpilogueTile_MN, // epilogue subtile size + class... Args // callbacks implementation dependent args (e.g. copy atoms, smem layouts) +> +struct DualFusionCallbacks { + static_assert(cutlass::detail::dependent_false, "Could not find a callbacks specialization."); +}; + +// Metadata helper to handle custom EVTs or other non-FusionCallbacks types +template +struct DualFusionCallbacksTraits { + using DispatchPolicy = void; + using Callbacks = T; + using Operation = DualFusionOperation; + using CtaTile_MNK = void; + using EpilogueTile_MN = void; + using ElementCompute = void; +}; + +template < + class DispatchPolicy_, + class Operation_, + class CtaTile_MNK_, + class EpilogueTile_MN_, + class... Args +> +struct DualFusionCallbacksTraits< + DualFusionCallbacks +> { + using DispatchPolicy = DispatchPolicy_; + using Callbacks = DualFusionCallbacks; + using Operation = Operation_; + using CtaTile_MNK = CtaTile_MNK_; + using EpilogueTile_MN = EpilogueTile_MN_; + using ElementCompute = typename Operation::ElementCompute; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/94_blackwell_geforce_dual_gemm/fusion/operations.hpp b/examples/94_blackwell_geforce_dual_gemm/fusion/operations.hpp new file mode 100644 index 0000000000..8dfe213a0b --- /dev/null +++ b/examples/94_blackwell_geforce_dual_gemm/fusion/operations.hpp @@ -0,0 +1,152 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include +#include // cute::false_type +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Fusion Operations +// Template args must not be implementation dependent +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct DualFusionOperation { + // metadata types/queries that can be overrided + using ElementOutput = void; + using ElementCompute = void; + FloatRoundStyle RoundStyle = FloatRoundStyle::round_indeterminate; + + using ElementSource = void; + static constexpr bool IsSourceSupported = false; + static constexpr bool IsResidualSupported = false; // Source is added after activation + + using ElementScalar = void; + static constexpr int AlignmentScalar = 0; + static constexpr bool IsScaleFactorSupported = false; + static constexpr bool IsPerRowScaleSupported = false; + static constexpr bool IsPerColScaleSupported = false; + + using ElementBias = void; + static constexpr int AlignmentBias = 0; + static constexpr bool IsPerRowBiasSupported = false; + static constexpr bool IsPerColBiasSupported = false; + static constexpr bool IsDePerRowBiasSupported = false; + + using ActivationFn = void; + static constexpr bool IsEltActSupported = false; + static constexpr bool IsDeEltActSupported = false; + + using ElementAux = void; + using GmemLayoutTagAux = void; + static constexpr int AlignmentAux = 0; + static constexpr bool IsAuxOutSupported = false; + static constexpr bool IsAuxInSupported = false; + + using ElementAmax = void; + static constexpr bool IsAbsMaxSupported = false; + + using ElementBlockScaleFactor = void; + static constexpr int SFVecSize = 0; + static constexpr bool IsBlockScaleSupported = false; // Umbrella variable to check BlockScaling support in the epilogues + using GmemLayoutTagScalefactor = void; +}; + +// D = alpha * acc +template< + class ElementOutput_, + class ElementCompute_, + class ElementScalar_ = ElementCompute_, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct DualScaledAcc : DualFusionOperation { + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementScalar = ElementScalar_; + static constexpr int AlignmentScalar = 1; + static constexpr auto RoundStyle = RoundStyle_; +}; + +// D = alpha * acc + beta * C +template< + class ElementOutput_, + class ElementCompute_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct DualLinearCombination + : DualScaledAcc { + using ElementSource = ElementSource_; + static constexpr bool IsSourceSupported = true; +}; + +template +struct DualOpPair : DualFusionOperation { + using LeftOp = Op0; + using RightOp = Op1; + using ElementCompute = + std::conditional_t< + !std::is_void_v, + typename Op0::ElementCompute, + typename Op1::ElementCompute>; + // Inherit element metadata heuristically from left if available else right + using ElementOutput = std::conditional_t< + !std::is_void_v, + typename Op0::ElementOutput, + typename Op1::ElementOutput>; + // Provide passthrough flags (OR) + static constexpr bool IsSourceSupported = + (bool)Op0::IsSourceSupported || (bool)Op1::IsSourceSupported; +}; + +// trait to detect DualOpPair +template struct is_dual_op_pair : std::false_type {}; +template struct is_dual_op_pair> : std::true_type {}; +template +inline constexpr bool is_dual_op_pair_v = is_dual_op_pair::value; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/94_blackwell_geforce_dual_gemm/fusion/sm120_callbacks_tma_warpspecialized_dual.hpp b/examples/94_blackwell_geforce_dual_gemm/fusion/sm120_callbacks_tma_warpspecialized_dual.hpp new file mode 100644 index 0000000000..3f0b6cf975 --- /dev/null +++ b/examples/94_blackwell_geforce_dual_gemm/fusion/sm120_callbacks_tma_warpspecialized_dual.hpp @@ -0,0 +1,121 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +/*! \file + \brief Fusion callbacks specializations for the SM120 TMA warp-specialized (ws) epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp" +#include "callbacks.hpp" +#include "../fusion/sm90_callbacks_tma_warpspecialized_dual.hpp" +#include "../collective/dispatch_policy_extra.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Sm120 Tma warp specialized callbacks just alias to their sm90 counterpart +// ... existing includes ... + +// Generic forwarder (ALL operations, including single-op) – keep Args... +template < + int StagesC, int StagesD, int FragmentSize, + bool ReuseSmemC, bool DelayTmaStore, + class Operation, + class CtaTileShapeMNK, + class EpilogueTile, + class... Args +> +struct DualFusionCallbacks< + epilogue::DualSm120TmaWarpSpecialized, + Operation, + CtaTileShapeMNK, + EpilogueTile, + Args... +> : DualFusionCallbacks< + epilogue::DualSm90TmaWarpSpecialized, + Operation, + CtaTileShapeMNK, + EpilogueTile, + Args...> { + using Base = DualFusionCallbacks< + epilogue::DualSm90TmaWarpSpecialized, + Operation, + CtaTileShapeMNK, + EpilogueTile, + Args...>; + using Base::Base; +}; + +// DualOpPair specific forwarder (preserves Args...) to ensure dual visit0/visit1 path propagates +template < + int StagesC, int StagesD, int FragmentSize, + bool ReuseSmemC, bool DelayTmaStore, + class Op0, class Op1, + class CtaTileShapeMNK, + class EpilogueTile, + class... Args +> +struct DualFusionCallbacks< + epilogue::DualSm120TmaWarpSpecialized, + DualOpPair, + CtaTileShapeMNK, + EpilogueTile, + Args... +> : DualFusionCallbacks< + epilogue::DualSm90TmaWarpSpecialized, + DualOpPair, + CtaTileShapeMNK, + EpilogueTile, + Args...> { + using Base = DualFusionCallbacks< + epilogue::DualSm90TmaWarpSpecialized, + DualOpPair, + CtaTileShapeMNK, + EpilogueTile, + Args...>; + using Base::Base; +}; + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/94_blackwell_geforce_dual_gemm/fusion/sm90_callbacks_tma_warpspecialized_dual.hpp b/examples/94_blackwell_geforce_dual_gemm/fusion/sm90_callbacks_tma_warpspecialized_dual.hpp new file mode 100644 index 0000000000..e14aa6e376 --- /dev/null +++ b/examples/94_blackwell_geforce_dual_gemm/fusion/sm90_callbacks_tma_warpspecialized_dual.hpp @@ -0,0 +1,436 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Fusion callbacks specializations for the sm90 TMA warp-specialized (ws) epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp" + +#include "callbacks.hpp" +#include "../collective/dispatch_policy_extra.hpp" +#include "sm120_callbacks_tma_warpspecialized_dual.hpp" +#include "operations.hpp" +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +using Sm90EVT = Sm90TreeVisitor; + +// D = alpha * acc +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct DualFusionCallbacks< + epilogue::DualSm90TmaWarpSpecialized, + fusion::DualScaledAcc, + CtaTileShapeMNK, + EpilogueTile +> : Sm90EVT, + Sm90ScalarBroadcast>, + Sm90AccFetch + > { + using Impl = + Sm90EVT, + Sm90ScalarBroadcast>, + Sm90AccFetch + >; + using Operation = fusion::DualScaledAcc; + + struct Arguments { + // Give a name and flat ordering to the fusion callback args + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + + // Conversion to the args expected by the visitor implementation + // to_underlying_arguments will implicitly call this + operator typename Impl::Arguments() const { + return + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }; // end binary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = alpha * acc + beta * C +template< + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using DualSm90LinearCombination = + Sm90EVT, // beta * C + (alpha * acc) + Sm90ScalarBroadcast>, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + Sm90ScalarBroadcast>, // alpha + Sm90AccFetch // acc + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct DualFusionCallbacks< + epilogue::DualSm90TmaWarpSpecialized, + fusion::DualLinearCombination, + CtaTileShapeMNK, + EpilogueTile +> : DualSm90LinearCombination::type, ElementCompute, ElementSource, ElementScalar, RoundStyle> { + + using Impl = DualSm90LinearCombination::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::DualLinearCombination; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + operator typename Impl::Arguments() const { + return + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// DualOpPair specialization: allow distinct ops for accumulator0 and accumulator1 +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class Op0, + class Op1, + class CtaTileShapeMNK, + class EpilogueTile +> +struct DualFusionCallbacks< + epilogue::DualSm90TmaWarpSpecialized, + DualOpPair, + CtaTileShapeMNK, + EpilogueTile +> { + + using DispatchPolicy = epilogue::DualSm90TmaWarpSpecialized; + using Operation = DualOpPair; + + using CB0 = DualFusionCallbacks; + using CB1 = DualFusionCallbacks; + + struct Arguments { + typename CB0::Arguments op0{}; + typename CB1::Arguments op1{}; + }; + + struct Params { + typename CB0::Params op0; + typename CB1::Params op1; + }; + + struct SharedStorage { + typename CB0::SharedStorage left; + typename CB1::SharedStorage right; + }; + + template + static Params to_underlying_arguments(ProblemShape const& ps, Arguments const& a, void* ws) { + return { + CB0::to_underlying_arguments(ps, a.op0, ws), + CB1::to_underlying_arguments(ps, a.op1, ws) + }; + } + + template + static size_t get_workspace_size(ProblemShape const& ps, Arguments const& a) { + return CB0::get_workspace_size(ps, a.op0) + CB1::get_workspace_size(ps, a.op1); + } + + template + static cutlass::Status initialize_workspace( + ProblemShape const& ps, Arguments const& a, void* ws, cudaStream_t stream, + CudaHostAdapter* adapter = nullptr) { + size_t size0 = CB0::get_workspace_size(ps, a.op0); + void* ws1 = static_cast(ws) + size0; + CUTLASS_CHECK(CB0::initialize_workspace(ps, a.op0, ws, stream, adapter)); + CUTLASS_CHECK(CB1::initialize_workspace(ps, a.op1, ws1, stream, adapter)); + return cutlass::Status::kSuccess; + } + + template + static bool can_implement(ProblemShape const& ps, Arguments const& a) { + return CB0::can_implement(ps, a.op0) && CB1::can_implement(ps, a.op1); + } + + CUTLASS_DEVICE + DualFusionCallbacks(Params const& p, SharedStorage& st) + : cb0_(p.op0, st.left), cb1_(p.op1, st.right) {} + + CUTLASS_DEVICE bool is_producer_load_needed() const { + return cb0_.is_producer_load_needed() || cb1_.is_producer_load_needed(); + } + + CUTLASS_DEVICE bool is_C_load_needed() const { + return detail_is_C_load_needed(cb0_) || detail_is_C_load_needed(cb1_); + } + +private: + CB0 cb0_; + CB1 cb1_; + template + struct has_is_C_load_needed : std::false_type {}; + + template + struct has_is_C_load_needed().is_C_load_needed())>> : std::true_type {}; + + template + CUTLASS_DEVICE static bool detail_is_C_load_needed(CbT const& cb) { + if constexpr (has_is_C_load_needed::value) return cb.is_C_load_needed(); + else return false; + } + + template + CUTLASS_DEVICE static auto obtain_consumer_impl(int, Cb& cb, Args const& a) + -> decltype(cb.template get_consumer_store_callbacks(a)) { + return cb.template get_consumer_store_callbacks(a); + } + template + CUTLASS_DEVICE static auto obtain_consumer_impl(long, Cb& cb, Args const& a) + -> decltype(cb.get_consumer_store_callbacks(a)) { + return cb.get_consumer_store_callbacks(a); + } + template + CUTLASS_DEVICE static auto obtain_consumer(Cb& cb, Args const& a) { + return obtain_consumer_impl(0, cb, a); + } + + // Wrapper structs + template + struct ProducerPair { + P0 a; P1 b; + CUTLASS_DEVICE void begin(){ a.begin(); b.begin(); } + CUTLASS_DEVICE void step(uint64_t* barrier,int em,int en,int stage,bool issue){ + a.step(barrier,em,en,stage,issue); + b.step(barrier,em,en,stage,issue); + } + CUTLASS_DEVICE void end(){ a.end(); b.end(); } + }; + + template + struct ConsumerPair { + C0 a; C1 b; + + // Trait helpers + template struct has_begin_sync_needed : std::false_type {}; + template struct has_begin_sync_needed().begin_sync_needed())>> : std::true_type {}; + + template struct has_tma_store : std::false_type {}; + template struct has_tma_store().tma_store(0,0,0,false))>> : std::true_type {}; + + CUTLASS_DEVICE void begin(){ a.begin(); b.begin(); } + + CUTLASS_DEVICE bool begin_sync_needed(){ + bool r0=false,r1=false; + if constexpr (has_begin_sync_needed::value) r0 = a.begin_sync_needed(); + if constexpr (has_begin_sync_needed::value) r1 = b.begin_sync_needed(); + return r0 || r1; + } + + CUTLASS_DEVICE void begin_loop(int em,int en){ a.begin_loop(em,en); b.begin_loop(em,en); } + + CUTLASS_DEVICE void previsit(int em,int en,int stage,bool has_src){ + a.previsit(em,en,stage,has_src); b.previsit(em,en,stage,has_src); + } + + template + CUTLASS_DEVICE auto visit0(Frag const& f,int idx,int em,int en){ + return a.visit(f,idx,em,en); + } + template + CUTLASS_DEVICE auto visit1(Frag const& f,int idx,int em,int en){ + return b.visit(f,idx,em,en); + } + template + CUTLASS_DEVICE auto visit(Frag const& f,int idx,int em,int en){ + return a.visit(f,idx,em,en); + } + + template + CUTLASS_DEVICE void reduce(SmemTensor&& t, SyncFn&& sync,int em,int en,bool last, FragTensor& fr){ + a.reduce(t,sync,em,en,last,fr); + b.reduce(t,sync,em,en,last,fr); + } + + CUTLASS_DEVICE void postreduce(int em,int en,int stage,bool issue){ + a.postreduce(em,en,stage,issue); + b.postreduce(em,en,stage,issue); + } + + CUTLASS_DEVICE void tma_store(int em,int en,int stage,bool issue){ + if constexpr (has_tma_store::value) a.tma_store(em,en,stage,issue); + if constexpr (has_tma_store::value) b.tma_store(em,en,stage,issue); + } + + CUTLASS_DEVICE void end_loop(int em,int en){ a.end_loop(em,en); b.end_loop(em,en); } + CUTLASS_DEVICE void end(){ a.end(); b.end(); } + }; + +public: + template + CUTLASS_DEVICE auto get_producer_load_callbacks(ProducerLoadArgs const& args) { + auto p0 = cb0_.get_producer_load_callbacks(args); + auto p1 = cb1_.get_producer_load_callbacks(args); + return ProducerPair{p0,p1}; + } + + template + CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto c0 = obtain_consumer(cb0_, args); + auto c1 = obtain_consumer(cb1_, args); + return ConsumerPair{c0,c1}; + } + +}; + +namespace detail { +template > +struct dual_get_element_aux { + using type = void; +}; + +template +struct dual_get_element_aux> { + using type = typename DualFusionOpOrCallbacks::ElementAux; +}; + +template +struct dual_get_element_aux, cute::void_t<>> { + using type = typename dual_get_element_aux::type; +}; + +template +struct dual_get_element_aux, cute::void_t::Operation>> { + private: + using Operation = typename DualFusionCallbacks::Operation; + public: + using type = typename dual_get_element_aux::type; +}; + +template +struct dual_get_element_aux, cute::void_t<>> { + using type = typename dual_get_element_aux::type; +}; +} // namespace cutlass:epilogue::fusion::detail + +template +using dual_get_element_aux_t = typename detail::dual_get_element_aux::type; + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/94_blackwell_geforce_dual_gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative_dual.hpp b/examples/94_blackwell_geforce_dual_gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative_dual.hpp new file mode 100644 index 0000000000..711aa882b6 --- /dev/null +++ b/examples/94_blackwell_geforce_dual_gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative_dual.hpp @@ -0,0 +1,870 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/workspace.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/arch/mma_sm90.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cute/tensor.hpp" +#include "cutlass/trace.h" +#include "cutlass/gemm/kernel/gemm_universal_decl.h" +#include "cutlass/arch/grid_dependency_control.h" + +#include "../collective/dispatch_policy_extra.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileSchedulerTag_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileSchedulerTag_, + cute::enable_if_t, typename CollectiveMainloop_::DispatchPolicy::Schedule>>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB0 = typename CollectiveMainloop::StrideB0; + using StrideB1 = typename CollectiveMainloop::StrideB1; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(ArchTag::kMinComputeCapability >= 90); + + static constexpr uint32_t TileSchedulerPipelineStageCount = DispatchPolicy::Schedule::SchedulerPipelineStageCount; + using TileSchedulerTag = TileSchedulerTag_; + + using TileScheduler = typename detail::TileSchedulerSelector< + TileSchedulerTag, + ArchTag, + TileShape, + ClusterShape + ,TileSchedulerPipelineStageCount + >::Scheduler; + + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + // Warp specialization thread count per threadblock + static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMMAThreads = size(TiledMma{}); // 8 warps + static constexpr uint32_t NumMainloopLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp for C + + static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent; + static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled; + + static constexpr uint32_t NumLoadWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = NumMMAThreads / NumThreadsPerWarpGroup; + static constexpr uint32_t MaxThreadsPerBlock = NumMMAThreads + (NumLoadWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + static constexpr uint32_t NumFixupBarriers = NumMmaWarpGroups; + static constexpr uint32_t NumProducerThreads = CollectiveMainloop::NumProducerThreadEvents; + static constexpr bool IsMainloopAuxiliaryLoadNeeded = detail::HasAuxiliaryLoad_v; + + /// Register requirement for Load and Math WGs + static constexpr int RegsPerThread = + size<0>(TileShape{}) * size<1>(TileShape{}) / NumMMAThreads * + sizeof(ElementAccumulator) / sizeof(uint32_t); + static constexpr bool HeavyRegisterPressure = RegsPerThread >= 208; + static constexpr uint32_t LoadRegisterRequirement = !HeavyRegisterPressure ? 40 : 24; + static constexpr uint32_t MmaRegisterRequirement = !HeavyRegisterPressure ? 232 : 240; + + // 1 stage ordered sequence between mainloop and epilogue producer load threads + using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; + + using TileSchedulerPipeline = typename TileScheduler::Pipeline; + using TileSchedulerPipelineState = typename TileSchedulerPipeline::PipelineState; + using TileSchedulerStorage = typename TileScheduler::SharedStorage; + using TileSchedulerThrottlePipeline = typename TileScheduler::ThrottlePipeline; + using TileSchedulerThrottlePipelineState = typename TileSchedulerThrottlePipeline::PipelineState; + + // Kernel level shared memory storage + struct SharedStorage { + struct PipelineStorage : cute::aligned_struct<16, _1> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; + } pipelines; + + alignas(16) TileSchedulerStorage scheduler; + + struct TensorStorage : cute::aligned_struct<128, _1> { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + EpilogueTensorStorage epilogue; + MainloopTensorStorage mainloop; + } tensors; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + void* workspace{nullptr}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + auto problem_shape = args.problem_shape; + if constexpr (detail::Has_SwapAB_v) { + // swap M/N + get<0>(problem_shape) = get<1>(args.problem_shape); + get<1>(problem_shape) = get<0>(args.problem_shape); + } + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + // Get maximum number of clusters that could co-exist on the target device + int max_active_clusters = args.hw_info.max_active_clusters; + if (max_active_clusters <= 0) { + max_active_clusters = 0; + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid max cluster count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the max_active_clusters."); + } + else { + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid cluster count to " << max_active_clusters); + } + + KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count, max_active_clusters}; + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* scheduler_workspace = workspace_ptr + workspace_offset; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = nullptr; + // Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used + // in separate reduction scheme for streamk case, NumEpilogueSubTiles default value is 1, which means + // subtile will not be used, therefore separate reduction will not be enabled. + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments( + problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles + ); + + return { + args.mode, + problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), + hw_info, + scheduler, + workspace + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + return implementable; + } + + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_size = 0; + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + return workspace_size; + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + static constexpr uint32_t NumAccumulatorMtxs = 1; + + status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, NumAccumulatorMtxs, cuda_adapter); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently + TileSchedulerArguments args{}; + if constexpr (!std::is_const_v) { + args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; + } + args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN ? TileScheduler::RasterOrderOptions::AlongN : TileScheduler::RasterOrderOptions::AlongM; + return TileScheduler::get_grid_shape(params.scheduler, params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + using namespace cute; + using X = Underscore; + +#if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM120_ALL) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1200)) +# define ENABLE_SM90_KERNEL_LEVEL 1 +#endif +// Any Tensor Op MMA Atom in the ISA is arch conditional. +#if ! defined(ENABLE_SM90_KERNEL_LEVEL) + printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); +#else + + // Preconditions + static_assert(NumMMAThreads == 256, "Cooperative kernel must have TiledMMA operating using 256 threads."); + static_assert(size<0>(TileShape{}) >= 128, + "Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension."); + + static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB0{}) == 3, "StrideB0 must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB1{}) == 3, "StrideB1 must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + /* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */ + enum class WarpGroupRole { + Producer = 0, + Consumer0 = 1, + Consumer1 = 2 + }; + enum class ProducerWarpRole { + Mainloop = 0, + Warp1 = 1, + Epilogue = 2, + MainloopAux = 3 + }; + + + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int lane_idx = canonical_lane_idx(); + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + int mma_thread_idx = thread_idx % NumMMAThreads; + auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); + auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); + int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); + // TileScheduler pipeline + typename TileSchedulerPipeline::Params scheduler_pipeline_params; + typename TileSchedulerThrottlePipeline::Params scheduler_throttle_pipeline_params; + if constexpr (IsSchedDynamicPersistent) { + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Warp1) { + scheduler_pipeline_params.role = TileSchedulerPipeline::ThreadCategory::ProducerConsumer; + } + else { + scheduler_pipeline_params.role = TileSchedulerPipeline::ThreadCategory::Consumer; + } + scheduler_pipeline_params.producer_blockid = 0; + scheduler_pipeline_params.producer_arv_count = 1; + scheduler_pipeline_params.consumer_arv_count = NumSchedThreads + NumMainloopLoadThreads + NumMMAThreads; + + if (is_epi_load_needed) { + scheduler_pipeline_params.consumer_arv_count += NumEpilogueLoadThreads; + } + scheduler_pipeline_params.transaction_bytes = sizeof(typename TileScheduler::CLCResponse); + + scheduler_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; + scheduler_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; + scheduler_throttle_pipeline_params.dst_blockid = 0; + scheduler_throttle_pipeline_params.initializing_warp = 3; + if (warp_group_role == WarpGroupRole::Producer && + producer_warp_role == ProducerWarpRole::Warp1) { + scheduler_throttle_pipeline_params.role = + TileSchedulerThrottlePipeline::ThreadCategory::Consumer; + } + // set role when it is for DMA warp in Mainloop + else if (warp_group_role == WarpGroupRole::Producer && + producer_warp_role == ProducerWarpRole::Mainloop) { + scheduler_throttle_pipeline_params.role = + TileSchedulerThrottlePipeline::ThreadCategory::Producer; + } + } + TileSchedulerPipeline scheduler_pipeline(shared_storage.scheduler.pipeline(), scheduler_pipeline_params); + TileSchedulerPipelineState scheduler_pipe_consumer_state; + + TileSchedulerThrottlePipeline scheduler_throttle_pipeline(shared_storage.scheduler.throttle_pipeline(), scheduler_throttle_pipeline_params); + TileSchedulerThrottlePipelineState scheduler_pipe_throttle_consumer_state; + TileSchedulerThrottlePipelineState scheduler_pipe_throttle_producer_state = cutlass::make_producer_start_state(); + + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && (producer_warp_role == ProducerWarpRole::Mainloop || + producer_warp_role == ProducerWarpRole::MainloopAux)) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; + mainloop_pipeline_params.num_consumers = NumMMAThreads; + mainloop_pipeline_params.num_producers = NumProducerThreads; + mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = NumEpilogueLoadThreads; + epi_load_pipeline_params.consumer_arv_count = NumMMAThreads; + if constexpr (CollectiveEpilogue::RequiresTransactionBytes) { + epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes; + } + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + typename LoadWarpOrderBarrier::Params params_load_order_barrier; + params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1; + params_load_order_barrier.group_size = NumThreadsPerWarp; + LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier); + + // Initialize starting pipeline states for the collectives + // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + + auto cluster_wait_fn = [] () { + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer thread blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + return [] () { cute::cluster_wait(); }; + } + else { + __syncthreads(); + return [] () {}; // do nothing + } + } (); + + // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + + // Get the appropriate blocks for this thread block -- potential for thread block locality + TiledMma tiled_mma; + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + + TileScheduler scheduler{params.scheduler}; + if constexpr (IsSchedDynamicPersistent) { + scheduler.set_data_ptr(shared_storage.scheduler.data()); + } + // Declare work_tile_info, then define it in each of warps that use it. + typename TileScheduler::WorkTileInfo work_tile_info; + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + + // Prepare and partition the input tensors. Expects a tuple of tensors where: + // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) + // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) + // get<2>(load_inputs) is the tma tensor B1 after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) + auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); + static_assert(cute::tuple_size_v >= 3, "Output of load_init must have at least two elements (A, B0, B1)"); + + // Extract out partitioned A and B. + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB0_nkl = get<1>(load_inputs); + Tensor gB1_nkl = get<2>(load_inputs); + // Wait for all thread blocks in the Cluster + cluster_wait_fn(); + + if (warp_group_role == WarpGroupRole::Producer) { + work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); + cutlass::arch::warpgroup_reg_dealloc(); + + // Scheduler Producer Warp + if (producer_warp_role == ProducerWarpRole::Warp1) { + if constexpr (IsSchedDynamicPersistent) { + bool requires_clc_query = true; + TileSchedulerPipelineState scheduler_pipe_producer_state = cutlass::make_producer_start_state(); + + cutlass::arch::wait_on_dependent_grids(); + while (work_tile_info.is_valid()) { + + if (requires_clc_query) { + // Throttle CLC query to mitigate workload imbalance caused by skews among persistent workers. + scheduler_throttle_pipeline.consumer_wait(scheduler_pipe_throttle_consumer_state); + scheduler_throttle_pipeline.consumer_release(scheduler_pipe_throttle_consumer_state); + ++scheduler_pipe_throttle_consumer_state; + + // Query next work tile + scheduler_pipe_producer_state = scheduler.advance_to_next_work(scheduler_pipeline, scheduler_pipe_producer_state); + } + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + scheduler_pipeline, + scheduler_pipe_consumer_state + ); + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++scheduler_pipe_consumer_state; + } + + work_tile_info = next_work_tile_info; + } + scheduler_pipeline.producer_tail(scheduler_pipe_producer_state); + } + } // Scheduler Producer Warp End + else + + // Mainloop Producer Warp + if (producer_warp_role == ProducerWarpRole::Mainloop) { + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + bool do_load_order_arrive = true; + bool requires_clc_query = true; + while (work_tile_info.is_valid()) { + if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info); + work_tile_info = next_work_tile_info; + continue; + } + + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB0_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB0_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); + auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); + auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); + + if (requires_clc_query) { + scheduler_throttle_pipeline.producer_acquire(scheduler_pipe_throttle_producer_state); + scheduler_throttle_pipeline.producer_commit(scheduler_pipe_throttle_producer_state); + ++scheduler_pipe_throttle_producer_state; + } + + collective_mainloop.load( + params.mainloop, + mainloop_pipeline, + mainloop_pipe_producer_state, + load_inputs, + blk_coord, + k_tile_iter, work_k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop + ); + // Update starting pipeline state for the next tile + mainloop_pipe_producer_state.advance(work_k_tile_count); + + // Signal for the epilogue load warp to begin + if (do_load_order_arrive) { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + // Get next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, + scheduler_pipeline, + scheduler_pipe_consumer_state + ); + + work_tile_info = next_work_tile_info; + if constexpr (IsSchedDynamicPersistent) { + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++scheduler_pipe_consumer_state; + } + } + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + + } + else if (producer_warp_role == ProducerWarpRole::MainloopAux) { + if constexpr (IsMainloopAuxiliaryLoadNeeded) { + while (work_tile_info.is_valid()) { + if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info); + work_tile_info = next_work_tile_info; + continue; + } + + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB0_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB0_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); + auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); + auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); + + collective_mainloop.load_auxiliary( + params.mainloop, + mainloop_pipeline, + mainloop_pipe_producer_state, + load_inputs, + blk_coord, + k_tile_iter, work_k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop + ); + // Update starting pipeline state for the next tile + mainloop_pipe_producer_state.advance(work_k_tile_count); + + // Get next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + scheduler_pipeline, + scheduler_pipe_consumer_state + ); + + work_tile_info = next_work_tile_info; + } // Scheduler work fetch loop + + } + } + + // Epilogue Producer Warp + else if (producer_warp_role == ProducerWarpRole::Epilogue && is_epi_load_needed) { + + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + if (!TileScheduler::requires_separate_reduction(params.scheduler) && work_tile_info.is_valid()) { + load_order_barrier.wait(); + } + + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + while (work_tile_info.is_valid()) { + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB0_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB0_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + epi_load_pipe_producer_state = + collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + tiled_mma, + lane_idx, + shared_storage.tensors.epilogue, + work_tile_info.reduction_subtile_idx() + ); + } + + // Get next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, + scheduler_pipeline, + scheduler_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + if constexpr (IsSchedDynamicPersistent) { + if (increment_pipe) { + ++scheduler_pipe_consumer_state; + } + } + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } // Epilogue Producer Warp End + } // Producer Warp Group End + + else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { + work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); + cutlass::arch::warpgroup_reg_alloc(); + + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it + bool do_store_tail = false; + while (work_tile_info.is_valid()) { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB0_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB0_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); + // Allocate the accumulators for the (M,N) blk_shape + // + // MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead. + auto accumulators0 = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + auto accumulators1 = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { + collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators0, + accumulators1, + work_k_tile_count, + mma_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop + ); + + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail( + mainloop_pipeline, + mainloop_pipe_consumer_state, + work_k_tile_count + ); + + // Update starting mainloop pipeline state for the next tile + mainloop_pipe_consumer_state.advance(work_k_tile_count); + } + #ifdef CUTLASS_ENABLE_GDC_FOR_SM90 + if (scheduler.is_last_tile(work_tile_info)) { + // Hint on an early release of global memory resources. + // The timing of calling this function only influences performance, + // not functional correctness. + cutlass::arch::launch_dependent_grids(); + + } + #endif + + // Index of warp group within consumer warp groups + int consumer_warp_group_idx = canonical_warp_group_idx() - NumLoadWarpGroups; + + // Perform reduction across splits, if needed + TileScheduler::fixup( + params.scheduler, work_tile_info, accumulators0, NumMmaWarpGroups, consumer_warp_group_idx); + TileScheduler::fixup( + params.scheduler, work_tile_info, accumulators1, NumMmaWarpGroups, consumer_warp_group_idx); + + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { + // Epilogue and write to gD + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = + collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + accumulators0, + accumulators1, + tiled_mma, + mma_thread_idx, + shared_storage.tensors.epilogue, + work_tile_info.reduction_subtile_idx() + ); + epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next; + epi_store_pipe_producer_state = epi_store_pipe_producer_state_next; + do_store_tail = true; + } + + // Get next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, + scheduler_pipeline, + scheduler_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + if constexpr (IsSchedDynamicPersistent) { + if (increment_pipe) { + ++scheduler_pipe_consumer_state; + } + } + } // Scheduler work fetch loop + + if (do_store_tail) { + collective_epilogue.store_tail( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state + ); + } + } // Consumer Warp Groups End +#endif + } + +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/examples/94_blackwell_geforce_dual_gemm/thread/left_silu_and_mul.h b/examples/94_blackwell_geforce_dual_gemm/thread/left_silu_and_mul.h new file mode 100644 index 0000000000..f51677062f --- /dev/null +++ b/examples/94_blackwell_geforce_dual_gemm/thread/left_silu_and_mul.h @@ -0,0 +1,152 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Functor performing linear combination operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/epilogue/thread/linear_combination_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements. +/// +/// D = alpha * accumulator + beta * source + uniform +/// +template < + typename ElementOutput_, ///< Data type used to load and store tensors + int Count, ///< Number of elements computed per operation. + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data to store + typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type + typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest +> +class LeftSiLUAndMul { +public: + + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + + static FloatRoundStyle const kRound = Round; + + struct Params{}; + +private: + + // + // Data members + // + + ElementCompute alpha_; + ElementCompute beta_; + +public: + + /// Constructs the function object, possibly loading from pointers in host memory + CUTLASS_HOST_DEVICE + LeftSiLUAndMul(Params const &/*params*/) {} + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return true; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { + assert(false); + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &lhs, + FragmentAccumulator const &rhs) const { + + // Convert source to interal compute numeric type + NumericArrayConverter accumulator_to_compute; + + // Convert to destination numeric type + NumericArrayConverter compute_to_output; + + ComputeFragment converted_lhs = accumulator_to_compute(lhs); + ComputeFragment converted_rhs = accumulator_to_compute(rhs); + + cutlass::epilogue::thread::SiLu silu; + cutlass::multiplies mul; + auto silu_lhs = silu(converted_lhs); + return compute_to_output(mul(silu_lhs, converted_rhs)); + } + + CUTLASS_HOST_DEVICE + ElementOutput operator()( + ElementAccumulator const& lhs, + ElementAccumulator const& rhs + ) const { + ElementCompute convert_lhs(lhs); + ElementCompute convert_rhs(rhs); + cutlass::epilogue::thread::SiLu silu; + cutlass::multiplies mul; + auto silu_lhs = silu(convert_lhs); + return ElementOutput(mul(silu_lhs, convert_rhs)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 6d97ea564e..b1cdb47f8f 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -170,6 +170,8 @@ foreach(EXAMPLE 90_sm103_fp4_ultra_grouped_gemm 91_fp4_gemv 92_blackwell_moe_gemm + 93_hopper_dual_gemm + 94_blackwell_geforce_dual_gemm ) add_subdirectory(${EXAMPLE})