@@ -530,30 +530,38 @@ struct PrefetchOpConversion
530
530
};
531
531
opIdx = getOpIdx ();
532
532
533
- auto repetitions = dpasLayout.getDPASRepetitions (tensorShape, opIdx);
534
- // getDPASRepetitions always return rank 3 size.
533
+ SmallVector<int64_t > repetitions =
534
+ dpasLayout.getDPASRepetitions (tensorShape, opIdx);
535
+ assert (repetitions.size () == 3 &&
536
+ " getDPASRepetitions always return rank 3 size" );
535
537
SmallVector<unsigned > numReps{repetitions.begin () + 1 , repetitions.end ()};
536
538
SmallVector<int64_t , 2 > shardTensorShape;
537
- if (opIdx == DpasEncodingAttr::OpIdx::OperandA) {
539
+ switch (opIdx) {
540
+ case DpasEncodingAttr::OpIdx::OperandA: {
538
541
auto opAShape = dpasLayout.getShapeA ();
539
542
shardTensorShape = {std::min<unsigned >(tensorShape[0 ], opAShape[0 ]),
540
543
tensorShape[1 ]};
541
544
warpsPerCTA[1 ] = 1 ;
542
545
repCluster[1 ] = 1 ;
543
546
numReps[1 ] = 1 ;
544
- } else {
547
+ } break ;
548
+ case DpasEncodingAttr::OpIdx::OperandB: {
545
549
auto opBShape = dpasLayout.getShapeB ();
546
550
shardTensorShape = {tensorShape[0 ],
547
551
std::min<unsigned >(tensorShape[1 ], opBShape[1 ])};
548
552
warpsPerCTA[0 ] = 1 ;
549
553
repCluster[0 ] = 1 ;
550
554
numReps[0 ] = 1 ;
555
+ } break ;
556
+ case DpasEncodingAttr::OpIdx::OperandC: {
557
+ llvm_unreachable (" unexpected OpIdx::OperandC" );
558
+ } break ;
551
559
}
552
560
553
561
auto ptrType = cast<PointerType>(tensorOfPointers.getElementType ());
554
562
Type elementType = ptrType.getPointeeType ();
555
- RankedTensorType tensorType = RankedTensorType::get (
556
- shardTensorShape, elementType, tensorOfPointers.getEncoding ());
563
+ auto tensorType = RankedTensorType::get (shardTensorShape, elementType,
564
+ tensorOfPointers.getEncoding ());
557
565
558
566
SmallVector<unsigned , 2 > prefetchShape =
559
567
get2DPrefetchShapePerWarp (tensorType);
@@ -562,6 +570,7 @@ struct PrefetchOpConversion
562
570
unsigned maskConstancyHor = std::numeric_limits<unsigned >::max (),
563
571
maskConstancyVer = std::numeric_limits<unsigned >::max ();
564
572
if (mask) {
573
+ // No need to check the constancy of scalar mask.
565
574
if (auto maskTy = dyn_cast_or_null<RankedTensorType>(mask.getType ())) {
566
575
auto axisInfo = const_cast <triton::intel::ModuleAxisInfoAnalysis &>(
567
576
axisAnalysisPass)
@@ -574,9 +583,6 @@ struct PrefetchOpConversion
574
583
maskConstancyVer = 1 ;
575
584
}
576
585
}
577
- /* else {
578
- // scalar mask. No need to check the constancy.
579
- }*/
580
586
}
581
587
prefetchShape = {std::min<unsigned >(prefetchShape[0 ], maskConstancyVer),
582
588
std::min<unsigned >(prefetchShape[1 ], maskConstancyHor)};
0 commit comments