@@ -620,6 +620,11 @@ string hlsl_reinterpret_name(Type t) {
620620 if (t.is_float ()) {
621621 return t.bits () == 16 ? " asfloat16" : " asfloat" ;
622622 }
623+ // 16-bit integer variants (asuint16/asint16) require SM 6.2+, but that is
624+ // the same requirement as using float16_t/uint16_t scalars at all.
625+ if (t.bits () == 16 ) {
626+ return t.is_int () ? " asint16" : " asuint16" ;
627+ }
623628 return t.is_int () ? " asint" : " asuint" ;
624629}
625630
@@ -756,11 +761,12 @@ struct StoragePackUnpack {
756761void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit (const Load *op) {
757762 user_assert (is_const_one (op->predicate )) << " Predicated load is not supported inside D3D12Compute kernel.\n " ;
758763
759- // elements in a threadgroup shared buffer are always 32bits:
760- // must reinterpret (and maybe unpack) bits.
764+ // SM 5.1 groupshared buffers are always 32-bit; sub-32-bit loads need a
765+ // bit-reinterpret cast. SM 6.2+ supports 16-bit natively — no cast needed.
766+ const int sm = target.get_d3d12compute_capability_lower_bound ();
761767 bool shared_promotion_required = false ;
762768 string promotion_str = " " ;
763- if (groupshared_allocations.contains (op->name )) {
769+ if (groupshared_allocations.contains (op->name ) && sm < 62 ) {
764770 internal_assert (allocations.contains (op->name ));
765771 Type promoted_type = op->type .with_bits (32 ).with_lanes (1 );
766772 if (promoted_type != op->type ) {
@@ -1030,11 +1036,12 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Store *op) {
10301036
10311037 Type value_type = op->value .type ();
10321038
1033- // elements in a threadgroup shared buffer are always 32bits:
1034- // must reinterpret (and maybe pack) bits.
1039+ // SM 5.1 groupshared buffers are always 32-bit; sub-32-bit stores need a
1040+ // bit-reinterpret cast. SM 6.2+ supports 16-bit natively — no cast needed.
1041+ const int sm = target.get_d3d12compute_capability_lower_bound ();
10351042 bool shared_promotion_required = false ;
10361043 string promotion_str = " " ;
1037- if (groupshared_allocations.contains (op->name )) {
1044+ if (groupshared_allocations.contains (op->name ) && sm < 62 ) {
10381045 const auto *alloc = allocations.find (op->name );
10391046 internal_assert (alloc);
10401047 Type promoted_type = alloc->type ;
@@ -1533,18 +1540,19 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s,
15331540 FindSharedAllocationsAndUniquify fsa;
15341541 s = fsa (s);
15351542
1543+ const int sm = target.get_d3d12compute_capability_lower_bound ();
15361544 uint32_t total_shared_bytes = 0 ;
15371545 for (const Stmt &sop : fsa.allocs ) {
15381546 const Allocate *op = sop.as <Allocate>();
15391547 internal_assert (op->extents .size () == 1 );
15401548 internal_assert (op->type .lanes () == 1 );
1541- // In D3D12/HLSL, only 32bit types (int/uint/float) are supported (even
1542- // though things are changing with newer shader models). Since there is
1543- // no uint8 type, we'll have to emulate it with 32bit types...
1544- // This will also require pack/unpack logic with bit-masking and aliased
1545- // type reinterpretation via asfloat()/asuint() in the shader code... :(
1549+ // SM 5.1 only supports 32-bit types in groupshared memory; promote
1550+ // sub-32-bit types and use bit reinterpretation on load/store.
1551+ // SM 6.2+ supports 16-bit types natively — no promotion needed.
15461552 Type smem_type = op->type ;
1547- smem_type.with_bits (32 );
1553+ if (sm < 62 ) {
1554+ smem_type = smem_type.with_bits (32 );
1555+ }
15481556 stream << " groupshared"
15491557 << " " << print_type (smem_type)
15501558 << " " << print_name (op->name );
@@ -1614,7 +1622,6 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s,
16141622 };
16151623 FindThreadGroupSize ftg;
16161624 s.accept (&ftg);
1617- const int sm = target.get_d3d12compute_capability_lower_bound ();
16181625 const bool use_dxc = (sm >= 60 );
16191626
16201627 if (use_dxc) {
0 commit comments