@@ -36,9 +36,10 @@ using namespace llvm;
3636using namespace llvm ::dxil;
3737
3838namespace {
39- // / A simple Wrapper DiagnosticInfo that generates Module-level diagnostic
40- // / for TranslateMetadata pass
41- class DiagnosticInfoTranslateMD : public DiagnosticInfo {
39+
40+ // / A simple wrapper of DiagnosticInfo that generates module-level diagnostic
41+ // / for the DXILValidateMetadata pass
42+ class DiagnosticInfoValidateMD : public DiagnosticInfo {
4243private:
4344 const Twine &Msg;
4445 const Module &Mod;
@@ -47,16 +48,26 @@ class DiagnosticInfoTranslateMD : public DiagnosticInfo {
4748 // / \p M is the module for which the diagnostic is being emitted. \p Msg is
4849 // / the message to show. Note that this class does not copy this message, so
4950 // / this reference must be valid for the whole life time of the diagnostic.
50- DiagnosticInfoTranslateMD (const Module &M,
51- const Twine &Msg LLVM_LIFETIME_BOUND,
52- DiagnosticSeverity Severity = DS_Error)
51+ DiagnosticInfoValidateMD (const Module &M,
52+ const Twine &Msg LLVM_LIFETIME_BOUND,
53+ DiagnosticSeverity Severity = DS_Error)
5354 : DiagnosticInfo(DK_Unsupported, Severity), Msg(Msg), Mod(M) {}
5455
5556 void print (DiagnosticPrinter &DP) const override {
5657 DP << Mod.getName () << " : " << Msg << ' \n ' ;
5758 }
5859};
5960
61+ static void reportError (Module &M, Twine Message,
62+ DiagnosticSeverity Severity = DS_Error) {
63+ M.getContext ().diagnose (DiagnosticInfoValidateMD (M, Message, Severity));
64+ }
65+
66+ static void reportLoopError (Module &M, Twine Message,
67+ DiagnosticSeverity Severity = DS_Error) {
68+ reportError (M, Twine (" Invalid \" llvm.loop\" metadata: " ) + Message, Severity);
69+ }
70+
6071enum class EntryPropsTag {
6172 ShaderFlags = 0 ,
6273 GSState,
@@ -314,25 +325,122 @@ static void translateBranchMetadata(Module &M, Instruction *BBTerminatorInst) {
314325 BBTerminatorInst->setMetadata (" hlsl.controlflow.hint" , nullptr );
315326}
316327
317- static std::array<unsigned , 6 > getCompatibleInstructionMDs (llvm::Module &M) {
328+ // Determines if the metadata node will be compatible with DXIL's loop metadata
329+ // representation.
330+ //
331+ // Reports an error for compatible metadata that is ill-formed.
332+ static bool isLoopMDCompatible (Module &M, Metadata *MD) {
333+ // DXIL only accepts the following loop hints:
334+ std::array<StringLiteral, 3 > ValidHintNames = {" llvm.loop.unroll.count" ,
335+ " llvm.loop.unroll.disable" ,
336+ " llvm.loop.unroll.full" };
337+
338+ MDNode *HintMD = dyn_cast<MDNode>(MD);
339+ if (!HintMD || HintMD->getNumOperands () == 0 )
340+ return false ;
341+
342+ auto *HintStr = dyn_cast<MDString>(HintMD->getOperand (0 ));
343+ if (!HintStr)
344+ return false ;
345+
346+ if (!llvm::is_contained (ValidHintNames, HintStr->getString ()))
347+ return false ;
348+
349+ auto ValidCountNode = [](MDNode *CountMD) -> bool {
350+ if (CountMD->getNumOperands () == 2 )
351+ if (auto *Count = dyn_cast<ConstantAsMetadata>(CountMD->getOperand (1 )))
352+ if (isa<ConstantInt>(Count->getValue ()))
353+ return true ;
354+ return false ;
355+ };
356+
357+ if (HintStr->getString () == " llvm.loop.unroll.count" ) {
358+ if (!ValidCountNode (HintMD)) {
359+ reportLoopError (M, " \" llvm.loop.unroll.count\" must have 2 operands and "
360+ " the second must be a constant integer" );
361+ return false ;
362+ }
363+ } else if (HintMD->getNumOperands () != 1 ) {
364+ reportLoopError (
365+ M, " \" llvm.loop.unroll.disable\" and \" llvm.loop.unroll.full\" "
366+ " must be provided as a single operand" );
367+ return false ;
368+ }
369+
370+ return true ;
371+ }
372+
373+ static void translateLoopMetadata (Module &M, Instruction *I, MDNode *BaseMD) {
374+ // A distinct node has the self-referential form: !0 = !{ !0, ... }
375+ auto IsDistinctNode = [](MDNode *Node) -> bool {
376+ return Node && Node->getNumOperands () != 0 && Node == Node->getOperand (0 );
377+ };
378+
379+ // Set metadata to null to remove empty/ill-formed metadata from instruction
380+ if (BaseMD->getNumOperands () == 0 || !IsDistinctNode (BaseMD))
381+ return I->setMetadata (" llvm.loop" , nullptr );
382+
383+ // It is valid to have a chain of self-refential loop metadata nodes, as
384+ // below. We will collapse these into just one when we reconstruct the
385+ // metadata.
386+ //
387+ // Eg:
388+ // !0 = !{!0, !1}
389+ // !1 = !{!1, !2}
390+ // !2 = !{!"llvm.loop.unroll.disable"}
391+ //
392+ // So, traverse down a potential self-referential chain
393+ while (1 < BaseMD->getNumOperands () &&
394+ IsDistinctNode (dyn_cast<MDNode>(BaseMD->getOperand (1 ))))
395+ BaseMD = dyn_cast<MDNode>(BaseMD->getOperand (1 ));
396+
397+ // To reconstruct a distinct node we create a temporary node that we will
398+ // then update to create a self-reference.
399+ llvm::TempMDTuple TempNode = llvm::MDNode::getTemporary (M.getContext (), {});
400+ SmallVector<Metadata *> CompatibleOperands = {TempNode.get ()};
401+
402+ // Iterate and reconstruct the metadata nodes that contains any hints,
403+ // stripping any unrecognized metadata.
404+ ArrayRef<MDOperand> Operands = BaseMD->operands ();
405+ for (auto &Op : Operands.drop_front ())
406+ if (isLoopMDCompatible (M, Op.get ()))
407+ CompatibleOperands.push_back (Op.get ());
408+
409+ if (2 < CompatibleOperands.size ())
410+ reportLoopError (M, " Provided conflicting hints" );
411+
412+ MDNode *CompatibleLoopMD = MDNode::get (M.getContext (), CompatibleOperands);
413+ TempNode->replaceAllUsesWith (CompatibleLoopMD);
414+
415+ I->setMetadata (" llvm.loop" , CompatibleLoopMD);
416+ }
417+
418+ using InstructionMDList = std::array<unsigned , 7 >;
419+
420+ static InstructionMDList getCompatibleInstructionMDs (llvm::Module &M) {
318421 return {
319422 M.getMDKindID (" dx.nonuniform" ), M.getMDKindID (" dx.controlflow.hints" ),
320423 M.getMDKindID (" dx.precise" ), llvm::LLVMContext::MD_range,
321- llvm::LLVMContext::MD_alias_scope, llvm::LLVMContext::MD_noalias};
424+ llvm::LLVMContext::MD_alias_scope, llvm::LLVMContext::MD_noalias,
425+ M.getMDKindID (" llvm.loop" )};
322426}
323427
324428static void translateInstructionMetadata (Module &M) {
325429 // construct allowlist of valid metadata node kinds
326- std::array<unsigned , 6 > DXILCompatibleMDs = getCompatibleInstructionMDs (M);
430+ InstructionMDList DXILCompatibleMDs = getCompatibleInstructionMDs (M);
431+ unsigned char MDLoopKind = M.getContext ().getMDKindID (" llvm.loop" );
327432
328433 for (Function &F : M) {
329434 for (BasicBlock &BB : F) {
330435 // This needs to be done first so that "hlsl.controlflow.hints" isn't
331- // removed in the whitelist below
436+ // removed in the allow-list below
332437 if (auto *I = BB.getTerminator ())
333438 translateBranchMetadata (M, I);
334439
335440 for (auto &I : make_early_inc_range (BB)) {
441+ if (isa<BranchInst>(I))
442+ if (MDNode *LoopMD = I.getMetadata (MDLoopKind))
443+ translateLoopMetadata (M, &I, LoopMD);
336444 I.dropUnknownNonDebugMetadata (DXILCompatibleMDs);
337445 }
338446 }
@@ -389,31 +497,23 @@ static void translateGlobalMetadata(Module &M, DXILResourceMap &DRM,
389497 uint64_t CombinedMask = ShaderFlags.getCombinedFlags ();
390498 EntryFnMDNodes.emplace_back (
391499 emitTopLevelLibraryNode (M, ResourceMD, CombinedMask));
392- } else if (MMDI.EntryPropertyVec .size () > 1 ) {
393- M.getContext ().diagnose (DiagnosticInfoTranslateMD (
394- M, " Non-library shader: One and only one entry expected" ));
395- }
500+ } else if (1 < MMDI.EntryPropertyVec .size ())
501+ reportError (M, " Non-library shader: One and only one entry expected" );
396502
397503 for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec ) {
398- const ComputedShaderFlags &EntrySFMask =
399- ShaderFlags.getFunctionFlags (EntryProp.Entry );
400-
401- // If ShaderProfile is Library, mask is already consolidated in the
402- // top-level library node. Hence it is not emitted.
403504 uint64_t EntryShaderFlags = 0 ;
404505 if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
405- EntryShaderFlags = EntrySFMask;
406- if (EntryProp.ShaderStage != MMDI.ShaderProfile ) {
407- M.getContext ().diagnose (DiagnosticInfoTranslateMD (
408- M,
409- " Shader stage '" +
410- Twine (getShortShaderStage (EntryProp.ShaderStage ) +
411- " ' for entry '" + Twine (EntryProp.Entry ->getName ()) +
412- " ' different from specified target profile '" +
413- Twine (Triple::getEnvironmentTypeName (MMDI.ShaderProfile ) +
414- " '" ))));
415- }
506+ EntryShaderFlags = ShaderFlags.getFunctionFlags (EntryProp.Entry );
507+ if (EntryProp.ShaderStage != MMDI.ShaderProfile )
508+ reportError (
509+ M, " Shader stage '" +
510+ Twine (getShortShaderStage (EntryProp.ShaderStage )) +
511+ " ' for entry '" + Twine (EntryProp.Entry ->getName ()) +
512+ " ' different from specified target profile '" +
513+ Twine (Triple::getEnvironmentTypeName (MMDI.ShaderProfile ) +
514+ " '" ));
416515 }
516+
417517 EntryFnMDNodes.emplace_back (emitEntryMD (EntryProp, Signatures, ResourceMD,
418518 EntryShaderFlags,
419519 MMDI.ShaderProfile ));
@@ -454,45 +554,34 @@ PreservedAnalyses DXILTranslateMetadata::run(Module &M,
454554 return PreservedAnalyses::all ();
455555}
456556
457- namespace {
458- class DXILTranslateMetadataLegacy : public ModulePass {
459- public:
460- static char ID; // Pass identification, replacement for typeid
461- explicit DXILTranslateMetadataLegacy () : ModulePass(ID) {}
462-
463- StringRef getPassName () const override { return " DXIL Translate Metadata" ; }
464-
465- void getAnalysisUsage (AnalysisUsage &AU) const override {
466- AU.addRequired <DXILResourceTypeWrapperPass>();
467- AU.addRequired <DXILResourceWrapperPass>();
468- AU.addRequired <ShaderFlagsAnalysisWrapper>();
469- AU.addRequired <DXILMetadataAnalysisWrapperPass>();
470- AU.addRequired <RootSignatureAnalysisWrapper>();
471-
472- AU.addPreserved <DXILMetadataAnalysisWrapperPass>();
473- AU.addPreserved <DXILResourceBindingWrapperPass>();
474- AU.addPreserved <DXILResourceWrapperPass>();
475- AU.addPreserved <RootSignatureAnalysisWrapper>();
476- AU.addPreserved <ShaderFlagsAnalysisWrapper>();
477- }
557+ void DXILTranslateMetadataLegacy::getAnalysisUsage (AnalysisUsage &AU) const {
558+ AU.addRequired <DXILResourceTypeWrapperPass>();
559+ AU.addRequired <DXILResourceWrapperPass>();
560+ AU.addRequired <ShaderFlagsAnalysisWrapper>();
561+ AU.addRequired <DXILMetadataAnalysisWrapperPass>();
562+ AU.addRequired <RootSignatureAnalysisWrapper>();
563+
564+ AU.addPreserved <DXILMetadataAnalysisWrapperPass>();
565+ AU.addPreserved <DXILResourceBindingWrapperPass>();
566+ AU.addPreserved <DXILResourceWrapperPass>();
567+ AU.addPreserved <RootSignatureAnalysisWrapper>();
568+ AU.addPreserved <ShaderFlagsAnalysisWrapper>();
569+ }
478570
479- bool runOnModule (Module &M) override {
480- DXILResourceMap &DRM =
481- getAnalysis<DXILResourceWrapperPass>().getResourceMap ();
482- DXILResourceTypeMap &DRTM =
483- getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap ();
484- const ModuleShaderFlags &ShaderFlags =
485- getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags ();
486- dxil::ModuleMetadataInfo MMDI =
487- getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata ();
488-
489- translateGlobalMetadata (M, DRM, DRTM, ShaderFlags, MMDI);
490- translateInstructionMetadata (M);
491- return true ;
492- }
493- };
571+ bool DXILTranslateMetadataLegacy::runOnModule (Module &M) {
572+ DXILResourceMap &DRM =
573+ getAnalysis<DXILResourceWrapperPass>().getResourceMap ();
574+ DXILResourceTypeMap &DRTM =
575+ getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap ();
576+ const ModuleShaderFlags &ShaderFlags =
577+ getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags ();
578+ dxil::ModuleMetadataInfo MMDI =
579+ getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata ();
494580
495- } // namespace
581+ translateGlobalMetadata (M, DRM, DRTM, ShaderFlags, MMDI);
582+ translateInstructionMetadata (M);
583+ return true ;
584+ }
496585
497586char DXILTranslateMetadataLegacy::ID = 0 ;
498587
0 commit comments