Skip to content

Commit 78b26b8

Browse files
address review comment
1 parent d36a089 commit 78b26b8

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -530,30 +530,38 @@ struct PrefetchOpConversion
530530
};
531531
opIdx = getOpIdx();
532532

533-
auto repetitions = dpasLayout.getDPASRepetitions(tensorShape, opIdx);
534-
// getDPASRepetitions always return rank 3 size.
533+
SmallVector<int64_t> repetitions =
534+
dpasLayout.getDPASRepetitions(tensorShape, opIdx);
535+
assert(repetitions.size() == 3 &&
536+
"getDPASRepetitions always return rank 3 size");
535537
SmallVector<unsigned> numReps{repetitions.begin() + 1, repetitions.end()};
536538
SmallVector<int64_t, 2> shardTensorShape;
537-
if (opIdx == DpasEncodingAttr::OpIdx::OperandA) {
539+
switch (opIdx) {
540+
case DpasEncodingAttr::OpIdx::OperandA: {
538541
auto opAShape = dpasLayout.getShapeA();
539542
shardTensorShape = {std::min<unsigned>(tensorShape[0], opAShape[0]),
540543
tensorShape[1]};
541544
warpsPerCTA[1] = 1;
542545
repCluster[1] = 1;
543546
numReps[1] = 1;
544-
} else {
547+
} break;
548+
case DpasEncodingAttr::OpIdx::OperandB: {
545549
auto opBShape = dpasLayout.getShapeB();
546550
shardTensorShape = {tensorShape[0],
547551
std::min<unsigned>(tensorShape[1], opBShape[1])};
548552
warpsPerCTA[0] = 1;
549553
repCluster[0] = 1;
550554
numReps[0] = 1;
555+
} break;
556+
case DpasEncodingAttr::OpIdx::OperandC: {
557+
llvm_unreachable("unexpected OpIdx::OperandC");
558+
} break;
551559
}
552560

553561
auto ptrType = cast<PointerType>(tensorOfPointers.getElementType());
554562
Type elementType = ptrType.getPointeeType();
555-
RankedTensorType tensorType = RankedTensorType::get(
556-
shardTensorShape, elementType, tensorOfPointers.getEncoding());
563+
auto tensorType = RankedTensorType::get(shardTensorShape, elementType,
564+
tensorOfPointers.getEncoding());
557565

558566
SmallVector<unsigned, 2> prefetchShape =
559567
get2DPrefetchShapePerWarp(tensorType);
@@ -562,6 +570,7 @@ struct PrefetchOpConversion
562570
unsigned maskConstancyHor = std::numeric_limits<unsigned>::max(),
563571
maskConstancyVer = std::numeric_limits<unsigned>::max();
564572
if (mask) {
573+
// No need to check the constancy of scalar mask.
565574
if (auto maskTy = dyn_cast_or_null<RankedTensorType>(mask.getType())) {
566575
auto axisInfo = const_cast<triton::intel::ModuleAxisInfoAnalysis &>(
567576
axisAnalysisPass)
@@ -574,9 +583,6 @@ struct PrefetchOpConversion
574583
maskConstancyVer = 1;
575584
}
576585
}
577-
/*else {
578-
// scalar mask. No need to check the constancy.
579-
}*/
580586
}
581587
prefetchShape = {std::min<unsigned>(prefetchShape[0], maskConstancyVer),
582588
std::min<unsigned>(prefetchShape[1], maskConstancyHor)};

0 commit comments

Comments
 (0)