diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h index a98e46c587273..7d8a0c4ce8e45 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -2030,6 +2030,11 @@ class SelectionDAG { LLVM_ABI SDValue foldConstantFPMath(unsigned Opcode, const SDLoc &DL, EVT VT, ArrayRef Ops); + /// Fold BUILD_VECTOR of constants/undefs to the destination type + /// BUILD_VECTOR of constants/undefs elements. + LLVM_ABI SDValue FoldConstantBuildVector(BuildVectorSDNode *BV, + const SDLoc &DL, EVT DstEltVT); + /// Constant fold a setcc to true or false. LLVM_ABI SDValue FoldSetCC(EVT VT, SDValue N1, SDValue N2, ISD::CondCode Cond, const SDLoc &dl); diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index d4ad4d3a09381..a10f9afcf96c3 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -638,7 +638,6 @@ namespace { SDValue mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex); SDValue combineInsertEltToShuffle(SDNode *N, unsigned InsIndex); SDValue combineInsertEltToLoad(SDNode *N, unsigned InsIndex); - SDValue ConstantFoldBITCASTofBUILD_VECTOR(SDNode *, EVT); SDValue BuildSDIV(SDNode *N); SDValue BuildSDIVPow2(SDNode *N); SDValue BuildUDIV(SDNode *N); @@ -16431,8 +16430,8 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) { TLI.isTypeLegal(VT.getVectorElementType()))) && N0.getOpcode() == ISD::BUILD_VECTOR && N0->hasOneUse() && cast(N0)->isConstant()) - return ConstantFoldBITCASTofBUILD_VECTOR(N0.getNode(), - VT.getVectorElementType()); + return DAG.FoldConstantBuildVector(cast(N0), SDLoc(N), + VT.getVectorElementType()); // If the input is a constant, let getNode fold it. if (isIntOrFPConstant(N0)) { @@ -16825,83 +16824,6 @@ SDValue DAGCombiner::visitFREEZE(SDNode *N) { return DAG.getNode(N0.getOpcode(), DL, N0->getVTList(), Ops, SafeFlags); } -/// We know that BV is a build_vector node with Constant, ConstantFP or Undef -/// operands. DstEltVT indicates the destination element value type. -SDValue DAGCombiner:: -ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) { - EVT SrcEltVT = BV->getValueType(0).getVectorElementType(); - - // If this is already the right type, we're done. - if (SrcEltVT == DstEltVT) return SDValue(BV, 0); - - unsigned SrcBitSize = SrcEltVT.getSizeInBits(); - unsigned DstBitSize = DstEltVT.getSizeInBits(); - - // If this is a conversion of N elements of one type to N elements of another - // type, convert each element. This handles FP<->INT cases. - if (SrcBitSize == DstBitSize) { - SmallVector Ops; - for (SDValue Op : BV->op_values()) { - // If the vector element type is not legal, the BUILD_VECTOR operands - // are promoted and implicitly truncated. Make that explicit here. - if (Op.getValueType() != SrcEltVT) - Op = DAG.getNode(ISD::TRUNCATE, SDLoc(BV), SrcEltVT, Op); - Ops.push_back(DAG.getBitcast(DstEltVT, Op)); - AddToWorklist(Ops.back().getNode()); - } - EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT, - BV->getValueType(0).getVectorNumElements()); - return DAG.getBuildVector(VT, SDLoc(BV), Ops); - } - - // Otherwise, we're growing or shrinking the elements. To avoid having to - // handle annoying details of growing/shrinking FP values, we convert them to - // int first. - if (SrcEltVT.isFloatingPoint()) { - // Convert the input float vector to a int vector where the elements are the - // same sizes. - EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), SrcEltVT.getSizeInBits()); - BV = ConstantFoldBITCASTofBUILD_VECTOR(BV, IntVT).getNode(); - SrcEltVT = IntVT; - } - - // Now we know the input is an integer vector. If the output is a FP type, - // convert to integer first, then to FP of the right size. - if (DstEltVT.isFloatingPoint()) { - EVT TmpVT = EVT::getIntegerVT(*DAG.getContext(), DstEltVT.getSizeInBits()); - SDNode *Tmp = ConstantFoldBITCASTofBUILD_VECTOR(BV, TmpVT).getNode(); - - // Next, convert to FP elements of the same size. - return ConstantFoldBITCASTofBUILD_VECTOR(Tmp, DstEltVT); - } - - // Okay, we know the src/dst types are both integers of differing types. - assert(SrcEltVT.isInteger() && DstEltVT.isInteger()); - - // TODO: Should ConstantFoldBITCASTofBUILD_VECTOR always take a - // BuildVectorSDNode? - auto *BVN = cast(BV); - - // Extract the constant raw bit data. - BitVector UndefElements; - SmallVector RawBits; - bool IsLE = DAG.getDataLayout().isLittleEndian(); - if (!BVN->getConstantRawBits(IsLE, DstBitSize, RawBits, UndefElements)) - return SDValue(); - - SDLoc DL(BV); - SmallVector Ops; - for (unsigned I = 0, E = RawBits.size(); I != E; ++I) { - if (UndefElements[I]) - Ops.push_back(DAG.getUNDEF(DstEltVT)); - else - Ops.push_back(DAG.getConstant(RawBits[I], DL, DstEltVT)); - } - - EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT, Ops.size()); - return DAG.getBuildVector(VT, DL, Ops); -} - // Returns true if floating point contraction is allowed on the FMUL-SDValue // `N` static bool isContractableFMUL(const TargetOptions &Options, SDValue N) { diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 2a3c8e2b011ad..fed910fb425ab 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -7280,6 +7280,78 @@ SDValue SelectionDAG::foldConstantFPMath(unsigned Opcode, const SDLoc &DL, return SDValue(); } +SDValue SelectionDAG::FoldConstantBuildVector(BuildVectorSDNode *BV, + const SDLoc &DL, EVT DstEltVT) { + EVT SrcEltVT = BV->getValueType(0).getVectorElementType(); + + // If this is already the right type, we're done. + if (SrcEltVT == DstEltVT) + return SDValue(BV, 0); + + unsigned SrcBitSize = SrcEltVT.getSizeInBits(); + unsigned DstBitSize = DstEltVT.getSizeInBits(); + + // If this is a conversion of N elements of one type to N elements of another + // type, convert each element. This handles FP<->INT cases. + if (SrcBitSize == DstBitSize) { + SmallVector Ops; + for (SDValue Op : BV->op_values()) { + // If the vector element type is not legal, the BUILD_VECTOR operands + // are promoted and implicitly truncated. Make that explicit here. + if (Op.getValueType() != SrcEltVT) + Op = getNode(ISD::TRUNCATE, DL, SrcEltVT, Op); + Ops.push_back(getBitcast(DstEltVT, Op)); + } + EVT VT = EVT::getVectorVT(*getContext(), DstEltVT, + BV->getValueType(0).getVectorNumElements()); + return getBuildVector(VT, DL, Ops); + } + + // Otherwise, we're growing or shrinking the elements. To avoid having to + // handle annoying details of growing/shrinking FP values, we convert them to + // int first. + if (SrcEltVT.isFloatingPoint()) { + // Convert the input float vector to a int vector where the elements are the + // same sizes. + EVT IntEltVT = EVT::getIntegerVT(*getContext(), SrcEltVT.getSizeInBits()); + if (SDValue Tmp = FoldConstantBuildVector(BV, DL, IntEltVT)) + return FoldConstantBuildVector(cast(Tmp), DL, + DstEltVT); + return SDValue(); + } + + // Now we know the input is an integer vector. If the output is a FP type, + // convert to integer first, then to FP of the right size. + if (DstEltVT.isFloatingPoint()) { + EVT IntEltVT = EVT::getIntegerVT(*getContext(), DstEltVT.getSizeInBits()); + if (SDValue Tmp = FoldConstantBuildVector(BV, DL, IntEltVT)) + return FoldConstantBuildVector(cast(Tmp), DL, + DstEltVT); + return SDValue(); + } + + // Okay, we know the src/dst types are both integers of differing types. + assert(SrcEltVT.isInteger() && DstEltVT.isInteger()); + + // Extract the constant raw bit data. + BitVector UndefElements; + SmallVector RawBits; + bool IsLE = getDataLayout().isLittleEndian(); + if (!BV->getConstantRawBits(IsLE, DstBitSize, RawBits, UndefElements)) + return SDValue(); + + SmallVector Ops; + for (unsigned I = 0, E = RawBits.size(); I != E; ++I) { + if (UndefElements[I]) + Ops.push_back(getUNDEF(DstEltVT)); + else + Ops.push_back(getConstant(RawBits[I], DL, DstEltVT)); + } + + EVT VT = EVT::getVectorVT(*getContext(), DstEltVT, Ops.size()); + return getBuildVector(VT, DL, Ops); +} + SDValue SelectionDAG::getAssertAlign(const SDLoc &DL, SDValue Val, Align A) { assert(Val.getValueType().isInteger() && "Invalid AssertAlign!"); diff --git a/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.iglp.AFLCustomIRMutator.opt.ll b/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.iglp.AFLCustomIRMutator.opt.ll index a319f1260d870..85dd275207902 100644 --- a/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.iglp.AFLCustomIRMutator.opt.ll +++ b/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.iglp.AFLCustomIRMutator.opt.ll @@ -5,18 +5,19 @@ define amdgpu_kernel void @test_iglp_opt_rev_mfma_gemm(<1 x i64> %L1) { ; GCN-LABEL: test_iglp_opt_rev_mfma_gemm: ; GCN: ; %bb.0: ; %entry ; GCN-NEXT: v_mov_b32_e32 v32, 0 +; GCN-NEXT: ds_read_b128 v[0:3], v32 ; GCN-NEXT: s_load_dwordx2 s[0:1], s[8:9], 0x0 ; GCN-NEXT: ds_read_b128 v[28:31], v32 offset:112 ; GCN-NEXT: ds_read_b128 v[24:27], v32 offset:96 ; GCN-NEXT: ds_read_b128 v[20:23], v32 offset:80 ; GCN-NEXT: ds_read_b128 v[16:19], v32 offset:64 -; GCN-NEXT: ds_read_b128 v[0:3], v32 ; GCN-NEXT: ds_read_b128 v[4:7], v32 offset:16 ; GCN-NEXT: ds_read_b128 v[8:11], v32 offset:32 ; GCN-NEXT: ds_read_b128 v[12:15], v32 offset:48 -; GCN-NEXT: v_mov_b32_e32 v34, 0 -; GCN-NEXT: v_mov_b32_e32 v35, v34 ; GCN-NEXT: s_waitcnt lgkmcnt(0) +; GCN-NEXT: ds_write_b128 v32, v[0:3] +; GCN-NEXT: v_mov_b32_e32 v0, 0 +; GCN-NEXT: v_mov_b32_e32 v1, v0 ; GCN-NEXT: s_cmp_lg_u64 s[0:1], 0 ; GCN-NEXT: ; iglp_opt mask(0x00000001) ; GCN-NEXT: ds_write_b128 v32, v[28:31] offset:112 @@ -26,8 +27,7 @@ define amdgpu_kernel void @test_iglp_opt_rev_mfma_gemm(<1 x i64> %L1) { ; GCN-NEXT: ds_write_b128 v32, v[12:15] offset:48 ; GCN-NEXT: ds_write_b128 v32, v[8:11] offset:32 ; GCN-NEXT: ds_write_b128 v32, v[4:7] offset:16 -; GCN-NEXT: ds_write_b128 v32, v[0:3] -; GCN-NEXT: ds_write_b64 v32, v[34:35] +; GCN-NEXT: ds_write_b64 v32, v[0:1] ; GCN-NEXT: s_endpgm entry: call void @llvm.amdgcn.iglp.opt(i32 1)