-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[NVPTX] Rework and cleanup FTZ ISel #146410
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[NVPTX] Rework and cleanup FTZ ISel #146410
Conversation
@llvm/pr-subscribers-backend-nvptx Author: Alex MacLean (AlexMaclean) ChangesThis change cleans up DAG-to-DAG instruction selection around FTZ and SETP comparison mode. Largely this is non-functional though support for Patch is 346.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146410.diff 35 Files Affected:
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index 28f6968ee6caf..4ba3d0f1eaccd 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -154,73 +154,114 @@ void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O,
llvm_unreachable("Invalid conversion modifier");
}
+void NVPTXInstPrinter::printFTZFlag(const MCInst *MI, int OpNum, raw_ostream &O) {
+ const MCOperand &MO = MI->getOperand(OpNum);
+ const int Imm = MO.getImm();
+ if (Imm)
+ O << ".ftz";
+}
+
void NVPTXInstPrinter::printCmpMode(const MCInst *MI, int OpNum, raw_ostream &O,
StringRef Modifier) {
const MCOperand &MO = MI->getOperand(OpNum);
int64_t Imm = MO.getImm();
- if (Modifier == "ftz") {
- // FTZ flag
- if (Imm & NVPTX::PTXCmpMode::FTZ_FLAG)
- O << ".ftz";
- return;
- } else if (Modifier == "base") {
- switch (Imm & NVPTX::PTXCmpMode::BASE_MASK) {
+ if (Modifier == "FCmp") {
+ switch (Imm) {
default:
return;
case NVPTX::PTXCmpMode::EQ:
- O << ".eq";
+ O << "eq";
return;
case NVPTX::PTXCmpMode::NE:
- O << ".ne";
+ O << "ne";
return;
case NVPTX::PTXCmpMode::LT:
- O << ".lt";
+ O << "lt";
return;
case NVPTX::PTXCmpMode::LE:
- O << ".le";
+ O << "le";
return;
case NVPTX::PTXCmpMode::GT:
- O << ".gt";
+ O << "gt";
return;
case NVPTX::PTXCmpMode::GE:
- O << ".ge";
- return;
- case NVPTX::PTXCmpMode::LO:
- O << ".lo";
- return;
- case NVPTX::PTXCmpMode::LS:
- O << ".ls";
- return;
- case NVPTX::PTXCmpMode::HI:
- O << ".hi";
- return;
- case NVPTX::PTXCmpMode::HS:
- O << ".hs";
+ O << "ge";
return;
case NVPTX::PTXCmpMode::EQU:
- O << ".equ";
+ O << "equ";
return;
case NVPTX::PTXCmpMode::NEU:
- O << ".neu";
+ O << "neu";
return;
case NVPTX::PTXCmpMode::LTU:
- O << ".ltu";
+ O << "ltu";
return;
case NVPTX::PTXCmpMode::LEU:
- O << ".leu";
+ O << "leu";
return;
case NVPTX::PTXCmpMode::GTU:
- O << ".gtu";
+ O << "gtu";
return;
case NVPTX::PTXCmpMode::GEU:
- O << ".geu";
+ O << "geu";
return;
case NVPTX::PTXCmpMode::NUM:
- O << ".num";
+ O << "num";
return;
case NVPTX::PTXCmpMode::NotANumber:
- O << ".nan";
+ O << "nan";
+ return;
+ }
+ }
+ if (Modifier == "ICmp") {
+ switch (Imm) {
+ default:
+ llvm_unreachable("Invalid ICmp mode");
+ case NVPTX::PTXCmpMode::EQ:
+ O << "eq";
+ return;
+ case NVPTX::PTXCmpMode::NE:
+ O << "ne";
+ return;
+ case NVPTX::PTXCmpMode::LT:
+ case NVPTX::PTXCmpMode::LTU:
+ O << "lt";
+ return;
+ case NVPTX::PTXCmpMode::LE:
+ case NVPTX::PTXCmpMode::LEU:
+ O << "le";
+ return;
+ case NVPTX::PTXCmpMode::GT:
+ case NVPTX::PTXCmpMode::GTU:
+ O << "gt";
+ return;
+ case NVPTX::PTXCmpMode::GE:
+ case NVPTX::PTXCmpMode::GEU:
+ O << "ge";
+ return;
+
+ }
+ }
+ if (Modifier == "IType") {
+ switch (Imm) {
+ default:
+ llvm_unreachable("Invalid ICmp mode");
+ case NVPTX::PTXCmpMode::EQ:
+ case NVPTX::PTXCmpMode::NE:
+ O << "b";
+ return;
+ case NVPTX::PTXCmpMode::LT:
+ case NVPTX::PTXCmpMode::LE:
+ case NVPTX::PTXCmpMode::GT:
+ case NVPTX::PTXCmpMode::GE:
+ O << "s";
+ return;
+ case NVPTX::PTXCmpMode::LTU:
+ case NVPTX::PTXCmpMode::LEU:
+ case NVPTX::PTXCmpMode::GTU:
+ case NVPTX::PTXCmpMode::GEU:
+ O << "u";
return;
}
}
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index 6189284e8a58c..193c436939f66 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -54,6 +54,7 @@ class NVPTXInstPrinter : public MCInstPrinter {
void printCTAGroup(const MCInst *MI, int OpNum, raw_ostream &O);
void printCallOperand(const MCInst *MI, int OpNum, raw_ostream &O,
StringRef Modifier = {});
+ void printFTZFlag(const MCInst *MI, int OpNum, raw_ostream &O);
};
}
diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index b7fd7090299a9..cfe2c8e7d2ed1 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -19,7 +19,7 @@
#include "llvm/Support/AtomicOrdering.h"
#include "llvm/Support/CodeGen.h"
#include "llvm/Target/TargetMachine.h"
-
+#include "llvm/CodeGen/ISDOpcodes.h"
namespace llvm {
class FunctionPass;
class MachineFunctionPass;
@@ -218,28 +218,21 @@ enum CvtMode {
/// PTXCmpMode - Comparison mode enumeration
namespace PTXCmpMode {
enum CmpMode {
- EQ = 0,
- NE,
- LT,
- LE,
- GT,
- GE,
- LO,
- LS,
- HI,
- HS,
- EQU,
- NEU,
- LTU,
- LEU,
- GTU,
- GEU,
- NUM,
+ EQ = ISD::SETEQ,
+ NE = ISD::SETNE,
+ LT = ISD::SETLT,
+ LE = ISD::SETLE,
+ GT = ISD::SETGT,
+ GE = ISD::SETGE,
+ EQU = ISD::SETUEQ,
+ NEU = ISD::SETUNE,
+ LTU = ISD::SETULT,
+ LEU = ISD::SETULE,
+ GTU = ISD::SETUGT,
+ GEU = ISD::SETUGE,
+ NUM = ISD::SETO,
// NAN is a MACRO
- NotANumber,
-
- BASE_MASK = 0xFF,
- FTZ_FLAG = 0x100
+ NotANumber = ISD::SETUO,
};
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 5631342ecc13e..75461a4b6213a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -363,23 +363,29 @@ bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) {
// Map ISD:CONDCODE value to appropriate CmpMode expected by
// NVPTXInstPrinter::printCmpMode()
-static unsigned getPTXCmpMode(const CondCodeSDNode &CondCode, bool FTZ) {
+SDValue NVPTXDAGToDAGISel::getPTXCmpMode(const CondCodeSDNode &CondCode) {
using NVPTX::PTXCmpMode::CmpMode;
- unsigned PTXCmpMode = [](ISD::CondCode CC) {
+ const unsigned PTXCmpMode = [](ISD::CondCode CC) {
switch (CC) {
default:
llvm_unreachable("Unexpected condition code.");
case ISD::SETOEQ:
+ case ISD::SETEQ:
return CmpMode::EQ;
case ISD::SETOGT:
+ case ISD::SETGT:
return CmpMode::GT;
case ISD::SETOGE:
+ case ISD::SETGE:
return CmpMode::GE;
case ISD::SETOLT:
+ case ISD::SETLT:
return CmpMode::LT;
case ISD::SETOLE:
+ case ISD::SETLE:
return CmpMode::LE;
case ISD::SETONE:
+ case ISD::SETNE:
return CmpMode::NE;
case ISD::SETO:
return CmpMode::NUM;
@@ -397,45 +403,29 @@ static unsigned getPTXCmpMode(const CondCodeSDNode &CondCode, bool FTZ) {
return CmpMode::LEU;
case ISD::SETUNE:
return CmpMode::NEU;
- case ISD::SETEQ:
- return CmpMode::EQ;
- case ISD::SETGT:
- return CmpMode::GT;
- case ISD::SETGE:
- return CmpMode::GE;
- case ISD::SETLT:
- return CmpMode::LT;
- case ISD::SETLE:
- return CmpMode::LE;
- case ISD::SETNE:
- return CmpMode::NE;
}
}(CondCode.get());
-
- if (FTZ)
- PTXCmpMode |= NVPTX::PTXCmpMode::FTZ_FLAG;
-
- return PTXCmpMode;
+ return CurDAG->getTargetConstant(PTXCmpMode, SDLoc(), MVT::i32);
}
bool NVPTXDAGToDAGISel::SelectSETP_F16X2(SDNode *N) {
- unsigned PTXCmpMode =
- getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)), useF32FTZ());
+ SDValue PTXCmpMode =
+ getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)));
SDLoc DL(N);
SDNode *SetP = CurDAG->getMachineNode(
- NVPTX::SETP_f16x2rr, DL, MVT::i1, MVT::i1, N->getOperand(0),
- N->getOperand(1), CurDAG->getTargetConstant(PTXCmpMode, DL, MVT::i32));
+ NVPTX::SETP_f16x2rr, DL, MVT::i1, MVT::i1, {N->getOperand(0),
+ N->getOperand(1), PTXCmpMode, CurDAG->getTargetConstant(useF32FTZ() ? 1 : 0, DL, MVT::i1)});
ReplaceNode(N, SetP);
return true;
}
bool NVPTXDAGToDAGISel::SelectSETP_BF16X2(SDNode *N) {
- unsigned PTXCmpMode =
- getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)), useF32FTZ());
+ SDValue PTXCmpMode =
+ getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)));
SDLoc DL(N);
SDNode *SetP = CurDAG->getMachineNode(
- NVPTX::SETP_bf16x2rr, DL, MVT::i1, MVT::i1, N->getOperand(0),
- N->getOperand(1), CurDAG->getTargetConstant(PTXCmpMode, DL, MVT::i32));
+ NVPTX::SETP_bf16x2rr, DL, MVT::i1, MVT::i1, {N->getOperand(0),
+ N->getOperand(1), PTXCmpMode, CurDAG->getTargetConstant(useF32FTZ() ? 1 : 0, DL, MVT::i1)});
ReplaceNode(N, SetP);
return true;
}
@@ -1953,7 +1943,7 @@ bool NVPTXDAGToDAGISel::tryBF16ArithToFMA(SDNode *N) {
llvm_unreachable("Unexpected opcode");
};
- int Opcode = IsVec ? NVPTX::BFMA16x2rrr : NVPTX::BFMA16rrr;
+ int Opcode = IsVec ? NVPTX::FMA_BF16x2rrr : NVPTX::FMA_BF16rrr;
MachineSDNode *FMA = CurDAG->getMachineNode(Opcode, DL, VT, Operands);
ReplaceNode(N, FMA);
return true;
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index 0e4dec1adca67..b314c4ccefe8b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -104,12 +104,11 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
}
bool SelectADDR(SDValue Addr, SDValue &Base, SDValue &Offset);
+ SDValue getPTXCmpMode(const CondCodeSDNode &CondCode);
SDValue selectPossiblyImm(SDValue V);
bool ChkMemSDNodeAddressSpace(SDNode *N, unsigned int spN) const;
- static unsigned GetConvertOpcode(MVT DestTy, MVT SrcTy, LoadSDNode *N);
-
// Returns the Memory Order and Scope that the PTX memory instruction should
// use, and inserts appropriate fence instruction before the memory
// instruction, if needed to implement the instructions memory order. Required
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 1a2515b7f66f3..9ef9ce3b7bb8d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -68,48 +68,28 @@ def CvtMode : Operand<i32> {
let PrintMethod = "printCvtMode";
}
+// FTZ flag
+
+def FTZ : PatLeaf<(i1 1)>;
+def NoFTZ : PatLeaf<(i1 0)>;
+
+def getFTZFlag : SDNodeXForm<imm, [{
+ (void)N;
+ return CurDAG->getTargetConstant(useF32FTZ() ? 1 : 0, SDLoc(), MVT::i1);
+}]>;
+
+def FTZFlag : OperandWithDefaultOps<i1, (ops (getFTZFlag (i1 0)))> {
+ let PrintMethod = "printFTZFlag";
+}
+
// Compare modes
// These must match the enum in NVPTX.h
-def CmpEQ : PatLeaf<(i32 0)>;
-def CmpNE : PatLeaf<(i32 1)>;
-def CmpLT : PatLeaf<(i32 2)>;
-def CmpLE : PatLeaf<(i32 3)>;
-def CmpGT : PatLeaf<(i32 4)>;
-def CmpGE : PatLeaf<(i32 5)>;
-def CmpLO : PatLeaf<(i32 6)>;
-def CmpLS : PatLeaf<(i32 7)>;
-def CmpHI : PatLeaf<(i32 8)>;
-def CmpHS : PatLeaf<(i32 9)>;
-def CmpEQU : PatLeaf<(i32 10)>;
-def CmpNEU : PatLeaf<(i32 11)>;
-def CmpLTU : PatLeaf<(i32 12)>;
-def CmpLEU : PatLeaf<(i32 13)>;
-def CmpGTU : PatLeaf<(i32 14)>;
-def CmpGEU : PatLeaf<(i32 15)>;
-def CmpNUM : PatLeaf<(i32 16)>;
-def CmpNAN : PatLeaf<(i32 17)>;
-
-def CmpEQ_FTZ : PatLeaf<(i32 0x100)>;
-def CmpNE_FTZ : PatLeaf<(i32 0x101)>;
-def CmpLT_FTZ : PatLeaf<(i32 0x102)>;
-def CmpLE_FTZ : PatLeaf<(i32 0x103)>;
-def CmpGT_FTZ : PatLeaf<(i32 0x104)>;
-def CmpGE_FTZ : PatLeaf<(i32 0x105)>;
-def CmpEQU_FTZ : PatLeaf<(i32 0x10A)>;
-def CmpNEU_FTZ : PatLeaf<(i32 0x10B)>;
-def CmpLTU_FTZ : PatLeaf<(i32 0x10C)>;
-def CmpLEU_FTZ : PatLeaf<(i32 0x10D)>;
-def CmpGTU_FTZ : PatLeaf<(i32 0x10E)>;
-def CmpGEU_FTZ : PatLeaf<(i32 0x10F)>;
-def CmpNUM_FTZ : PatLeaf<(i32 0x110)>;
-def CmpNAN_FTZ : PatLeaf<(i32 0x111)>;
+def CmpEQ : PatLeaf<(i32 17)>;
+def CmpNE : PatLeaf<(i32 22)>;
def CmpMode : Operand<i32> {
let PrintMethod = "printCmpMode";
}
-def VecElement : Operand<i32> {
- let PrintMethod = "printVecElement";
-}
// PRMT modes
// These must match the enum in NVPTX.h
@@ -152,8 +132,6 @@ def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
def hasTcgen05Instructions : Predicate<"Subtarget->hasTcgen05Instructions()">;
-def True : Predicate<"true">;
-
class hasPTX<int version>: Predicate<"Subtarget->getPTXVersion() >= " # version>;
class hasSM<int version>: Predicate<"Subtarget->getSmVersion() >= " # version>;
@@ -198,7 +176,7 @@ def RI64 : Operand<Any>;
// Utility class to wrap up information about a register and DAG type for more
// convenient iteration and parameterization
-class RegTyInfo<ValueType ty, NVPTXRegClass rc, Operand imm, SDNode imm_node,
+class RegTyInfo<ValueType ty, NVPTXRegClass rc, string str, Operand imm, SDNode imm_node,
bit supports_imm = 1> {
ValueType Ty = ty;
NVPTXRegClass RC = rc;
@@ -206,20 +184,21 @@ class RegTyInfo<ValueType ty, NVPTXRegClass rc, Operand imm, SDNode imm_node,
SDNode ImmNode = imm_node;
bit SupportsImm = supports_imm;
int Size = ty.Size;
+ string Str = str;
}
-def I1RT : RegTyInfo<i1, B1, i1imm, imm>;
-def I16RT : RegTyInfo<i16, B16, i16imm, imm>;
-def I32RT : RegTyInfo<i32, B32, i32imm, imm>;
-def I64RT : RegTyInfo<i64, B64, i64imm, imm>;
+def I1RT : RegTyInfo<i1, B1, "pred", i1imm, imm>;
+def I16RT : RegTyInfo<i16, B16, "b16", i16imm, imm>;
+def I32RT : RegTyInfo<i32, B32, "b32", i32imm, imm>;
+def I64RT : RegTyInfo<i64, B64, "b64", i64imm, imm>;
-def F32RT : RegTyInfo<f32, B32, f32imm, fpimm>;
-def F64RT : RegTyInfo<f64, B64, f64imm, fpimm>;
-def F16RT : RegTyInfo<f16, B16, f16imm, fpimm, supports_imm = 0>;
-def BF16RT : RegTyInfo<bf16, B16, bf16imm, fpimm, supports_imm = 0>;
+def F32RT : RegTyInfo<f32, B32, "f32", f32imm, fpimm>;
+def F64RT : RegTyInfo<f64, B64, "f64", f64imm, fpimm>;
+def F16RT : RegTyInfo<f16, B16, "f16", f16imm, fpimm, supports_imm = 0>;
+def BF16RT : RegTyInfo<bf16, B16, "bf16", bf16imm, fpimm, supports_imm = 0>;
-def F16X2RT : RegTyInfo<v2f16, B32, ?, ?, supports_imm = 0>;
-def BF16X2RT : RegTyInfo<v2bf16, B32, ?, ?, supports_imm = 0>;
+def F16X2RT : RegTyInfo<v2f16, B32, "f16x2", ?, ?, supports_imm = 0>;
+def BF16X2RT : RegTyInfo<v2bf16, B32, "bf16x2", ?, ?, supports_imm = 0>;
// This class provides a basic wrapper around an NVPTXInst that abstracts the
@@ -321,76 +300,57 @@ multiclass ADD_SUB_INT_CARRY<string op_str, SDNode op_node, bit commutative> {
// Also defines ftz (flush subnormal inputs and results to sign-preserving
// zero) variants for fp32 functions.
multiclass FMINIMUMMAXIMUM<string OpcStr, bit NaN, SDNode OpNode> {
+ defvar nan_str = !if(NaN, ".NaN", "");
if !not(NaN) then {
- def f64rr :
+ def _f64_rr :
BasicNVPTXInst<(outs B64:$dst),
(ins B64:$a, B64:$b),
OpcStr # ".f64",
[(set f64:$dst, (OpNode f64:$a, f64:$b))]>;
- def f64ri :
+ def _f64_ri :
BasicNVPTXInst<(outs B64:$dst),
(ins B64:$a, f64imm:$b),
OpcStr # ".f64",
[(set f64:$dst, (OpNode f64:$a, fpimm:$b))]>;
}
- def f32rr_ftz :
- BasicNVPTXInst<(outs B32:$dst),
+ def _f32_rr :
+ BasicFlagsNVPTXInst<(outs B32:$dst),
(ins B32:$a, B32:$b),
- OpcStr # ".ftz.f32",
- [(set f32:$dst, (OpNode f32:$a, f32:$b))]>,
- Requires<[doF32FTZ]>;
- def f32ri_ftz :
- BasicNVPTXInst<(outs B32:$dst),
- (ins B32:$a, f32imm:$b),
- OpcStr # ".ftz.f32",
- [(set f32:$dst, (OpNode f32:$a, fpimm:$b))]>,
- Requires<[doF32FTZ]>;
- def f32rr :
- BasicNVPTXInst<(outs B32:$dst),
- (ins B32:$a, B32:$b),
- OpcStr # ".f32",
+ (ins FTZFlag:$ftz),
+ OpcStr # "$ftz" # nan_str # ".f32",
[(set f32:$dst, (OpNode f32:$a, f32:$b))]>;
- def f32ri :
- BasicNVPTXInst<(outs B32:$dst),
+ def _f32_ri :
+ BasicFlagsNVPTXInst<(outs B32:$dst),
(ins B32:$a, f32imm:$b),
- OpcStr # ".f32",
+ (ins FTZFlag:$ftz),
+ OpcStr # "$ftz" # nan_str # ".f32",
[(set f32:$dst, (OpNode f32:$a, fpimm:$b))]>;
- def f16rr_ftz :
- BasicNVPTXInst<(outs B16:$dst),
- (ins B16:$a, B16:$b),
- OpcStr # ".ftz.f16",
- [(set f16:$dst, (OpNode f16:$a, f16:$b))]>,
- Requires<[useFP16Math, doF32FTZ]>;
- def f16rr :
- BasicNVPTXInst<(outs B16:$dst),
+ def _f16_rr :
+ BasicFlagsNVPTXInst<(outs B16:$dst),
(ins B16:$a, B16:$b),
- OpcStr # ".f16",
+ (ins FTZFlag:$ftz),
+ OpcStr # "$ftz" # nan_str # ".f16",
[(set f16:$dst, (OpNode f16:$a, f16:$b))]>,
- Requires<[useFP16Math, hasSM<80>, hasPTX<70>]>;
+ Requires<[useFP16Math]>;
- def f16x2rr_ftz :
- BasicNVPTXInst<(outs B32:$dst),
- (ins B32:$a, B32:$b),
- OpcStr # ".ftz.f16x2",
- [(set v2f16:$dst, (OpNode v2f16:$a, v2f16:$b))]>,
- Requires<[useFP16Math, hasSM<80>, hasPTX<70>, doF32FTZ]>;
- def f16x2rr :
- BasicNVPTXInst<(outs B32:$dst),
+ def _f16x2_rr :
+ BasicFlagsNVPTXInst<(outs B32:$dst),
(ins B32:$a, B32:$b),
- OpcStr # ".f16x2",
+ (ins FTZFlag:$ftz),
+ OpcStr # "$ftz" # nan_str # ".f16x2",
[(set v2f16:$dst, (OpNode v2f16:$a, v2f16:$b))]>,
Requires<[useFP16Math, hasSM<80>, hasPTX<70>]>;
- def bf16rr :
+ def _bf16_rr :
BasicNVPTXInst<(outs B16:$dst),
(ins B16:$a, B16:$b),
- OpcStr # ".bf16",
+ OpcStr # nan_str # ".bf16",
[(set bf16:$dst, (OpNode bf16:$a, bf16:$b))]>,
Requires<[hasBF16Math, hasSM<80>, hasPTX<70>]>;
- def bf16x2rr :
+ def _bf16x2_rr :
BasicNVPTXInst<(outs B32:$dst),
(ins B32:$a, B32:$b),
- OpcStr # ".bf16x2",
+ OpcStr # nan_str # ".bf16x2",
[(set v2bf16:$dst, (OpNode v2bf16:$a, v2bf16:$b))]>,
Requires<[hasBF16Math, hasSM<80>, hasPTX<70>]>;
}
@@ -415,52 +375,31 @@ multiclass F3<string op_str, SDPatternOperator op_pat> {
(ins B64:$a, f64imm:$b),
op_str # ".f64",
[(set f64:$dst, (op_pat f64:$a, fpimm:$b))]>;
- def f32rr_ftz :
- BasicNVPTXInst<(outs B32:$dst),
- (ins B32:$a, B32:$b),
- op_str # ".ftz.f32",
- [(set f32:$dst, (op_pat f32:$a, f32:$b))]>,
- Requires<[doF32FTZ]>;
- def f32ri_ftz :
- BasicNVPTXInst<(outs B32:$dst),
- (ins B32:$a, f32imm:$b),
- op_str # ".ftz.f32",
- [(set f32:$dst, (op_pat f32:$a, fpimm:$b))]>,
- Requires<[doF32FTZ]>;
def f32rr :
- BasicNVPTXInst<(outs B32:$dst),
+ BasicFlagsNVPTXInst<(outs B32:$dst),
(ins B32:$a, B32:$b),
- op_str # ".f32",
+ (ins FTZFlag:$ftz),
+ op_str # "$ftz.f32",
[(set f32:$dst, (op_pat f32:$a, f32:$b))]>;
def f32ri :
- BasicNVPTXInst<(outs B32:$dst),
+ BasicFlagsNVPTXInst<(outs B32:$dst),
(ins B32:$a, f32imm:$b),
- op_str # ".f32",
+ (ins FTZFlag:$ftz),
+ op_str # "$ftz.f32",
[(set f32:$dst, (op_pat f32:$a, fpimm:$b))]>;
- def f16rr_ftz :
- BasicNVPTXInst<(outs B16:$dst),
- (ins B16:$a, B16:$b),
- op_str # ".ftz.f16",
- [(set f16:$dst, (op_pat f16:$a, f16:$b))]>,
- Requires<[useFP16Math, doF32FTZ]>;
def f16rr :
- BasicNVPTXInst<(outs B16:$dst),
+ BasicFlagsNVPTXInst<(outs B16:$dst),
(ins B16:$a, B16:$b),
- op_str # ".f16",
+ (ins FTZFlag:$ftz),
+ op_str # "$ftz.f16",
[(set f16:$dst, (op_pat f16:$a, f16:$b))]>,
Requires<[useFP16Math]>;
-
- def f16x2rr_ftz :
- BasicNVPTXInst<(outs B32:$dst),
- (...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
8080007
to
e161048
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PTXCmpMode |= NVPTX::PTXCmpMode::FTZ_FLAG; | ||
|
||
return PTXCmpMode; | ||
return CurDAG->getTargetConstant(PTXCmpMode, SDLoc(), MVT::i32); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return CurDAG->getTargetConstant(PTXCmpMode, SDLoc(), MVT::i32); | |
return CurDAG->getTargetConstant(PTXCmpMode, SDLoc(&CondCode), MVT::i32); |
?
@@ -397,45 +403,29 @@ static unsigned getPTXCmpMode(const CondCodeSDNode &CondCode, bool FTZ) { | |||
return CmpMode::LEU; | |||
case ISD::SETUNE: | |||
return CmpMode::NEU; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given how you've defined CmpMode
in terms of SDNode opcodes, couldn't you just return the opcode for these cases?
This change cleans up DAG-to-DAG instruction selection around FTZ and SETP comparison mode. Largely this is non-functional though support for
{sin.cos}.approx.ftz.f32
is added.