Skip to content

Commit 75eb507

Browse files
address review comment
1 parent 33016db commit 75eb507

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -534,30 +534,38 @@ struct PrefetchOpConversion
534534
};
535535
opIdx = getOpIdx();
536536

537-
auto repetitions = dpasLayout.getDPASRepetitions(tensorShape, opIdx);
538-
// getDPASRepetitions always return rank 3 size.
537+
SmallVector<int64_t> repetitions =
538+
dpasLayout.getDPASRepetitions(tensorShape, opIdx);
539+
assert(repetitions.size() == 3 &&
540+
"getDPASRepetitions always return rank 3 size");
539541
SmallVector<unsigned> numReps{repetitions.begin() + 1, repetitions.end()};
540542
SmallVector<int64_t, 2> shardTensorShape;
541-
if (opIdx == DpasEncodingAttr::OpIdx::OperandA) {
543+
switch (opIdx) {
544+
case DpasEncodingAttr::OpIdx::OperandA: {
542545
auto opAShape = dpasLayout.getShapeA();
543546
shardTensorShape = {std::min<unsigned>(tensorShape[0], opAShape[0]),
544547
tensorShape[1]};
545548
warpsPerCTA[1] = 1;
546549
repCluster[1] = 1;
547550
numReps[1] = 1;
548-
} else {
551+
} break;
552+
case DpasEncodingAttr::OpIdx::OperandB: {
549553
auto opBShape = dpasLayout.getShapeB();
550554
shardTensorShape = {tensorShape[0],
551555
std::min<unsigned>(tensorShape[1], opBShape[1])};
552556
warpsPerCTA[0] = 1;
553557
repCluster[0] = 1;
554558
numReps[0] = 1;
559+
} break;
560+
case DpasEncodingAttr::OpIdx::OperandC: {
561+
llvm_unreachable("unexpected OpIdx::OperandC");
562+
} break;
555563
}
556564

557565
auto ptrType = cast<PointerType>(tensorOfPointers.getElementType());
558566
Type elementType = ptrType.getPointeeType();
559-
RankedTensorType tensorType = RankedTensorType::get(
560-
shardTensorShape, elementType, tensorOfPointers.getEncoding());
567+
auto tensorType = RankedTensorType::get(shardTensorShape, elementType,
568+
tensorOfPointers.getEncoding());
561569

562570
SmallVector<unsigned, 2> prefetchShape =
563571
get2DPrefetchShapePerWarp(tensorType);
@@ -566,6 +574,7 @@ struct PrefetchOpConversion
566574
unsigned maskConstancyHor = std::numeric_limits<unsigned>::max(),
567575
maskConstancyVer = std::numeric_limits<unsigned>::max();
568576
if (mask) {
577+
// No need to check the constancy of scalar mask.
569578
if (auto maskTy = dyn_cast_or_null<RankedTensorType>(mask.getType())) {
570579
auto axisInfo = const_cast<triton::intel::ModuleAxisInfoAnalysis &>(
571580
axisAnalysisPass)
@@ -578,9 +587,6 @@ struct PrefetchOpConversion
578587
maskConstancyVer = 1;
579588
}
580589
}
581-
/*else {
582-
// scalar mask. No need to check the constancy.
583-
}*/
584590
}
585591
prefetchShape = {std::min<unsigned>(prefetchShape[0], maskConstancyVer),
586592
std::min<unsigned>(prefetchShape[1], maskConstancyHor)};

0 commit comments

Comments
 (0)