Skip to content

Commit 37f4d92

Browse files
committed
Fix uint16 cast errors
1 parent 2486419 commit 37f4d92

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

src/CodeGen_D3D12Compute_Dev.cpp

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,11 @@ string hlsl_reinterpret_name(Type t) {
620620
if (t.is_float()) {
621621
return t.bits() == 16 ? "asfloat16" : "asfloat";
622622
}
623+
// 16-bit integer variants (asuint16/asint16) require SM 6.2+, but that is
624+
// the same requirement as using float16_t/uint16_t scalars at all.
625+
if (t.bits() == 16) {
626+
return t.is_int() ? "asint16" : "asuint16";
627+
}
623628
return t.is_int() ? "asint" : "asuint";
624629
}
625630

@@ -756,11 +761,12 @@ struct StoragePackUnpack {
756761
void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Load *op) {
757762
user_assert(is_const_one(op->predicate)) << "Predicated load is not supported inside D3D12Compute kernel.\n";
758763

759-
// elements in a threadgroup shared buffer are always 32bits:
760-
// must reinterpret (and maybe unpack) bits.
764+
// SM 5.1 groupshared buffers are always 32-bit; sub-32-bit loads need a
765+
// bit-reinterpret cast. SM 6.2+ supports 16-bit natively — no cast needed.
766+
const int sm = target.get_d3d12compute_capability_lower_bound();
761767
bool shared_promotion_required = false;
762768
string promotion_str = "";
763-
if (groupshared_allocations.contains(op->name)) {
769+
if (groupshared_allocations.contains(op->name) && sm < 62) {
764770
internal_assert(allocations.contains(op->name));
765771
Type promoted_type = op->type.with_bits(32).with_lanes(1);
766772
if (promoted_type != op->type) {
@@ -1030,11 +1036,12 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Store *op) {
10301036

10311037
Type value_type = op->value.type();
10321038

1033-
// elements in a threadgroup shared buffer are always 32bits:
1034-
// must reinterpret (and maybe pack) bits.
1039+
// SM 5.1 groupshared buffers are always 32-bit; sub-32-bit stores need a
1040+
// bit-reinterpret cast. SM 6.2+ supports 16-bit natively — no cast needed.
1041+
const int sm = target.get_d3d12compute_capability_lower_bound();
10351042
bool shared_promotion_required = false;
10361043
string promotion_str = "";
1037-
if (groupshared_allocations.contains(op->name)) {
1044+
if (groupshared_allocations.contains(op->name) && sm < 62) {
10381045
const auto *alloc = allocations.find(op->name);
10391046
internal_assert(alloc);
10401047
Type promoted_type = alloc->type;
@@ -1533,18 +1540,19 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s,
15331540
FindSharedAllocationsAndUniquify fsa;
15341541
s = fsa(s);
15351542

1543+
const int sm = target.get_d3d12compute_capability_lower_bound();
15361544
uint32_t total_shared_bytes = 0;
15371545
for (const Stmt &sop : fsa.allocs) {
15381546
const Allocate *op = sop.as<Allocate>();
15391547
internal_assert(op->extents.size() == 1);
15401548
internal_assert(op->type.lanes() == 1);
1541-
// In D3D12/HLSL, only 32bit types (int/uint/float) are supported (even
1542-
// though things are changing with newer shader models). Since there is
1543-
// no uint8 type, we'll have to emulate it with 32bit types...
1544-
// This will also require pack/unpack logic with bit-masking and aliased
1545-
// type reinterpretation via asfloat()/asuint() in the shader code... :(
1549+
// SM 5.1 only supports 32-bit types in groupshared memory; promote
1550+
// sub-32-bit types and use bit reinterpretation on load/store.
1551+
// SM 6.2+ supports 16-bit types natively — no promotion needed.
15461552
Type smem_type = op->type;
1547-
smem_type.with_bits(32);
1553+
if (sm < 62) {
1554+
smem_type = smem_type.with_bits(32);
1555+
}
15481556
stream << "groupshared"
15491557
<< " " << print_type(smem_type)
15501558
<< " " << print_name(op->name);
@@ -1614,7 +1622,6 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s,
16141622
};
16151623
FindThreadGroupSize ftg;
16161624
s.accept(&ftg);
1617-
const int sm = target.get_d3d12compute_capability_lower_bound();
16181625
const bool use_dxc = (sm >= 60);
16191626

16201627
if (use_dxc) {

0 commit comments

Comments
 (0)