Skip to content

Commit d96a80e

Browse files
Merge commit '73df068b8e24d68f7afe776e798db12a75ba9271'
2 parents 555d666 + 73df068 commit d96a80e

File tree

17 files changed

+676
-515
lines changed

17 files changed

+676
-515
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,17 @@ unsigned getNumWarpsPerCTA(Attribute layout);
130130

131131
unsigned getNumCTAs(Attribute layout);
132132

133+
// Return the order that represents that the batch is in row-major or
134+
// column-major order for a batch of matrices of shape [*, m, n] with
135+
// len(shape) == rank.
136+
SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor);
137+
138+
// Return the order that represents that the dot operand is in kMajor
139+
// (contiguous in the inner dimension) or it's contiguous on the outer
140+
// dimension.
141+
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
142+
bool kMajor);
143+
133144
bool isExpensiveCat(CatOp cat, Attribute targetEncoding);
134145

135146
// Return true if a view between the two types cannot be implemented as a no-op.

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 34 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,19 @@ static SmallVector<unsigned> eraseOrder(ArrayRef<unsigned> order,
238238
return resOrder;
239239
}
240240

241+
SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor) {
242+
// Return the order that represents that the batch is in row-major or
243+
// column-major order for a batch of matrices of shape [*, m, n] with
244+
// len(shape) == rank.
245+
assert(rank >= 2);
246+
SmallVector<unsigned> order(rank);
247+
std::iota(order.rbegin(), order.rend(), 0);
248+
if (!rowMajor) {
249+
std::swap(order[0], order[1]);
250+
}
251+
return order;
252+
}
253+
241254
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
242255
bool kMajor) {
243256
// kMajor: if true, the matrix is fastest-running on k,
@@ -247,15 +260,8 @@ SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
247260
// batch (if rank == 3) is always the slowest running dimension
248261
assert(rank == 2 || rank == 3);
249262
assert(opIdx == 0 || opIdx == 1);
250-
SmallVector<unsigned> order(rank);
251-
std::iota(order.rbegin(), order.rend(), 0);
252-
// If opIdx is 1 and kMajor is true, the order is [0, 1]
253-
// (resp. [1, 2, 0] if rank == 3)
254-
// Same if opIdx is 0 and kMajor is false
255-
if (bool(opIdx) == kMajor) {
256-
std::swap(order[0], order[1]);
257-
}
258-
return order;
263+
auto rowMajor = bool(opIdx) != kMajor;
264+
return getMatrixOrder(rank, rowMajor);
259265
}
260266

261267
SmallVector<unsigned> getWarpOrder(Attribute layout) {
@@ -265,20 +271,21 @@ SmallVector<unsigned> getWarpOrder(Attribute layout) {
265271
}
266272
}
267273
auto order = getOrder(layout);
268-
// FIXME: This mmaLayout if should just return
269-
// getOrderForDotOperand(0, order.size(), kMajor=false)
270-
// as mma has the same order as DotOperand(opIdx=0)
274+
// FIXME: At the moment, warpOrder in Ampere is N-major but in Hopper it's
275+
// M-major This is awkward. Since we can choose any warpOrder in Ampere, we
276+
// should probably choose M-major and change `LinearLayoutConversion.cpp` and
277+
// `MMAv2.cpp` to match.
271278
if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
272279
if (mmaLayout.isHopper()) {
273-
// Hopper MMA instructions force a warp order of [0, 1]. See docs:
280+
// Hopper MMA instructions force warps to be column-major
274281
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-wgmma-mma-async-m64nnk8
275-
auto it = std::find(order.begin(), order.end(), 0);
276-
order.erase(it);
277-
order.insert(order.begin(), 0);
282+
return getMatrixOrder(order.size(), /*rowMajor*/ false);
278283
}
279284
} else if (auto dotOpLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
280-
order = getOrderForDotOperand(dotOpLayout.getOpIdx(), order.size(),
281-
/*kMajor*/ false);
285+
// It's quite weird to talk about warp order when that the warps
286+
// are broadcasted along the K dimension
287+
llvm::report_fatal_error(
288+
"DotOperandEncoding::getWarpOrder not implemented");
282289
}
283290
return order;
284291
}
@@ -288,11 +295,11 @@ SmallVector<unsigned> getOrder(Attribute layout) {
288295
return llvm::to_vector(blockedLayout.getOrder());
289296
}
290297
if (auto mmaLayout = dyn_cast<MmaEncodingTrait>(layout)) {
298+
// Order doesn't really matter. We just have to be consistent when unpacking
299+
// the elements in the MMAv2/V3 lowerings. We choose row-major
291300
auto distributedLayout = cast<DistributedEncodingTrait>(layout);
292301
auto rank = distributedLayout.getWarpsPerCTA().size();
293-
SmallVector<unsigned> order(rank);
294-
std::iota(order.rbegin(), order.rend(), 0);
295-
return order;
302+
return getMatrixOrder(rank, /*rowMajor*/ true);
296303
}
297304
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
298305
auto rank = dotLayout.getWarpsPerCTA().size();
@@ -434,7 +441,7 @@ unsigned getNumWarpsPerCTA(Attribute layout) {
434441
else if (auto wmmaLayout = dyn_cast<AMDWmmaEncodingAttr>(layout))
435442
warpsPerCTA = wmmaLayout.getWarpsPerCTA();
436443
else if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout))
437-
return getNumWarpsPerCTA(dotLayout.getParent());
444+
warpsPerCTA = dotLayout.getWarpsPerCTA();
438445
else if (auto sharedLayout = dyn_cast<SharedEncodingAttr>(layout))
439446
llvm::report_fatal_error("Cannot get numWarps from SharedEncodingAttr");
440447
else
@@ -2176,25 +2183,12 @@ unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand(
21762183
SmallVector<unsigned> NvidiaMmaEncodingAttr::getShapePerCTATileForOperand(
21772184
ArrayRef<int64_t> shape, int kWidth, int opIdx) const {
21782185
assert(isAmpere() && "mmaLayout version = 1 is not implemented yet");
2179-
auto parentShapePerCTATile = getShapePerCTATile(shape);
2180-
auto rank = parentShapePerCTATile.size();
2186+
auto shapePerCTATile = getShapePerCTATile(shape);
2187+
auto rank = shapePerCTATile.size();
2188+
auto kDim = opIdx == 0 ? rank - 1 : rank - 2;
21812189
// 4 threads * 2 subtiles
2182-
unsigned kWidthTile = kWidth * 2 * 4;
2183-
if (opIdx == 0) {
2184-
if (rank == 2)
2185-
return {parentShapePerCTATile[rank - 2], kWidthTile};
2186-
else
2187-
return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2],
2188-
kWidthTile};
2189-
} else if (opIdx == 1) {
2190-
if (rank == 2)
2191-
return {kWidthTile, parentShapePerCTATile[rank - 1]};
2192-
else
2193-
return {parentShapePerCTATile[0], kWidthTile,
2194-
parentShapePerCTATile[rank - 1]};
2195-
} else {
2196-
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
2197-
}
2190+
shapePerCTATile[kDim] = kWidth * 2 * 4;
2191+
return shapePerCTATile;
21982192
}
21992193
SmallVector<unsigned>
22002194
NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 78 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,17 @@ SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank) {
4242
return ret;
4343
}
4444

45+
// TODO Have order be a mandatory argument of standardOutDimNames.
46+
SmallVector<StringAttr> permuteDimNames(const SmallVector<StringAttr> &names,
47+
const SmallVector<unsigned> &order) {
48+
assert(names.size() == order.size());
49+
SmallVector<StringAttr> ret;
50+
for (unsigned i : order) {
51+
ret.push_back(names[i]);
52+
}
53+
return ret;
54+
}
55+
4556
void assertIsRegisterLayout(const LinearLayout &layout) {
4657
assert(layout.getNumInDims() > 0);
4758
MLIRContext *ctx = layout.getInDimNames().begin()->getContext();
@@ -282,15 +293,19 @@ LinearLayout ampereMmaToLinearLayout(ArrayRef<int64_t> shape,
282293

283294
MLIRContext *ctx = mma.getContext();
284295
SmallVector<StringAttr> dimNames = standardOutDimNames(ctx, rank);
296+
auto orderedDimNames = permuteDimNames(dimNames, getOrder(mma));
297+
// By using `reverse(dimNames)` below, we set the order to be row-major
298+
assert(getOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true));
285299

286300
LinearLayout ctaLayout(
287301
{{S("register"), {{1, 0}, {0, 8}}},
288302
{S("lane"), {{2, 0}, {4, 0}, {0, 1}, {0, 2}, {0, 4}}}},
289-
llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2))));
290-
291-
ctaLayout *= identityND(
292-
S("warp"), mma.getWarpsPerCTA(),
293-
llvm::to_vector(llvm::reverse(llvm::seq<unsigned>(rank))), dimNames);
303+
ArrayRef(orderedDimNames).take_front(2));
304+
assert(getWarpOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true));
305+
// FIXME(Lezcano). identityND should not have an `order` param as it's
306+
// redundant with the order of the out dims.
307+
ctaLayout *=
308+
identityND(S("warp"), mma.getWarpsPerCTA(), mma.getWarpOrder(), dimNames);
294309

295310
return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape);
296311
}
@@ -323,10 +338,14 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef<int64_t> shape,
323338
ctaLayout *= LinearLayout::identity1D(n / ctaLayout.getOutDimSize(S("dim1")),
324339
S("register"), S("dim1"));
325340

326-
// Expand the `warp` dimension according to warpsPerCTA.
327-
//
328-
// It's weird that this is order [0,1] when MMAv2's warpsPerCTA is [1,0], but
329-
// this really does seem to be correct.
341+
// The order given by choosing (`dim1`, `dim0`) is [1, 0], that is, N-major.
342+
// Since the warpOrder needs to be M-major, we need to transpose the out
343+
// dimensions AND transpose the order
344+
// FIXME(Lezcano). identityND should not have an `order` param as it's
345+
// redundant. The order is already given by the order of the
346+
// out dims, and if it has an order, it shouldn't change the
347+
// order of the out dims.
348+
assert(getWarpOrder(mma) == SmallVector<unsigned>({0, 1}));
330349
ctaLayout *= identityND(S("warp"), mma.getWarpsPerCTA(), /*order=*/{0, 1},
331350
{S("dim0"), S("dim1")})
332351
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));
@@ -844,18 +863,24 @@ SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
844863

845864
LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
846865
DotOperandEncodingAttr dot) {
847-
// TODO,BE. Implement ampereMMA in terms of this one
866+
// Note that, even though MMAv2 looks similar to this layout, they are just
867+
// the same at a register and lane level. The warps treatment is different!
848868
int rank = shape.size();
849869
auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent());
850870
int kWidth = dot.getKWidth();
851871
bool isA = dot.getOpIdx() == 0;
852872

853-
assert(mma.isAmpere());
854873
assert((rank == 2 && mma.getInstrShape() == ArrayRef<unsigned>({16, 8})) ||
855874
(rank == 3 && mma.getInstrShape() == ArrayRef<unsigned>({1, 16, 8})));
875+
assert(mma.isAmpere());
856876

857877
MLIRContext *ctx = mma.getContext();
858-
SmallVector<StringAttr> dimNames = standardOutDimNames(ctx, rank);
878+
// A and B have kMajor order
879+
assert(getOrder(dot) ==
880+
getOrderForDotOperand(dot.getOpIdx(), rank, /*kMajor=*/true));
881+
882+
auto kMajorDims =
883+
permuteDimNames(standardOutDimNames(ctx, rank), getOrder(dot));
859884

860885
// Implement A. For B transpose in the end
861886
std::vector<std::vector<int32_t>> registers;
@@ -882,24 +907,51 @@ LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
882907
}
883908
registers.push_back({i, 0});
884909

885-
if (!isA) {
886-
for (auto &r : registers) {
887-
std::swap(r[0], r[1]);
910+
LinearLayout ctaLayout({{S("register"), registers}, {S("lane"), lanes}},
911+
ArrayRef(kMajorDims).take_front(2));
912+
913+
// Let warpsPerCTAMma = {2, 2}, then
914+
// warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB
915+
// assume warpOrder = {0, 1}
916+
// Assume that C is tiled by 2x2 tiles. Since warpOrder={1, 0}, we have that
917+
// the C is owned as per the following layout:
918+
// C: 0 | 1
919+
// - | -
920+
// 2 | 3
921+
// In order to be able to compute C, we need the following warp tiling of
922+
// A and B:
923+
// A: 0 1 | 0 1 B: 0 2 | 1 3
924+
// - - | - - - - | - -
925+
// 2 3 | 2 3 0 2 | 1 3
926+
// In particular, for A and B we need to broadcast along K
927+
928+
assert(mma.getWarpOrder() == getMatrixOrder(rank, /*rowMajor=*/true));
929+
auto warpsPerCTAMma = mma.getWarpsPerCTA();
930+
std::vector<std::vector<int32_t>> warps;
931+
if (isA) {
932+
for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) {
933+
warps.push_back({0, 0});
934+
}
935+
for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) {
936+
warps.push_back({0, i});
937+
}
938+
} else {
939+
for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) {
940+
warps.push_back({0, i});
888941
}
889-
for (auto &l : lanes) {
890-
std::swap(l[0], l[1]);
942+
for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) {
943+
warps.push_back({0, 0});
944+
}
945+
}
946+
if (rank == 3) {
947+
for (auto &w : warps) {
948+
w.push_back(0);
891949
}
892950
}
893951

894-
LinearLayout ctaLayout(
895-
{{S("register"), registers}, {S("lane"), lanes}},
896-
llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2))));
897-
898-
auto order = dot.getCTAOrder();
899-
assert(order[0] == rank - 1 && order[1] == rank - 2);
900-
ctaLayout *= identityND(S("warp"), dot.getWarpsPerCTA(), order, dimNames);
952+
ctaLayout *= LinearLayout({{S("warp"), warps}}, kMajorDims);
901953

902-
return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape);
954+
return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape);
903955
}
904956

905957
std::optional<LinearLayout>
@@ -908,7 +960,7 @@ DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
908960
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
909961
return mfmaDotToLinearLayout(*this, shape);
910962
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
911-
if (mma.getVersionMajor() == 2 && mma.getVersionMinor() == 0) {
963+
if (mma.isAmpere()) {
912964
return ampereDotToLinearLayout(shape, *this);
913965
}
914966
} else if (auto dpasLayout =

lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "mlir/IR/Dominance.h"
2+
#include "mlir/Support/LLVM.h"
23
#include "mlir/Transforms/Passes.h"
34
#include "triton/Analysis/Utility.h"
45
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -14,8 +15,52 @@ namespace gpu {
1415
#define GEN_PASS_DEF_TRITONGPUCOMBINETENSORSELECTANDIF
1516
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
1617

17-
// Return true if the select could be merged into the If without breaking SSA
18-
// rules.
18+
/// The user of select maybe inside either the ThenRegion or ElseRegion of
19+
/// the scf.if. So, canonicalize user of select in scf.if first.
20+
static void canonicalizeSelectUsersInSCFIf(ModuleOp input) {
21+
llvm::MapVector<std::pair<Value, Value>, SmallVector<Operation *>>
22+
usersNeedreplaced;
23+
input.walk([&](arith::SelectOp selectOp) {
24+
auto *parentBlock = selectOp->getBlock();
25+
Value condition = selectOp.getOperand(0);
26+
Value trueVal = selectOp.getOperand(1);
27+
Value falseVal = selectOp.getOperand(2);
28+
Value resVal = selectOp.getResult();
29+
for (auto *condUser : condition.getUsers()) {
30+
if (!llvm::isa<scf::IfOp>(condUser))
31+
continue;
32+
scf::IfOp ifOp = llvm::cast<scf::IfOp>(condUser);
33+
for (auto *resUser : resVal.getUsers()) {
34+
if (ifOp->isProperAncestor(resUser)) {
35+
if (ifOp.getThenRegion().findAncestorOpInRegion(*resUser) !=
36+
nullptr) {
37+
// The user is inside the ThenRegion of the scf.if.
38+
usersNeedreplaced[std::make_pair(resVal, trueVal)].push_back(
39+
resUser);
40+
} else {
41+
// The user is inside the ElseRegion of the scf.if.
42+
usersNeedreplaced[std::make_pair(resVal, falseVal)].push_back(
43+
resUser);
44+
}
45+
}
46+
}
47+
}
48+
});
49+
50+
// Replace the operand of user.
51+
for (auto [replacedSrcAndDst, users] :
52+
llvm::make_early_inc_range(usersNeedreplaced)) {
53+
Value srcVal = replacedSrcAndDst.first;
54+
Value dstVal = replacedSrcAndDst.second;
55+
for (Operation *user : llvm::make_early_inc_range(users)) {
56+
srcVal.replaceUsesWithIf(
57+
dstVal, [&](OpOperand &use) { return use.getOwner() == user; });
58+
}
59+
}
60+
}
61+
62+
/// Return true if the select could be merged into the If without breaking SSA
63+
/// rules.
1964
static bool canMergeIntoIf(arith::SelectOp selectOp, scf::IfOp ifOp,
2065
DominanceInfo &dom) {
2166
// If needs to be dominated by the select.
@@ -38,10 +83,11 @@ class CombineTensorSelectAndIfPass
3883
void runOnOperation() override {
3984
MLIRContext *context = &getContext();
4085
ModuleOp m = getOperation();
41-
DominanceInfo dom(m);
86+
canonicalizeSelectUsersInSCFIf(m);
4287

4388
// Go over the arith.select ops, look if there is an if
4489
// with the same condition.
90+
DominanceInfo dom(m);
4591
llvm::MapVector<scf::IfOp, SmallVector<arith::SelectOp>> selectToIf;
4692
m.walk([&](arith::SelectOp selectOp) {
4793
// Look if there is an if in the same block, with the same condition.

0 commit comments

Comments
 (0)