Skip to content

[NVPTX] Rework and cleanup FTZ ISel #146410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 74 additions & 33 deletions llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,73 +154,114 @@ void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O,
llvm_unreachable("Invalid conversion modifier");
}

void NVPTXInstPrinter::printFTZFlag(const MCInst *MI, int OpNum,
raw_ostream &O) {
const MCOperand &MO = MI->getOperand(OpNum);
const int Imm = MO.getImm();
if (Imm)
O << ".ftz";
}

void NVPTXInstPrinter::printCmpMode(const MCInst *MI, int OpNum, raw_ostream &O,
StringRef Modifier) {
const MCOperand &MO = MI->getOperand(OpNum);
int64_t Imm = MO.getImm();

if (Modifier == "ftz") {
// FTZ flag
if (Imm & NVPTX::PTXCmpMode::FTZ_FLAG)
O << ".ftz";
return;
} else if (Modifier == "base") {
switch (Imm & NVPTX::PTXCmpMode::BASE_MASK) {
if (Modifier == "FCmp") {
switch (Imm) {
default:
return;
case NVPTX::PTXCmpMode::EQ:
O << ".eq";
O << "eq";
return;
case NVPTX::PTXCmpMode::NE:
O << ".ne";
O << "ne";
return;
case NVPTX::PTXCmpMode::LT:
O << ".lt";
O << "lt";
return;
case NVPTX::PTXCmpMode::LE:
O << ".le";
O << "le";
return;
case NVPTX::PTXCmpMode::GT:
O << ".gt";
O << "gt";
return;
case NVPTX::PTXCmpMode::GE:
O << ".ge";
return;
case NVPTX::PTXCmpMode::LO:
O << ".lo";
return;
case NVPTX::PTXCmpMode::LS:
O << ".ls";
return;
case NVPTX::PTXCmpMode::HI:
O << ".hi";
return;
case NVPTX::PTXCmpMode::HS:
O << ".hs";
O << "ge";
return;
case NVPTX::PTXCmpMode::EQU:
O << ".equ";
O << "equ";
return;
case NVPTX::PTXCmpMode::NEU:
O << ".neu";
O << "neu";
return;
case NVPTX::PTXCmpMode::LTU:
O << ".ltu";
O << "ltu";
return;
case NVPTX::PTXCmpMode::LEU:
O << ".leu";
O << "leu";
return;
case NVPTX::PTXCmpMode::GTU:
O << ".gtu";
O << "gtu";
return;
case NVPTX::PTXCmpMode::GEU:
O << ".geu";
O << "geu";
return;
case NVPTX::PTXCmpMode::NUM:
O << ".num";
O << "num";
return;
case NVPTX::PTXCmpMode::NotANumber:
O << ".nan";
O << "nan";
return;
}
}
if (Modifier == "ICmp") {
switch (Imm) {
default:
llvm_unreachable("Invalid ICmp mode");
case NVPTX::PTXCmpMode::EQ:
O << "eq";
return;
case NVPTX::PTXCmpMode::NE:
O << "ne";
return;
case NVPTX::PTXCmpMode::LT:
case NVPTX::PTXCmpMode::LTU:
O << "lt";
return;
case NVPTX::PTXCmpMode::LE:
case NVPTX::PTXCmpMode::LEU:
O << "le";
return;
case NVPTX::PTXCmpMode::GT:
case NVPTX::PTXCmpMode::GTU:
O << "gt";
return;
case NVPTX::PTXCmpMode::GE:
case NVPTX::PTXCmpMode::GEU:
O << "ge";
return;
}
}
if (Modifier == "IType") {
switch (Imm) {
default:
llvm_unreachable("Invalid IType");
case NVPTX::PTXCmpMode::EQ:
case NVPTX::PTXCmpMode::NE:
O << "b";
return;
case NVPTX::PTXCmpMode::LT:
case NVPTX::PTXCmpMode::LE:
case NVPTX::PTXCmpMode::GT:
case NVPTX::PTXCmpMode::GE:
O << "s";
return;
case NVPTX::PTXCmpMode::LTU:
case NVPTX::PTXCmpMode::LEU:
case NVPTX::PTXCmpMode::GTU:
case NVPTX::PTXCmpMode::GEU:
O << "u";
return;
}
}
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class NVPTXInstPrinter : public MCInstPrinter {
void printCTAGroup(const MCInst *MI, int OpNum, raw_ostream &O);
void printCallOperand(const MCInst *MI, int OpNum, raw_ostream &O,
StringRef Modifier = {});
void printFTZFlag(const MCInst *MI, int OpNum, raw_ostream &O);
};

}
Expand Down
37 changes: 15 additions & 22 deletions llvm/lib/Target/NVPTX/NVPTX.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
#ifndef LLVM_LIB_TARGET_NVPTX_NVPTX_H
#define LLVM_LIB_TARGET_NVPTX_NVPTX_H

#include "llvm/CodeGen/ISDOpcodes.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Pass.h"
#include "llvm/Support/AtomicOrdering.h"
#include "llvm/Support/CodeGen.h"
#include "llvm/Target/TargetMachine.h"

namespace llvm {
class FunctionPass;
class MachineFunctionPass;
Expand Down Expand Up @@ -218,28 +218,21 @@ enum CvtMode {
/// PTXCmpMode - Comparison mode enumeration
namespace PTXCmpMode {
enum CmpMode {
EQ = 0,
NE,
LT,
LE,
GT,
GE,
LO,
LS,
HI,
HS,
EQU,
NEU,
LTU,
LEU,
GTU,
GEU,
NUM,
EQ = ISD::SETEQ,
NE = ISD::SETNE,
LT = ISD::SETLT,
LE = ISD::SETLE,
GT = ISD::SETGT,
GE = ISD::SETGE,
EQU = ISD::SETUEQ,
NEU = ISD::SETUNE,
LTU = ISD::SETULT,
LEU = ISD::SETULE,
GTU = ISD::SETUGT,
GEU = ISD::SETUGE,
NUM = ISD::SETO,
// NAN is a MACRO
NotANumber,

BASE_MASK = 0xFF,
FTZ_FLAG = 0x100
NotANumber = ISD::SETUO,
};
}

Expand Down
46 changes: 18 additions & 28 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,23 +363,29 @@ bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) {

// Map ISD:CONDCODE value to appropriate CmpMode expected by
// NVPTXInstPrinter::printCmpMode()
static unsigned getPTXCmpMode(const CondCodeSDNode &CondCode, bool FTZ) {
SDValue NVPTXDAGToDAGISel::getPTXCmpMode(const CondCodeSDNode &CondCode) {
using NVPTX::PTXCmpMode::CmpMode;
unsigned PTXCmpMode = [](ISD::CondCode CC) {
const unsigned PTXCmpMode = [](ISD::CondCode CC) {
switch (CC) {
default:
llvm_unreachable("Unexpected condition code.");
case ISD::SETOEQ:
case ISD::SETEQ:
return CmpMode::EQ;
case ISD::SETOGT:
case ISD::SETGT:
return CmpMode::GT;
case ISD::SETOGE:
case ISD::SETGE:
return CmpMode::GE;
case ISD::SETOLT:
case ISD::SETLT:
return CmpMode::LT;
case ISD::SETOLE:
case ISD::SETLE:
return CmpMode::LE;
case ISD::SETONE:
case ISD::SETNE:
return CmpMode::NE;
case ISD::SETO:
return CmpMode::NUM;
Expand All @@ -397,45 +403,29 @@ static unsigned getPTXCmpMode(const CondCodeSDNode &CondCode, bool FTZ) {
return CmpMode::LEU;
case ISD::SETUNE:
return CmpMode::NEU;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given how you've defined CmpMode in terms of SDNode opcodes, couldn't you just return the opcode for these cases?

case ISD::SETEQ:
return CmpMode::EQ;
case ISD::SETGT:
return CmpMode::GT;
case ISD::SETGE:
return CmpMode::GE;
case ISD::SETLT:
return CmpMode::LT;
case ISD::SETLE:
return CmpMode::LE;
case ISD::SETNE:
return CmpMode::NE;
}
}(CondCode.get());

if (FTZ)
PTXCmpMode |= NVPTX::PTXCmpMode::FTZ_FLAG;

return PTXCmpMode;
return CurDAG->getTargetConstant(PTXCmpMode, SDLoc(), MVT::i32);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return CurDAG->getTargetConstant(PTXCmpMode, SDLoc(), MVT::i32);
return CurDAG->getTargetConstant(PTXCmpMode, SDLoc(&CondCode), MVT::i32);

?

}

bool NVPTXDAGToDAGISel::SelectSETP_F16X2(SDNode *N) {
unsigned PTXCmpMode =
getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)), useF32FTZ());
SDValue PTXCmpMode = getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)));
SDLoc DL(N);
SDNode *SetP = CurDAG->getMachineNode(
NVPTX::SETP_f16x2rr, DL, MVT::i1, MVT::i1, N->getOperand(0),
N->getOperand(1), CurDAG->getTargetConstant(PTXCmpMode, DL, MVT::i32));
NVPTX::SETP_f16x2rr, DL, MVT::i1, MVT::i1,
{N->getOperand(0), N->getOperand(1), PTXCmpMode,
CurDAG->getTargetConstant(useF32FTZ() ? 1 : 0, DL, MVT::i1)});
ReplaceNode(N, SetP);
return true;
}

bool NVPTXDAGToDAGISel::SelectSETP_BF16X2(SDNode *N) {
unsigned PTXCmpMode =
getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)), useF32FTZ());
SDValue PTXCmpMode = getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)));
SDLoc DL(N);
SDNode *SetP = CurDAG->getMachineNode(
NVPTX::SETP_bf16x2rr, DL, MVT::i1, MVT::i1, N->getOperand(0),
N->getOperand(1), CurDAG->getTargetConstant(PTXCmpMode, DL, MVT::i32));
NVPTX::SETP_bf16x2rr, DL, MVT::i1, MVT::i1,
{N->getOperand(0), N->getOperand(1), PTXCmpMode,
CurDAG->getTargetConstant(useF32FTZ() ? 1 : 0, DL, MVT::i1)});
ReplaceNode(N, SetP);
return true;
}
Expand Down Expand Up @@ -1953,7 +1943,7 @@ bool NVPTXDAGToDAGISel::tryBF16ArithToFMA(SDNode *N) {
llvm_unreachable("Unexpected opcode");
};

int Opcode = IsVec ? NVPTX::BFMA16x2rrr : NVPTX::BFMA16rrr;
int Opcode = IsVec ? NVPTX::FMA_BF16x2rrr : NVPTX::FMA_BF16rrr;
MachineSDNode *FMA = CurDAG->getMachineNode(Opcode, DL, VT, Operands);
ReplaceNode(N, FMA);
return true;
Expand Down
3 changes: 1 addition & 2 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,11 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
}

bool SelectADDR(SDValue Addr, SDValue &Base, SDValue &Offset);
SDValue getPTXCmpMode(const CondCodeSDNode &CondCode);
SDValue selectPossiblyImm(SDValue V);

bool ChkMemSDNodeAddressSpace(SDNode *N, unsigned int spN) const;

static unsigned GetConvertOpcode(MVT DestTy, MVT SrcTy, LoadSDNode *N);

// Returns the Memory Order and Scope that the PTX memory instruction should
// use, and inserts appropriate fence instruction before the memory
// instruction, if needed to implement the instructions memory order. Required
Expand Down
Loading
Loading