Skip to content

[NVPTX] add combiner rule for final packed op in reduction #143943

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 108 additions & 6 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
if (STI.allowFP16Math() || STI.hasBF16Math())
setTargetDAGCombine(ISD::SETCC);

// Combine reduction operations on packed types (e.g. fadd.f16x2) with vector
// shuffles when one of their lanes is a no-op.
if (STI.allowFP16Math() || STI.hasBF16Math())
// already added above: FADD, ADD, AND
setTargetDAGCombine({ISD::FMUL, ISD::FMINIMUM, ISD::FMAXIMUM, ISD::UMIN,
ISD::UMAX, ISD::SMIN, ISD::SMAX, ISD::OR, ISD::XOR});

// Promote fp16 arithmetic if fp16 hardware isn't available or the
// user passed --nvptx-no-fp16-math. The flag is useful because,
// although sm_53+ GPUs have some sort of FP16 support in
Expand Down Expand Up @@ -5069,20 +5076,102 @@ static SDValue PerformStoreRetvalCombine(SDNode *N) {
return PerformStoreCombineHelper(N, 2, 0);
}

/// For vector reductions, the final result needs to be a scalar. The default
/// expansion will use packed ops (ex. fadd.f16x2) even for the final operation.
/// This requires a packed operation where one of the lanes is undef.
///
/// ex: lowering of vecreduce_fadd(V) where V = v4f16<a b c d>
///
/// v1: v2f16 = fadd reassoc v2f16<a b>, v2f16<c d> (== <a+c b+d>)
/// v2: v2f16 = vector_shuffle<1,u> v1, undef:v2f16 (== <b+d undef>)
/// v3: v2f16 = fadd reassoc v2, v1 (== <b+d+a+c undef>)
/// vR: f16 = extractelt v3, 1
///
/// We wish to replace vR, v3, and v2 with:
/// vR: f16 = fadd reassoc (extractelt v1, 1) (extractelt v1, 0)
///
/// ...so that we get:
/// v1: v2f16 = fadd reassoc v2f16<a b>, v2f16<c d> (== <a+c b+d>)
/// s1: f16 = extractelt v1, 1
/// s2: f16 = extractelt v1, 0
/// vR: f16 = fadd reassoc s1, s2 (== a+c+b+d)
///
/// So for this example, this rule will replace v3 and v2, returning a vector
/// with the result in lane 0 and an undef in lane 1, which we expect will be
/// folded into the extractelt in vR.
static SDValue PerformPackedOpCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
// Convert:
// (fop.x2 (vector_shuffle<i,u> A), B) -> ((fop A:i, B:0), undef)
// ...or...
// (fop.x2 (vector_shuffle<u,i> A), B) -> (undef, (fop A:i, B:1))
// ...where i is a valid index and u is poison.
const EVT VectorVT = N->getValueType(0);
if (!Isv2x16VT(VectorVT))
return SDValue();

SDLoc DL(N);

SDValue ShufOp = N->getOperand(0);
SDValue VectOp = N->getOperand(1);
bool Swapped = false;

// canonicalize shuffle to op0
if (VectOp.getOpcode() == ISD::VECTOR_SHUFFLE) {
std::swap(ShufOp, VectOp);
Swapped = true;
}

if (ShufOp.getOpcode() != ISD::VECTOR_SHUFFLE)
return SDValue();

auto *ShuffleOp = cast<ShuffleVectorSDNode>(ShufOp);
int LiveLane; // exclusively live lane
for (LiveLane = 0; LiveLane < 2; ++LiveLane) {
// check if the current lane is live and the other lane is dead
if (ShuffleOp->getMaskElt(LiveLane) != PoisonMaskElem &&
ShuffleOp->getMaskElt(!LiveLane) == PoisonMaskElem)
break;
}
if (LiveLane == 2)
return SDValue();

int ElementIdx = ShuffleOp->getMaskElt(LiveLane);
const EVT ScalarVT = VectorVT.getScalarType();
SDValue Lanes[2] = {};
for (auto [LaneID, LaneVal] : enumerate(Lanes)) {
if (LaneID == (unsigned)LiveLane) {
SDValue Operands[2] = {
DCI.DAG.getExtractVectorElt(DL, ScalarVT, ShufOp.getOperand(0),
ElementIdx),
DCI.DAG.getExtractVectorElt(DL, ScalarVT, VectOp, LiveLane)};
// preserve the order of operands
if (Swapped)
std::swap(Operands[0], Operands[1]);
LaneVal = DCI.DAG.getNode(N->getOpcode(), DL, ScalarVT, Operands);
} else {
LaneVal = DCI.DAG.getUNDEF(ScalarVT);
}
}
return DCI.DAG.getBuildVector(VectorVT, DL, Lanes);
}

/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
///
static SDValue PerformADDCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
CodeGenOptLevel OptLevel) {
if (OptLevel == CodeGenOptLevel::None)
return SDValue();

SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);

// Skip non-integer, non-scalar case
EVT VT = N0.getValueType();
if (VT.isVector() || VT != MVT::i32)
if (VT.isVector())
return PerformPackedOpCombine(N, DCI);
if (VT != MVT::i32)
return SDValue();

if (OptLevel == CodeGenOptLevel::None)
return SDValue();

// First try with the default operand order.
Expand All @@ -5102,7 +5191,10 @@ static SDValue PerformFADDCombine(SDNode *N,
SDValue N1 = N->getOperand(1);

EVT VT = N0.getValueType();
if (VT.isVector() || !(VT == MVT::f32 || VT == MVT::f64))
if (VT.isVector())
return PerformPackedOpCombine(N, DCI);

if (!(VT == MVT::f32 || VT == MVT::f64))
return SDValue();

// First try with the default operand order.
Expand Down Expand Up @@ -5205,7 +5297,7 @@ static SDValue PerformANDCombine(SDNode *N,
DCI.CombineTo(N, Val, AddTo);
}

return SDValue();
return PerformPackedOpCombine(N, DCI);
}

static SDValue PerformREMCombine(SDNode *N,
Expand Down Expand Up @@ -5686,6 +5778,16 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
return PerformADDCombine(N, DCI, OptLevel);
case ISD::FADD:
return PerformFADDCombine(N, DCI, OptLevel);
case ISD::FMUL:
case ISD::FMINNUM:
case ISD::FMAXIMUM:
case ISD::UMIN:
case ISD::UMAX:
case ISD::SMIN:
case ISD::SMAX:
case ISD::OR:
case ISD::XOR:
return PerformPackedOpCombine(N, DCI);
case ISD::MUL:
return PerformMULCombine(N, DCI, OptLevel);
case ISD::SHL:
Expand Down
Loading
Loading