Skip to content

[BUG] For mixed type gemms on SM90, custom stage count doesn't work. #2654

@Algy

Description

@Algy

Which component has the problem?

CUTLASS C++

Bug Report

Describe the bug
A clear and concise description of what the bug is.

At sm90_gmma_builder.inl, the builder class expects template argument StageCountType has a static member named bytes, which is not always the case. In fact, only StageCountAutoCarveout<...> has that member whereas StageCount<..> doesn't.

static constexpr int PipelineStages = IsMixedInput ?
( IsArrayOfPointersGemm ?
detail::compute_stage_count_or_override_single_affine_transformed_input<Sm90ReducedSmemCapacityBytes,
RealElementA, RealElementB, ElementScale, ElementZero, TileShape_MNK, StageCountType::bytes, SmemAlignment>(StageCountType{}) :
detail::compute_stage_count_or_override_single_affine_transformed_input<detail::sm90_smem_capacity_bytes,
RealElementA, RealElementB, ElementScale, ElementZero, TileShape_MNK, StageCountType::bytes, SmemAlignment>(StageCountType{})
)
: detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes,
ElementAMma, ElementBMma, TileShape_MNK, StageCountType::bytes, SmemAlignment>(StageCountType{});

For example, when I provide StageCount<5> instead of StageCountAutoCarveout<...> with mixed types, I encounter a compilation error on SM90.

To address this issue, you might dispatch the StageCount to different path so it is handled properly.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions