Skip to content

[NVPTX] Use cvt.sat to lower min/max clamping to i8 and i16 ranges #143016

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 78 additions & 1 deletion llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineJumpTableInfo.h"
#include "llvm/CodeGen/MachineMemOperand.h"
#include "llvm/CodeGen/SDPatternMatch.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/SelectionDAGNodes.h"
#include "llvm/CodeGen/TargetCallingConv.h"
Expand Down Expand Up @@ -74,6 +75,7 @@
#define DEBUG_TYPE "nvptx-lower"

using namespace llvm;
using namespace llvm::SDPatternMatch;

static cl::opt<bool> sched4reg(
"nvptx-sched4reg",
Expand Down Expand Up @@ -659,6 +661,11 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::BR_CC, VT, Expand);
}

setOperationAction({ISD::TRUNCATE_SSAT_S, ISD::TRUNCATE_SSAT_U}, MVT::i16,
Legal);
setOperationAction({ISD::TRUNCATE_SSAT_S, ISD::TRUNCATE_SSAT_U}, MVT::i8,
Custom);

// Some SIGN_EXTEND_INREG can be done using cvt instruction.
// For others we will expand to a SHL/SRA pair.
setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i64, Legal);
Expand Down Expand Up @@ -836,7 +843,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// We have some custom DAG combine patterns for these nodes
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST});
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::SMIN,
ISD::SMAX});

// setcc for f16x2 and bf16x2 needs special handling to prevent
// legalizer's attempt to scalarize it due to v2i1 not being legal.
Expand Down Expand Up @@ -1081,6 +1089,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(NVPTXISD::PseudoUseParam)
MAKE_CASE(NVPTXISD::UNPACK_VECTOR)
MAKE_CASE(NVPTXISD::BUILD_VECTOR)
MAKE_CASE(NVPTXISD::TRUNCATE_SSAT_U_I8)
MAKE_CASE(NVPTXISD::TRUNCATE_SSAT_S_I8)
MAKE_CASE(NVPTXISD::RETURN)
MAKE_CASE(NVPTXISD::CallSeqBegin)
MAKE_CASE(NVPTXISD::CallSeqEnd)
Expand Down Expand Up @@ -5667,6 +5677,52 @@ static SDValue combineADDRSPACECAST(SDNode *N,
return SDValue();
}

static SDValue combineMINMAX(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {

EVT VT = N->getValueType(0);
if (!(VT == MVT::i32 || VT == MVT::i64 || VT == MVT::i16))
return SDValue();

SDValue Val;
APInt Ceil, Floor;
if (!(sd_match(N, m_SMin(m_SMax(m_Value(Val), m_ConstInt(Floor)),
m_ConstInt(Ceil))) ||
sd_match(N, m_SMax(m_SMin(m_Value(Val), m_ConstInt(Ceil)),
m_ConstInt(Floor)))))
return SDValue();

const unsigned BitWidth = VT.getSizeInBits();
SDLoc DL(N);
const auto TryToLowerAsSaturatedConversion = [&](MVT DestVT) {
const unsigned DestBitWidth = DestVT.getSizeInBits();

// If the destination is greater than or equal to the source, the min/max
// comparisons are meaningless/impossible.
if (DestBitWidth >= BitWidth)
return SDValue();

bool IsSigned;
if (Ceil == APInt::getSignedMaxValue(DestBitWidth).sext(BitWidth) &&
Floor == APInt::getSignedMinValue(DestBitWidth).sext(BitWidth))
IsSigned = true;
else if (Ceil == APInt::getMaxValue(DestBitWidth).zext(BitWidth) &&
Floor == APInt::getMinValue(BitWidth))
IsSigned = false;
else
return SDValue();

unsigned Opcode = IsSigned ? ISD::TRUNCATE_SSAT_S : ISD::TRUNCATE_SSAT_U;
SDValue Trunc = DCI.DAG.getNode(Opcode, DL, DestVT, Val);
return DCI.DAG.getExtOrTrunc(IsSigned, Trunc, DL, VT);
};

for (const auto DestVT : {MVT::i16, MVT::i8})
if (auto Res = TryToLowerAsSaturatedConversion(DestVT))
return Res;

return SDValue();
}

SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
Expand All @@ -5685,6 +5741,9 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
case ISD::UREM:
case ISD::SREM:
return PerformREMCombine(N, DCI, OptLevel);
case ISD::SMIN:
case ISD::SMAX:
return combineMINMAX(N, DCI);
case ISD::SETCC:
return PerformSETCCCombine(N, DCI, STI.getSmVersion());
case NVPTXISD::StoreRetval:
Expand Down Expand Up @@ -6045,6 +6104,20 @@ static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG,
Results.push_back(NewValue.getValue(3));
}

static void replaceTruncateSSat(SDNode *N, SelectionDAG &DAG,
SmallVectorImpl<SDValue> &Results) {
SDLoc DL(N);

const bool IsSigned = N->getOpcode() == ISD::TRUNCATE_SSAT_S;
const unsigned Opcode =
IsSigned ? NVPTXISD::TRUNCATE_SSAT_S_I8 : NVPTXISD::TRUNCATE_SSAT_U_I8;
SDValue NewTrunc = DAG.getNode(Opcode, DL, MVT::i16, N->getOperand(0));
SDValue Assert = DAG.getNode(IsSigned ? ISD::AssertSext : ISD::AssertZext, DL,
MVT::i16, NewTrunc, DAG.getValueType(MVT::i8));

Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, Assert));
}

void NVPTXTargetLowering::ReplaceNodeResults(
SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
switch (N->getOpcode()) {
Expand All @@ -6062,6 +6135,10 @@ void NVPTXTargetLowering::ReplaceNodeResults(
case ISD::CopyFromReg:
ReplaceCopyFromReg_128(N, DAG, Results);
return;
case ISD::TRUNCATE_SSAT_U:
case ISD::TRUNCATE_SSAT_S:
replaceTruncateSSat(N, DAG, Results);
return;
}
}

Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ enum NodeType : unsigned {
/// converting it to a vector.
UNPACK_VECTOR,

/// These nodes are equivalent to the corresponding ISD nodes except that
/// they truncate to an i8 output and then sign or zero extend that value back
/// to i16. This is a workaround for the fact that NVPTX does not consider
/// i8 to be a legal type. TODO: consider making i8 legal and removing these.
TRUNCATE_SSAT_U_I8,
TRUNCATE_SSAT_S_I8,

FCOPYSIGN,
DYNAMIC_STACKALLOC,
STACKRESTORE,
Expand Down
19 changes: 19 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -2649,6 +2649,25 @@ def : Pat<(i1 (trunc i32:$a)), (SETP_b32ri (ANDb32ri $a, 1), 0, CmpNE)>;
// truncate i16
def : Pat<(i1 (trunc i16:$a)), (SETP_b16ri (ANDb16ri $a, 1), 0, CmpNE)>;

// truncate ssat
def SDTTruncSatI8Op : SDTypeProfile<1, 1, [SDTCisInt<1>, SDTCisVT<0, i16>]>;
def truncssat_s_i8 : SDNode<"NVPTXISD::TRUNCATE_SSAT_S_I8", SDTTruncSatI8Op>;
def truncssat_u_i8 : SDNode<"NVPTXISD::TRUNCATE_SSAT_U_I8", SDTTruncSatI8Op>;

def : Pat<(i16 (truncssat_s i32:$a)), (CVT_s16_s32 $a, CvtSAT)>;
def : Pat<(i16 (truncssat_s i64:$a)), (CVT_s16_s64 $a, CvtSAT)>;

def : Pat<(i16 (truncssat_u i32:$a)), (CVT_u16_s32 $a, CvtSAT)>;
def : Pat<(i16 (truncssat_u i64:$a)), (CVT_u16_s64 $a, CvtSAT)>;

def : Pat<(truncssat_s_i8 i16:$a), (CVT_s8_s16 $a, CvtSAT)>;
def : Pat<(truncssat_s_i8 i32:$a), (CVT_s8_s32 $a, CvtSAT)>;
def : Pat<(truncssat_s_i8 i64:$a), (CVT_s8_s64 $a, CvtSAT)>;

def : Pat<(truncssat_u_i8 i16:$a), (CVT_u8_u16 $a, CvtSAT)>;
def : Pat<(truncssat_u_i8 i32:$a), (CVT_u8_u32 $a, CvtSAT)>;
def : Pat<(truncssat_u_i8 i64:$a), (CVT_u8_u64 $a, CvtSAT)>;

// sext_inreg
def : Pat<(sext_inreg i16:$a, i8), (CVT_INREG_s16_s8 $a)>;
def : Pat<(sext_inreg i32:$a, i8), (CVT_INREG_s32_s8 $a)>;
Expand Down
177 changes: 177 additions & 0 deletions llvm/test/CodeGen/NVPTX/trunc-sat.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_20 -verify-machineinstrs | FileCheck %s
; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_20 -verify-machineinstrs | %ptxas-verify %}

target triple = "nvptx-unknown-cuda"


define i64 @trunc_ssat_i64_u16(i64 %a) {
; CHECK-LABEL: trunc_ssat_i64_u16(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<2>;
; CHECK-NEXT: .reg .b64 %rd<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b64 %rd1, [trunc_ssat_i64_u16_param_0];
; CHECK-NEXT: cvt.sat.u16.s64 %rs1, %rd1;
; CHECK-NEXT: cvt.u64.u16 %rd2, %rs1;
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
; CHECK-NEXT: ret;
%v1 = call i64 @llvm.smax.i64(i64 %a, i64 0)
%v2 = call i64 @llvm.smin.i64(i64 %v1, i64 65535)
ret i64 %v2
}

define i32 @trunc_ssat_i32_u16(i32 %a) {
; CHECK-LABEL: trunc_ssat_i32_u16(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<2>;
; CHECK-NEXT: .reg .b32 %r<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b32 %r1, [trunc_ssat_i32_u16_param_0];
; CHECK-NEXT: cvt.sat.u16.s32 %rs1, %r1;
; CHECK-NEXT: cvt.u32.u16 %r2, %rs1;
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
; CHECK-NEXT: ret;
%v1 = call i32 @llvm.smax.i32(i32 %a, i32 0)
%v2 = call i32 @llvm.smin.i32(i32 %v1, i32 65535)
ret i32 %v2
}

define i64 @trunc_ssat_i64_s16(i64 %a) {
; CHECK-LABEL: trunc_ssat_i64_s16(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<2>;
; CHECK-NEXT: .reg .b64 %rd<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b64 %rd1, [trunc_ssat_i64_s16_param_0];
; CHECK-NEXT: cvt.sat.s16.s64 %rs1, %rd1;
; CHECK-NEXT: cvt.s64.s16 %rd2, %rs1;
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
; CHECK-NEXT: ret;
%v1 = call i64 @llvm.smax.i64(i64 %a, i64 -32768)
%v2 = call i64 @llvm.smin.i64(i64 %v1, i64 32767)
ret i64 %v2
}

define i32 @trunc_ssat_i32_s16(i32 %a) {
; CHECK-LABEL: trunc_ssat_i32_s16(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<2>;
; CHECK-NEXT: .reg .b32 %r<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b32 %r1, [trunc_ssat_i32_s16_param_0];
; CHECK-NEXT: cvt.sat.s16.s32 %rs1, %r1;
; CHECK-NEXT: cvt.s32.s16 %r2, %rs1;
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
; CHECK-NEXT: ret;
%v1 = call i32 @llvm.smax.i32(i32 %a, i32 -32768)
%v2 = call i32 @llvm.smin.i32(i32 %v1, i32 32767)
ret i32 %v2
}

define i64 @trunc_ssat_i64_u8(i64 %a) {
; CHECK-LABEL: trunc_ssat_i64_u8(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<2>;
; CHECK-NEXT: .reg .b64 %rd<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b64 %rd1, [trunc_ssat_i64_u8_param_0];
; CHECK-NEXT: cvt.sat.u8.u64 %rs1, %rd1;
; CHECK-NEXT: cvt.u64.u16 %rd2, %rs1;
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
; CHECK-NEXT: ret;
%v1 = call i64 @llvm.smax.i64(i64 %a, i64 0)
%v2 = call i64 @llvm.smin.i64(i64 %v1, i64 255)
ret i64 %v2
}

define i32 @trunc_ssat_i32_u8(i32 %a) {
; CHECK-LABEL: trunc_ssat_i32_u8(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<2>;
; CHECK-NEXT: .reg .b32 %r<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b32 %r1, [trunc_ssat_i32_u8_param_0];
; CHECK-NEXT: cvt.sat.u8.u32 %rs1, %r1;
; CHECK-NEXT: cvt.u32.u16 %r2, %rs1;
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
; CHECK-NEXT: ret;
%v1 = call i32 @llvm.smax.i32(i32 %a, i32 0)
%v2 = call i32 @llvm.smin.i32(i32 %v1, i32 255)
ret i32 %v2
}

define i16 @trunc_ssat_i16_u8(i16 %a) {
; CHECK-LABEL: trunc_ssat_i16_u8(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<3>;
; CHECK-NEXT: .reg .b32 %r<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b16 %rs1, [trunc_ssat_i16_u8_param_0];
; CHECK-NEXT: cvt.sat.u8.u16 %rs2, %rs1;
; CHECK-NEXT: cvt.u32.u16 %r1, %rs2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
; CHECK-NEXT: ret;
%v1 = call i16 @llvm.smax.i16(i16 %a, i16 0)
%v2 = call i16 @llvm.smin.i16(i16 %v1, i16 255)
ret i16 %v2
}

define i64 @trunc_ssat_i64_s8(i64 %a) {
; CHECK-LABEL: trunc_ssat_i64_s8(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<2>;
; CHECK-NEXT: .reg .b64 %rd<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b64 %rd1, [trunc_ssat_i64_s8_param_0];
; CHECK-NEXT: cvt.sat.s8.s64 %rs1, %rd1;
; CHECK-NEXT: cvt.s64.s16 %rd2, %rs1;
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
; CHECK-NEXT: ret;
%v1 = call i64 @llvm.smax.i64(i64 %a, i64 -128)
%v2 = call i64 @llvm.smin.i64(i64 %v1, i64 127)
ret i64 %v2
}

define i32 @trunc_ssat_i32_s8(i32 %a) {
; CHECK-LABEL: trunc_ssat_i32_s8(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<2>;
; CHECK-NEXT: .reg .b32 %r<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b32 %r1, [trunc_ssat_i32_s8_param_0];
; CHECK-NEXT: cvt.sat.s8.s32 %rs1, %r1;
; CHECK-NEXT: cvt.s32.s16 %r2, %rs1;
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
; CHECK-NEXT: ret;
%v1 = call i32 @llvm.smax.i32(i32 %a, i32 -128)
%v2 = call i32 @llvm.smin.i32(i32 %v1, i32 127)
ret i32 %v2
}

define i16 @trunc_ssat_i16_s8(i16 %a) {
; CHECK-LABEL: trunc_ssat_i16_s8(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<3>;
; CHECK-NEXT: .reg .b32 %r<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b16 %rs1, [trunc_ssat_i16_s8_param_0];
; CHECK-NEXT: cvt.sat.s8.s16 %rs2, %rs1;
; CHECK-NEXT: cvt.u32.u16 %r1, %rs2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
; CHECK-NEXT: ret;
%v1 = call i16 @llvm.smax.i16(i16 %a, i16 -128)
%v2 = call i16 @llvm.smin.i16(i16 %v1, i16 127)
ret i16 %v2
}

Loading