@@ -42,6 +42,17 @@ SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank) {
4242 return ret;
4343}
4444
45+ // TODO Have order be a mandatory argument of standardOutDimNames.
46+ SmallVector<StringAttr> permuteDimNames (const SmallVector<StringAttr> &names,
47+ const SmallVector<unsigned > &order) {
48+ assert (names.size () == order.size ());
49+ SmallVector<StringAttr> ret;
50+ for (unsigned i : order) {
51+ ret.push_back (names[i]);
52+ }
53+ return ret;
54+ }
55+
4556void assertIsRegisterLayout (const LinearLayout &layout) {
4657 assert (layout.getNumInDims () > 0 );
4758 MLIRContext *ctx = layout.getInDimNames ().begin ()->getContext ();
@@ -282,15 +293,19 @@ LinearLayout ampereMmaToLinearLayout(ArrayRef<int64_t> shape,
282293
283294 MLIRContext *ctx = mma.getContext ();
284295 SmallVector<StringAttr> dimNames = standardOutDimNames (ctx, rank);
296+ auto orderedDimNames = permuteDimNames (dimNames, getOrder (mma));
297+ // By using `reverse(dimNames)` below, we set the order to be row-major
298+ assert (getOrder (mma) == getMatrixOrder (rank, /* rowMajor=*/ true ));
285299
286300 LinearLayout ctaLayout (
287301 {{S (" register" ), {{1 , 0 }, {0 , 8 }}},
288302 {S (" lane" ), {{2 , 0 }, {4 , 0 }, {0 , 1 }, {0 , 2 }, {0 , 4 }}}},
289- llvm::to_vector (llvm::reverse (ArrayRef (dimNames).take_back (2 ))));
290-
291- ctaLayout *= identityND (
292- S (" warp" ), mma.getWarpsPerCTA (),
293- llvm::to_vector (llvm::reverse (llvm::seq<unsigned >(rank))), dimNames);
303+ ArrayRef (orderedDimNames).take_front (2 ));
304+ assert (getWarpOrder (mma) == getMatrixOrder (rank, /* rowMajor=*/ true ));
305+ // FIXME(Lezcano). identityND should not have an `order` param as it's
306+ // redundant with the order of the out dims.
307+ ctaLayout *=
308+ identityND (S (" warp" ), mma.getWarpsPerCTA (), mma.getWarpOrder (), dimNames);
294309
295310 return combineCtaCgaWithShape (ctaLayout, mma.getCTALayout (), shape);
296311}
@@ -323,10 +338,14 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef<int64_t> shape,
323338 ctaLayout *= LinearLayout::identity1D (n / ctaLayout.getOutDimSize (S (" dim1" )),
324339 S (" register" ), S (" dim1" ));
325340
326- // Expand the `warp` dimension according to warpsPerCTA.
327- //
328- // It's weird that this is order [0,1] when MMAv2's warpsPerCTA is [1,0], but
329- // this really does seem to be correct.
341+ // The order given by choosing (`dim1`, `dim0`) is [1, 0], that is, N-major.
342+ // Since the warpOrder needs to be M-major, we need to transpose the out
343+ // dimensions AND transpose the order
344+ // FIXME(Lezcano). identityND should not have an `order` param as it's
345+ // redundant. The order is already given by the order of the
346+ // out dims, and if it has an order, it shouldn't change the
347+ // order of the out dims.
348+ assert (getWarpOrder (mma) == SmallVector<unsigned >({0 , 1 }));
330349 ctaLayout *= identityND (S (" warp" ), mma.getWarpsPerCTA (), /* order=*/ {0 , 1 },
331350 {S (" dim0" ), S (" dim1" )})
332351 .transposeOuts (llvm::to_vector (ctaLayout.getOutDimNames ()));
@@ -844,18 +863,24 @@ SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
844863
845864LinearLayout ampereDotToLinearLayout (ArrayRef<int64_t > shape,
846865 DotOperandEncodingAttr dot) {
847- // TODO,BE. Implement ampereMMA in terms of this one
866+ // Note that, even though MMAv2 looks similar to this layout, they are just
867+ // the same at a register and lane level. The warps treatment is different!
848868 int rank = shape.size ();
849869 auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent ());
850870 int kWidth = dot.getKWidth ();
851871 bool isA = dot.getOpIdx () == 0 ;
852872
853- assert (mma.isAmpere ());
854873 assert ((rank == 2 && mma.getInstrShape () == ArrayRef<unsigned >({16 , 8 })) ||
855874 (rank == 3 && mma.getInstrShape () == ArrayRef<unsigned >({1 , 16 , 8 })));
875+ assert (mma.isAmpere ());
856876
857877 MLIRContext *ctx = mma.getContext ();
858- SmallVector<StringAttr> dimNames = standardOutDimNames (ctx, rank);
878+ // A and B have kMajor order
879+ assert (getOrder (dot) ==
880+ getOrderForDotOperand (dot.getOpIdx (), rank, /* kMajor=*/ true ));
881+
882+ auto kMajorDims =
883+ permuteDimNames (standardOutDimNames (ctx, rank), getOrder (dot));
859884
860885 // Implement A. For B transpose in the end
861886 std::vector<std::vector<int32_t >> registers;
@@ -882,24 +907,51 @@ LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
882907 }
883908 registers.push_back ({i, 0 });
884909
885- if (!isA) {
886- for (auto &r : registers) {
887- std::swap (r[0 ], r[1 ]);
910+ LinearLayout ctaLayout ({{S (" register" ), registers}, {S (" lane" ), lanes}},
911+ ArrayRef (kMajorDims ).take_front (2 ));
912+
913+ // Let warpsPerCTAMma = {2, 2}, then
914+ // warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB
915+ // assume warpOrder = {0, 1}
916+ // Assume that C is tiled by 2x2 tiles. Since warpOrder={1, 0}, we have that
917+ // the C is owned as per the following layout:
918+ // C: 0 | 1
919+ // - | -
920+ // 2 | 3
921+ // In order to be able to compute C, we need the following warp tiling of
922+ // A and B:
923+ // A: 0 1 | 0 1 B: 0 2 | 1 3
924+ // - - | - - - - | - -
925+ // 2 3 | 2 3 0 2 | 1 3
926+ // In particular, for A and B we need to broadcast along K
927+
928+ assert (mma.getWarpOrder () == getMatrixOrder (rank, /* rowMajor=*/ true ));
929+ auto warpsPerCTAMma = mma.getWarpsPerCTA ();
930+ std::vector<std::vector<int32_t >> warps;
931+ if (isA) {
932+ for (int i = 1 ; i < warpsPerCTAMma[1 ]; i *= 2 ) {
933+ warps.push_back ({0 , 0 });
934+ }
935+ for (int i = 1 ; i < warpsPerCTAMma[0 ]; i *= 2 ) {
936+ warps.push_back ({0 , i});
937+ }
938+ } else {
939+ for (int i = 1 ; i < warpsPerCTAMma[1 ]; i *= 2 ) {
940+ warps.push_back ({0 , i});
888941 }
889- for (auto &l : lanes) {
890- std::swap (l[0 ], l[1 ]);
942+ for (int i = 1 ; i < warpsPerCTAMma[0 ]; i *= 2 ) {
943+ warps.push_back ({0 , 0 });
944+ }
945+ }
946+ if (rank == 3 ) {
947+ for (auto &w : warps) {
948+ w.push_back (0 );
891949 }
892950 }
893951
894- LinearLayout ctaLayout (
895- {{S (" register" ), registers}, {S (" lane" ), lanes}},
896- llvm::to_vector (llvm::reverse (ArrayRef (dimNames).take_back (2 ))));
897-
898- auto order = dot.getCTAOrder ();
899- assert (order[0 ] == rank - 1 && order[1 ] == rank - 2 );
900- ctaLayout *= identityND (S (" warp" ), dot.getWarpsPerCTA (), order, dimNames);
952+ ctaLayout *= LinearLayout ({{S (" warp" ), warps}}, kMajorDims );
901953
902- return combineCtaCgaWithShape (ctaLayout, mma. getCTALayout (), shape);
954+ return combineCtaCgaWithShape (ctaLayout, getCTALayout (dot ), shape);
903955}
904956
905957std::optional<LinearLayout>
@@ -908,7 +960,7 @@ DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
908960 if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
909961 return mfmaDotToLinearLayout (*this , shape);
910962 } else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
911- if (mma.getVersionMajor () == 2 && mma. getVersionMinor () == 0 ) {
963+ if (mma.isAmpere () ) {
912964 return ampereDotToLinearLayout (shape, *this );
913965 }
914966 } else if (auto dpasLayout =
0 commit comments