Skip to content

Commit

Permalink
Adding structures needed for sparse support (#819)
Browse files Browse the repository at this point in the history
  • Loading branch information
cliffburdick authored Jan 14, 2025
1 parent 1bb2f9c commit f76060c
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 84 deletions.
2 changes: 1 addition & 1 deletion include/matx/core/storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ namespace matx
*
* @return Size of allocation
*/
size_t size() const
__MATX_INLINE__ size_t size() const
{
return size_;
}
Expand Down
50 changes: 22 additions & 28 deletions include/matx/core/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
* @param rhs Object to copy from
*/
__MATX_HOST__ tensor_t(tensor_t const &rhs) noexcept
: detail::tensor_impl_t<T, RANK, Desc>{rhs.ldata_, rhs.desc_}, storage_(rhs.storage_)
: detail::tensor_impl_t<T, RANK, Desc>{rhs.Data(), rhs.desc_}, storage_(rhs.storage_)
{ }

/**
Expand All @@ -116,7 +116,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
* @param rhs Object to move from
*/
__MATX_HOST__ tensor_t(tensor_t &&rhs) noexcept
: detail::tensor_impl_t<T, RANK, Desc>{rhs.ldata_, std::move(rhs.desc_)}, storage_(std::move(rhs.storage_))
: detail::tensor_impl_t<T, RANK, Desc>{rhs.Data(), std::move(rhs.desc_)}, storage_(std::move(rhs.storage_))
{ }


Expand All @@ -131,7 +131,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
*/
__MATX_HOST__ void Shallow(const self_type &rhs) noexcept
{
this->ldata_ = rhs.ldata_;
this->SetData(rhs.Data());
storage_ = rhs.storage_;
this->desc_ = rhs.desc_;
}
Expand All @@ -149,7 +149,9 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
{
using std::swap;

std::swap(lhs.ldata_, rhs.ldata_);
auto tmpdata = lhs.Data();
lhs.SetData(rhs.Data());
rhs.SetData(tmpdata);
swap(lhs.storage_, rhs.storage_);
swap(lhs.desc_, rhs.desc_);
}
Expand Down Expand Up @@ -651,7 +653,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {

// Copy descriptor and call ctor with shape
Desc new_desc{std::forward<Shape>(shape)};
return tensor_t<M, R, Storage, Desc>{storage_, std::move(new_desc), this->ldata_};
return tensor_t<M, R, Storage, Desc>{storage_, std::move(new_desc), this->Data()};
}

/**
Expand Down Expand Up @@ -710,7 +712,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
"To get a reshaped view the tensor must be compact");

DefaultDescriptor<tshape.size()> desc{std::move(tshape)};
return tensor_t<T, NRANK, Storage, decltype(desc)>{storage_, std::move(desc), this->ldata_};
return tensor_t<T, NRANK, Storage, decltype(desc)>{storage_, std::move(desc), this->Data()};
}

/**
Expand Down Expand Up @@ -739,7 +741,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {

int dev;
cudaGetDevice(&dev);
cudaMemPrefetchAsync(this->ldata_, this->desc_.TotalSize() * sizeof(T), dev, stream);
cudaMemPrefetchAsync(this->Data(), this->desc_.TotalSize() * sizeof(T), dev, stream);
}

/**
Expand All @@ -756,7 +758,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
{
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)

cudaMemPrefetchAsync(this->ldata_, this->desc_.TotalSize() * sizeof(T), cudaCpuDeviceId,
cudaMemPrefetchAsync(this->Data(), this->desc_.TotalSize() * sizeof(T), cudaCpuDeviceId,
stream);
}

Expand All @@ -776,7 +778,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
static_assert(is_complex_v<T>, "RealView() only works with complex types");

using Type = typename U::value_type;
Type *data = reinterpret_cast<Type *>(this->ldata_);
Type *data = reinterpret_cast<Type *>(this->Data());
cuda::std::array<typename Desc::stride_type, RANK> strides;

#pragma unroll
Expand Down Expand Up @@ -821,7 +823,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
static_assert(is_complex_v<T>, "ImagView() only works with complex types");

using Type = typename U::value_type;
Type *data = reinterpret_cast<Type *>(this->ldata_) + 1;
Type *data = reinterpret_cast<Type *>(this->Data()) + 1;
cuda::std::array<stride_type, RANK> strides;
#pragma unroll
for (int i = 0; i < RANK; i++) {
Expand Down Expand Up @@ -859,7 +861,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)

auto new_desc = this->PermuteImpl(dims);
return tensor_t<T, RANK, Storage, Desc>{storage_, std::move(new_desc), this->ldata_};
return tensor_t<T, RANK, Storage, Desc>{storage_, std::move(new_desc), this->Data()};
}


Expand Down Expand Up @@ -904,14 +906,6 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
return Permute(tdims);
}

/**
* Get the underlying local data pointer from the view
*
* @returns Underlying data pointer of type T
*
*/
__MATX_HOST__ __MATX_INLINE__ T *Data() const noexcept { return this->ldata_; }

/**
* Set the underlying data pointer from the view
*
Expand All @@ -933,7 +927,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
{
this->desc_.InitFromShape(std::forward<ShapeType>(shape));
storage_.SetData(data);
this->ldata_ = data;
this->SetData(data);
}

/**
Expand All @@ -953,7 +947,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)

storage_.SetData(data);
this->ldata_ = data;
this->SetData(data);
}

/**
Expand All @@ -973,7 +967,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
Reset(T *const data, T *const ldata) noexcept
{
storage_.SetData(data);
this->ldata_ = ldata;
this->SetData(data);
}


Expand Down Expand Up @@ -1035,7 +1029,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
OverlapView(const cuda::std::array<typename Desc::shape_type, N> &windows,
const cuda::std::array<typename Desc::stride_type, N> &strides) const {
auto new_desc = this->template OverlapViewImpl<N>(windows, strides);
return tensor_t<T, RANK + 1, Storage, decltype(new_desc)>{storage_, std::move(new_desc), this->ldata_};
return tensor_t<T, RANK + 1, Storage, decltype(new_desc)>{storage_, std::move(new_desc), this->Data()};
}

/**
Expand Down Expand Up @@ -1069,7 +1063,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)

auto new_desc = this->template CloneImpl<N>(clones);
return tensor_t<T, N, Storage, decltype(new_desc)>{storage_, std::move(new_desc), this->ldata_};
return tensor_t<T, N, Storage, decltype(new_desc)>{storage_, std::move(new_desc), this->Data()};
}

template <int N>
Expand All @@ -1080,7 +1074,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {

__MATX_INLINE__ __MATX_HOST__ bool IsManagedPointer() {
bool managed;
const CUresult retval = cuPointerGetAttribute(&managed, CU_POINTER_ATTRIBUTE_IS_MANAGED, (CUdeviceptr)Data());
const CUresult retval = cuPointerGetAttribute(&managed, CU_POINTER_ATTRIBUTE_IS_MANAGED, (CUdeviceptr)this->Data());
MATX_ASSERT(retval == CUDA_SUCCESS, matxNotSupported);
return managed;
}
Expand Down Expand Up @@ -1454,12 +1448,12 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
int dev_ord;
void *data[2] = {&mem_type, &dev_ord};

t->data = static_cast<void*>(this->ldata_);
t->data = static_cast<void*>(this->Data());
t->device.device_id = 0;

// Determine where this memory resides
auto kind = GetPointerKind(this->ldata_);
auto mem_res = cuPointerGetAttributes(sizeof(attr)/sizeof(attr[0]), attr, data, reinterpret_cast<CUdeviceptr>(this->ldata_));
auto kind = GetPointerKind(this->Data());
auto mem_res = cuPointerGetAttributes(sizeof(attr)/sizeof(attr[0]), attr, data, reinterpret_cast<CUdeviceptr>(this->Data()));
MATX_ASSERT_STR_EXP(mem_res, CUDA_SUCCESS, matxCudaError, "Error returned from cuPointerGetAttributes");
if (kind == MATX_INVALID_MEMORY) {
if (mem_type == CU_MEMORYTYPE_DEVICE) {
Expand Down
Loading

0 comments on commit f76060c

Please sign in to comment.