@@ -534,30 +534,38 @@ struct PrefetchOpConversion
534
534
};
535
535
opIdx = getOpIdx ();
536
536
537
- auto repetitions = dpasLayout.getDPASRepetitions (tensorShape, opIdx);
538
- // getDPASRepetitions always return rank 3 size.
537
+ SmallVector<int64_t > repetitions =
538
+ dpasLayout.getDPASRepetitions (tensorShape, opIdx);
539
+ assert (repetitions.size () == 3 &&
540
+ " getDPASRepetitions always return rank 3 size" );
539
541
SmallVector<unsigned > numReps{repetitions.begin () + 1 , repetitions.end ()};
540
542
SmallVector<int64_t , 2 > shardTensorShape;
541
- if (opIdx == DpasEncodingAttr::OpIdx::OperandA) {
543
+ switch (opIdx) {
544
+ case DpasEncodingAttr::OpIdx::OperandA: {
542
545
auto opAShape = dpasLayout.getShapeA ();
543
546
shardTensorShape = {std::min<unsigned >(tensorShape[0 ], opAShape[0 ]),
544
547
tensorShape[1 ]};
545
548
warpsPerCTA[1 ] = 1 ;
546
549
repCluster[1 ] = 1 ;
547
550
numReps[1 ] = 1 ;
548
- } else {
551
+ } break ;
552
+ case DpasEncodingAttr::OpIdx::OperandB: {
549
553
auto opBShape = dpasLayout.getShapeB ();
550
554
shardTensorShape = {tensorShape[0 ],
551
555
std::min<unsigned >(tensorShape[1 ], opBShape[1 ])};
552
556
warpsPerCTA[0 ] = 1 ;
553
557
repCluster[0 ] = 1 ;
554
558
numReps[0 ] = 1 ;
559
+ } break ;
560
+ case DpasEncodingAttr::OpIdx::OperandC: {
561
+ llvm_unreachable (" unexpected OpIdx::OperandC" );
562
+ } break ;
555
563
}
556
564
557
565
auto ptrType = cast<PointerType>(tensorOfPointers.getElementType ());
558
566
Type elementType = ptrType.getPointeeType ();
559
- RankedTensorType tensorType = RankedTensorType::get (
560
- shardTensorShape, elementType, tensorOfPointers.getEncoding ());
567
+ auto tensorType = RankedTensorType::get (shardTensorShape, elementType,
568
+ tensorOfPointers.getEncoding ());
561
569
562
570
SmallVector<unsigned , 2 > prefetchShape =
563
571
get2DPrefetchShapePerWarp (tensorType);
@@ -566,6 +574,7 @@ struct PrefetchOpConversion
566
574
unsigned maskConstancyHor = std::numeric_limits<unsigned >::max (),
567
575
maskConstancyVer = std::numeric_limits<unsigned >::max ();
568
576
if (mask) {
577
+ // No need to check the constancy of scalar mask.
569
578
if (auto maskTy = dyn_cast_or_null<RankedTensorType>(mask.getType ())) {
570
579
auto axisInfo = const_cast <triton::intel::ModuleAxisInfoAnalysis &>(
571
580
axisAnalysisPass)
@@ -578,9 +587,6 @@ struct PrefetchOpConversion
578
587
maskConstancyVer = 1 ;
579
588
}
580
589
}
581
- /* else {
582
- // scalar mask. No need to check the constancy.
583
- }*/
584
590
}
585
591
prefetchShape = {std::min<unsigned >(prefetchShape[0 ], maskConstancyVer),
586
592
std::min<unsigned >(prefetchShape[1 ], maskConstancyHor)};
0 commit comments