Skip to content

Commit 503bcbe

Browse files
alexbadenchengjunlu
authored andcommitted
[Intel] Rework load-store redundant data masking 5/5 (remove debug code
and old implementation)
1 parent 945aaf5 commit 503bcbe

File tree

1 file changed

+1
-106
lines changed

1 file changed

+1
-106
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 1 addition & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,6 @@ using namespace mlir::triton::gpu::intel;
2323

2424
namespace {
2525

26-
// Returns a Value for the format string, which you can reuse. Writes the byte
27-
// count for the string to |formatStrByteCount| if not null.
28-
Value llPrintf(StringRef msg, ValueRange args, ArrayRef<bool> isSigned,
29-
ConversionPatternRewriter &rewriter,
30-
const triton::intel::TargetInfo &targetInfo,
31-
int *formatStrByteCount = nullptr) {
32-
assert(!msg.empty() && "printf with empty string not supported");
33-
llvm::SmallString<64> msgNewline(msg);
34-
msgNewline.push_back('\n');
35-
msgNewline.push_back('\0');
36-
Value msgValue = targetInfo.getGlobalStringStart(
37-
rewriter.getUnknownLoc(), rewriter, "printfFormat_", msgNewline,
38-
/*addressSpace=*/TritonGEN::kUniformConstant);
39-
targetInfo.printf(rewriter, msgValue, msgNewline.size_in_bytes(), args,
40-
isSigned);
41-
return msgValue;
42-
}
43-
4426
Value maybeAnd(RewriterBase &rewriter, Location loc, Value a, Value b) {
4527
auto tb = TritonLLVMOpBuilder(loc, rewriter);
4628
if (a && b) {
@@ -81,71 +63,6 @@ Value emitRedundantThreadPredicate(
8163
return pred;
8264
}
8365

84-
// Return the mask for the unique data accessed by given tensor type.
85-
// Used to mask out the redundant data accessed by threads.
86-
Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter,
87-
Location loc,
88-
const triton::intel::TargetInfo &targetInfo) {
89-
auto b = TritonLLVMOpBuilder(loc, rewriter);
90-
91-
Value mask = b.true_val();
92-
auto clusterCTAId = targetInfo.getClusterCTAId(rewriter, loc);
93-
94-
auto tensorTy = dyn_cast<RankedTensorType>(valueTy);
95-
if (tensorTy) {
96-
// To remove this use, port https://github.com/triton-lang/triton/pull/5432
97-
// to the INTELGPU dialect
98-
auto layout = cast<DistributedEncodingTrait>(tensorTy.getEncoding());
99-
auto shape = tensorTy.getShape();
100-
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
101-
auto kLane = StringAttr::get(rewriter.getContext(), "lane");
102-
auto kWarp = StringAttr::get(rewriter.getContext(), "warp");
103-
auto maskLane =
104-
std::get<1>(delinearize(rewriter, loc, layout, shape, kLane, laneId));
105-
auto maskWarp =
106-
std::get<1>(delinearize(rewriter, loc, layout, shape, kWarp, warpId));
107-
mask = b.and_(maskLane, maskWarp);
108-
109-
// Do not write duplicated data when multicast is enabled
110-
if (triton::gpu::getNumCTAs(layout) > 1) {
111-
auto _0 = b.i32_val(0);
112-
auto CTAsPerCGA = triton::gpu::getCTAsPerCGA(layout);
113-
auto CTASplitNum = triton::gpu::getCTASplitNum(layout);
114-
auto CTAOrder = triton::gpu::getCTAOrder(layout);
115-
116-
auto multiDimClusterCTAId =
117-
delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder);
118-
119-
auto rank = tensorTy.getRank();
120-
for (unsigned dim = 0; dim < rank; ++dim) {
121-
// Skip when multicast is not enabled in this dimension
122-
if (CTAsPerCGA[dim] == CTASplitNum[dim])
123-
continue;
124-
// This wrapping rule must be consistent with emitCTAOffsetForLayout
125-
unsigned splitNum = std::min<unsigned>(shape[dim], CTASplitNum[dim]);
126-
Value repId = b.udiv(multiDimClusterCTAId[dim], b.i32_val(splitNum));
127-
// Consider the example where CTAsPerCGA = [4] and CTASplitNum = [2]:
128-
// CTA0 and CTA2 holds data of block0,
129-
// CTA1 and CTA3 holds data of block1.
130-
// Only CTA0 and CTA1 are expected to write while CTA2 and CTA3 should
131-
// be masked. We add the following mask:
132-
// multiDimClusterCTAId[dim] / splitNum == 0
133-
// Actually in all existing cases of multicast, splitNum is always 1.
134-
// The mask is equivalent to:
135-
// multiDimClusterCTAId[dim] == 0
136-
mask = b.and_(mask, b.icmp_eq(repId, _0));
137-
}
138-
}
139-
} else {
140-
// If the tensor is not ranked, then it is a scalar and only thread 0 of
141-
// CTA0 can write
142-
mask = b.and_(mask, b.icmp_eq(clusterCTAId, b.i32_val(0)));
143-
auto tid = getThreadId(rewriter, loc);
144-
mask = b.and_(mask, b.icmp_eq(tid, b.i32_val(0)));
145-
}
146-
return mask;
147-
}
148-
14966
/// Holds the values related to a block pointer.
15067
/// It includes the base pointer, base width and height, row and column
15168
/// stride, and offset base for X and Y.
@@ -2220,12 +2137,6 @@ struct StoreOpConversion
22202137
valueElems.size() == maskElems.size() && "Mask size mismatch");
22212138

22222139
auto freeVarMasks = getFreeVariableMasks(valueTy);
2223-
#if 0
2224-
for (auto mask : freeVarMasks) {
2225-
llvm::errs() << mask.first << " = " << mask.second << " ("
2226-
<< std::to_string(1 << mask.second) << ")\n";
2227-
}
2228-
#endif
22292140
Value threadPred =
22302141
emitRedundantThreadPredicate(freeVarMasks, rewriter, loc, targetInfo);
22312142
uint32_t regMask = freeVarMasks[str_attr("register")];
@@ -2291,22 +2202,6 @@ struct StoreOpConversion
22912202
vecWord = b.insert_element(vecTy, vecWord, llWord, b.i32_val(index));
22922203
}
22932204

2294-
#if 0
2295-
auto vecTestElem = b.extract_element(valArgTy, vecWord, b.i32_val(0));
2296-
auto testElemBitcast = b.bitcast(vecTestElem, wordTy);
2297-
auto testElemVal =
2298-
b.extract_element(valueElemTy, testElemBitcast, b.i32_val(0));
2299-
2300-
Value addrElem = b.bitcast(ptrElems[vecStart], ptr_ty(ctx, 1 /*global*/));
2301-
2302-
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
2303-
llPrintf("warp %d lane %d mask %d & %d = %d addr %p val %f vec %d",
2304-
{warpId, laneId, threadPred, maskElems[vecStart], maskVal, addrElem,
2305-
testElemVal, b.i32_val(vecStart)},
2306-
{true, true, true, true, true, true, true, true}, rewriter,
2307-
targetInfo);
2308-
#endif
2309-
23102205
// Create a predicated store operation.
23112206
LLVM::intel::createPredicatedBlock(rewriter, loc, maskVal, [&] {
23122207
Value addrElem =
@@ -2680,7 +2575,7 @@ struct AtomicRMWOpConversion
26802575
}
26812576
}
26822577

2683-
ret = endBlock ? endBlock->getArgument(0) : ret;
2578+
ret = endBlock ? endBlock->getArgument(0) : ret;
26842579
assert(ret);
26852580
Type retType = (!tensorTy || vec == 1) ? valueElemTy : vecTy;
26862581
ret = b.bitcast(ret, retType);

0 commit comments

Comments
 (0)