Skip to content

Commit 5369e1b

Browse files
committed
Address PR review: fix DXC codegen bugs
- Remove dead duplicate strict_float handler (bad merge artifact); the ScopedValue<emit_precise> approach is the correct implementation - Fix strict_fma to emit mad() for float/float16 (HLSL fma() is double-only; mad() is the correct float FMA intrinsic) - Add hlsl_reinterpret_name() helper so 'as' casts always emit the correct HLSL name (asfloat/asint/asuint/asfloat16) regardless of SM level; fixes asfloat32_t -> asfloat regression at SM >= 6.2 - Fix DXC global name clashes: prefix all arg names with the kernel name via RenameKernelArgs IRMutator; rename cbuffer per-kernel to avoid redeclaration errors in multi-kernel pipelines
1 parent 263effe commit 5369e1b

File tree

1 file changed

+76
-75
lines changed

1 file changed

+76
-75
lines changed

src/CodeGen_D3D12Compute_Dev.cpp

Lines changed: 76 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -529,76 +529,21 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Call *op) {
529529
Expr equiv = Call::make(op->type, "pow", op->args, Call::PureExtern);
530530
equiv.accept(this);
531531
} else if (op->is_strict_float_intrinsic()) {
532+
// Emit with HLSL 'precise' qualifier, which prevents the compiler from
533+
// reordering or fusing these operations (e.g. reassociating additions or
534+
// collapsing separate mul+add into an FMA).
535+
// Note: HLSL fma() is double-only; for float use mad() instead.
532536
ScopedValue old_emit_precise(emit_precise, true);
533537
Expr equiv = op->is_intrinsic(Call::strict_fma) ?
534-
Call::make(op->type, "fma", op->args, Call::PureExtern) :
538+
Call::make(op->type,
539+
op->type.bits() == 64 ? "fma" : "mad",
540+
op->args, Call::PureExtern) :
535541
unstrictify_float(op);
536542
equiv.accept(this);
537543
} else if (op->is_intrinsic(Call::round)) {
538544
// HLSL's round intrinsic has the correct semantics for our rounding.
539545
Expr equiv = Call::make(op->type, "round", op->args, Call::PureExtern);
540546
equiv.accept(this);
541-
} else if (op->is_strict_float_intrinsic()) {
542-
// Emit strict float intrinsics as HLSL 'precise'-qualified temporaries.
543-
// The 'precise' qualifier prevents the HLSL compiler from reordering or
544-
// fusing these operations (e.g. forming FMAs, reassociating additions).
545-
// We use a "precise:" cache key prefix so that strict and non-strict
546-
// evaluations of the same sub-expression don't share a variable.
547-
string rhs;
548-
if (op->is_intrinsic(Call::strict_fma)) {
549-
string a = print_expr(op->args[0]);
550-
string b = print_expr(op->args[1]);
551-
string c = print_expr(op->args[2]);
552-
rhs = "mad(" + a + ", " + b + ", " + c + ")";
553-
} else if (op->is_intrinsic(Call::strict_add)) {
554-
string a = print_expr(op->args[0]);
555-
string b = print_expr(op->args[1]);
556-
rhs = a + " + " + b;
557-
} else if (op->is_intrinsic(Call::strict_sub)) {
558-
string a = print_expr(op->args[0]);
559-
string b = print_expr(op->args[1]);
560-
rhs = a + " - " + b;
561-
} else if (op->is_intrinsic(Call::strict_mul)) {
562-
string a = print_expr(op->args[0]);
563-
string b = print_expr(op->args[1]);
564-
rhs = a + " * " + b;
565-
} else if (op->is_intrinsic(Call::strict_div)) {
566-
string a = print_expr(op->args[0]);
567-
string b = print_expr(op->args[1]);
568-
rhs = a + " / " + b;
569-
} else if (op->is_intrinsic(Call::strict_min)) {
570-
string a = print_expr(op->args[0]);
571-
string b = print_expr(op->args[1]);
572-
rhs = "min(" + a + ", " + b + ")";
573-
} else if (op->is_intrinsic(Call::strict_max)) {
574-
string a = print_expr(op->args[0]);
575-
string b = print_expr(op->args[1]);
576-
rhs = "max(" + a + ", " + b + ")";
577-
} else if (op->is_intrinsic(Call::strict_lt)) {
578-
string a = print_expr(op->args[0]);
579-
string b = print_expr(op->args[1]);
580-
rhs = a + " < " + b;
581-
} else if (op->is_intrinsic(Call::strict_le)) {
582-
string a = print_expr(op->args[0]);
583-
string b = print_expr(op->args[1]);
584-
rhs = a + " <= " + b;
585-
} else if (op->is_intrinsic(Call::strict_eq)) {
586-
string a = print_expr(op->args[0]);
587-
string b = print_expr(op->args[1]);
588-
rhs = a + " == " + b;
589-
} else {
590-
internal_assert(op->is_intrinsic(Call::strict_cast));
591-
rhs = "(" + print_type(op->type) + ")(" + print_expr(op->args[0]) + ")";
592-
}
593-
const string key = "precise:" + rhs;
594-
const auto it = cache.find(key);
595-
if (it == cache.end()) {
596-
id = unique_name('_');
597-
stream << get_indent() << "precise " << print_type(op->type) << " " << id << " = " << rhs << ";\n";
598-
cache[key] = id;
599-
} else {
600-
id = it->second;
601-
}
602547
} else {
603548
CodeGen_GPU_C::visit(op);
604549
}
@@ -628,6 +573,15 @@ string hex_literal(T value) {
628573
return hex.str();
629574
}
630575

576+
// Return the HLSL bit-reinterpret intrinsic name for a given type.
577+
// These names are fixed regardless of SM level (legacy aliases always work).
578+
string hlsl_reinterpret_name(Type t) {
579+
if (t.is_float()) {
580+
return t.bits() == 16 ? "asfloat16" : "asfloat";
581+
}
582+
return t.is_int() ? "asint" : "asuint";
583+
}
584+
631585
} // namespace
632586

633587
struct StoragePackUnpack {
@@ -707,7 +661,7 @@ struct StoragePackUnpack {
707661
// the smallest type granularity in HLSL SM 5.1 allows is 32bit types):
708662
if (op->type.bits() == 32) {
709663
// loading a 32bit word? great! just reinterpret as float/int/uint
710-
rhs << "as" << cg.print_type(op->type.element_of())
664+
rhs << hlsl_reinterpret_name(op->type.element_of())
711665
<< "("
712666
<< cg.print_name(op->name)
713667
<< "[" << cg.print_expr(op->index) << "]"
@@ -771,7 +725,7 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Load *op) {
771725
if (promoted_type != op->type) {
772726
shared_promotion_required = true;
773727
// NOTE(marcos): might need to resort to StoragePackUnpack::unpack_load() here
774-
promotion_str = "as" + print_type(promoted_type);
728+
promotion_str = hlsl_reinterpret_name(promoted_type);
775729
}
776730
}
777731

@@ -1046,7 +1000,7 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Store *op) {
10461000
if (promoted_type != op->value.type()) {
10471001
shared_promotion_required = true;
10481002
// NOTE(marcos): might need to resort to StoragePackUnpack::pack_store() here
1049-
promotion_str = "as" + print_type(promoted_type);
1003+
promotion_str = hlsl_reinterpret_name(promoted_type);
10501004
}
10511005
}
10521006

@@ -1697,9 +1651,53 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s,
16971651
const bool use_dxc = (sm >= 60);
16981652

16991653
if (use_dxc) {
1700-
// DXC (SM 6.x) does not support FXC-style resource/uniform parameters to entry
1701-
// functions. Declare all resources globally with explicit register bindings and
1702-
// put scalar uniforms in a constant buffer. The runtime binds:
1654+
// DXC (SM 6.x): resources must be at global module scope, not function
1655+
// parameters. When multiple kernels share a module, arg names can clash.
1656+
// Prefix every arg name with the kernel name to guarantee uniqueness.
1657+
std::map<string, string> dxc_renames;
1658+
for (const auto &arg : args) {
1659+
dxc_renames[arg.name] = name + "_" + arg.name;
1660+
}
1661+
1662+
// Mutate Load/Store/Variable nodes in the body to use the prefixed names.
1663+
class RenameKernelArgs : public IRMutator {
1664+
using IRMutator::visit;
1665+
const std::map<string, string> &renames;
1666+
Expr visit(const Load *op) override {
1667+
auto it = renames.find(op->name);
1668+
if (it != renames.end()) {
1669+
return Load::make(op->type, it->second,
1670+
mutate(op->index), op->image, op->param,
1671+
mutate(op->predicate), op->alignment);
1672+
}
1673+
return IRMutator::visit(op);
1674+
}
1675+
Stmt visit(const Store *op) override {
1676+
auto it = renames.find(op->name);
1677+
if (it != renames.end()) {
1678+
return Store::make(it->second, mutate(op->value),
1679+
mutate(op->index), op->param,
1680+
mutate(op->predicate), op->alignment);
1681+
}
1682+
return IRMutator::visit(op);
1683+
}
1684+
Expr visit(const Variable *op) override {
1685+
auto it = renames.find(op->name);
1686+
if (it != renames.end()) {
1687+
return Variable::make(op->type, it->second,
1688+
op->image, op->param, op->reduction_domain);
1689+
}
1690+
return IRMutator::visit(op);
1691+
}
1692+
1693+
public:
1694+
RenameKernelArgs(const std::map<string, string> &r)
1695+
: renames(r) {}
1696+
};
1697+
s = RenameKernelArgs(dxc_renames)(s);
1698+
1699+
// Declare all resources globally with explicit register bindings and
1700+
// put scalar uniforms in a per-kernel constant buffer. The runtime binds:
17031701
// - scalar args → cbuffer at register(b0)
17041702
// - buffer args → UAV at register(u0), register(u1), ...
17051703
int uav_index = 0;
@@ -1709,28 +1707,30 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s,
17091707
has_scalars = true;
17101708
continue;
17111709
}
1710+
const string &pname = dxc_renames.at(arg.name);
17121711
if (arg.memory_type == MemoryType::GPUTexture) {
17131712
int dims = arg.dimensions;
17141713
internal_assert(dims >= 1 && dims <= 3) << "D3D12Compute texture must have 1-3 dimensions\n";
17151714
stream << "RWTexture" << dims << "D"
17161715
<< "<" << print_type(arg.type) << ">"
1717-
<< " " << print_name(arg.name)
1716+
<< " " << print_name(pname)
17181717
<< " : register(u" << uav_index++ << ");\n";
17191718
} else {
17201719
stream << "RWBuffer"
17211720
<< "<" << print_type(arg.type) << ">"
1722-
<< " " << print_name(arg.name)
1721+
<< " " << print_name(pname)
17231722
<< " : register(u" << uav_index++ << ");\n";
17241723
}
17251724
Allocation alloc;
17261725
alloc.type = arg.type;
1727-
allocations.push(arg.name, alloc);
1726+
allocations.push(pname, alloc);
17281727
}
17291728
if (has_scalars) {
1730-
stream << "cbuffer _halide_uniform_args : register(b0) {\n";
1729+
stream << "cbuffer " << name << "_uniforms : register(b0) {\n";
17311730
for (const auto &arg : args) {
17321731
if (!arg.is_buffer) {
1733-
stream << " " << print_type(arg.type) << " " << print_name(arg.name) << ";\n";
1732+
const string &pname = dxc_renames.at(arg.name);
1733+
stream << " " << print_type(arg.type) << " " << print_name(pname) << ";\n";
17341734
}
17351735
}
17361736
stream << "};\n";
@@ -1791,9 +1791,10 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s,
17911791
close_scope("kernel " + name);
17921792

17931793
for (const auto &arg : args) {
1794-
// Remove buffer arguments from allocation scope
1794+
// Remove buffer arguments from allocation scope.
1795+
// DXC allocations were pushed under prefixed names; FXC under original names.
17951796
if (arg.is_buffer) {
1796-
allocations.pop(arg.name);
1797+
allocations.pop(use_dxc ? (name + "_" + arg.name) : arg.name);
17971798
}
17981799
}
17991800

0 commit comments

Comments
 (0)