Skip to content

Commit 4fed562

Browse files
authored
Remove const qualifier on input to thrust iterator (#879)
1 parent 4378aaf commit 4fed562

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

include/matx/core/iterator.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -317,13 +317,16 @@ struct RandomOperatorOutputIterator {
317317
template <typename OperatorType, bool ConvertType = true>
318318
struct RandomOperatorThrustIterator {
319319
using self_type = RandomOperatorThrustIterator<OperatorType, ConvertType>;
320-
using value_type = typename std::conditional_t<ConvertType, detail::convert_matx_type_t<typename OperatorType::value_type>, typename OperatorType::value_type>;
320+
using const_strip_type = remove_cvref_t<typename OperatorType::value_type>;
321+
using value_type = typename std::conditional_t<ConvertType,
322+
detail::convert_matx_type_t<const_strip_type>,
323+
const_strip_type>;
321324
// using stride_type = std::conditional_t<is_tensor_view_v<OperatorType>, typename OperatorType::desc_type::stride_type,
322325
// index_t>;
323326
using stride_type = index_t;
324-
using pointer = value_type*;
325-
using reference = value_type&;
326-
using const_reference = value_type&;
327+
using pointer = cuda::std::remove_const_t<value_type>*;
328+
using reference = cuda::std::remove_const_t<value_type>&;
329+
using const_reference = cuda::std::remove_const_t<value_type>&;
327330
using iterator_category = std::random_access_iterator_tag;
328331
using difference_type = index_t;
329332
using OperatorBaseType = typename detail::base_type_t<OperatorType>;
@@ -344,14 +347,14 @@ struct RandomOperatorThrustIterator {
344347
[[nodiscard]] __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ reference operator*() const
345348
{
346349
if constexpr (OperatorType::Rank() == 0) {
347-
auto &tmp = t_.operator()();
350+
auto &tmp = const_cast<const_strip_type&>(t_.operator()());
348351
return tmp;
349352
}
350353
else {
351354
auto arrs = detail::GetIdxFromAbs(t_, offset_);
352355

353356
return cuda::std::apply([&](auto &&...args) -> reference {
354-
auto &tmp = t_.operator()(args...);
357+
auto &tmp = const_cast<const_strip_type&>(t_.operator()(args...));
355358
return tmp;
356359
}, arrs);
357360
}

0 commit comments

Comments
 (0)