Skip to content

Fixed StridedBufferView::to_numpy on Metal devices #207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion slangpy/tests/slangpy_tests/test_buffer_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_to_numpy(
strides = Shape(unravelled_shape).calc_contiguous_strides()
byte_strides = tuple(s * np_dtype.itemsize for s in strides)

ndarray = buffer.to_numpy()
ndarray = np.ascontiguousarray(buffer.to_numpy())
assert ndarray.shape == unravelled_shape
assert ndarray.strides == byte_strides
assert ndarray.dtype == np_dtype
Expand Down
61 changes: 58 additions & 3 deletions src/slangpy_ext/utils/slangpystridedbufferview.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <cstdint>
#include <initializer_list>
#include <vector>
#include "nanobind.h"

#include "sgl/device/device.h"
#include "sgl/device/command.h"
#include "sgl/device/buffer_cursor.h"

#include "sgl/device/reflection.h"
#include "utils/slangpybuffer.h"

namespace sgl::slangpy {
Expand Down Expand Up @@ -64,6 +67,21 @@ ref<NativeSlangType> innermost_type(ref<NativeSlangType> type)
return result;
}

std::vector<ref<NativeSlangType>> type_stack(ref<NativeSlangType> type)
{
std::vector<ref<NativeSlangType>> res;
ref<NativeSlangType> curr = type;
while (true) {
res.push_back(curr);
ref<NativeSlangType> child = curr->element_type();
if (!child || child == curr) {
break;
}
curr = child;
}
return res;
}

StridedBufferView::StridedBufferView(Device* device, const StridedBufferViewDesc& desc, ref<Buffer> storage)
{
if (!storage) {
Expand Down Expand Up @@ -324,7 +342,6 @@ static nb::ndarray<Framework> to_ndarray(void* data, nb::handle owner, const Str
// Buffer with shape (5, ) of struct Foo { ... } -> ndarray of shape (5, sizeof(Foo)) and dtype uint8
bool is_scalar = innermost_layout->type()->kind() == TypeReflection::Kind::scalar;
auto dtype_shape = desc.dtype->get_shape();
auto dtype_strides = dtype_shape.calc_contiguous_strides();

size_t innermost_size = is_scalar ? innermost_layout->stride() : 1;
TypeReflection::ScalarType scalar_type
Expand All @@ -339,9 +356,13 @@ static nb::ndarray<Framework> to_ndarray(void* data, nb::handle owner, const Str
sizes.push_back(desc.shape[i]);
strides.push_back(desc.strides[i] * dtype_size / innermost_size);
}
// Use cursor reflection to calculate dtype stride.
ref<NativeSlangType> curr_type = desc.dtype;
for (size_t i = 0; i < dtype_shape.size(); ++i) {
sizes.push_back(dtype_shape[i]);
strides.push_back(dtype_strides[i]);
curr_type = curr_type->element_type();
auto dtype_stride = curr_type->buffer_type_layout()->stride() / innermost_size;
strides.push_back(dtype_stride);
}
// If the innermost dtype is not a scalar, add one innermost dimension over
// the bytes of the element
Expand Down Expand Up @@ -388,6 +409,8 @@ nb::ndarray<nb::pytorch> StridedBufferView::to_torch() const

void StridedBufferView::copy_from_numpy(nb::ndarray<nb::numpy> data)
{
// StridedBufferView::is_contiguous() == true does not necessarily means the internal buffer is continuous in memory
// (4 element alignment requirement on metal will break this continuity).
SGL_CHECK(is_ndarray_contiguous(data), "Source Numpy array must be contiguous");
SGL_CHECK(is_contiguous(), "Destination buffer view must be contiguous");

Expand All @@ -397,7 +420,39 @@ void StridedBufferView::copy_from_numpy(nb::ndarray<nb::numpy> data)
size_t buffer_size = m_storage->size() - byte_offset;
SGL_CHECK(data_size <= buffer_size, "Numpy array is larger than the buffer ({} > {})", data_size, buffer_size);

m_storage->set_data(data.data(), data_size, byte_offset);
// At this point, the only possible way to break stride in a contiguous buffer is metal buffer alignment in the
// second last dimension. (matrix or vector)
auto kind = desc().dtype->buffer_type_layout()->kind();
if (kind != TypeReflection::Kind::vector && kind != TypeReflection::Kind::matrix) {
m_storage->set_data(data.data(), data_size, byte_offset);
return;
}
// Get dlpack type from scalar type.
auto stack = type_stack(desc().dtype);
ref<NativeSlangType> innermost = stack[stack.size() - 1];
ref<TypeLayoutReflection> innermost_layout = innermost->buffer_type_layout();
size_t innermost_size = innermost_layout->stride();
ref<NativeSlangType> second_innermost = stack[stack.size() - 2];
ref<TypeLayoutReflection> second_innermost_layout = second_innermost->buffer_type_layout();
size_t second_innermost_size = second_innermost_layout->stride();
// Alignment fits.
if (second_innermost_size == second_innermost_layout->type()->col_count() * innermost_size) {
m_storage->set_data(data.data(), data_size, byte_offset);
return;
}

// Copy with local buffer.
std::vector<uint8_t> buffer(buffer_size);
auto actual_size = second_innermost_layout->type()->col_count() * innermost_size;
// Write element.
for (size_t i = 0; i < data_size / actual_size; i++) {
std::memcpy(
buffer.data() + i * second_innermost_size,
static_cast<uint8_t*>(data.data()) + i * actual_size,
actual_size
);
}
m_storage->set_data(buffer.data(), buffer.size(), byte_offset);
}

} // namespace sgl::slangpy
Expand Down