-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Open
Labels
Description
Which component has the problem?
Cutlass C++
Bug Report
There is an issue with the TMA copy implementation in the file cute/arch/copy_sm90_tma.hpp.
The following function is defined in a header such that it should be compilable for both host and device code. However, its implementation relies on device-only operations, leading to a compilation error.
Is this incorrect? Thanks!
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
void * smem_ptr,
int32_t const& crd0)
{
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr);
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr);
#if defined(CUTE_ARCH_TMA_SM120_ENABLED)
asm volatile (
"cp.async.bulk.tensor.1d.shared::cta.global.mbarrier::complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3}], [%2], %4;"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "l"(cache_hint)
: "memory");
#else
asm volatile (
"cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3}], [%2], %4;"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "l"(cache_hint)
: "memory");
#endif
#else
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED.");
#endif
}