-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[NVPTX] Use cvt.sat to lower min/max clamping to i8 and i16 ranges #143016
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with a couple of nits.
return DCI.DAG.getExtOrTrunc(IsSigned, Trunc, DL, VT); | ||
}; | ||
|
||
if (VT != MVT::i16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd fold it into the MatchTruncSat as if (DestVT == VT) return SDNode; // can't use conversion on identical types
. Or just mention it in the comment here. Otherwise it's not clear what are the other possible types we may see here and why we exclude i16 from that set.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, that is cleaner. I've moved the check.
|
||
const unsigned BitWidth = VT.getSizeInBits(); | ||
SDLoc DL(N); | ||
auto MatchTuncSat = [&](MVT DestVT) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Perhaps rename MatchTruncSat
to TryToLowerAsSaturatedConversion
as we're not just matching the graph nodes, but also constructing their replacement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. Fixed
@llvm/pr-subscribers-backend-nvptx Author: Alex MacLean (AlexMaclean) ChangesFull diff: https://github.com/llvm/llvm-project/pull/143016.diff 4 Files Affected:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index d6a134d9abafd..b16a2842ed28e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -28,6 +28,7 @@
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineJumpTableInfo.h"
#include "llvm/CodeGen/MachineMemOperand.h"
+#include "llvm/CodeGen/SDPatternMatch.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/SelectionDAGNodes.h"
#include "llvm/CodeGen/TargetCallingConv.h"
@@ -74,6 +75,7 @@
#define DEBUG_TYPE "nvptx-lower"
using namespace llvm;
+using namespace llvm::SDPatternMatch;
static cl::opt<bool> sched4reg(
"nvptx-sched4reg",
@@ -659,6 +661,11 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::BR_CC, VT, Expand);
}
+ setOperationAction({ISD::TRUNCATE_SSAT_S, ISD::TRUNCATE_SSAT_U}, MVT::i16,
+ Legal);
+ setOperationAction({ISD::TRUNCATE_SSAT_S, ISD::TRUNCATE_SSAT_U}, MVT::i8,
+ Custom);
+
// Some SIGN_EXTEND_INREG can be done using cvt instruction.
// For others we will expand to a SHL/SRA pair.
setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i64, Legal);
@@ -836,7 +843,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// We have some custom DAG combine patterns for these nodes
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
- ISD::BUILD_VECTOR, ISD::ADDRSPACECAST});
+ ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::SMIN,
+ ISD::SMAX});
// setcc for f16x2 and bf16x2 needs special handling to prevent
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -1081,6 +1089,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(NVPTXISD::PseudoUseParam)
MAKE_CASE(NVPTXISD::UNPACK_VECTOR)
MAKE_CASE(NVPTXISD::BUILD_VECTOR)
+ MAKE_CASE(NVPTXISD::TRUNCATE_SSAT_U_I8)
+ MAKE_CASE(NVPTXISD::TRUNCATE_SSAT_S_I8)
MAKE_CASE(NVPTXISD::RETURN)
MAKE_CASE(NVPTXISD::CallSeqBegin)
MAKE_CASE(NVPTXISD::CallSeqEnd)
@@ -5667,6 +5677,52 @@ static SDValue combineADDRSPACECAST(SDNode *N,
return SDValue();
}
+static SDValue combineMINMAX(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
+
+ EVT VT = N->getValueType(0);
+ if (!(VT == MVT::i32 || VT == MVT::i64 || VT == MVT::i16))
+ return SDValue();
+
+ SDValue Val;
+ APInt Ceil, Floor;
+ if (!(sd_match(N, m_SMin(m_SMax(m_Value(Val), m_ConstInt(Floor)),
+ m_ConstInt(Ceil))) ||
+ sd_match(N, m_SMax(m_SMin(m_Value(Val), m_ConstInt(Ceil)),
+ m_ConstInt(Floor)))))
+ return SDValue();
+
+ const unsigned BitWidth = VT.getSizeInBits();
+ SDLoc DL(N);
+ const auto TryToLowerAsSaturatedConversion = [&](MVT DestVT) {
+ const unsigned DestBitWidth = DestVT.getSizeInBits();
+
+ // If the destination is greater than or equal to the source, the min/max
+ // comparisons are meaningless/impossible.
+ if (DestBitWidth >= BitWidth)
+ return SDValue();
+
+ bool IsSigned;
+ if (Ceil == APInt::getSignedMaxValue(DestBitWidth).sext(BitWidth) &&
+ Floor == APInt::getSignedMinValue(DestBitWidth).sext(BitWidth))
+ IsSigned = true;
+ else if (Ceil == APInt::getMaxValue(DestBitWidth).zext(BitWidth) &&
+ Floor == APInt::getMinValue(BitWidth))
+ IsSigned = false;
+ else
+ return SDValue();
+
+ unsigned Opcode = IsSigned ? ISD::TRUNCATE_SSAT_S : ISD::TRUNCATE_SSAT_U;
+ SDValue Trunc = DCI.DAG.getNode(Opcode, DL, DestVT, Val);
+ return DCI.DAG.getExtOrTrunc(IsSigned, Trunc, DL, VT);
+ };
+
+ for (const auto DestVT : {MVT::i16, MVT::i8})
+ if (auto Res = TryToLowerAsSaturatedConversion(DestVT))
+ return Res;
+
+ return SDValue();
+}
+
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5685,6 +5741,9 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
case ISD::UREM:
case ISD::SREM:
return PerformREMCombine(N, DCI, OptLevel);
+ case ISD::SMIN:
+ case ISD::SMAX:
+ return combineMINMAX(N, DCI);
case ISD::SETCC:
return PerformSETCCCombine(N, DCI, STI.getSmVersion());
case NVPTXISD::StoreRetval:
@@ -6045,6 +6104,20 @@ static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG,
Results.push_back(NewValue.getValue(3));
}
+static void replaceTruncateSSat(SDNode *N, SelectionDAG &DAG,
+ SmallVectorImpl<SDValue> &Results) {
+ SDLoc DL(N);
+
+ const bool IsSigned = N->getOpcode() == ISD::TRUNCATE_SSAT_S;
+ const unsigned Opcode =
+ IsSigned ? NVPTXISD::TRUNCATE_SSAT_S_I8 : NVPTXISD::TRUNCATE_SSAT_U_I8;
+ SDValue NewTrunc = DAG.getNode(Opcode, DL, MVT::i16, N->getOperand(0));
+ SDValue Assert = DAG.getNode(IsSigned ? ISD::AssertSext : ISD::AssertZext, DL,
+ MVT::i16, NewTrunc, DAG.getValueType(MVT::i8));
+
+ Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, Assert));
+}
+
void NVPTXTargetLowering::ReplaceNodeResults(
SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
switch (N->getOpcode()) {
@@ -6062,6 +6135,10 @@ void NVPTXTargetLowering::ReplaceNodeResults(
case ISD::CopyFromReg:
ReplaceCopyFromReg_128(N, DAG, Results);
return;
+ case ISD::TRUNCATE_SSAT_U:
+ case ISD::TRUNCATE_SSAT_S:
+ replaceTruncateSSat(N, DAG, Results);
+ return;
}
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 8d71022a1f102..1bd8ffb65c501 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -72,6 +72,13 @@ enum NodeType : unsigned {
/// converting it to a vector.
UNPACK_VECTOR,
+ /// These nodes are equivalent to the corresponding ISD nodes except that
+ /// they truncate to an i8 output and then sign or zero extend that value back
+ /// to i16. This is a workaround for the fact that NVPTX does not consider
+ /// i8 to be a legal type. TODO: consider making i8 legal and removing these.
+ TRUNCATE_SSAT_U_I8,
+ TRUNCATE_SSAT_S_I8,
+
FCOPYSIGN,
DYNAMIC_STACKALLOC,
STACKRESTORE,
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index b646d39194c7e..e26a05a112e09 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -2649,6 +2649,25 @@ def : Pat<(i1 (trunc i32:$a)), (SETP_b32ri (ANDb32ri $a, 1), 0, CmpNE)>;
// truncate i16
def : Pat<(i1 (trunc i16:$a)), (SETP_b16ri (ANDb16ri $a, 1), 0, CmpNE)>;
+// truncate ssat
+def SDTTruncSatI8Op : SDTypeProfile<1, 1, [SDTCisInt<1>, SDTCisVT<0, i16>]>;
+def truncssat_s_i8 : SDNode<"NVPTXISD::TRUNCATE_SSAT_S_I8", SDTTruncSatI8Op>;
+def truncssat_u_i8 : SDNode<"NVPTXISD::TRUNCATE_SSAT_U_I8", SDTTruncSatI8Op>;
+
+def : Pat<(i16 (truncssat_s i32:$a)), (CVT_s16_s32 $a, CvtSAT)>;
+def : Pat<(i16 (truncssat_s i64:$a)), (CVT_s16_s64 $a, CvtSAT)>;
+
+def : Pat<(i16 (truncssat_u i32:$a)), (CVT_u16_s32 $a, CvtSAT)>;
+def : Pat<(i16 (truncssat_u i64:$a)), (CVT_u16_s64 $a, CvtSAT)>;
+
+def : Pat<(truncssat_s_i8 i16:$a), (CVT_s8_s16 $a, CvtSAT)>;
+def : Pat<(truncssat_s_i8 i32:$a), (CVT_s8_s32 $a, CvtSAT)>;
+def : Pat<(truncssat_s_i8 i64:$a), (CVT_s8_s64 $a, CvtSAT)>;
+
+def : Pat<(truncssat_u_i8 i16:$a), (CVT_u8_u16 $a, CvtSAT)>;
+def : Pat<(truncssat_u_i8 i32:$a), (CVT_u8_u32 $a, CvtSAT)>;
+def : Pat<(truncssat_u_i8 i64:$a), (CVT_u8_u64 $a, CvtSAT)>;
+
// sext_inreg
def : Pat<(sext_inreg i16:$a, i8), (CVT_INREG_s16_s8 $a)>;
def : Pat<(sext_inreg i32:$a, i8), (CVT_INREG_s32_s8 $a)>;
diff --git a/llvm/test/CodeGen/NVPTX/trunc-sat.ll b/llvm/test/CodeGen/NVPTX/trunc-sat.ll
new file mode 100644
index 0000000000000..e77e7691c9cb7
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/trunc-sat.ll
@@ -0,0 +1,177 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_20 -verify-machineinstrs | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_20 -verify-machineinstrs | %ptxas-verify %}
+
+target triple = "nvptx-unknown-cuda"
+
+
+define i64 @trunc_ssat_i64_u16(i64 %a) {
+; CHECK-LABEL: trunc_ssat_i64_u16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [trunc_ssat_i64_u16_param_0];
+; CHECK-NEXT: cvt.sat.u16.s64 %rs1, %rd1;
+; CHECK-NEXT: cvt.u64.u16 %rd2, %rs1;
+; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
+; CHECK-NEXT: ret;
+ %v1 = call i64 @llvm.smax.i64(i64 %a, i64 0)
+ %v2 = call i64 @llvm.smin.i64(i64 %v1, i64 65535)
+ ret i64 %v2
+}
+
+define i32 @trunc_ssat_i32_u16(i32 %a) {
+; CHECK-LABEL: trunc_ssat_i32_u16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b32 %r<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [trunc_ssat_i32_u16_param_0];
+; CHECK-NEXT: cvt.sat.u16.s32 %rs1, %r1;
+; CHECK-NEXT: cvt.u32.u16 %r2, %rs1;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
+; CHECK-NEXT: ret;
+ %v1 = call i32 @llvm.smax.i32(i32 %a, i32 0)
+ %v2 = call i32 @llvm.smin.i32(i32 %v1, i32 65535)
+ ret i32 %v2
+}
+
+define i64 @trunc_ssat_i64_s16(i64 %a) {
+; CHECK-LABEL: trunc_ssat_i64_s16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [trunc_ssat_i64_s16_param_0];
+; CHECK-NEXT: cvt.sat.s16.s64 %rs1, %rd1;
+; CHECK-NEXT: cvt.s64.s16 %rd2, %rs1;
+; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
+; CHECK-NEXT: ret;
+ %v1 = call i64 @llvm.smax.i64(i64 %a, i64 -32768)
+ %v2 = call i64 @llvm.smin.i64(i64 %v1, i64 32767)
+ ret i64 %v2
+}
+
+define i32 @trunc_ssat_i32_s16(i32 %a) {
+; CHECK-LABEL: trunc_ssat_i32_s16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b32 %r<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [trunc_ssat_i32_s16_param_0];
+; CHECK-NEXT: cvt.sat.s16.s32 %rs1, %r1;
+; CHECK-NEXT: cvt.s32.s16 %r2, %rs1;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
+; CHECK-NEXT: ret;
+ %v1 = call i32 @llvm.smax.i32(i32 %a, i32 -32768)
+ %v2 = call i32 @llvm.smin.i32(i32 %v1, i32 32767)
+ ret i32 %v2
+}
+
+define i64 @trunc_ssat_i64_u8(i64 %a) {
+; CHECK-LABEL: trunc_ssat_i64_u8(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [trunc_ssat_i64_u8_param_0];
+; CHECK-NEXT: cvt.sat.u8.u64 %rs1, %rd1;
+; CHECK-NEXT: cvt.u64.u16 %rd2, %rs1;
+; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
+; CHECK-NEXT: ret;
+ %v1 = call i64 @llvm.smax.i64(i64 %a, i64 0)
+ %v2 = call i64 @llvm.smin.i64(i64 %v1, i64 255)
+ ret i64 %v2
+}
+
+define i32 @trunc_ssat_i32_u8(i32 %a) {
+; CHECK-LABEL: trunc_ssat_i32_u8(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b32 %r<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [trunc_ssat_i32_u8_param_0];
+; CHECK-NEXT: cvt.sat.u8.u32 %rs1, %r1;
+; CHECK-NEXT: cvt.u32.u16 %r2, %rs1;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
+; CHECK-NEXT: ret;
+ %v1 = call i32 @llvm.smax.i32(i32 %a, i32 0)
+ %v2 = call i32 @llvm.smin.i32(i32 %v1, i32 255)
+ ret i32 %v2
+}
+
+define i16 @trunc_ssat_i16_u8(i16 %a) {
+; CHECK-LABEL: trunc_ssat_i16_u8(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [trunc_ssat_i16_u8_param_0];
+; CHECK-NEXT: cvt.sat.u8.u16 %rs2, %rs1;
+; CHECK-NEXT: cvt.u32.u16 %r1, %rs2;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ret;
+ %v1 = call i16 @llvm.smax.i16(i16 %a, i16 0)
+ %v2 = call i16 @llvm.smin.i16(i16 %v1, i16 255)
+ ret i16 %v2
+}
+
+define i64 @trunc_ssat_i64_s8(i64 %a) {
+; CHECK-LABEL: trunc_ssat_i64_s8(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [trunc_ssat_i64_s8_param_0];
+; CHECK-NEXT: cvt.sat.s8.s64 %rs1, %rd1;
+; CHECK-NEXT: cvt.s64.s16 %rd2, %rs1;
+; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
+; CHECK-NEXT: ret;
+ %v1 = call i64 @llvm.smax.i64(i64 %a, i64 -128)
+ %v2 = call i64 @llvm.smin.i64(i64 %v1, i64 127)
+ ret i64 %v2
+}
+
+define i32 @trunc_ssat_i32_s8(i32 %a) {
+; CHECK-LABEL: trunc_ssat_i32_s8(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b32 %r<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %r1, [trunc_ssat_i32_s8_param_0];
+; CHECK-NEXT: cvt.sat.s8.s32 %rs1, %r1;
+; CHECK-NEXT: cvt.s32.s16 %r2, %rs1;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
+; CHECK-NEXT: ret;
+ %v1 = call i32 @llvm.smax.i32(i32 %a, i32 -128)
+ %v2 = call i32 @llvm.smin.i32(i32 %v1, i32 127)
+ ret i32 %v2
+}
+
+define i16 @trunc_ssat_i16_s8(i16 %a) {
+; CHECK-LABEL: trunc_ssat_i16_s8(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [trunc_ssat_i16_s8_param_0];
+; CHECK-NEXT: cvt.sat.s8.s16 %rs2, %rs1;
+; CHECK-NEXT: cvt.u32.u16 %r1, %rs2;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ret;
+ %v1 = call i16 @llvm.smax.i16(i16 %a, i16 -128)
+ %v2 = call i16 @llvm.smin.i16(i16 %v1, i16 127)
+ ret i16 %v2
+}
+
|
No description provided.