Skip to content

Commit 58223e6

Browse files
committed
Address PR review: fix DXC codegen bugs
- Remove dead duplicate strict_float handler (bad merge artifact); the ScopedValue<emit_precise> approach is the correct implementation - Fix strict_fma to emit mad() for float/float16 (HLSL fma() is double-only; mad() is the correct float FMA intrinsic) - Add hlsl_reinterpret_name() helper so 'as' casts always emit the correct HLSL name (asfloat/asint/asuint/asfloat16) regardless of SM level; fixes asfloat32_t -> asfloat regression at SM >= 6.2 - Fix DXC global name clashes: prefix all arg names with the kernel name via RenameKernelArgs IRMutator; rename cbuffer per-kernel to avoid redeclaration errors in multi-kernel pipelines
1 parent 2eba11b commit 58223e6

File tree

1 file changed

+73
-13
lines changed

1 file changed

+73
-13
lines changed

src/CodeGen_D3D12Compute_Dev.cpp

Lines changed: 73 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,10 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Call *op) {
573573
Expr equiv = Call::make(op->type, "pow", op->args, Call::PureExtern);
574574
equiv.accept(this);
575575
} else if (op->is_strict_float_intrinsic()) {
576+
// Emit with HLSL 'precise' qualifier, which prevents the compiler from
577+
// reordering or fusing these operations (e.g. reassociating additions or
578+
// collapsing separate mul+add into an FMA).
579+
// Note: HLSL fma() is double-only; for float use mad() instead.
576580
ScopedValue old_emit_precise(emit_precise, true);
577581
Expr equiv = op->is_intrinsic(Call::strict_fma) ?
578582
Call::make(op->type, op->type.bits() == 64 ? "fma" : "mad", op->args, Call::PureExtern) :
@@ -611,6 +615,15 @@ string hex_literal(T value) {
611615
return hex.str();
612616
}
613617

618+
// Return the HLSL bit-reinterpret intrinsic name for a given type.
619+
// These names are fixed regardless of SM level (legacy aliases always work).
620+
string hlsl_reinterpret_name(Type t) {
621+
if (t.is_float()) {
622+
return t.bits() == 16 ? "asfloat16" : "asfloat";
623+
}
624+
return t.is_int() ? "asint" : "asuint";
625+
}
626+
614627
} // namespace
615628

616629
struct StoragePackUnpack {
@@ -690,7 +703,7 @@ struct StoragePackUnpack {
690703
// the smallest type granularity in HLSL SM 5.1 allows is 32bit types):
691704
if (op->type.bits() == 32) {
692705
// loading a 32bit word? great! just reinterpret as float/int/uint
693-
rhs << "as" << cg.print_type(op->type.element_of())
706+
rhs << hlsl_reinterpret_name(op->type.element_of())
694707
<< "("
695708
<< cg.print_name(op->name)
696709
<< "[" << cg.print_expr(op->index) << "]"
@@ -754,7 +767,7 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Load *op) {
754767
if (promoted_type != op->type) {
755768
shared_promotion_required = true;
756769
// NOTE(marcos): might need to resort to StoragePackUnpack::unpack_load() here
757-
promotion_str = "as" + print_type(promoted_type);
770+
promotion_str = hlsl_reinterpret_name(promoted_type);
758771
}
759772
}
760773

@@ -1029,7 +1042,7 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Store *op) {
10291042
if (promoted_type != op->value.type()) {
10301043
shared_promotion_required = true;
10311044
// NOTE(marcos): might need to resort to StoragePackUnpack::pack_store() here
1032-
promotion_str = "as" + print_type(promoted_type);
1045+
promotion_str = hlsl_reinterpret_name(promoted_type);
10331046
}
10341047
}
10351048

@@ -1680,9 +1693,53 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s,
16801693
const bool use_dxc = (sm >= 60);
16811694

16821695
if (use_dxc) {
1683-
// DXC (SM 6.x) does not support FXC-style resource/uniform parameters to entry
1684-
// functions. Declare all resources globally with explicit register bindings and
1685-
// put scalar uniforms in a constant buffer. The runtime binds:
1696+
// DXC (SM 6.x): resources must be at global module scope, not function
1697+
// parameters. When multiple kernels share a module, arg names can clash.
1698+
// Prefix every arg name with the kernel name to guarantee uniqueness.
1699+
std::map<string, string> dxc_renames;
1700+
for (const auto &arg : args) {
1701+
dxc_renames[arg.name] = name + "_" + arg.name;
1702+
}
1703+
1704+
// Mutate Load/Store/Variable nodes in the body to use the prefixed names.
1705+
class RenameKernelArgs : public IRMutator {
1706+
using IRMutator::visit;
1707+
const std::map<string, string> &renames;
1708+
Expr visit(const Load *op) override {
1709+
auto it = renames.find(op->name);
1710+
if (it != renames.end()) {
1711+
return Load::make(op->type, it->second,
1712+
mutate(op->index), op->image, op->param,
1713+
mutate(op->predicate), op->alignment);
1714+
}
1715+
return IRMutator::visit(op);
1716+
}
1717+
Stmt visit(const Store *op) override {
1718+
auto it = renames.find(op->name);
1719+
if (it != renames.end()) {
1720+
return Store::make(it->second, mutate(op->value),
1721+
mutate(op->index), op->param,
1722+
mutate(op->predicate), op->alignment);
1723+
}
1724+
return IRMutator::visit(op);
1725+
}
1726+
Expr visit(const Variable *op) override {
1727+
auto it = renames.find(op->name);
1728+
if (it != renames.end()) {
1729+
return Variable::make(op->type, it->second,
1730+
op->image, op->param, op->reduction_domain);
1731+
}
1732+
return IRMutator::visit(op);
1733+
}
1734+
1735+
public:
1736+
RenameKernelArgs(const std::map<string, string> &r)
1737+
: renames(r) {}
1738+
};
1739+
s = RenameKernelArgs(dxc_renames)(s);
1740+
1741+
// Declare all resources globally with explicit register bindings and
1742+
// put scalar uniforms in a per-kernel constant buffer. The runtime binds:
16861743
// - scalar args → cbuffer at register(b0)
16871744
// - buffer args → UAV at register(u0), register(u1), ...
16881745
int uav_index = 0;
@@ -1692,28 +1749,30 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s,
16921749
has_scalars = true;
16931750
continue;
16941751
}
1752+
const string &pname = dxc_renames.at(arg.name);
16951753
if (arg.memory_type == MemoryType::GPUTexture) {
16961754
int dims = arg.dimensions;
16971755
internal_assert(dims >= 1 && dims <= 3) << "D3D12Compute texture must have 1-3 dimensions\n";
16981756
stream << "RWTexture" << dims << "D"
16991757
<< "<" << print_type(arg.type) << ">"
1700-
<< " " << print_name(arg.name)
1758+
<< " " << print_name(pname)
17011759
<< " : register(u" << uav_index++ << ");\n";
17021760
} else {
17031761
stream << "RWBuffer"
17041762
<< "<" << print_type(arg.type) << ">"
1705-
<< " " << print_name(arg.name)
1763+
<< " " << print_name(pname)
17061764
<< " : register(u" << uav_index++ << ");\n";
17071765
}
17081766
Allocation alloc;
17091767
alloc.type = arg.type;
1710-
allocations.push(arg.name, alloc);
1768+
allocations.push(pname, alloc);
17111769
}
17121770
if (has_scalars) {
1713-
stream << "cbuffer _halide_uniform_args : register(b0) {\n";
1771+
stream << "cbuffer " << name << "_uniforms : register(b0) {\n";
17141772
for (const auto &arg : args) {
17151773
if (!arg.is_buffer) {
1716-
stream << " " << print_type(arg.type) << " " << print_name(arg.name) << ";\n";
1774+
const string &pname = dxc_renames.at(arg.name);
1775+
stream << " " << print_type(arg.type) << " " << print_name(pname) << ";\n";
17171776
}
17181777
}
17191778
stream << "};\n";
@@ -1774,9 +1833,10 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s,
17741833
close_scope("kernel " + name);
17751834

17761835
for (const auto &arg : args) {
1777-
// Remove buffer arguments from allocation scope
1836+
// Remove buffer arguments from allocation scope.
1837+
// DXC allocations were pushed under prefixed names; FXC under original names.
17781838
if (arg.is_buffer) {
1779-
allocations.pop(arg.name);
1839+
allocations.pop(use_dxc ? (name + "_" + arg.name) : arg.name);
17801840
}
17811841
}
17821842

0 commit comments

Comments
 (0)