From ab4d3a1b3b90bd1427198cb40fb8059f94afaecd Mon Sep 17 00:00:00 2001 From: siyuanf Date: Fri, 3 Oct 2025 17:34:47 -0700 Subject: [PATCH] fix --- .../detail/collective/mixed_input_utils.hpp | 12 ++++----- .../collective/builders/sm90_gmma_builder.inl | 25 ++++++++++++------- .../cutlass/util/mixed_dtype_utils.hpp | 6 ++--- 3 files changed, 24 insertions(+), 19 deletions(-) diff --git a/include/cutlass/detail/collective/mixed_input_utils.hpp b/include/cutlass/detail/collective/mixed_input_utils.hpp index 89d250001e..41fc4a9ee9 100644 --- a/include/cutlass/detail/collective/mixed_input_utils.hpp +++ b/include/cutlass/detail/collective/mixed_input_utils.hpp @@ -347,7 +347,7 @@ struct LayoutAwareConvertImpl< // Specialization for INT8 -> BF16 with [3120] value order template <> struct LayoutAwareConvertImpl< - cutlass::int8_t, + int8_t, cutlass::bfloat16_t, cute::Layout, cute::Stride<_2,_1>>, cute::Layout<_4> @@ -362,9 +362,9 @@ struct LayoutAwareConvertImpl< cute::Layout<_4> >& dst) { - static_assert(cute::is_same_v && + static_assert(cute::is_same_v && cute::is_same_v); - using SrcArray = cutlass::Array; + using SrcArray = cutlass::Array; using DstArray = cutlass::Array; using RegArray = cutlass::AlignedArray; @@ -402,7 +402,7 @@ struct LayoutAwareConvertImpl< // Specialization for INT8 -> FP16 with [3120] value order template <> struct LayoutAwareConvertImpl< - cutlass::int8_t, + int8_t, cutlass::half_t, cute::Layout, cute::Stride<_2,_1>>, cute::Layout<_4> @@ -417,9 +417,9 @@ struct LayoutAwareConvertImpl< cute::Layout<_4> >& dst) { - static_assert(cute::is_same_v && + static_assert(cute::is_same_v && cute::is_same_v); - using SrcArray = cutlass::Array; + using SrcArray = cutlass::Array; using DstArray = cutlass::Array; using RegArray = cutlass::AlignedArray; diff --git a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl index a1ea257e7f..c3445be3e3 100644 --- a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl @@ -52,21 +52,21 @@ namespace cutlass::gemm::collective { namespace detail { // Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. -template +template constexpr int compute_stage_count_or_override(StageCount stage_count) { return stages; } // Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. -template +template constexpr int compute_stage_count_or_override(cute::Int stage_count) { return stages; } // Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. -template +template constexpr int compute_stage_count_or_override(StageCountAutoCarveout stage_count) { constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage); @@ -85,7 +85,7 @@ compute_stage_count_or_override(StageCountAutoCarveout stage_co } // Returns the maximum number of smem tiles that can be used with a given smem capacity in gemm of blockwise/groupwise scale. -template +template constexpr int compute_stage_count_with_blockwise_scale(StageCountAutoCarveout stage_count) { constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage); @@ -107,7 +107,14 @@ compute_stage_count_with_blockwise_scale(StageCountAutoCarveout } // Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count. -template +template +constexpr int +compute_stage_count_or_override_single_affine_transformed_input(cute::Int stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count. +template constexpr int compute_stage_count_or_override_single_affine_transformed_input(StageCount stage_count) { return stages; @@ -124,7 +131,7 @@ constexpr int get_bits_for_possibly_void_element() { } // Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count. -template +template constexpr int compute_stage_count_or_override_single_affine_transformed_input(StageCountAutoCarveout stage_count) { @@ -456,12 +463,12 @@ public: static constexpr int PipelineStages = IsMixedInput ? ( IsArrayOfPointersGemm ? detail::compute_stage_count_or_override_single_affine_transformed_input(StageCountType{}) : + RealElementA, RealElementB, ElementScale, ElementZero, TileShape_MNK, SmemAlignment>(StageCountType{}) : detail::compute_stage_count_or_override_single_affine_transformed_input(StageCountType{}) + RealElementA, RealElementB, ElementScale, ElementZero, TileShape_MNK, SmemAlignment>(StageCountType{}) ) : detail::compute_stage_count_or_override(StageCountType{}); + ElementAMma, ElementBMma, TileShape_MNK, SmemAlignment>(StageCountType{}); using DispatchPolicy = cute::conditional_t class packed_scale_t { public: - static_assert(cute::is_same_v || - cute::is_same_v || - cute::is_same_v || - cute::is_same_v, + static_assert(cute::sizeof_bits_v == 8, "only 8 bit arithmetic types are supported."); CUTLASS_HOST_DEVICE explicit packed_scale_t(T val) {