Skip to content

[BUG] Const operators don't work with argmax #878

@deanljohnson

Description

@deanljohnson

Describe the Bug
Using argmax with operators of a const data type fails to compile.

To Reproduce

  1. Call make_tensor with a const T*
  2. Call argmax

This will produce an error such as:

[build] /usr/local/include/matx/transforms/reduce.h(1173): error: expression must be a modifiable lvalue
[build]         smem[threadIdx.x / 32] = in_val;
[build]         ^
[build]           detected during:
[build]             instantiation of "void matx::detail::matxReduceKernel(OutType, InType, ReduceOp, matx::index_t) [with OutType=matx::detail::tensor_impl_t<float, 0, matx::tensor_desc_t<cuda::std::__4::array<long long, 0UL>, cuda::std::__4::array<long long, 0UL>, 0>>, InType=matx::detail::tensor_impl_t<const float, 1, matx::tensor_desc_t<cuda::std::__4::array<long long, 1UL>, cuda::std::__4::array<long long, 1UL>, 1>>, ReduceOp=matx::detail::reduceOpMax<float>]" at line 1395
[build]             instantiation of "void matx::reduce(OutType, TensorIndexType, const InType &, ReduceOp, cudaStream_t, __nv_bool) [with OutType=matx::detail::tensor_impl_t<float, 0, matx::tensor_desc_t<cuda::std::__4::array<long long, 0UL>, cuda::std::__4::array<long long, 0UL>, 0>>, TensorIndexType=matx::detail::tensor_impl_t<matx::index_t, 0, matx::tensor_desc_t<cuda::std::__4::array<long long, 0UL>, cuda::std::__4::array<long long, 0UL>, 0>>, InType=matx::detail::tensor_impl_t<const float, 1, matx::tensor_desc_t<cuda::std::__4::array<long long, 1UL>, cuda::std::__4::array<long long, 1UL>, 1>>, ReduceOp=matx::detail::reduceOpMax<float>, <unnamed>=true]" at line 2088
[build]             instantiation of "void matx::argmax_impl(OutType, TensorIndexType &, const InType &, matx::cudaExecutor) [with OutType=matx::detail::tensor_impl_t<float, 0, matx::tensor_desc_t<cuda::std::__4::array<long long, 0UL>, cuda::std::__4::array<long long, 0UL>, 0>>, TensorIndexType=matx::detail::tensor_impl_t<matx::index_t, 0, matx::tensor_desc_t<cuda::std::__4::array<long long, 0UL>, cuda::std::__4::array<long long, 0UL>, 0>>, InType=matx::detail::tensor_impl_t<const float, 1, matx::tensor_desc_t<cuda::std::__4::array<long long, 1UL>, cuda::std::__4::array<long long, 1UL>, 1>>]" at line 69 of /usr/local/include/matx/operators/argmax.h
[build]             instantiation of "void matx::detail::ArgMaxOp<OpA, ORank>::Exec(Out &&, Executor &&) const [with OpA=const matx::tensor_t<const float, 1, matx::basic_storage<matx::raw_pointer_buffer<const float, matx::matx_allocator<const float>>>, matx::tensor_desc_t<cuda::std::__4::array<long long, 1UL>, cuda::std::__4::array<long long, 1UL>, 1>> &, ORank=0, Out=cuda::std::__4::tuple<matx::detail::tensor_impl_t<float, 0, matx::tensor_desc_t<cuda::std::__4::array<long long, 0UL>, cuda::std::__4::array<long long, 0UL>, 0>>, matx::detail::tensor_impl_t<matx::index_t, 0, matx::tensor_desc_t<cuda::std::__4::array<long long, 0UL>, cuda::std::__4::array<long long, 0UL>, 0>>, matx::detail::ArgMaxOp<const matx::tensor_t<const float, 1, matx::basic_storage<matx::raw_pointer_buffer<const float, matx::matx_allocator<const float>>>, matx::tensor_desc_t<cuda::std::__4::array<long long, 1UL>, cuda::std::__4::array<long long, 1UL>, 1>> &, 0>> &, Executor=matx::cudaExecutor &]" at line 100 of /usr/local/include/matx/core/tie.h
[build]             instantiation of "void matx::mtie<Ts...>::Exec(Executor &&) [with Ts=<matx::detail::tensor_impl_t<float, 0, matx::tensor_desc_t<cuda::std::__4::array<long long, 0UL>, cuda::std::__4::array<long long, 0UL>, 0>>, matx::detail::tensor_impl_t<matx::index_t, 0, matx::tensor_desc_t<cuda::std::__4::array<long long, 0UL>, cuda::std::__4::array<long long, 0UL>, 0>>, matx::detail::ArgMaxOp<const matx::tensor_t<const float, 1, matx::basic_storage<matx::raw_pointer_buffer<const float, matx::matx_allocator<const float>>>, matx::tensor_desc_t<cuda::std::__4::array<long long, 1UL>, cuda::std::__4::array<long long, 1UL>, 1>> &, 0>>, Executor=matx::cudaExecutor &]" at line 73 of /usr/local/include/matx/operators/base_operator.h
[build]             instantiation of "void matx::BaseOp<T>::run(Ex &&) [with T=matx::mtie<matx::detail::tensor_impl_t<float, 0, matx::tensor_desc_t<cuda::std::__4::array<long long, 0UL>, cuda::std::__4::array<long long, 0UL>, 0>>, matx::detail::tensor_impl_t<matx::index_t, 0, matx::tensor_desc_t<cuda::std::__4::array<long long, 0UL>, cuda::std::__4::array<long long, 0UL>, 0>>, matx::detail::ArgMaxOp<const matx::tensor_t<const float, 1, matx::basic_storage<matx::raw_pointer_buffer<const float, matx::matx_allocator<const float>>>, matx::tensor_desc_t<cuda::std::__4::array<long long, 1UL>, cuda::std::__4::array<long long, 1UL>, 1>> &, 0>>, Ex=matx::cudaExecutor]" at line 112 of /usr/local/include/matx/operators/base_operator.h

Expected Behavior
I would expect argmax to work with both const and non-const inputs

Code Snippets

const float* gpu_data = ...;
auto t_data = matx::make_tensor(gpu_data, {10});
auto data_mag = matx::make_tensor<float>({});
auto data_idx = matx::make_tensor<matx::index_t>({});
(matx::mtie(data_mag , data_idx ) = matx::argmax(t_data)).run(stream);

System Details (please complete the following information):

  • OS: Rocky 9
  • CUDA version: CUDA 12.6 12.3
  • g++ version: 11.4.1

Additional Context
None

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions