-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Open
Labels
Description
Which component has the problem?
CUTLASS C++
Bug Report
Describe the bug
A clear and concise description of what the bug is.
At sm90_gmma_builder.inl, the builder class expects template argument StageCountType has a static member named bytes, which is not always the case. In fact, only StageCountAutoCarveout<...> has that member whereas StageCount<..> doesn't.
cutlass/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl
Lines 456 to 464 in 6457918
| static constexpr int PipelineStages = IsMixedInput ? | |
| ( IsArrayOfPointersGemm ? | |
| detail::compute_stage_count_or_override_single_affine_transformed_input<Sm90ReducedSmemCapacityBytes, | |
| RealElementA, RealElementB, ElementScale, ElementZero, TileShape_MNK, StageCountType::bytes, SmemAlignment>(StageCountType{}) : | |
| detail::compute_stage_count_or_override_single_affine_transformed_input<detail::sm90_smem_capacity_bytes, | |
| RealElementA, RealElementB, ElementScale, ElementZero, TileShape_MNK, StageCountType::bytes, SmemAlignment>(StageCountType{}) | |
| ) | |
| : detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes, | |
| ElementAMma, ElementBMma, TileShape_MNK, StageCountType::bytes, SmemAlignment>(StageCountType{}); |
For example, when I provide StageCount<5> instead of StageCountAutoCarveout<...> with mixed types, I encounter a compilation error on SM90.
To address this issue, you might dispatch the StageCount to different path so it is handled properly.