@@ -573,6 +573,10 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Call *op) {
573573 Expr equiv = Call::make (op->type , " pow" , op->args , Call::PureExtern);
574574 equiv.accept (this );
575575 } else if (op->is_strict_float_intrinsic ()) {
576+ // Emit with HLSL 'precise' qualifier, which prevents the compiler from
577+ // reordering or fusing these operations (e.g. reassociating additions or
578+ // collapsing separate mul+add into an FMA).
579+ // Note: HLSL fma() is double-only; for float use mad() instead.
576580 ScopedValue old_emit_precise (emit_precise, true );
577581 Expr equiv = op->is_intrinsic (Call::strict_fma) ?
578582 Call::make (op->type , op->type .bits () == 64 ? " fma" : " mad" , op->args , Call::PureExtern) :
@@ -611,6 +615,15 @@ string hex_literal(T value) {
611615 return hex.str ();
612616}
613617
618+ // Return the HLSL bit-reinterpret intrinsic name for a given type.
619+ // These names are fixed regardless of SM level (legacy aliases always work).
620+ string hlsl_reinterpret_name (Type t) {
621+ if (t.is_float ()) {
622+ return t.bits () == 16 ? " asfloat16" : " asfloat" ;
623+ }
624+ return t.is_int () ? " asint" : " asuint" ;
625+ }
626+
614627} // namespace
615628
616629struct StoragePackUnpack {
@@ -690,7 +703,7 @@ struct StoragePackUnpack {
690703 // the smallest type granularity in HLSL SM 5.1 allows is 32bit types):
691704 if (op->type .bits () == 32 ) {
692705 // loading a 32bit word? great! just reinterpret as float/int/uint
693- rhs << " as " << cg. print_type (op->type .element_of ())
706+ rhs << hlsl_reinterpret_name (op->type .element_of ())
694707 << " ("
695708 << cg.print_name (op->name )
696709 << " [" << cg.print_expr (op->index ) << " ]"
@@ -754,7 +767,7 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Load *op) {
754767 if (promoted_type != op->type ) {
755768 shared_promotion_required = true ;
756769 // NOTE(marcos): might need to resort to StoragePackUnpack::unpack_load() here
757- promotion_str = " as " + print_type (promoted_type);
770+ promotion_str = hlsl_reinterpret_name (promoted_type);
758771 }
759772 }
760773
@@ -1029,7 +1042,7 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Store *op) {
10291042 if (promoted_type != op->value .type ()) {
10301043 shared_promotion_required = true ;
10311044 // NOTE(marcos): might need to resort to StoragePackUnpack::pack_store() here
1032- promotion_str = " as " + print_type (promoted_type);
1045+ promotion_str = hlsl_reinterpret_name (promoted_type);
10331046 }
10341047 }
10351048
@@ -1680,9 +1693,53 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s,
16801693 const bool use_dxc = (sm >= 60 );
16811694
16821695 if (use_dxc) {
1683- // DXC (SM 6.x) does not support FXC-style resource/uniform parameters to entry
1684- // functions. Declare all resources globally with explicit register bindings and
1685- // put scalar uniforms in a constant buffer. The runtime binds:
1696+ // DXC (SM 6.x): resources must be at global module scope, not function
1697+ // parameters. When multiple kernels share a module, arg names can clash.
1698+ // Prefix every arg name with the kernel name to guarantee uniqueness.
1699+ std::map<string, string> dxc_renames;
1700+ for (const auto &arg : args) {
1701+ dxc_renames[arg.name ] = name + " _" + arg.name ;
1702+ }
1703+
1704+ // Mutate Load/Store/Variable nodes in the body to use the prefixed names.
1705+ class RenameKernelArgs : public IRMutator {
1706+ using IRMutator::visit;
1707+ const std::map<string, string> &renames;
1708+ Expr visit (const Load *op) override {
1709+ auto it = renames.find (op->name );
1710+ if (it != renames.end ()) {
1711+ return Load::make (op->type , it->second ,
1712+ mutate (op->index ), op->image , op->param ,
1713+ mutate (op->predicate ), op->alignment );
1714+ }
1715+ return IRMutator::visit (op);
1716+ }
1717+ Stmt visit (const Store *op) override {
1718+ auto it = renames.find (op->name );
1719+ if (it != renames.end ()) {
1720+ return Store::make (it->second , mutate (op->value ),
1721+ mutate (op->index ), op->param ,
1722+ mutate (op->predicate ), op->alignment );
1723+ }
1724+ return IRMutator::visit (op);
1725+ }
1726+ Expr visit (const Variable *op) override {
1727+ auto it = renames.find (op->name );
1728+ if (it != renames.end ()) {
1729+ return Variable::make (op->type , it->second ,
1730+ op->image , op->param , op->reduction_domain );
1731+ }
1732+ return IRMutator::visit (op);
1733+ }
1734+
1735+ public:
1736+ RenameKernelArgs (const std::map<string, string> &r)
1737+ : renames(r) {}
1738+ };
1739+ s = RenameKernelArgs (dxc_renames)(s);
1740+
1741+ // Declare all resources globally with explicit register bindings and
1742+ // put scalar uniforms in a per-kernel constant buffer. The runtime binds:
16861743 // - scalar args → cbuffer at register(b0)
16871744 // - buffer args → UAV at register(u0), register(u1), ...
16881745 int uav_index = 0 ;
@@ -1692,28 +1749,30 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s,
16921749 has_scalars = true ;
16931750 continue ;
16941751 }
1752+ const string &pname = dxc_renames.at (arg.name );
16951753 if (arg.memory_type == MemoryType::GPUTexture) {
16961754 int dims = arg.dimensions ;
16971755 internal_assert (dims >= 1 && dims <= 3 ) << " D3D12Compute texture must have 1-3 dimensions\n " ;
16981756 stream << " RWTexture" << dims << " D"
16991757 << " <" << print_type (arg.type ) << " >"
1700- << " " << print_name (arg. name )
1758+ << " " << print_name (pname )
17011759 << " : register(u" << uav_index++ << " );\n " ;
17021760 } else {
17031761 stream << " RWBuffer"
17041762 << " <" << print_type (arg.type ) << " >"
1705- << " " << print_name (arg. name )
1763+ << " " << print_name (pname )
17061764 << " : register(u" << uav_index++ << " );\n " ;
17071765 }
17081766 Allocation alloc;
17091767 alloc.type = arg.type ;
1710- allocations.push (arg. name , alloc);
1768+ allocations.push (pname , alloc);
17111769 }
17121770 if (has_scalars) {
1713- stream << " cbuffer _halide_uniform_args : register(b0) {\n " ;
1771+ stream << " cbuffer " << name << " _uniforms : register(b0) {\n " ;
17141772 for (const auto &arg : args) {
17151773 if (!arg.is_buffer ) {
1716- stream << " " << print_type (arg.type ) << " " << print_name (arg.name ) << " ;\n " ;
1774+ const string &pname = dxc_renames.at (arg.name );
1775+ stream << " " << print_type (arg.type ) << " " << print_name (pname) << " ;\n " ;
17171776 }
17181777 }
17191778 stream << " };\n " ;
@@ -1774,9 +1833,10 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s,
17741833 close_scope (" kernel " + name);
17751834
17761835 for (const auto &arg : args) {
1777- // Remove buffer arguments from allocation scope
1836+ // Remove buffer arguments from allocation scope.
1837+ // DXC allocations were pushed under prefixed names; FXC under original names.
17781838 if (arg.is_buffer ) {
1779- allocations.pop (arg.name );
1839+ allocations.pop (use_dxc ? (name + " _ " + arg. name ) : arg.name );
17801840 }
17811841 }
17821842
0 commit comments