diff --git a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp index 28ed1e520ce52..71591074cea50 100644 --- a/llvm/lib/IR/AutoUpgrade.cpp +++ b/llvm/lib/IR/AutoUpgrade.cpp @@ -1450,6 +1450,7 @@ static bool upgradeIntrinsicFunction1(Function *F, Function *&NewFn, .Case("popc.ll", true) .Case("h2f", true) .Case("swap.lo.hi.b64", true) + .Case("tanh.approx.f32", true) .Default(false); if (Expand) { @@ -2543,6 +2544,12 @@ static Value *upgradeNVVMIntrinsicCall(StringRef Name, CallBase *CI, MDNode *MD = MDNode::get(Builder.getContext(), {}); LD->setMetadata(LLVMContext::MD_invariant_load, MD); return LD; + } else if (Name == "tanh.approx.f32") { + // nvvm.tanh.approx.f32 -> afn llvm.tanh.f32 + FastMathFlags FMF; + FMF.setApproxFunc(); + Rep = Builder.CreateUnaryIntrinsic(Intrinsic::tanh, CI->getArgOperand(0), + FMF); } else if (Name == "barrier0" || Name == "barrier.n" || Name == "bar.sync") { Value *Arg = Name.ends_with('0') ? Builder.getInt32(0) : CI->getArgOperand(0); diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 7883acca6d8b3..e26c07902f2c4 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -952,10 +952,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, // promoted to f32. v2f16 is expanded to f16, which is then promoted // to f32. for (const auto &Op : - {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS}) { + {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS, ISD::FTANH}) { setOperationAction(Op, MVT::f16, Promote); setOperationAction(Op, MVT::f32, Legal); - setOperationAction(Op, MVT::f64, Legal); + // only div/rem/sqrt are legal for f64 + if (Op == ISD::FDIV || Op == ISD::FREM || Op == ISD::FSQRT) { + setOperationAction(Op, MVT::f64, Legal); + } setOperationAction(Op, {MVT::v2f16, MVT::v2bf16, MVT::v2f32}, Expand); setOperationAction(Op, MVT::bf16, Promote); AddPromotedToType(Op, MVT::bf16, MVT::f32); diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index b5df4c6de7fd8..442b900d280b2 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -1234,7 +1234,7 @@ defm FMA_F32 : FMA; defm FMA_F32x2 : FMA; defm FMA_F64 : FMA; -// sin/cos +// sin/cos/tanh class UnaryOpAllowsApproxFn : PatFrag<(ops node:$A), @@ -1250,6 +1250,10 @@ def COS_APPROX_f32 : BasicFlagsNVPTXInst<(outs B32:$dst), (ins B32:$src), (ins FTZFlag:$ftz), "cos.approx$ftz.f32", [(set f32:$dst, (UnaryOpAllowsApproxFn f32:$src))]>; +def TANH_APPROX_f32 : + BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), "tanh.approx.f32", + [(set f32:$dst, (UnaryOpAllowsApproxFn f32:$src))]>, + Requires<[hasPTX<70>, hasSM<75>]>; //----------------------------------- // Bitwise operations diff --git a/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll b/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll index a17f11a680aa2..362586af4f9b7 100644 --- a/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll +++ b/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll @@ -17,6 +17,8 @@ declare float @llvm.nvvm.fabs.f(float) declare float @llvm.nvvm.fabs.ftz.f(float) declare double @llvm.nvvm.fabs.d(double) +declare float @llvm.nvvm.tanh.approx.f32(float) + declare i16 @llvm.nvvm.max.s(i16, i16) declare i32 @llvm.nvvm.max.i(i32, i32) declare i64 @llvm.nvvm.max.ll(i64, i64) @@ -138,6 +140,13 @@ define void @fabs(float %a, double %b) { ret void } +; CHECK-LABEL: @tanh +define void @tanh(float %a) { +; CHECK: call afn float @llvm.tanh.f32(float %a) + %r1 = call float @llvm.nvvm.tanh.approx.f32(float %a) + ret void +} + ; CHECK-LABEL: @min_max define void @min_max(i16 %a1, i16 %a2, i32 %b1, i32 %b2, i64 %c1, i64 %c2) { ; CHECK: [[maxs:%[a-zA-Z0-9.]+]] = icmp sge i16 %a1, %a2 diff --git a/llvm/test/CodeGen/NVPTX/tanhf.ll b/llvm/test/CodeGen/NVPTX/tanhf.ll new file mode 100644 index 0000000000000..6f4eb222e0b38 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/tanhf.ll @@ -0,0 +1,40 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc < %s -mcpu=sm_75 -mattr=+ptx70 | FileCheck %s +; RUN: %if ptxas-11.0 %{ llc < %s -mcpu=sm_75 -mattr=+ptx70 | %ptxas-verify -arch=sm_75 %} + +target triple = "nvptx64-nvidia-cuda" + +define float @test1(float %in) local_unnamed_addr { +; CHECK-LABEL: test1( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [test1_param_0]; +; CHECK-NEXT: tanh.approx.f32 %r2, %r1; +; CHECK-NEXT: st.param.b32 [func_retval0], %r2; +; CHECK-NEXT: ret; + %call = call afn float @llvm.tanh.f32(float %in) + ret float %call +} + +define half @test2(half %in) local_unnamed_addr { +; CHECK-LABEL: test2( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<3>; +; CHECK-NEXT: .reg .b32 %r<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [test2_param_0]; +; CHECK-NEXT: cvt.f32.f16 %r1, %rs1; +; CHECK-NEXT: tanh.approx.f32 %r2, %r1; +; CHECK-NEXT: cvt.rn.f16.f32 %rs2, %r2; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs2; +; CHECK-NEXT: ret; + %call = call afn half @llvm.tanh.f16(half %in) + ret half %call +} + +declare float @llvm.tanh.f32(float) +declare half @llvm.tanh.f16(half) +