Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions include/cutlass/detail/collective/mixed_input_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ struct LayoutAwareConvertImpl<
// Specialization for INT8 -> BF16 with [3120] value order
template <>
struct LayoutAwareConvertImpl<
cutlass::int8_t,
int8_t,
cutlass::bfloat16_t,
cute::Layout<cute::Shape<_2,_2>, cute::Stride<_2,_1>>,
cute::Layout<_4>
Expand All @@ -362,9 +362,9 @@ struct LayoutAwareConvertImpl<
cute::Layout<_4>
>& dst) {

static_assert(cute::is_same_v<cutlass::int8_t, typename EngineIn::value_type> &&
static_assert(cute::is_same_v<int8_t, typename EngineIn::value_type> &&
cute::is_same_v<cutlass::bfloat16_t, typename EngineOut::value_type>);
using SrcArray = cutlass::Array<cutlass::int8_t, 8>;
using SrcArray = cutlass::Array<int8_t, 8>;
using DstArray = cutlass::Array<cutlass::bfloat16_t, 8>;
using RegArray = cutlass::AlignedArray<uint32_t, 4, sizeof(DstArray)>;

Expand Down Expand Up @@ -402,7 +402,7 @@ struct LayoutAwareConvertImpl<
// Specialization for INT8 -> FP16 with [3120] value order
template <>
struct LayoutAwareConvertImpl<
cutlass::int8_t,
int8_t,
cutlass::half_t,
cute::Layout<cute::Shape<_2,_2>, cute::Stride<_2,_1>>,
cute::Layout<_4>
Expand All @@ -417,9 +417,9 @@ struct LayoutAwareConvertImpl<
cute::Layout<_4>
>& dst) {

static_assert(cute::is_same_v<cutlass::int8_t, typename EngineIn::value_type> &&
static_assert(cute::is_same_v<int8_t, typename EngineIn::value_type> &&
cute::is_same_v<cutlass::half_t, typename EngineOut::value_type>);
using SrcArray = cutlass::Array<cutlass::int8_t, 8>;
using SrcArray = cutlass::Array<int8_t, 8>;
using DstArray = cutlass::Array<cutlass::half_t, 8>;
using RegArray = cutlass::AlignedArray<uint32_t, 4, sizeof(DstArray)>;

Expand Down
25 changes: 16 additions & 9 deletions include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,21 @@ namespace cutlass::gemm::collective {
namespace detail {

// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
template<int capacity_bytes, class ElementA, class ElementB, class TileShapeMNK, int stages, int alignment = 128>
template<int capacity_bytes, class ElementA, class ElementB, class TileShapeMNK, int alignment = 128, int stages>
constexpr int
compute_stage_count_or_override(StageCount<stages> stage_count) {
return stages;
}

// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
template<int capacity_bytes, class ElementA, class ElementB, class TileShapeMNK, int stages, int alignment = 128>
template<int capacity_bytes, class ElementA, class ElementB, class TileShapeMNK, int alignment = 128, int stages>
constexpr int
compute_stage_count_or_override(cute::Int<stages> stage_count) {
return stages;
}

// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
template<int capacity_bytes_, class ElementA, class ElementB, class TileShapeMNK, int carveout_bytes_, int alignment = 128>
template<int capacity_bytes_, class ElementA, class ElementB, class TileShapeMNK, int alignment = 128, int carveout_bytes_>
constexpr int
compute_stage_count_or_override(StageCountAutoCarveout<carveout_bytes_> stage_count) {
constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage);
Expand All @@ -85,7 +85,7 @@ compute_stage_count_or_override(StageCountAutoCarveout<carveout_bytes_> stage_co
}

// Returns the maximum number of smem tiles that can be used with a given smem capacity in gemm of blockwise/groupwise scale.
template<int capacity_bytes_, class ElementA, class ElementB, class ElementBlockScale, class TileShapeMNK, int ScaleMsPerTile, int ScaleNsPerTile, int carveout_bytes_, int alignment = 128>
template<int capacity_bytes_, class ElementA, class ElementB, class ElementBlockScale, class TileShapeMNK, int ScaleMsPerTile, int ScaleNsPerTile, int alignment = 128, int carveout_bytes_>
constexpr int
compute_stage_count_with_blockwise_scale(StageCountAutoCarveout<carveout_bytes_> stage_count) {
constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage);
Expand All @@ -107,7 +107,14 @@ compute_stage_count_with_blockwise_scale(StageCountAutoCarveout<carveout_bytes_>
}

// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count.
template<int capacity_bytes, class ElementA, class ElementB, class ElementScale, class ElementZero, class TileShapeMNK, int stages, int alignment = 128>
template<int capacity_bytes, class ElementA, class ElementB, class ElementScale, class ElementZero, class TileShapeMNK, int alignment = 128, int stages>
constexpr int
compute_stage_count_or_override_single_affine_transformed_input(cute::Int<stages> stage_count) {
return stages;
}

// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count.
template<int capacity_bytes, class ElementA, class ElementB, class ElementScale, class ElementZero, class TileShapeMNK, int alignment = 128, int stages>
constexpr int
compute_stage_count_or_override_single_affine_transformed_input(StageCount<stages> stage_count) {
return stages;
Expand All @@ -124,7 +131,7 @@ constexpr int get_bits_for_possibly_void_element() {
}

// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count.
template<int capacity_bytes_, class ElementA, class ElementB, class ElementScale, class ElementZero, class TileShapeMNK, int carveout_bytes_, int alignment = 128>
template<int capacity_bytes_, class ElementA, class ElementB, class ElementScale, class ElementZero, class TileShapeMNK, int alignment = 128, int carveout_bytes_>
constexpr int
compute_stage_count_or_override_single_affine_transformed_input(StageCountAutoCarveout<carveout_bytes_> stage_count) {

Expand Down Expand Up @@ -456,12 +463,12 @@ public:
static constexpr int PipelineStages = IsMixedInput ?
( IsArrayOfPointersGemm ?
detail::compute_stage_count_or_override_single_affine_transformed_input<Sm90ReducedSmemCapacityBytes,
RealElementA, RealElementB, ElementScale, ElementZero, TileShape_MNK, StageCountType::bytes, SmemAlignment>(StageCountType{}) :
RealElementA, RealElementB, ElementScale, ElementZero, TileShape_MNK, SmemAlignment>(StageCountType{}) :
detail::compute_stage_count_or_override_single_affine_transformed_input<detail::sm90_smem_capacity_bytes,
RealElementA, RealElementB, ElementScale, ElementZero, TileShape_MNK, StageCountType::bytes, SmemAlignment>(StageCountType{})
RealElementA, RealElementB, ElementScale, ElementZero, TileShape_MNK, SmemAlignment>(StageCountType{})
)
: detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes,
ElementAMma, ElementBMma, TileShape_MNK, StageCountType::bytes, SmemAlignment>(StageCountType{});
ElementAMma, ElementBMma, TileShape_MNK, SmemAlignment>(StageCountType{});

using DispatchPolicy = cute::conditional_t<IsMixedInput,
cute::conditional_t<IsArrayOfPointersGemm,
Expand Down
6 changes: 2 additions & 4 deletions tools/util/include/cutlass/util/mixed_dtype_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include "cutlass/util/device_memory.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cute/util/type_traits.hpp"
#include "cute/numeric/numeric_types.hpp"

namespace cutlass {

Expand Down Expand Up @@ -177,10 +178,7 @@ static void dequantize(DequantizedElement* dq_buffer,
template <typename T>
class packed_scale_t {
public:
static_assert(cute::is_same_v<T, cutlass::int8_t> ||
cute::is_same_v<T, cutlass::uint8_t> ||
cute::is_same_v<T, cutlass::float_e4m3_t> ||
cute::is_same_v<T, cutlass::float_e5m2_t>,
static_assert(cute::sizeof_bits_v<T> == 8,
"only 8 bit arithmetic types are supported.");
CUTLASS_HOST_DEVICE
explicit packed_scale_t(T val) {
Expand Down