Skip to content

Commit 598c59f

Browse files
committed
[LV] Vectorize FMax w/o fast-math flags.
Add a new recurrence kind for FMax reductions without fast-math flags and a corresponding VPlan transform tries to vectorize without fast-math flags. To do so, a new FindFirstIV reduction is added that tracks the first indices that contain the maximum values. This serves as tie breaker if the partial reduction vector contains NaNs or signed zeros. After the loop, the first index is used to retrieve the final max value after vectorization from the vector containing the partial maximum values
1 parent ce81eb5 commit 598c59f

17 files changed

+569
-90
lines changed

llvm/include/llvm/Analysis/IVDescriptors.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ enum class RecurKind {
4747
FMul, ///< Product of floats.
4848
FMin, ///< FP min implemented in terms of select(cmp()).
4949
FMax, ///< FP max implemented in terms of select(cmp()).
50+
FMaxNoFMFs, ///< FP max implemented in terms of select(cmp()), but without
51+
///any fast-math flags. Users need to handle NaNs and signed zeros when generating code.
5052
FMinimum, ///< FP min with llvm.minimum semantics
5153
FMaximum, ///< FP max with llvm.maximum semantics
5254
FMinimumNum, ///< FP min with llvm.minimumnum semantics
@@ -250,8 +252,9 @@ class RecurrenceDescriptor {
250252
/// Returns true if the recurrence kind is a floating-point min/max kind.
251253
static bool isFPMinMaxRecurrenceKind(RecurKind Kind) {
252254
return Kind == RecurKind::FMin || Kind == RecurKind::FMax ||
253-
Kind == RecurKind::FMinimum || Kind == RecurKind::FMaximum ||
254-
Kind == RecurKind::FMinimumNum || Kind == RecurKind::FMaximumNum;
255+
Kind == RecurKind::FMaxNoFMFs || Kind == RecurKind::FMinimum ||
256+
Kind == RecurKind::FMaximum || Kind == RecurKind::FMinimumNum ||
257+
Kind == RecurKind::FMaximumNum;
255258
}
256259

257260
/// Returns true if the recurrence kind is any min/max kind.

llvm/lib/Analysis/IVDescriptors.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -817,7 +817,8 @@ RecurrenceDescriptor::isMinMaxPattern(Instruction *I, RecurKind Kind,
817817
if (match(I, m_OrdOrUnordFMin(m_Value(), m_Value())))
818818
return InstDesc(Kind == RecurKind::FMin, I);
819819
if (match(I, m_OrdOrUnordFMax(m_Value(), m_Value())))
820-
return InstDesc(Kind == RecurKind::FMax, I);
820+
return InstDesc(Kind == RecurKind::FMax || Kind == RecurKind::FMaxNoFMFs,
821+
I);
821822
if (match(I, m_FMinNum(m_Value(), m_Value())))
822823
return InstDesc(Kind == RecurKind::FMin, I);
823824
if (match(I, m_FMaxNum(m_Value(), m_Value())))
@@ -939,10 +940,15 @@ RecurrenceDescriptor::InstDesc RecurrenceDescriptor::isRecurrenceInstr(
939940
m_Intrinsic<Intrinsic::minimumnum>(m_Value(), m_Value())) ||
940941
match(I, m_Intrinsic<Intrinsic::maximumnum>(m_Value(), m_Value()));
941942
};
942-
if (isIntMinMaxRecurrenceKind(Kind) ||
943-
(HasRequiredFMF() && isFPMinMaxRecurrenceKind(Kind)))
943+
if (isIntMinMaxRecurrenceKind(Kind))
944944
return isMinMaxPattern(I, Kind, Prev);
945-
else if (isFMulAddIntrinsic(I))
945+
if (isFPMinMaxRecurrenceKind(Kind)) {
946+
if (HasRequiredFMF())
947+
return isMinMaxPattern(I, Kind, Prev);
948+
if ((Kind == RecurKind::FMax || Kind == RecurKind::FMaxNoFMFs) &&
949+
isMinMaxPattern(I, Kind, Prev).isRecurrence())
950+
return InstDesc(I, RecurKind::FMaxNoFMFs);
951+
} else if (isFMulAddIntrinsic(I))
946952
return InstDesc(Kind == RecurKind::FMulAdd, I,
947953
I->hasAllowReassoc() ? nullptr : I);
948954
return InstDesc(false, I);
@@ -1205,6 +1211,7 @@ unsigned RecurrenceDescriptor::getOpcode(RecurKind Kind) {
12051211
case RecurKind::UMin:
12061212
return Instruction::ICmp;
12071213
case RecurKind::FMax:
1214+
case RecurKind::FMaxNoFMFs:
12081215
case RecurKind::FMin:
12091216
case RecurKind::FMaximum:
12101217
case RecurKind::FMinimum:

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,7 @@ constexpr Intrinsic::ID llvm::getReductionIntrinsicID(RecurKind RK) {
937937
return Intrinsic::vector_reduce_umax;
938938
case RecurKind::UMin:
939939
return Intrinsic::vector_reduce_umin;
940+
case RecurKind::FMaxNoFMFs:
940941
case RecurKind::FMax:
941942
return Intrinsic::vector_reduce_fmax;
942943
case RecurKind::FMin:
@@ -1085,6 +1086,7 @@ CmpInst::Predicate llvm::getMinMaxReductionPredicate(RecurKind RK) {
10851086
case RecurKind::FMin:
10861087
return CmpInst::FCMP_OLT;
10871088
case RecurKind::FMax:
1089+
case RecurKind::FMaxNoFMFs:
10881090
return CmpInst::FCMP_OGT;
10891091
// We do not add FMinimum/FMaximum recurrence kind here since there is no
10901092
// equivalent predicate which compares signed zeroes according to the
@@ -1307,6 +1309,7 @@ Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
13071309
case RecurKind::UMax:
13081310
case RecurKind::UMin:
13091311
case RecurKind::FMax:
1312+
case RecurKind::FMaxNoFMFs:
13101313
case RecurKind::FMin:
13111314
case RecurKind::FMinimum:
13121315
case RecurKind::FMaximum:

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4459,8 +4459,13 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
44594459
ElementCount VF) const {
44604460
// Cross iteration phis such as reductions need special handling and are
44614461
// currently unsupported.
4462-
if (any_of(OrigLoop->getHeader()->phis(),
4463-
[&](PHINode &Phi) { return Legal->isFixedOrderRecurrence(&Phi); }))
4462+
if (any_of(OrigLoop->getHeader()->phis(), [&](PHINode &Phi) {
4463+
return Legal->isFixedOrderRecurrence(&Phi) ||
4464+
(Legal->isReductionVariable(&Phi) &&
4465+
Legal->getReductionVars()
4466+
.find(&Phi)
4467+
->second.getRecurrenceKind() == RecurKind::FMaxNoFMFs);
4468+
}))
44644469
return false;
44654470

44664471
// Phis with uses outside of the loop require special handling and are
@@ -8909,6 +8914,9 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
89098914

89108915
// Adjust the recipes for any inloop reductions.
89118916
adjustRecipesForReductions(Plan, RecipeBuilder, Range.Start);
8917+
if (!VPlanTransforms::runPass(
8918+
VPlanTransforms::handleFMaxReductionsWithoutFastMath, *Plan))
8919+
return nullptr;
89128920

89138921
// Transform recipes to abstract recipes if it is legal and beneficial and
89148922
// clamp the range for better cost estimation.

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23423,6 +23423,7 @@ class HorizontalReduction {
2342323423
case RecurKind::FindFirstIVUMin:
2342423424
case RecurKind::FindLastIVSMax:
2342523425
case RecurKind::FindLastIVUMax:
23426+
case RecurKind::FMaxNoFMFs:
2342623427
case RecurKind::FMaximumNum:
2342723428
case RecurKind::FMinimumNum:
2342823429
case RecurKind::None:

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -980,7 +980,10 @@ class VPInstruction : public VPRecipeWithIRFlags,
980980
ReductionStartVector,
981981
// Creates a step vector starting from 0 to VF with a step of 1.
982982
StepVector,
983-
983+
/// Extracts a single lane (first operand) from a set of vector operands.
984+
/// The lane specifies an index into a vector formed by combining all vector
985+
/// operands (all operands after the first one).
986+
ExtractLane,
984987
};
985988

986989
private:

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
8585
return ResTy;
8686
}
8787
case Instruction::ICmp:
88+
case Instruction::FCmp:
8889
case VPInstruction::ActiveLaneMask:
8990
assert(inferScalarType(R->getOperand(0)) ==
9091
inferScalarType(R->getOperand(1)) &&
@@ -110,6 +111,8 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
110111
case VPInstruction::BuildStructVector:
111112
case VPInstruction::BuildVector:
112113
return SetResultTyFromOp();
114+
case VPInstruction::ExtractLane:
115+
return inferScalarType(R->getOperand(1));
113116
case VPInstruction::FirstActiveLane:
114117
return Type::getIntNTy(Ctx, 64);
115118
case VPInstruction::ExtractLastElement:

llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#define DEBUG_TYPE "vplan"
2525

2626
using namespace llvm;
27+
using namespace VPlanPatternMatch;
2728

2829
namespace {
2930
// Class that is used to build the plain CFG for the incoming IR.
@@ -426,7 +427,6 @@ static void createLoopRegion(VPlan &Plan, VPBlockBase *HeaderVPB) {
426427
static void addCanonicalIVRecipes(VPlan &Plan, VPBasicBlock *HeaderVPBB,
427428
VPBasicBlock *LatchVPBB, Type *IdxTy,
428429
DebugLoc DL) {
429-
using namespace VPlanPatternMatch;
430430
Value *StartIdx = ConstantInt::get(IdxTy, 0);
431431
auto *StartV = Plan.getOrAddLiveIn(StartIdx);
432432

@@ -589,3 +589,114 @@ void VPlanTransforms::createLoopRegions(VPlan &Plan) {
589589
TopRegion->setName("vector loop");
590590
TopRegion->getEntryBasicBlock()->setName("vector.body");
591591
}
592+
593+
bool VPlanTransforms::handleFMaxReductionsWithoutFastMath(VPlan &Plan) {
594+
VPRegionBlock *LoopRegion = Plan.getVectorLoopRegion();
595+
VPReductionPHIRecipe *RedPhiR = nullptr;
596+
VPRecipeWithIRFlags *MinMaxOp = nullptr;
597+
VPWidenIntOrFpInductionRecipe *WideIV = nullptr;
598+
599+
// Check if there are any FMaxNoFMFs reductions using wide selects that we can
600+
// fix up. To do so, we also need a wide canonical IV to keep track of the
601+
// indices of the max values.
602+
for (auto &R : LoopRegion->getEntryBasicBlock()->phis()) {
603+
// We need a wide canonical IV
604+
if (auto *CurIV = dyn_cast<VPWidenIntOrFpInductionRecipe>(&R)) {
605+
if (!CurIV->isCanonical())
606+
continue;
607+
WideIV = CurIV;
608+
continue;
609+
}
610+
611+
// And a single FMaxNoFMFs reduction phi.
612+
// TODO: Support FMin reductions as well.
613+
auto *CurRedPhiR = dyn_cast<VPReductionPHIRecipe>(&R);
614+
if (!CurRedPhiR)
615+
continue;
616+
if (RedPhiR)
617+
return false;
618+
if (CurRedPhiR->getRecurrenceKind() != RecurKind::FMaxNoFMFs ||
619+
CurRedPhiR->isInLoop() || CurRedPhiR->isOrdered())
620+
continue;
621+
RedPhiR = CurRedPhiR;
622+
623+
// MaxOp feeding the reduction phi must be a select (either wide or a
624+
// replicate recipe), where the phi is the last operand, and the compare
625+
// predicate is strict. This ensures NaNs won't get propagated unless the
626+
// initial value is NaN
627+
VPRecipeBase *Inc = RedPhiR->getBackedgeValue()->getDefiningRecipe();
628+
auto *RepR = dyn_cast<VPReplicateRecipe>(Inc);
629+
if (!isa<VPWidenSelectRecipe>(Inc) &&
630+
!(RepR && (isa<SelectInst>(RepR->getUnderlyingInstr()))))
631+
return false;
632+
633+
MinMaxOp = cast<VPRecipeWithIRFlags>(Inc);
634+
auto *Cmp = cast<VPRecipeWithIRFlags>(MinMaxOp->getOperand(0));
635+
if (MinMaxOp->getOperand(1) == RedPhiR ||
636+
!CmpInst::isStrictPredicate(Cmp->getPredicate()))
637+
return false;
638+
}
639+
640+
// Nothing to do.
641+
if (!RedPhiR)
642+
return true;
643+
644+
// A wide canonical IV is currently required.
645+
// TODO: Create an induction if no suitable existing one is available.
646+
if (!WideIV)
647+
return false;
648+
649+
// Create a reduction that tracks the first indices where the latest maximum
650+
// value has been selected. This is later used to select the max value from
651+
// the partial reductions in a way that correctly handles signed zeros and
652+
// NaNs in the input.
653+
// Note that we do not need to check if the induction may hit the sentinel
654+
// value. If the sentinel value gets hit, the final reduction value is at the
655+
// last index or the maximum was never set and all lanes contain the start
656+
// value. In either case, the correct value is selected.
657+
unsigned IVWidth =
658+
VPTypeAnalysis(Plan).inferScalarType(WideIV)->getScalarSizeInBits();
659+
LLVMContext &Ctx = Plan.getScalarHeader()->getIRBasicBlock()->getContext();
660+
VPValue *UMinSentinel =
661+
Plan.getOrAddLiveIn(ConstantInt::get(Ctx, APInt::getMaxValue(IVWidth)));
662+
auto *IdxPhi = new VPReductionPHIRecipe(nullptr, RecurKind::FindFirstIVUMin,
663+
*UMinSentinel, false, false, 1);
664+
IdxPhi->insertBefore(RedPhiR);
665+
auto *MinIdxSel = new VPInstruction(
666+
Instruction::Select, {MinMaxOp->getOperand(0), WideIV, IdxPhi});
667+
MinIdxSel->insertAfter(MinMaxOp);
668+
IdxPhi->addOperand(MinIdxSel);
669+
670+
// Find the first index of with the maximum value. This is used to extract the
671+
// lane with the final max value and is needed to handle signed zeros and NaNs
672+
// in the input.
673+
auto *MiddleVPBB = Plan.getMiddleBlock();
674+
auto *OrigRdxResult = cast<VPSingleDefRecipe>(&MiddleVPBB->front());
675+
VPBuilder Builder(OrigRdxResult->getParent(),
676+
std::next(OrigRdxResult->getIterator()));
677+
678+
// Create mask for lanes that have the max value and use it to mask out
679+
// indices that don't contain maximum values.
680+
auto *MaskFinalMaxValue = Builder.createNaryOp(
681+
Instruction::FCmp, {OrigRdxResult->getOperand(1), OrigRdxResult},
682+
VPIRFlags(CmpInst::FCMP_OEQ));
683+
auto *IndicesWithMaxValue = Builder.createNaryOp(
684+
Instruction::Select, {MaskFinalMaxValue, MinIdxSel, UMinSentinel});
685+
auto *FirstMaxIdx = Builder.createNaryOp(
686+
VPInstruction::ComputeFindIVResult,
687+
{IdxPhi, WideIV->getStartValue(), UMinSentinel, IndicesWithMaxValue});
688+
// Convert the index of the first max value to an index in the vector lanes of
689+
// the partial reduction results. This ensures we select the first max value
690+
// and acts as a tie-breaker if the partial reductions contain signed zeros.
691+
auto *FirstMaxLane =
692+
Builder.createNaryOp(Instruction::URem, {FirstMaxIdx, &Plan.getVFxUF()});
693+
694+
// Extract the final max value and update the users.
695+
auto *Res = Builder.createNaryOp(
696+
VPInstruction::ExtractLane, {FirstMaxLane, OrigRdxResult->getOperand(1)});
697+
OrigRdxResult->replaceUsesWithIf(Res,
698+
[MaskFinalMaxValue](VPUser &U, unsigned) {
699+
return &U != MaskFinalMaxValue;
700+
});
701+
return true;
702+
}

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
585585
Value *Op = State.get(getOperand(0), vputils::onlyFirstLaneUsed(this));
586586
return Builder.CreateFreeze(Op, Name);
587587
}
588+
case Instruction::FCmp:
588589
case Instruction::ICmp: {
589590
bool OnlyFirstLaneUsed = vputils::onlyFirstLaneUsed(this);
590591
Value *A = State.get(getOperand(0), OnlyFirstLaneUsed);
@@ -595,7 +596,8 @@ Value *VPInstruction::generate(VPTransformState &State) {
595596
llvm_unreachable("should be handled by VPPhi::execute");
596597
}
597598
case Instruction::Select: {
598-
bool OnlyFirstLaneUsed = vputils::onlyFirstLaneUsed(this);
599+
bool OnlyFirstLaneUsed =
600+
State.VF.isScalar() || vputils::onlyFirstLaneUsed(this);
599601
Value *Cond = State.get(getOperand(0), OnlyFirstLaneUsed);
600602
Value *Op1 = State.get(getOperand(1), OnlyFirstLaneUsed);
601603
Value *Op2 = State.get(getOperand(2), OnlyFirstLaneUsed);
@@ -858,7 +860,30 @@ Value *VPInstruction::generate(VPTransformState &State) {
858860
Value *Res = State.get(getOperand(0));
859861
for (VPValue *Op : drop_begin(operands()))
860862
Res = Builder.CreateOr(Res, State.get(Op));
861-
return Builder.CreateOrReduce(Res);
863+
return Res->getType()->isIntegerTy(1) ? Res : Builder.CreateOrReduce(Res);
864+
}
865+
case VPInstruction::ExtractLane: {
866+
Value *LaneToExtract = State.get(getOperand(0), true);
867+
Type *IdxTy = State.TypeAnalysis.inferScalarType(getOperand(0));
868+
Value *Res = nullptr;
869+
Value *RuntimeVF = getRuntimeVF(State.Builder, IdxTy, State.VF);
870+
871+
for (unsigned Idx = 1; Idx != getNumOperands(); ++Idx) {
872+
Value *VectorStart =
873+
Builder.CreateMul(RuntimeVF, ConstantInt::get(IdxTy, Idx - 1));
874+
Value *VectorIdx = Builder.CreateSub(LaneToExtract, VectorStart);
875+
Value *Ext = State.VF.isScalar()
876+
? State.get(getOperand(Idx))
877+
: Builder.CreateExtractElement(
878+
State.get(getOperand(Idx)), VectorIdx);
879+
if (Res) {
880+
Value *Cmp = Builder.CreateICmpUGE(LaneToExtract, VectorStart);
881+
Res = Builder.CreateSelect(Cmp, Ext, Res);
882+
} else {
883+
Res = Ext;
884+
}
885+
}
886+
return Res;
862887
}
863888
case VPInstruction::FirstActiveLane: {
864889
if (getNumOperands() == 1) {
@@ -984,7 +1009,8 @@ bool VPInstruction::isVectorToScalar() const {
9841009
getOpcode() == VPInstruction::ComputeAnyOfResult ||
9851010
getOpcode() == VPInstruction::ComputeFindIVResult ||
9861011
getOpcode() == VPInstruction::ComputeReductionResult ||
987-
getOpcode() == VPInstruction::AnyOf;
1012+
getOpcode() == VPInstruction::AnyOf ||
1013+
getOpcode() == VPInstruction::ExtractLane;
9881014
}
9891015

9901016
bool VPInstruction::isSingleScalar() const {
@@ -1031,6 +1057,7 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
10311057
switch (getOpcode()) {
10321058
case Instruction::ExtractElement:
10331059
case Instruction::Freeze:
1060+
case Instruction::FCmp:
10341061
case Instruction::ICmp:
10351062
case Instruction::Select:
10361063
case VPInstruction::AnyOf:
@@ -1066,6 +1093,7 @@ bool VPInstruction::onlyFirstLaneUsed(const VPValue *Op) const {
10661093
return Op == getOperand(1);
10671094
case Instruction::PHI:
10681095
return true;
1096+
case Instruction::FCmp:
10691097
case Instruction::ICmp:
10701098
case Instruction::Select:
10711099
case Instruction::Or:
@@ -1098,6 +1126,7 @@ bool VPInstruction::onlyFirstPartUsed(const VPValue *Op) const {
10981126
switch (getOpcode()) {
10991127
default:
11001128
return false;
1129+
case Instruction::FCmp:
11011130
case Instruction::ICmp:
11021131
case Instruction::Select:
11031132
return vputils::onlyFirstPartUsed(this);
@@ -1782,7 +1811,7 @@ bool VPIRFlags::flagsValidForOpcode(unsigned Opcode) const {
17821811
return Opcode == Instruction::ZExt;
17831812
break;
17841813
case OperationType::Cmp:
1785-
return Opcode == Instruction::ICmp;
1814+
return Opcode == Instruction::FCmp || Opcode == Instruction::ICmp;
17861815
case OperationType::Other:
17871816
return true;
17881817
}

llvm/lib/Transforms/Vectorize/VPlanTransforms.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,8 @@ struct VPlanTransforms {
190190
VPBasicBlock *LatchVPBB,
191191
VFRange &Range);
192192

193+
static bool handleFMaxReductionsWithoutFastMath(VPlan &Plan);
194+
193195
/// Replace loop regions with explicit CFG.
194196
static void dissolveLoopRegions(VPlan &Plan);
195197

0 commit comments

Comments
 (0)