Skip to content

Commit 1a01686

Browse files
committed
[NVPTX] support VECREDUCE_SEQ ops and remove option
1 parent 59a0fc2 commit 1a01686

File tree

2 files changed

+44
-30
lines changed

2 files changed

+44
-30
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -876,6 +876,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
876876
if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
877877
EltVT == MVT::f64) {
878878
setOperationAction({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
879+
ISD::VECREDUCE_SEQ_FADD, ISD::VECREDUCE_SEQ_FMUL,
879880
ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
880881
ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
881882
VT, Custom);
@@ -2270,12 +2271,19 @@ static SDValue BuildTreeReduction(
22702271
/// max3/min3 when the target supports them.
22712272
SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
22722273
SelectionDAG &DAG) const {
2273-
if (DisableFOpTreeReduce)
2274-
return SDValue();
2275-
22762274
SDLoc DL(Op);
22772275
const SDNodeFlags Flags = Op->getFlags();
2278-
const SDValue &Vector = Op.getOperand(0);
2276+
SDValue Vector;
2277+
SDValue Accumulator;
2278+
if (Op->getOpcode() == ISD::VECREDUCE_SEQ_FADD ||
2279+
Op->getOpcode() == ISD::VECREDUCE_SEQ_FMUL) {
2280+
// special case with accumulator as first arg
2281+
Accumulator = Op.getOperand(0);
2282+
Vector = Op.getOperand(1);
2283+
} else {
2284+
// default case
2285+
Vector = Op.getOperand(0);
2286+
}
22792287
EVT EltTy = Vector.getValueType().getVectorElementType();
22802288
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
22812289
STI.getPTXVersion() >= 88;
@@ -2287,10 +2295,12 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
22872295

22882296
switch (Op->getOpcode()) {
22892297
case ISD::VECREDUCE_FADD:
2298+
case ISD::VECREDUCE_SEQ_FADD:
22902299
ScalarOps = {{ISD::FADD, 2}};
22912300
IsReassociatable = false;
22922301
break;
22932302
case ISD::VECREDUCE_FMUL:
2303+
case ISD::VECREDUCE_SEQ_FMUL:
22942304
ScalarOps = {{ISD::FMUL, 2}};
22952305
IsReassociatable = false;
22962306
break;
@@ -2369,11 +2379,13 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
23692379
}
23702380

23712381
// Lower to tree reduction.
2372-
if (IsReassociatable || Flags.hasAllowReassociation())
2382+
if (IsReassociatable || Flags.hasAllowReassociation()) {
2383+
// we don't expect an accumulator for reassociatable vector reduction ops
2384+
assert(!Accumulator && "unexpected accumulator");
23732385
return BuildTreeReduction(Elements, EltTy, ScalarOps, DL, Flags, DAG);
2386+
}
23742387

23752388
// Lower to sequential reduction.
2376-
SDValue Accumulator;
23772389
for (unsigned OpIdx = 0, I = 0; I < NumElts; ++OpIdx) {
23782390
assert(OpIdx < ScalarOps.size() && "no smaller operators for reduction");
23792391
const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
@@ -3234,6 +3246,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
32343246
return LowerCONCAT_VECTORS(Op, DAG);
32353247
case ISD::VECREDUCE_FADD:
32363248
case ISD::VECREDUCE_FMUL:
3249+
case ISD::VECREDUCE_SEQ_FADD:
3250+
case ISD::VECREDUCE_SEQ_FMUL:
32373251
case ISD::VECREDUCE_FMAX:
32383252
case ISD::VECREDUCE_FMIN:
32393253
case ISD::VECREDUCE_FMAXIMUM:

llvm/test/CodeGen/NVPTX/reduction-intrinsics.ll

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,19 @@ define half @reduce_fadd_half(<8 x half> %in) {
2323
; CHECK-EMPTY:
2424
; CHECK-NEXT: // %bb.0:
2525
; CHECK-NEXT: ld.param.v4.b32 {%r1, %r2, %r3, %r4}, [reduce_fadd_half_param_0];
26-
; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r4;
27-
; CHECK-NEXT: mov.b32 {%rs3, %rs4}, %r3;
28-
; CHECK-NEXT: mov.b32 {%rs5, %rs6}, %r2;
29-
; CHECK-NEXT: mov.b32 {%rs7, %rs8}, %r1;
30-
; CHECK-NEXT: mov.b16 %rs9, 0x0000;
31-
; CHECK-NEXT: add.rn.f16 %rs10, %rs7, %rs9;
32-
; CHECK-NEXT: add.rn.f16 %rs11, %rs10, %rs8;
33-
; CHECK-NEXT: add.rn.f16 %rs12, %rs11, %rs5;
34-
; CHECK-NEXT: add.rn.f16 %rs13, %rs12, %rs6;
35-
; CHECK-NEXT: add.rn.f16 %rs14, %rs13, %rs3;
36-
; CHECK-NEXT: add.rn.f16 %rs15, %rs14, %rs4;
37-
; CHECK-NEXT: add.rn.f16 %rs16, %rs15, %rs1;
38-
; CHECK-NEXT: add.rn.f16 %rs17, %rs16, %rs2;
26+
; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r1;
27+
; CHECK-NEXT: mov.b16 %rs3, 0x0000;
28+
; CHECK-NEXT: add.rn.f16 %rs4, %rs1, %rs3;
29+
; CHECK-NEXT: add.rn.f16 %rs5, %rs4, %rs2;
30+
; CHECK-NEXT: mov.b32 {%rs6, %rs7}, %r2;
31+
; CHECK-NEXT: add.rn.f16 %rs8, %rs5, %rs6;
32+
; CHECK-NEXT: add.rn.f16 %rs9, %rs8, %rs7;
33+
; CHECK-NEXT: mov.b32 {%rs10, %rs11}, %r3;
34+
; CHECK-NEXT: add.rn.f16 %rs12, %rs9, %rs10;
35+
; CHECK-NEXT: add.rn.f16 %rs13, %rs12, %rs11;
36+
; CHECK-NEXT: mov.b32 {%rs14, %rs15}, %r4;
37+
; CHECK-NEXT: add.rn.f16 %rs16, %rs13, %rs14;
38+
; CHECK-NEXT: add.rn.f16 %rs17, %rs16, %rs15;
3939
; CHECK-NEXT: st.param.b16 [func_retval0], %rs17;
4040
; CHECK-NEXT: ret;
4141
%res = call half @llvm.vector.reduce.fadd(half 0.0, <8 x half> %in)
@@ -174,17 +174,17 @@ define half @reduce_fmul_half(<8 x half> %in) {
174174
; CHECK-EMPTY:
175175
; CHECK-NEXT: // %bb.0:
176176
; CHECK-NEXT: ld.param.v4.b32 {%r1, %r2, %r3, %r4}, [reduce_fmul_half_param_0];
177-
; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r4;
178-
; CHECK-NEXT: mov.b32 {%rs3, %rs4}, %r3;
179-
; CHECK-NEXT: mov.b32 {%rs5, %rs6}, %r2;
180-
; CHECK-NEXT: mov.b32 {%rs7, %rs8}, %r1;
181-
; CHECK-NEXT: mul.rn.f16 %rs9, %rs7, %rs8;
182-
; CHECK-NEXT: mul.rn.f16 %rs10, %rs9, %rs5;
183-
; CHECK-NEXT: mul.rn.f16 %rs11, %rs10, %rs6;
184-
; CHECK-NEXT: mul.rn.f16 %rs12, %rs11, %rs3;
185-
; CHECK-NEXT: mul.rn.f16 %rs13, %rs12, %rs4;
186-
; CHECK-NEXT: mul.rn.f16 %rs14, %rs13, %rs1;
187-
; CHECK-NEXT: mul.rn.f16 %rs15, %rs14, %rs2;
177+
; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r2;
178+
; CHECK-NEXT: mov.b32 {%rs3, %rs4}, %r1;
179+
; CHECK-NEXT: mul.rn.f16 %rs5, %rs3, %rs4;
180+
; CHECK-NEXT: mul.rn.f16 %rs6, %rs5, %rs1;
181+
; CHECK-NEXT: mul.rn.f16 %rs7, %rs6, %rs2;
182+
; CHECK-NEXT: mov.b32 {%rs8, %rs9}, %r3;
183+
; CHECK-NEXT: mul.rn.f16 %rs10, %rs7, %rs8;
184+
; CHECK-NEXT: mul.rn.f16 %rs11, %rs10, %rs9;
185+
; CHECK-NEXT: mov.b32 {%rs12, %rs13}, %r4;
186+
; CHECK-NEXT: mul.rn.f16 %rs14, %rs11, %rs12;
187+
; CHECK-NEXT: mul.rn.f16 %rs15, %rs14, %rs13;
188188
; CHECK-NEXT: st.param.b16 [func_retval0], %rs15;
189189
; CHECK-NEXT: ret;
190190
%res = call half @llvm.vector.reduce.fmul(half 1.0, <8 x half> %in)

0 commit comments

Comments
 (0)