diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index d6a134d9abafd..b16a2842ed28e 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -28,6 +28,7 @@ #include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/MachineJumpTableInfo.h" #include "llvm/CodeGen/MachineMemOperand.h" +#include "llvm/CodeGen/SDPatternMatch.h" #include "llvm/CodeGen/SelectionDAG.h" #include "llvm/CodeGen/SelectionDAGNodes.h" #include "llvm/CodeGen/TargetCallingConv.h" @@ -74,6 +75,7 @@ #define DEBUG_TYPE "nvptx-lower" using namespace llvm; +using namespace llvm::SDPatternMatch; static cl::opt sched4reg( "nvptx-sched4reg", @@ -659,6 +661,11 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, setOperationAction(ISD::BR_CC, VT, Expand); } + setOperationAction({ISD::TRUNCATE_SSAT_S, ISD::TRUNCATE_SSAT_U}, MVT::i16, + Legal); + setOperationAction({ISD::TRUNCATE_SSAT_S, ISD::TRUNCATE_SSAT_U}, MVT::i8, + Custom); + // Some SIGN_EXTEND_INREG can be done using cvt instruction. // For others we will expand to a SHL/SRA pair. setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i64, Legal); @@ -836,7 +843,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, // We have some custom DAG combine patterns for these nodes setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD, ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT, - ISD::BUILD_VECTOR, ISD::ADDRSPACECAST}); + ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::SMIN, + ISD::SMAX}); // setcc for f16x2 and bf16x2 needs special handling to prevent // legalizer's attempt to scalarize it due to v2i1 not being legal. @@ -1081,6 +1089,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(NVPTXISD::PseudoUseParam) MAKE_CASE(NVPTXISD::UNPACK_VECTOR) MAKE_CASE(NVPTXISD::BUILD_VECTOR) + MAKE_CASE(NVPTXISD::TRUNCATE_SSAT_U_I8) + MAKE_CASE(NVPTXISD::TRUNCATE_SSAT_S_I8) MAKE_CASE(NVPTXISD::RETURN) MAKE_CASE(NVPTXISD::CallSeqBegin) MAKE_CASE(NVPTXISD::CallSeqEnd) @@ -5667,6 +5677,52 @@ static SDValue combineADDRSPACECAST(SDNode *N, return SDValue(); } +static SDValue combineMINMAX(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { + + EVT VT = N->getValueType(0); + if (!(VT == MVT::i32 || VT == MVT::i64 || VT == MVT::i16)) + return SDValue(); + + SDValue Val; + APInt Ceil, Floor; + if (!(sd_match(N, m_SMin(m_SMax(m_Value(Val), m_ConstInt(Floor)), + m_ConstInt(Ceil))) || + sd_match(N, m_SMax(m_SMin(m_Value(Val), m_ConstInt(Ceil)), + m_ConstInt(Floor))))) + return SDValue(); + + const unsigned BitWidth = VT.getSizeInBits(); + SDLoc DL(N); + const auto TryToLowerAsSaturatedConversion = [&](MVT DestVT) { + const unsigned DestBitWidth = DestVT.getSizeInBits(); + + // If the destination is greater than or equal to the source, the min/max + // comparisons are meaningless/impossible. + if (DestBitWidth >= BitWidth) + return SDValue(); + + bool IsSigned; + if (Ceil == APInt::getSignedMaxValue(DestBitWidth).sext(BitWidth) && + Floor == APInt::getSignedMinValue(DestBitWidth).sext(BitWidth)) + IsSigned = true; + else if (Ceil == APInt::getMaxValue(DestBitWidth).zext(BitWidth) && + Floor == APInt::getMinValue(BitWidth)) + IsSigned = false; + else + return SDValue(); + + unsigned Opcode = IsSigned ? ISD::TRUNCATE_SSAT_S : ISD::TRUNCATE_SSAT_U; + SDValue Trunc = DCI.DAG.getNode(Opcode, DL, DestVT, Val); + return DCI.DAG.getExtOrTrunc(IsSigned, Trunc, DL, VT); + }; + + for (const auto DestVT : {MVT::i16, MVT::i8}) + if (auto Res = TryToLowerAsSaturatedConversion(DestVT)) + return Res; + + return SDValue(); +} + SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel(); @@ -5685,6 +5741,9 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, case ISD::UREM: case ISD::SREM: return PerformREMCombine(N, DCI, OptLevel); + case ISD::SMIN: + case ISD::SMAX: + return combineMINMAX(N, DCI); case ISD::SETCC: return PerformSETCCCombine(N, DCI, STI.getSmVersion()); case NVPTXISD::StoreRetval: @@ -6045,6 +6104,20 @@ static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG, Results.push_back(NewValue.getValue(3)); } +static void replaceTruncateSSat(SDNode *N, SelectionDAG &DAG, + SmallVectorImpl &Results) { + SDLoc DL(N); + + const bool IsSigned = N->getOpcode() == ISD::TRUNCATE_SSAT_S; + const unsigned Opcode = + IsSigned ? NVPTXISD::TRUNCATE_SSAT_S_I8 : NVPTXISD::TRUNCATE_SSAT_U_I8; + SDValue NewTrunc = DAG.getNode(Opcode, DL, MVT::i16, N->getOperand(0)); + SDValue Assert = DAG.getNode(IsSigned ? ISD::AssertSext : ISD::AssertZext, DL, + MVT::i16, NewTrunc, DAG.getValueType(MVT::i8)); + + Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, Assert)); +} + void NVPTXTargetLowering::ReplaceNodeResults( SDNode *N, SmallVectorImpl &Results, SelectionDAG &DAG) const { switch (N->getOpcode()) { @@ -6062,6 +6135,10 @@ void NVPTXTargetLowering::ReplaceNodeResults( case ISD::CopyFromReg: ReplaceCopyFromReg_128(N, DAG, Results); return; + case ISD::TRUNCATE_SSAT_U: + case ISD::TRUNCATE_SSAT_S: + replaceTruncateSSat(N, DAG, Results); + return; } } diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h index 8d71022a1f102..1bd8ffb65c501 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h @@ -72,6 +72,13 @@ enum NodeType : unsigned { /// converting it to a vector. UNPACK_VECTOR, + /// These nodes are equivalent to the corresponding ISD nodes except that + /// they truncate to an i8 output and then sign or zero extend that value back + /// to i16. This is a workaround for the fact that NVPTX does not consider + /// i8 to be a legal type. TODO: consider making i8 legal and removing these. + TRUNCATE_SSAT_U_I8, + TRUNCATE_SSAT_S_I8, + FCOPYSIGN, DYNAMIC_STACKALLOC, STACKRESTORE, diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index b646d39194c7e..e26a05a112e09 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -2649,6 +2649,25 @@ def : Pat<(i1 (trunc i32:$a)), (SETP_b32ri (ANDb32ri $a, 1), 0, CmpNE)>; // truncate i16 def : Pat<(i1 (trunc i16:$a)), (SETP_b16ri (ANDb16ri $a, 1), 0, CmpNE)>; +// truncate ssat +def SDTTruncSatI8Op : SDTypeProfile<1, 1, [SDTCisInt<1>, SDTCisVT<0, i16>]>; +def truncssat_s_i8 : SDNode<"NVPTXISD::TRUNCATE_SSAT_S_I8", SDTTruncSatI8Op>; +def truncssat_u_i8 : SDNode<"NVPTXISD::TRUNCATE_SSAT_U_I8", SDTTruncSatI8Op>; + +def : Pat<(i16 (truncssat_s i32:$a)), (CVT_s16_s32 $a, CvtSAT)>; +def : Pat<(i16 (truncssat_s i64:$a)), (CVT_s16_s64 $a, CvtSAT)>; + +def : Pat<(i16 (truncssat_u i32:$a)), (CVT_u16_s32 $a, CvtSAT)>; +def : Pat<(i16 (truncssat_u i64:$a)), (CVT_u16_s64 $a, CvtSAT)>; + +def : Pat<(truncssat_s_i8 i16:$a), (CVT_s8_s16 $a, CvtSAT)>; +def : Pat<(truncssat_s_i8 i32:$a), (CVT_s8_s32 $a, CvtSAT)>; +def : Pat<(truncssat_s_i8 i64:$a), (CVT_s8_s64 $a, CvtSAT)>; + +def : Pat<(truncssat_u_i8 i16:$a), (CVT_u8_u16 $a, CvtSAT)>; +def : Pat<(truncssat_u_i8 i32:$a), (CVT_u8_u32 $a, CvtSAT)>; +def : Pat<(truncssat_u_i8 i64:$a), (CVT_u8_u64 $a, CvtSAT)>; + // sext_inreg def : Pat<(sext_inreg i16:$a, i8), (CVT_INREG_s16_s8 $a)>; def : Pat<(sext_inreg i32:$a, i8), (CVT_INREG_s32_s8 $a)>; diff --git a/llvm/test/CodeGen/NVPTX/trunc-sat.ll b/llvm/test/CodeGen/NVPTX/trunc-sat.ll new file mode 100644 index 0000000000000..e77e7691c9cb7 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/trunc-sat.ll @@ -0,0 +1,177 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_20 -verify-machineinstrs | FileCheck %s +; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_20 -verify-machineinstrs | %ptxas-verify %} + +target triple = "nvptx-unknown-cuda" + + +define i64 @trunc_ssat_i64_u16(i64 %a) { +; CHECK-LABEL: trunc_ssat_i64_u16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<2>; +; CHECK-NEXT: .reg .b64 %rd<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b64 %rd1, [trunc_ssat_i64_u16_param_0]; +; CHECK-NEXT: cvt.sat.u16.s64 %rs1, %rd1; +; CHECK-NEXT: cvt.u64.u16 %rd2, %rs1; +; CHECK-NEXT: st.param.b64 [func_retval0], %rd2; +; CHECK-NEXT: ret; + %v1 = call i64 @llvm.smax.i64(i64 %a, i64 0) + %v2 = call i64 @llvm.smin.i64(i64 %v1, i64 65535) + ret i64 %v2 +} + +define i32 @trunc_ssat_i32_u16(i32 %a) { +; CHECK-LABEL: trunc_ssat_i32_u16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<2>; +; CHECK-NEXT: .reg .b32 %r<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [trunc_ssat_i32_u16_param_0]; +; CHECK-NEXT: cvt.sat.u16.s32 %rs1, %r1; +; CHECK-NEXT: cvt.u32.u16 %r2, %rs1; +; CHECK-NEXT: st.param.b32 [func_retval0], %r2; +; CHECK-NEXT: ret; + %v1 = call i32 @llvm.smax.i32(i32 %a, i32 0) + %v2 = call i32 @llvm.smin.i32(i32 %v1, i32 65535) + ret i32 %v2 +} + +define i64 @trunc_ssat_i64_s16(i64 %a) { +; CHECK-LABEL: trunc_ssat_i64_s16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<2>; +; CHECK-NEXT: .reg .b64 %rd<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b64 %rd1, [trunc_ssat_i64_s16_param_0]; +; CHECK-NEXT: cvt.sat.s16.s64 %rs1, %rd1; +; CHECK-NEXT: cvt.s64.s16 %rd2, %rs1; +; CHECK-NEXT: st.param.b64 [func_retval0], %rd2; +; CHECK-NEXT: ret; + %v1 = call i64 @llvm.smax.i64(i64 %a, i64 -32768) + %v2 = call i64 @llvm.smin.i64(i64 %v1, i64 32767) + ret i64 %v2 +} + +define i32 @trunc_ssat_i32_s16(i32 %a) { +; CHECK-LABEL: trunc_ssat_i32_s16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<2>; +; CHECK-NEXT: .reg .b32 %r<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [trunc_ssat_i32_s16_param_0]; +; CHECK-NEXT: cvt.sat.s16.s32 %rs1, %r1; +; CHECK-NEXT: cvt.s32.s16 %r2, %rs1; +; CHECK-NEXT: st.param.b32 [func_retval0], %r2; +; CHECK-NEXT: ret; + %v1 = call i32 @llvm.smax.i32(i32 %a, i32 -32768) + %v2 = call i32 @llvm.smin.i32(i32 %v1, i32 32767) + ret i32 %v2 +} + +define i64 @trunc_ssat_i64_u8(i64 %a) { +; CHECK-LABEL: trunc_ssat_i64_u8( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<2>; +; CHECK-NEXT: .reg .b64 %rd<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b64 %rd1, [trunc_ssat_i64_u8_param_0]; +; CHECK-NEXT: cvt.sat.u8.u64 %rs1, %rd1; +; CHECK-NEXT: cvt.u64.u16 %rd2, %rs1; +; CHECK-NEXT: st.param.b64 [func_retval0], %rd2; +; CHECK-NEXT: ret; + %v1 = call i64 @llvm.smax.i64(i64 %a, i64 0) + %v2 = call i64 @llvm.smin.i64(i64 %v1, i64 255) + ret i64 %v2 +} + +define i32 @trunc_ssat_i32_u8(i32 %a) { +; CHECK-LABEL: trunc_ssat_i32_u8( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<2>; +; CHECK-NEXT: .reg .b32 %r<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [trunc_ssat_i32_u8_param_0]; +; CHECK-NEXT: cvt.sat.u8.u32 %rs1, %r1; +; CHECK-NEXT: cvt.u32.u16 %r2, %rs1; +; CHECK-NEXT: st.param.b32 [func_retval0], %r2; +; CHECK-NEXT: ret; + %v1 = call i32 @llvm.smax.i32(i32 %a, i32 0) + %v2 = call i32 @llvm.smin.i32(i32 %v1, i32 255) + ret i32 %v2 +} + +define i16 @trunc_ssat_i16_u8(i16 %a) { +; CHECK-LABEL: trunc_ssat_i16_u8( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<3>; +; CHECK-NEXT: .reg .b32 %r<2>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [trunc_ssat_i16_u8_param_0]; +; CHECK-NEXT: cvt.sat.u8.u16 %rs2, %rs1; +; CHECK-NEXT: cvt.u32.u16 %r1, %rs2; +; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +; CHECK-NEXT: ret; + %v1 = call i16 @llvm.smax.i16(i16 %a, i16 0) + %v2 = call i16 @llvm.smin.i16(i16 %v1, i16 255) + ret i16 %v2 +} + +define i64 @trunc_ssat_i64_s8(i64 %a) { +; CHECK-LABEL: trunc_ssat_i64_s8( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<2>; +; CHECK-NEXT: .reg .b64 %rd<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b64 %rd1, [trunc_ssat_i64_s8_param_0]; +; CHECK-NEXT: cvt.sat.s8.s64 %rs1, %rd1; +; CHECK-NEXT: cvt.s64.s16 %rd2, %rs1; +; CHECK-NEXT: st.param.b64 [func_retval0], %rd2; +; CHECK-NEXT: ret; + %v1 = call i64 @llvm.smax.i64(i64 %a, i64 -128) + %v2 = call i64 @llvm.smin.i64(i64 %v1, i64 127) + ret i64 %v2 +} + +define i32 @trunc_ssat_i32_s8(i32 %a) { +; CHECK-LABEL: trunc_ssat_i32_s8( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<2>; +; CHECK-NEXT: .reg .b32 %r<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [trunc_ssat_i32_s8_param_0]; +; CHECK-NEXT: cvt.sat.s8.s32 %rs1, %r1; +; CHECK-NEXT: cvt.s32.s16 %r2, %rs1; +; CHECK-NEXT: st.param.b32 [func_retval0], %r2; +; CHECK-NEXT: ret; + %v1 = call i32 @llvm.smax.i32(i32 %a, i32 -128) + %v2 = call i32 @llvm.smin.i32(i32 %v1, i32 127) + ret i32 %v2 +} + +define i16 @trunc_ssat_i16_s8(i16 %a) { +; CHECK-LABEL: trunc_ssat_i16_s8( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<3>; +; CHECK-NEXT: .reg .b32 %r<2>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [trunc_ssat_i16_s8_param_0]; +; CHECK-NEXT: cvt.sat.s8.s16 %rs2, %rs1; +; CHECK-NEXT: cvt.u32.u16 %r1, %rs2; +; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +; CHECK-NEXT: ret; + %v1 = call i16 @llvm.smax.i16(i16 %a, i16 -128) + %v2 = call i16 @llvm.smin.i16(i16 %v1, i16 127) + ret i16 %v2 +} +