Fix: print subbyte<T> compilation error #2783
Open
+0
−6
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Bug Fix
Which component has the problem?
CuTe C++
Describe the bug
When printing a tensor of subbyte
<cutlass::float_e2m1_t>created bymake_fragment_like, a compilation error occurs.Steps/Code to reproduce bug
With repo at commit:
a2439551c765c5393aebe557ee75d3a0412d2211Outputs
Compilation error:
error: more than one instance of overloaded function "cuda_kernel::print" matches the argument list:
function template "void cute::print(const cute::subbyte_reference &)" (declared at line 370 of ../third_party/cutlass/include/cute/container/array_subbyte.hpp)
function template "void cute::print(cute::subbyte_reference)" (declared at line 198 of ../third_party/cutlass/include/cute/container/array_subbyte.hpp)
argument types are: (cute::subbyte_referencecutlass::float_e2m1_t)
print(A_tensor_fp4(0)); print("\n");
Expected behavior
compile pass & result correct.
Environment details
Additional context
The two overloads are indistinguishable at the "pass-by-value / pass-by-const-reference" level. Changing one of them to "accept only rvalues" allows the compiler to make a unique distinction: "passing an lvalue invokes the const& version, while passing an rvalue invokes the && version."
With the rvalue overload added, the code now compiles successfully and produces the expected results.
Could you please take a look and let me know your thoughts?
Thanks!
@ccecka @thakkarV