@@ -529,76 +529,21 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Call *op) {
529529 Expr equiv = Call::make (op->type , " pow" , op->args , Call::PureExtern);
530530 equiv.accept (this );
531531 } else if (op->is_strict_float_intrinsic ()) {
532+ // Emit with HLSL 'precise' qualifier, which prevents the compiler from
533+ // reordering or fusing these operations (e.g. reassociating additions or
534+ // collapsing separate mul+add into an FMA).
535+ // Note: HLSL fma() is double-only; for float use mad() instead.
532536 ScopedValue old_emit_precise (emit_precise, true );
533537 Expr equiv = op->is_intrinsic (Call::strict_fma) ?
534- Call::make (op->type , " fma" , op->args , Call::PureExtern) :
538+ Call::make (op->type ,
539+ op->type .bits () == 64 ? " fma" : " mad" ,
540+ op->args , Call::PureExtern) :
535541 unstrictify_float (op);
536542 equiv.accept (this );
537543 } else if (op->is_intrinsic (Call::round)) {
538544 // HLSL's round intrinsic has the correct semantics for our rounding.
539545 Expr equiv = Call::make (op->type , " round" , op->args , Call::PureExtern);
540546 equiv.accept (this );
541- } else if (op->is_strict_float_intrinsic ()) {
542- // Emit strict float intrinsics as HLSL 'precise'-qualified temporaries.
543- // The 'precise' qualifier prevents the HLSL compiler from reordering or
544- // fusing these operations (e.g. forming FMAs, reassociating additions).
545- // We use a "precise:" cache key prefix so that strict and non-strict
546- // evaluations of the same sub-expression don't share a variable.
547- string rhs;
548- if (op->is_intrinsic (Call::strict_fma)) {
549- string a = print_expr (op->args [0 ]);
550- string b = print_expr (op->args [1 ]);
551- string c = print_expr (op->args [2 ]);
552- rhs = " mad(" + a + " , " + b + " , " + c + " )" ;
553- } else if (op->is_intrinsic (Call::strict_add)) {
554- string a = print_expr (op->args [0 ]);
555- string b = print_expr (op->args [1 ]);
556- rhs = a + " + " + b;
557- } else if (op->is_intrinsic (Call::strict_sub)) {
558- string a = print_expr (op->args [0 ]);
559- string b = print_expr (op->args [1 ]);
560- rhs = a + " - " + b;
561- } else if (op->is_intrinsic (Call::strict_mul)) {
562- string a = print_expr (op->args [0 ]);
563- string b = print_expr (op->args [1 ]);
564- rhs = a + " * " + b;
565- } else if (op->is_intrinsic (Call::strict_div)) {
566- string a = print_expr (op->args [0 ]);
567- string b = print_expr (op->args [1 ]);
568- rhs = a + " / " + b;
569- } else if (op->is_intrinsic (Call::strict_min)) {
570- string a = print_expr (op->args [0 ]);
571- string b = print_expr (op->args [1 ]);
572- rhs = " min(" + a + " , " + b + " )" ;
573- } else if (op->is_intrinsic (Call::strict_max)) {
574- string a = print_expr (op->args [0 ]);
575- string b = print_expr (op->args [1 ]);
576- rhs = " max(" + a + " , " + b + " )" ;
577- } else if (op->is_intrinsic (Call::strict_lt)) {
578- string a = print_expr (op->args [0 ]);
579- string b = print_expr (op->args [1 ]);
580- rhs = a + " < " + b;
581- } else if (op->is_intrinsic (Call::strict_le)) {
582- string a = print_expr (op->args [0 ]);
583- string b = print_expr (op->args [1 ]);
584- rhs = a + " <= " + b;
585- } else if (op->is_intrinsic (Call::strict_eq)) {
586- string a = print_expr (op->args [0 ]);
587- string b = print_expr (op->args [1 ]);
588- rhs = a + " == " + b;
589- } else {
590- internal_assert (op->is_intrinsic (Call::strict_cast));
591- rhs = " (" + print_type (op->type ) + " )(" + print_expr (op->args [0 ]) + " )" ;
592- }
593- const string key = " precise:" + rhs;
594- const auto it = cache.find (key);
595- if (it == cache.end ()) {
596- id = unique_name (' _' );
597- stream << get_indent () << " precise " << print_type (op->type ) << " " << id << " = " << rhs << " ;\n " ;
598- cache[key] = id;
599- } else {
600- id = it->second ;
601- }
602547 } else {
603548 CodeGen_GPU_C::visit (op);
604549 }
@@ -628,6 +573,15 @@ string hex_literal(T value) {
628573 return hex.str ();
629574}
630575
576+ // Return the HLSL bit-reinterpret intrinsic name for a given type.
577+ // These names are fixed regardless of SM level (legacy aliases always work).
578+ string hlsl_reinterpret_name (Type t) {
579+ if (t.is_float ()) {
580+ return t.bits () == 16 ? " asfloat16" : " asfloat" ;
581+ }
582+ return t.is_int () ? " asint" : " asuint" ;
583+ }
584+
631585} // namespace
632586
633587struct StoragePackUnpack {
@@ -707,7 +661,7 @@ struct StoragePackUnpack {
707661 // the smallest type granularity in HLSL SM 5.1 allows is 32bit types):
708662 if (op->type .bits () == 32 ) {
709663 // loading a 32bit word? great! just reinterpret as float/int/uint
710- rhs << " as " << cg. print_type (op->type .element_of ())
664+ rhs << hlsl_reinterpret_name (op->type .element_of ())
711665 << " ("
712666 << cg.print_name (op->name )
713667 << " [" << cg.print_expr (op->index ) << " ]"
@@ -771,7 +725,7 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Load *op) {
771725 if (promoted_type != op->type ) {
772726 shared_promotion_required = true ;
773727 // NOTE(marcos): might need to resort to StoragePackUnpack::unpack_load() here
774- promotion_str = " as " + print_type (promoted_type);
728+ promotion_str = hlsl_reinterpret_name (promoted_type);
775729 }
776730 }
777731
@@ -1046,7 +1000,7 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Store *op) {
10461000 if (promoted_type != op->value .type ()) {
10471001 shared_promotion_required = true ;
10481002 // NOTE(marcos): might need to resort to StoragePackUnpack::pack_store() here
1049- promotion_str = " as " + print_type (promoted_type);
1003+ promotion_str = hlsl_reinterpret_name (promoted_type);
10501004 }
10511005 }
10521006
@@ -1697,9 +1651,53 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s,
16971651 const bool use_dxc = (sm >= 60 );
16981652
16991653 if (use_dxc) {
1700- // DXC (SM 6.x) does not support FXC-style resource/uniform parameters to entry
1701- // functions. Declare all resources globally with explicit register bindings and
1702- // put scalar uniforms in a constant buffer. The runtime binds:
1654+ // DXC (SM 6.x): resources must be at global module scope, not function
1655+ // parameters. When multiple kernels share a module, arg names can clash.
1656+ // Prefix every arg name with the kernel name to guarantee uniqueness.
1657+ std::map<string, string> dxc_renames;
1658+ for (const auto &arg : args) {
1659+ dxc_renames[arg.name ] = name + " _" + arg.name ;
1660+ }
1661+
1662+ // Mutate Load/Store/Variable nodes in the body to use the prefixed names.
1663+ class RenameKernelArgs : public IRMutator {
1664+ using IRMutator::visit;
1665+ const std::map<string, string> &renames;
1666+ Expr visit (const Load *op) override {
1667+ auto it = renames.find (op->name );
1668+ if (it != renames.end ()) {
1669+ return Load::make (op->type , it->second ,
1670+ mutate (op->index ), op->image , op->param ,
1671+ mutate (op->predicate ), op->alignment );
1672+ }
1673+ return IRMutator::visit (op);
1674+ }
1675+ Stmt visit (const Store *op) override {
1676+ auto it = renames.find (op->name );
1677+ if (it != renames.end ()) {
1678+ return Store::make (it->second , mutate (op->value ),
1679+ mutate (op->index ), op->param ,
1680+ mutate (op->predicate ), op->alignment );
1681+ }
1682+ return IRMutator::visit (op);
1683+ }
1684+ Expr visit (const Variable *op) override {
1685+ auto it = renames.find (op->name );
1686+ if (it != renames.end ()) {
1687+ return Variable::make (op->type , it->second ,
1688+ op->image , op->param , op->reduction_domain );
1689+ }
1690+ return IRMutator::visit (op);
1691+ }
1692+
1693+ public:
1694+ RenameKernelArgs (const std::map<string, string> &r)
1695+ : renames(r) {}
1696+ };
1697+ s = RenameKernelArgs (dxc_renames)(s);
1698+
1699+ // Declare all resources globally with explicit register bindings and
1700+ // put scalar uniforms in a per-kernel constant buffer. The runtime binds:
17031701 // - scalar args → cbuffer at register(b0)
17041702 // - buffer args → UAV at register(u0), register(u1), ...
17051703 int uav_index = 0 ;
@@ -1709,28 +1707,30 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s,
17091707 has_scalars = true ;
17101708 continue ;
17111709 }
1710+ const string &pname = dxc_renames.at (arg.name );
17121711 if (arg.memory_type == MemoryType::GPUTexture) {
17131712 int dims = arg.dimensions ;
17141713 internal_assert (dims >= 1 && dims <= 3 ) << " D3D12Compute texture must have 1-3 dimensions\n " ;
17151714 stream << " RWTexture" << dims << " D"
17161715 << " <" << print_type (arg.type ) << " >"
1717- << " " << print_name (arg. name )
1716+ << " " << print_name (pname )
17181717 << " : register(u" << uav_index++ << " );\n " ;
17191718 } else {
17201719 stream << " RWBuffer"
17211720 << " <" << print_type (arg.type ) << " >"
1722- << " " << print_name (arg. name )
1721+ << " " << print_name (pname )
17231722 << " : register(u" << uav_index++ << " );\n " ;
17241723 }
17251724 Allocation alloc;
17261725 alloc.type = arg.type ;
1727- allocations.push (arg. name , alloc);
1726+ allocations.push (pname , alloc);
17281727 }
17291728 if (has_scalars) {
1730- stream << " cbuffer _halide_uniform_args : register(b0) {\n " ;
1729+ stream << " cbuffer " << name << " _uniforms : register(b0) {\n " ;
17311730 for (const auto &arg : args) {
17321731 if (!arg.is_buffer ) {
1733- stream << " " << print_type (arg.type ) << " " << print_name (arg.name ) << " ;\n " ;
1732+ const string &pname = dxc_renames.at (arg.name );
1733+ stream << " " << print_type (arg.type ) << " " << print_name (pname) << " ;\n " ;
17341734 }
17351735 }
17361736 stream << " };\n " ;
@@ -1791,9 +1791,10 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s,
17911791 close_scope (" kernel " + name);
17921792
17931793 for (const auto &arg : args) {
1794- // Remove buffer arguments from allocation scope
1794+ // Remove buffer arguments from allocation scope.
1795+ // DXC allocations were pushed under prefixed names; FXC under original names.
17951796 if (arg.is_buffer ) {
1796- allocations.pop (arg.name );
1797+ allocations.pop (use_dxc ? (name + " _ " + arg. name ) : arg.name );
17971798 }
17981799 }
17991800
0 commit comments