Skip to content

Commit 032900e

Browse files
authored
[DirectX] Add DXIL validation of llvm.loop metadata (#164292)
This pr adds the equivalent validation of `llvm.loop` metadata that is [done in DXC](https://github.com/microsoft/DirectXShaderCompiler/blob/8f21027f2ad5dcfa63a275cbd278691f2c8fad33/lib/DxilValidation/DxilValidation.cpp#L3010). This is done as follows: - Add `llvm.loop` to the metadata allow-list in `DXILTranslateMetadata` - Iterate through all `llvm.loop` metadata nodes and strip all incompatible ones - Raise an error for ill-formed nodes that are compatible with DXIL Resolves: #137387
1 parent b17f1fd commit 032900e

File tree

7 files changed

+443
-70
lines changed

7 files changed

+443
-70
lines changed

llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp

Lines changed: 156 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@ using namespace llvm;
3636
using namespace llvm::dxil;
3737

3838
namespace {
39-
/// A simple Wrapper DiagnosticInfo that generates Module-level diagnostic
40-
/// for TranslateMetadata pass
41-
class DiagnosticInfoTranslateMD : public DiagnosticInfo {
39+
40+
/// A simple wrapper of DiagnosticInfo that generates module-level diagnostic
41+
/// for the DXILValidateMetadata pass
42+
class DiagnosticInfoValidateMD : public DiagnosticInfo {
4243
private:
4344
const Twine &Msg;
4445
const Module &Mod;
@@ -47,16 +48,26 @@ class DiagnosticInfoTranslateMD : public DiagnosticInfo {
4748
/// \p M is the module for which the diagnostic is being emitted. \p Msg is
4849
/// the message to show. Note that this class does not copy this message, so
4950
/// this reference must be valid for the whole life time of the diagnostic.
50-
DiagnosticInfoTranslateMD(const Module &M,
51-
const Twine &Msg LLVM_LIFETIME_BOUND,
52-
DiagnosticSeverity Severity = DS_Error)
51+
DiagnosticInfoValidateMD(const Module &M,
52+
const Twine &Msg LLVM_LIFETIME_BOUND,
53+
DiagnosticSeverity Severity = DS_Error)
5354
: DiagnosticInfo(DK_Unsupported, Severity), Msg(Msg), Mod(M) {}
5455

5556
void print(DiagnosticPrinter &DP) const override {
5657
DP << Mod.getName() << ": " << Msg << '\n';
5758
}
5859
};
5960

61+
static void reportError(Module &M, Twine Message,
62+
DiagnosticSeverity Severity = DS_Error) {
63+
M.getContext().diagnose(DiagnosticInfoValidateMD(M, Message, Severity));
64+
}
65+
66+
static void reportLoopError(Module &M, Twine Message,
67+
DiagnosticSeverity Severity = DS_Error) {
68+
reportError(M, Twine("Invalid \"llvm.loop\" metadata: ") + Message, Severity);
69+
}
70+
6071
enum class EntryPropsTag {
6172
ShaderFlags = 0,
6273
GSState,
@@ -314,25 +325,122 @@ static void translateBranchMetadata(Module &M, Instruction *BBTerminatorInst) {
314325
BBTerminatorInst->setMetadata("hlsl.controlflow.hint", nullptr);
315326
}
316327

317-
static std::array<unsigned, 6> getCompatibleInstructionMDs(llvm::Module &M) {
328+
// Determines if the metadata node will be compatible with DXIL's loop metadata
329+
// representation.
330+
//
331+
// Reports an error for compatible metadata that is ill-formed.
332+
static bool isLoopMDCompatible(Module &M, Metadata *MD) {
333+
// DXIL only accepts the following loop hints:
334+
std::array<StringLiteral, 3> ValidHintNames = {"llvm.loop.unroll.count",
335+
"llvm.loop.unroll.disable",
336+
"llvm.loop.unroll.full"};
337+
338+
MDNode *HintMD = dyn_cast<MDNode>(MD);
339+
if (!HintMD || HintMD->getNumOperands() == 0)
340+
return false;
341+
342+
auto *HintStr = dyn_cast<MDString>(HintMD->getOperand(0));
343+
if (!HintStr)
344+
return false;
345+
346+
if (!llvm::is_contained(ValidHintNames, HintStr->getString()))
347+
return false;
348+
349+
auto ValidCountNode = [](MDNode *CountMD) -> bool {
350+
if (CountMD->getNumOperands() == 2)
351+
if (auto *Count = dyn_cast<ConstantAsMetadata>(CountMD->getOperand(1)))
352+
if (isa<ConstantInt>(Count->getValue()))
353+
return true;
354+
return false;
355+
};
356+
357+
if (HintStr->getString() == "llvm.loop.unroll.count") {
358+
if (!ValidCountNode(HintMD)) {
359+
reportLoopError(M, "\"llvm.loop.unroll.count\" must have 2 operands and "
360+
"the second must be a constant integer");
361+
return false;
362+
}
363+
} else if (HintMD->getNumOperands() != 1) {
364+
reportLoopError(
365+
M, "\"llvm.loop.unroll.disable\" and \"llvm.loop.unroll.full\" "
366+
"must be provided as a single operand");
367+
return false;
368+
}
369+
370+
return true;
371+
}
372+
373+
static void translateLoopMetadata(Module &M, Instruction *I, MDNode *BaseMD) {
374+
// A distinct node has the self-referential form: !0 = !{ !0, ... }
375+
auto IsDistinctNode = [](MDNode *Node) -> bool {
376+
return Node && Node->getNumOperands() != 0 && Node == Node->getOperand(0);
377+
};
378+
379+
// Set metadata to null to remove empty/ill-formed metadata from instruction
380+
if (BaseMD->getNumOperands() == 0 || !IsDistinctNode(BaseMD))
381+
return I->setMetadata("llvm.loop", nullptr);
382+
383+
// It is valid to have a chain of self-refential loop metadata nodes, as
384+
// below. We will collapse these into just one when we reconstruct the
385+
// metadata.
386+
//
387+
// Eg:
388+
// !0 = !{!0, !1}
389+
// !1 = !{!1, !2}
390+
// !2 = !{!"llvm.loop.unroll.disable"}
391+
//
392+
// So, traverse down a potential self-referential chain
393+
while (1 < BaseMD->getNumOperands() &&
394+
IsDistinctNode(dyn_cast<MDNode>(BaseMD->getOperand(1))))
395+
BaseMD = dyn_cast<MDNode>(BaseMD->getOperand(1));
396+
397+
// To reconstruct a distinct node we create a temporary node that we will
398+
// then update to create a self-reference.
399+
llvm::TempMDTuple TempNode = llvm::MDNode::getTemporary(M.getContext(), {});
400+
SmallVector<Metadata *> CompatibleOperands = {TempNode.get()};
401+
402+
// Iterate and reconstruct the metadata nodes that contains any hints,
403+
// stripping any unrecognized metadata.
404+
ArrayRef<MDOperand> Operands = BaseMD->operands();
405+
for (auto &Op : Operands.drop_front())
406+
if (isLoopMDCompatible(M, Op.get()))
407+
CompatibleOperands.push_back(Op.get());
408+
409+
if (2 < CompatibleOperands.size())
410+
reportLoopError(M, "Provided conflicting hints");
411+
412+
MDNode *CompatibleLoopMD = MDNode::get(M.getContext(), CompatibleOperands);
413+
TempNode->replaceAllUsesWith(CompatibleLoopMD);
414+
415+
I->setMetadata("llvm.loop", CompatibleLoopMD);
416+
}
417+
418+
using InstructionMDList = std::array<unsigned, 7>;
419+
420+
static InstructionMDList getCompatibleInstructionMDs(llvm::Module &M) {
318421
return {
319422
M.getMDKindID("dx.nonuniform"), M.getMDKindID("dx.controlflow.hints"),
320423
M.getMDKindID("dx.precise"), llvm::LLVMContext::MD_range,
321-
llvm::LLVMContext::MD_alias_scope, llvm::LLVMContext::MD_noalias};
424+
llvm::LLVMContext::MD_alias_scope, llvm::LLVMContext::MD_noalias,
425+
M.getMDKindID("llvm.loop")};
322426
}
323427

324428
static void translateInstructionMetadata(Module &M) {
325429
// construct allowlist of valid metadata node kinds
326-
std::array<unsigned, 6> DXILCompatibleMDs = getCompatibleInstructionMDs(M);
430+
InstructionMDList DXILCompatibleMDs = getCompatibleInstructionMDs(M);
431+
unsigned char MDLoopKind = M.getContext().getMDKindID("llvm.loop");
327432

328433
for (Function &F : M) {
329434
for (BasicBlock &BB : F) {
330435
// This needs to be done first so that "hlsl.controlflow.hints" isn't
331-
// removed in the whitelist below
436+
// removed in the allow-list below
332437
if (auto *I = BB.getTerminator())
333438
translateBranchMetadata(M, I);
334439

335440
for (auto &I : make_early_inc_range(BB)) {
441+
if (isa<BranchInst>(I))
442+
if (MDNode *LoopMD = I.getMetadata(MDLoopKind))
443+
translateLoopMetadata(M, &I, LoopMD);
336444
I.dropUnknownNonDebugMetadata(DXILCompatibleMDs);
337445
}
338446
}
@@ -389,31 +497,23 @@ static void translateGlobalMetadata(Module &M, DXILResourceMap &DRM,
389497
uint64_t CombinedMask = ShaderFlags.getCombinedFlags();
390498
EntryFnMDNodes.emplace_back(
391499
emitTopLevelLibraryNode(M, ResourceMD, CombinedMask));
392-
} else if (MMDI.EntryPropertyVec.size() > 1) {
393-
M.getContext().diagnose(DiagnosticInfoTranslateMD(
394-
M, "Non-library shader: One and only one entry expected"));
395-
}
500+
} else if (1 < MMDI.EntryPropertyVec.size())
501+
reportError(M, "Non-library shader: One and only one entry expected");
396502

397503
for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) {
398-
const ComputedShaderFlags &EntrySFMask =
399-
ShaderFlags.getFunctionFlags(EntryProp.Entry);
400-
401-
// If ShaderProfile is Library, mask is already consolidated in the
402-
// top-level library node. Hence it is not emitted.
403504
uint64_t EntryShaderFlags = 0;
404505
if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
405-
EntryShaderFlags = EntrySFMask;
406-
if (EntryProp.ShaderStage != MMDI.ShaderProfile) {
407-
M.getContext().diagnose(DiagnosticInfoTranslateMD(
408-
M,
409-
"Shader stage '" +
410-
Twine(getShortShaderStage(EntryProp.ShaderStage) +
411-
"' for entry '" + Twine(EntryProp.Entry->getName()) +
412-
"' different from specified target profile '" +
413-
Twine(Triple::getEnvironmentTypeName(MMDI.ShaderProfile) +
414-
"'"))));
415-
}
506+
EntryShaderFlags = ShaderFlags.getFunctionFlags(EntryProp.Entry);
507+
if (EntryProp.ShaderStage != MMDI.ShaderProfile)
508+
reportError(
509+
M, "Shader stage '" +
510+
Twine(getShortShaderStage(EntryProp.ShaderStage)) +
511+
"' for entry '" + Twine(EntryProp.Entry->getName()) +
512+
"' different from specified target profile '" +
513+
Twine(Triple::getEnvironmentTypeName(MMDI.ShaderProfile) +
514+
"'"));
416515
}
516+
417517
EntryFnMDNodes.emplace_back(emitEntryMD(EntryProp, Signatures, ResourceMD,
418518
EntryShaderFlags,
419519
MMDI.ShaderProfile));
@@ -454,45 +554,34 @@ PreservedAnalyses DXILTranslateMetadata::run(Module &M,
454554
return PreservedAnalyses::all();
455555
}
456556

457-
namespace {
458-
class DXILTranslateMetadataLegacy : public ModulePass {
459-
public:
460-
static char ID; // Pass identification, replacement for typeid
461-
explicit DXILTranslateMetadataLegacy() : ModulePass(ID) {}
462-
463-
StringRef getPassName() const override { return "DXIL Translate Metadata"; }
464-
465-
void getAnalysisUsage(AnalysisUsage &AU) const override {
466-
AU.addRequired<DXILResourceTypeWrapperPass>();
467-
AU.addRequired<DXILResourceWrapperPass>();
468-
AU.addRequired<ShaderFlagsAnalysisWrapper>();
469-
AU.addRequired<DXILMetadataAnalysisWrapperPass>();
470-
AU.addRequired<RootSignatureAnalysisWrapper>();
471-
472-
AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
473-
AU.addPreserved<DXILResourceBindingWrapperPass>();
474-
AU.addPreserved<DXILResourceWrapperPass>();
475-
AU.addPreserved<RootSignatureAnalysisWrapper>();
476-
AU.addPreserved<ShaderFlagsAnalysisWrapper>();
477-
}
557+
void DXILTranslateMetadataLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
558+
AU.addRequired<DXILResourceTypeWrapperPass>();
559+
AU.addRequired<DXILResourceWrapperPass>();
560+
AU.addRequired<ShaderFlagsAnalysisWrapper>();
561+
AU.addRequired<DXILMetadataAnalysisWrapperPass>();
562+
AU.addRequired<RootSignatureAnalysisWrapper>();
563+
564+
AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
565+
AU.addPreserved<DXILResourceBindingWrapperPass>();
566+
AU.addPreserved<DXILResourceWrapperPass>();
567+
AU.addPreserved<RootSignatureAnalysisWrapper>();
568+
AU.addPreserved<ShaderFlagsAnalysisWrapper>();
569+
}
478570

479-
bool runOnModule(Module &M) override {
480-
DXILResourceMap &DRM =
481-
getAnalysis<DXILResourceWrapperPass>().getResourceMap();
482-
DXILResourceTypeMap &DRTM =
483-
getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
484-
const ModuleShaderFlags &ShaderFlags =
485-
getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags();
486-
dxil::ModuleMetadataInfo MMDI =
487-
getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
488-
489-
translateGlobalMetadata(M, DRM, DRTM, ShaderFlags, MMDI);
490-
translateInstructionMetadata(M);
491-
return true;
492-
}
493-
};
571+
bool DXILTranslateMetadataLegacy::runOnModule(Module &M) {
572+
DXILResourceMap &DRM =
573+
getAnalysis<DXILResourceWrapperPass>().getResourceMap();
574+
DXILResourceTypeMap &DRTM =
575+
getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
576+
const ModuleShaderFlags &ShaderFlags =
577+
getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags();
578+
dxil::ModuleMetadataInfo MMDI =
579+
getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
494580

495-
} // namespace
581+
translateGlobalMetadata(M, DRM, DRTM, ShaderFlags, MMDI);
582+
translateInstructionMetadata(M);
583+
return true;
584+
}
496585

497586
char DXILTranslateMetadataLegacy::ID = 0;
498587

llvm/lib/Target/DirectX/DXILTranslateMetadata.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define LLVM_TARGET_DIRECTX_DXILTRANSLATEMETADATA_H
1111

1212
#include "llvm/IR/PassManager.h"
13+
#include "llvm/Pass.h"
1314

1415
namespace llvm {
1516

@@ -20,6 +21,22 @@ class DXILTranslateMetadata : public PassInfoMixin<DXILTranslateMetadata> {
2021
PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
2122
};
2223

24+
/// Wrapper pass for the legacy pass manager.
25+
///
26+
/// This is required because the passes that will depend on this are codegen
27+
/// passes which run through the legacy pass manager.
28+
class DXILTranslateMetadataLegacy : public ModulePass {
29+
public:
30+
static char ID; // Pass identification, replacement for typeid
31+
explicit DXILTranslateMetadataLegacy() : ModulePass(ID) {}
32+
33+
StringRef getPassName() const override { return "DXIL Translate Metadata"; }
34+
35+
void getAnalysisUsage(AnalysisUsage &AU) const override;
36+
37+
bool runOnModule(Module &M) override;
38+
};
39+
2340
} // namespace llvm
2441

2542
#endif // LLVM_TARGET_DIRECTX_DXILTRANSLATEMETADATA_H

0 commit comments

Comments
 (0)