@@ -23,24 +23,6 @@ using namespace mlir::triton::gpu::intel;
23
23
24
24
namespace {
25
25
26
- // Returns a Value for the format string, which you can reuse. Writes the byte
27
- // count for the string to |formatStrByteCount| if not null.
28
- Value llPrintf (StringRef msg, ValueRange args, ArrayRef<bool > isSigned,
29
- ConversionPatternRewriter &rewriter,
30
- const triton::intel::TargetInfo &targetInfo,
31
- int *formatStrByteCount = nullptr ) {
32
- assert (!msg.empty () && " printf with empty string not supported" );
33
- llvm::SmallString<64 > msgNewline (msg);
34
- msgNewline.push_back (' \n ' );
35
- msgNewline.push_back (' \0 ' );
36
- Value msgValue = targetInfo.getGlobalStringStart (
37
- rewriter.getUnknownLoc (), rewriter, " printfFormat_" , msgNewline,
38
- /* addressSpace=*/ TritonGEN::kUniformConstant );
39
- targetInfo.printf (rewriter, msgValue, msgNewline.size_in_bytes (), args,
40
- isSigned);
41
- return msgValue;
42
- }
43
-
44
26
Value maybeAnd (RewriterBase &rewriter, Location loc, Value a, Value b) {
45
27
auto tb = TritonLLVMOpBuilder (loc, rewriter);
46
28
if (a && b) {
@@ -81,71 +63,6 @@ Value emitRedundantThreadPredicate(
81
63
return pred;
82
64
}
83
65
84
- // Return the mask for the unique data accessed by given tensor type.
85
- // Used to mask out the redundant data accessed by threads.
86
- Value redundantDataMask (Type valueTy, ConversionPatternRewriter &rewriter,
87
- Location loc,
88
- const triton::intel::TargetInfo &targetInfo) {
89
- auto b = TritonLLVMOpBuilder (loc, rewriter);
90
-
91
- Value mask = b.true_val ();
92
- auto clusterCTAId = targetInfo.getClusterCTAId (rewriter, loc);
93
-
94
- auto tensorTy = dyn_cast<RankedTensorType>(valueTy);
95
- if (tensorTy) {
96
- // To remove this use, port https://github.com/triton-lang/triton/pull/5432
97
- // to the INTELGPU dialect
98
- auto layout = cast<DistributedEncodingTrait>(tensorTy.getEncoding ());
99
- auto shape = tensorTy.getShape ();
100
- auto [laneId, warpId] = getLaneAndWarpId (rewriter, loc);
101
- auto kLane = StringAttr::get (rewriter.getContext (), " lane" );
102
- auto kWarp = StringAttr::get (rewriter.getContext (), " warp" );
103
- auto maskLane =
104
- std::get<1 >(delinearize (rewriter, loc, layout, shape, kLane , laneId));
105
- auto maskWarp =
106
- std::get<1 >(delinearize (rewriter, loc, layout, shape, kWarp , warpId));
107
- mask = b.and_ (maskLane, maskWarp);
108
-
109
- // Do not write duplicated data when multicast is enabled
110
- if (triton::gpu::getNumCTAs (layout) > 1 ) {
111
- auto _0 = b.i32_val (0 );
112
- auto CTAsPerCGA = triton::gpu::getCTAsPerCGA (layout);
113
- auto CTASplitNum = triton::gpu::getCTASplitNum (layout);
114
- auto CTAOrder = triton::gpu::getCTAOrder (layout);
115
-
116
- auto multiDimClusterCTAId =
117
- delinearize (rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder);
118
-
119
- auto rank = tensorTy.getRank ();
120
- for (unsigned dim = 0 ; dim < rank; ++dim) {
121
- // Skip when multicast is not enabled in this dimension
122
- if (CTAsPerCGA[dim] == CTASplitNum[dim])
123
- continue ;
124
- // This wrapping rule must be consistent with emitCTAOffsetForLayout
125
- unsigned splitNum = std::min<unsigned >(shape[dim], CTASplitNum[dim]);
126
- Value repId = b.udiv (multiDimClusterCTAId[dim], b.i32_val (splitNum));
127
- // Consider the example where CTAsPerCGA = [4] and CTASplitNum = [2]:
128
- // CTA0 and CTA2 holds data of block0,
129
- // CTA1 and CTA3 holds data of block1.
130
- // Only CTA0 and CTA1 are expected to write while CTA2 and CTA3 should
131
- // be masked. We add the following mask:
132
- // multiDimClusterCTAId[dim] / splitNum == 0
133
- // Actually in all existing cases of multicast, splitNum is always 1.
134
- // The mask is equivalent to:
135
- // multiDimClusterCTAId[dim] == 0
136
- mask = b.and_ (mask, b.icmp_eq (repId, _0));
137
- }
138
- }
139
- } else {
140
- // If the tensor is not ranked, then it is a scalar and only thread 0 of
141
- // CTA0 can write
142
- mask = b.and_ (mask, b.icmp_eq (clusterCTAId, b.i32_val (0 )));
143
- auto tid = getThreadId (rewriter, loc);
144
- mask = b.and_ (mask, b.icmp_eq (tid, b.i32_val (0 )));
145
- }
146
- return mask;
147
- }
148
-
149
66
// / Holds the values related to a block pointer.
150
67
// / It includes the base pointer, base width and height, row and column
151
68
// / stride, and offset base for X and Y.
@@ -2220,12 +2137,6 @@ struct StoreOpConversion
2220
2137
valueElems.size () == maskElems.size () && " Mask size mismatch" );
2221
2138
2222
2139
auto freeVarMasks = getFreeVariableMasks (valueTy);
2223
- #if 0
2224
- for (auto mask : freeVarMasks) {
2225
- llvm::errs() << mask.first << " = " << mask.second << " ("
2226
- << std::to_string(1 << mask.second) << ")\n";
2227
- }
2228
- #endif
2229
2140
Value threadPred =
2230
2141
emitRedundantThreadPredicate (freeVarMasks, rewriter, loc, targetInfo);
2231
2142
uint32_t regMask = freeVarMasks[str_attr (" register" )];
@@ -2291,22 +2202,6 @@ struct StoreOpConversion
2291
2202
vecWord = b.insert_element (vecTy, vecWord, llWord, b.i32_val (index ));
2292
2203
}
2293
2204
2294
- #if 0
2295
- auto vecTestElem = b.extract_element(valArgTy, vecWord, b.i32_val(0));
2296
- auto testElemBitcast = b.bitcast(vecTestElem, wordTy);
2297
- auto testElemVal =
2298
- b.extract_element(valueElemTy, testElemBitcast, b.i32_val(0));
2299
-
2300
- Value addrElem = b.bitcast(ptrElems[vecStart], ptr_ty(ctx, 1 /*global*/));
2301
-
2302
- auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
2303
- llPrintf("warp %d lane %d mask %d & %d = %d addr %p val %f vec %d",
2304
- {warpId, laneId, threadPred, maskElems[vecStart], maskVal, addrElem,
2305
- testElemVal, b.i32_val(vecStart)},
2306
- {true, true, true, true, true, true, true, true}, rewriter,
2307
- targetInfo);
2308
- #endif
2309
-
2310
2205
// Create a predicated store operation.
2311
2206
LLVM::intel::createPredicatedBlock (rewriter, loc, maskVal, [&] {
2312
2207
Value addrElem =
@@ -2680,7 +2575,7 @@ struct AtomicRMWOpConversion
2680
2575
}
2681
2576
}
2682
2577
2683
- ret = endBlock ? endBlock->getArgument (0 ) : ret;
2578
+ ret = endBlock ? endBlock->getArgument (0 ) : ret;
2684
2579
assert (ret);
2685
2580
Type retType = (!tensorTy || vec == 1 ) ? valueElemTy : vecTy;
2686
2581
ret = b.bitcast (ret, retType);
0 commit comments