From 4dbc5cc56481d7aa81fd2ea2040f3abf13a09a48 Mon Sep 17 00:00:00 2001 From: Jhalak Patel Date: Tue, 8 Oct 2024 11:34:51 -0700 Subject: [PATCH] [API/MemRef] Implement canonical stride validation for MemRefValue creation (#252) Add optional stride validation in `MemRefValue::create` to compute canonical stride and compare against given strides while creaing a memref view from DLPack tensors. We need to handle special cases for zero-sized and unit-sized dimensions since frameworks deal with them arbitrarily while converting to the corresponding DLPack tensor. Add Python tests to verify both canonical and non-canonical stride validation. --- .../include/mlir-executor-c/Runtime/Runtime.h | 6 +- .../include/mlir-executor/Runtime/API/API.h | 9 ++- .../executor/lib/CAPI/Runtime/Runtime.cpp | 12 ++-- .../executor/lib/Runtime/API/API.cpp | 63 +++++++++++++++++-- .../python/bindings/Runtime/RuntimePyBind.cpp | 31 +++++---- .../test_create_memref.py | 51 +++++++++++++++ 6 files changed, 145 insertions(+), 27 deletions(-) diff --git a/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h b/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h index 345412aee..7dd277969 100644 --- a/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h +++ b/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h @@ -130,7 +130,8 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtMemRefCreate(MTRT_RuntimeClient client, MTRT_PointerType pointerKind, int64_t bitsPerElement, int64_t rank, const int64_t *shape, const int64_t *strides, MTRT_Device device, MTRT_Stream stream, - MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result); + MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result, + bool assertCanonicalStrides = false); /// Creates an externally managed MemRef value. The caller provides all the /// metadata for the MemRef including the shape, strides (in elements), pointer, @@ -142,7 +143,8 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtMemRefCreateExternal( MTRT_RuntimeClient client, MTRT_PointerType pointerKind, int64_t bitsPerElement, uintptr_t ptr, int64_t offset, int64_t rank, const int64_t *shape, const int64_t *strides, MTRT_Device device, - MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result); + MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result, + bool assertCanonicalStrides = false); /// Destroys `MTRT_MemRefValue` in a potentially asynchronous manner. /// If `buffer` is a device buffer, device memory is freed in the stream diff --git a/mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h b/mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h index 2df42a178..eb6adab4a 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h @@ -647,7 +647,8 @@ class MemRefValue : public RuntimeValue { int64_t bitsPerElement, uintptr_t ptr, int64_t offset, llvm::ArrayRef shape, llvm::ArrayRef strides, std::optional device, - std::optional scalarType); + std::optional scalarType, + std::optional assertCanonicalStrides = {}); mlirtrt::runtime::PointerType getBufferKind() { return addressSpace; } int64_t getElementBitWidth() const { return bitsPerElement; } @@ -917,7 +918,8 @@ class RuntimeClient { llvm::ArrayRef shape, llvm::ArrayRef strides, std::optional device = {}, std::optional stream = {}, - std::optional scalarType = {}); + std::optional scalarType = {}, + std::optional assertCanonicalStrides = {}); StatusOr> createExternalMemRef(PointerType addressSpace, int64_t bitsPerElement, @@ -925,7 +927,8 @@ class RuntimeClient { llvm::ArrayRef shape, llvm::ArrayRef strides, std::optional device = {}, - std::optional scalarType = {}); + std::optional scalarType = {}, + std::optional assertCanonicalStrides = {}); /// Frees the memory in `value`. The `stream` may optionally be provided /// for resources that can be deallocated asynchronously. diff --git a/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp b/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp index 8b4e208e8..17fffbdaa 100644 --- a/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp +++ b/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp @@ -231,7 +231,8 @@ MTRT_Status mtrtMemRefCreate(MTRT_RuntimeClient client, MTRT_PointerType pointerKind, int64_t bitsPerElement, int64_t rank, const int64_t *shape, const int64_t *strides, MTRT_Device device, MTRT_Stream stream, - MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result) { + MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result, + bool assertCanonicalStrides) { StatusOr> bufferImpl = unwrap(client)->allocateMemRef( unwrap(pointerKind), bitsPerElement, @@ -244,7 +245,8 @@ mtrtMemRefCreate(MTRT_RuntimeClient client, MTRT_PointerType pointerKind, : std::optional(unwrap(stream)->getRawStream()), scalarType != MTRT_ScalarTypeCode::MTRT_ScalarTypeCode_unknown ? std::optional(ScalarType(unwrap(scalarType))) - : std::nullopt); + : std::nullopt, + std::optional(assertCanonicalStrides)); if (bufferImpl.isError()) return wrap(bufferImpl.getStatus()); @@ -257,7 +259,8 @@ MTRT_Status mtrtMemRefCreateExternal( MTRT_RuntimeClient client, MTRT_PointerType pointerKind, int64_t bitsPerElement, uintptr_t ptr, int64_t offset, int64_t rank, const int64_t *shape, const int64_t *strides, MTRT_Device device, - MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result) { + MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result, + bool assertCanonicalStrides) { StatusOr> bufferImpl = unwrap(client)->createExternalMemRef( unwrap(pointerKind), bitsPerElement, ptr, offset, @@ -267,7 +270,8 @@ MTRT_Status mtrtMemRefCreateExternal( : std::optional(unwrap(device)), scalarType == MTRT_ScalarTypeCode_unknown ? std::nullopt - : std::optional(ScalarType(unwrap(scalarType)))); + : std::optional(ScalarType(unwrap(scalarType))), + std::optional(assertCanonicalStrides)); if (bufferImpl.isError()) return wrap(bufferImpl.getStatus()); diff --git a/mlir-tensorrt/executor/lib/Runtime/API/API.cpp b/mlir-tensorrt/executor/lib/Runtime/API/API.cpp index effc7f34b..e1205f21a 100644 --- a/mlir-tensorrt/executor/lib/Runtime/API/API.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/API/API.cpp @@ -671,12 +671,50 @@ static StatusOr getFootprintInBytes(llvm::ArrayRef shape, return sizeBytes; } +static llvm::SmallVector getCanonicalStride(const llvm::ArrayRef& shape) { + if (shape.empty()) + return {}; + + llvm::SmallVector canonicalStride(shape.size(), 1); + int64_t cumulativeProduct = 1; + + for (int64_t dimIndex = shape.size() - 1; dimIndex >= 0; --dimIndex) { + bool isFirstZeroDim = (shape[dimIndex] == 0 && dimIndex != static_cast(shape.size()) - 1); + // For dimensions with size 0 or 1, the stride can be arbitrary. + // We set it to 1 here, but other values would also be valid. + if (isFirstZeroDim || shape[dimIndex] == 1) + canonicalStride[dimIndex] = 1; + else + canonicalStride[dimIndex] = cumulativeProduct; + // For zero-sized dimensions (except the last one), we don't update the cumulative product + // This allows for consistent handling of zero-sized dimensions across different frameworks + cumulativeProduct *= isFirstZeroDim ? 1 : shape[dimIndex]; + } + + return canonicalStride; +} + +static bool areStridesEquivalent(llvm::ArrayRef shape, + llvm::ArrayRef stride, + llvm::ArrayRef expectedStride) { + if (shape.size() != stride.size() || shape.size() != expectedStride.size()) + return false; + + for (size_t i = 0; i < shape.size(); ++i) + // Allow arbitrary strides for dimensions with size 0 or 1 + // This accounts for discrepancies in how different frameworks handle these cases + if (stride[i] != expectedStride[i] && shape[i] != 0 && shape[i] != 1) + return false; + + return true; +} + StatusOr> MemRefValue::create( RuntimeClient *client, mlirtrt::runtime::PointerType addressSpace, int64_t bitsPerElement, uintptr_t ptr, int64_t offset, llvm::ArrayRef shape, llvm::ArrayRef strides, - std::optional device, - std::optional scalarType) { + std::optional device, std::optional scalarType, + std::optional assertCanonicalStrides) { if (!client) return getInvalidArgStatus("a valid RuntimeClient must be provided to " "create a tracked MemRef object"); @@ -691,6 +729,19 @@ StatusOr> MemRefValue::create( return getInvalidArgStatus("a specific device must be provided for MemRefs " "that are device-visible"); + // Check if given strides match canonical stride + if (assertCanonicalStrides && *assertCanonicalStrides) { + llvm::SmallVector canonicalStride = getCanonicalStride(shape); + if (!strides.empty() && + !areStridesEquivalent(shape, strides, canonicalStride)) { + std::string errorMsg = + llvm::formatv("Given strides [{0}] do not match canonical strides " + "[{1}] for shape [{2}]", + strides, canonicalStride, shape); + return getInvalidArgStatus(errorMsg.c_str()); + } + } + return std::unique_ptr( new MemRefValue(client, addressSpace, bitsPerElement, ptr, offset, shape, strides, device, scalarType)); @@ -777,7 +828,7 @@ StatusOr> RuntimeClient::allocateMemRef( PointerType addressSpace, int64_t bitsPerElement, llvm::ArrayRef shape, llvm::ArrayRef strides, std::optional device, std::optional stream, - std::optional scalarType) { + std::optional scalarType, std::optional assertCanonicalStrides) { if (addressSpace == PointerType::device || addressSpace == PointerType::unified) { if (!device || !*device) @@ -800,7 +851,7 @@ StatusOr> RuntimeClient::allocateMemRef( // Create the descriptor. StatusOr> bufferImpl = MemRefValue::create(this, addressSpace, bitsPerElement, allocation->ptr, - 0, shape, strides, device, scalarType); + 0, shape, strides, device, scalarType, assertCanonicalStrides); if (bufferImpl.isError()) return bufferImpl.getStatus(); @@ -811,11 +862,11 @@ StatusOr> RuntimeClient::createExternalMemRef( PointerType addressSpace, int64_t bitsPerElement, uintptr_t ptr, int64_t offset, llvm::ArrayRef shape, llvm::ArrayRef strides, std::optional device, - std::optional scalarType) { + std::optional scalarType, std::optional assertCanonicalStrides) { // Create the descriptor. StatusOr> memref = MemRefValue::create(this, addressSpace, bitsPerElement, ptr, offset, - shape, strides, device, scalarType); + shape, strides, device, scalarType, assertCanonicalStrides); if (!memref.isOk()) return memref.getStatus(); diff --git a/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp b/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp index 200a7ebda..3a98d9621 100644 --- a/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp +++ b/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp @@ -313,7 +313,8 @@ static std::unique_ptr createMemRef( } static std::unique_ptr -createMemRefViewFromDLPack(PyRuntimeClient &client, py::capsule capsule) { +createMemRefViewFromDLPack(PyRuntimeClient &client, py::capsule capsule, + std::optional assertCanonicalStrides) { DLManagedTensor *managedTensor = static_cast( PyCapsule_GetPointer(capsule.ptr(), "dltensor")); @@ -368,14 +369,16 @@ createMemRefViewFromDLPack(PyRuntimeClient &client, py::capsule capsule) { } if (data) { - s = mtrtMemRefCreateExternal(client, addressSpace, bytesPerElement * 8, - reinterpret_cast(data), offset, - rank, shape, strides, device, elementType, - &result); + s = mtrtMemRefCreateExternal( + client, addressSpace, bytesPerElement * 8, + reinterpret_cast(data), offset, rank, shape, strides, device, + elementType, &result, + assertCanonicalStrides ? *assertCanonicalStrides : false); } else { - s = mtrtMemRefCreate(client, addressSpace, bytesPerElement * 8, rank, shape, - strides, device, mtrtStreamGetNull(), elementType, - &result); + s = mtrtMemRefCreate( + client, addressSpace, bytesPerElement * 8, rank, shape, strides, device, + mtrtStreamGetNull(), elementType, &result, + assertCanonicalStrides ? *assertCanonicalStrides : false); } THROW_IF_MTRT_ERROR(s); @@ -788,11 +791,15 @@ PYBIND11_MODULE(_api, m) { "returns a new memref and allocates uninitialized backing storage") .def( "create_memref_view_from_dlpack", - [](PyRuntimeClient &self, py::capsule capsule) { - return createMemRefViewFromDLPack(self, capsule).release(); + [](PyRuntimeClient &self, py::capsule capsule, + std::optional assertCanonicalStrides) { + return createMemRefViewFromDLPack(self, capsule, + assertCanonicalStrides) + .release(); }, - py::arg("dltensor") = py::none(), py::keep_alive<0, 1>(), - py::keep_alive<0, 2>()) + py::arg("dltensor") = py::none(), + py::arg("assert_canonical_strides") = py::none(), + py::keep_alive<0, 1>(), py::keep_alive<0, 2>()) .def( "create_device_memref_view", [](PyRuntimeClient &self, uintptr_t ptr, std::vector shape, diff --git a/mlir-tensorrt/test/python/mlir_tensorrt_runtime/test_create_memref.py b/mlir-tensorrt/test/python/mlir_tensorrt_runtime/test_create_memref.py index 3a7779eb6..ef55148bd 100644 --- a/mlir-tensorrt/test/python/mlir_tensorrt_runtime/test_create_memref.py +++ b/mlir-tensorrt/test/python/mlir_tensorrt_runtime/test_create_memref.py @@ -514,3 +514,54 @@ def create_dangling_memref(): # CHECK-LABEL: Test memref maintains data's lifetime # CHECK-NEXT: -- Inner scope: np.from_dlpack(): [1 2] # CHECK-NEXT: -- Outer scope: np.from_dlpack(): [1 2] + + +def check_non_canonical_stride(client, assert_canonical_strides): + try: + t = cp.arange(12, dtype=cp.float32).reshape(3, 4) + a = cp.transpose(t) + memref = client.create_memref_view_from_dlpack( + a.__dlpack__(), assert_canonical_strides + ) + except Exception as e: + print(f"Received error message: {str(e)}") + + +def check_canonical_stride(client, assert_canonical_strides): + try: + t = cp.arange(12, dtype=cp.float32).reshape(3, 4) + memref = client.create_memref_view_from_dlpack( + t.__dlpack__(), assert_canonical_strides + ) + except Exception as e: + print(f"Received error message: {str(e)}") + + +def test_memref_strides(): + print("Testing non-canonical stride: assert_canonical_strides = True") + non_canonical_result = check_non_canonical_stride( + client, assert_canonical_strides=True + ) + + print("Testing non-canonical stride: assert_canonical_strides = False") + non_canonical_result = check_non_canonical_stride( + client, assert_canonical_strides=False + ) + + print("Testing canonical stride: assert_canonical_strides = True") + canonical_result = check_canonical_stride(client, assert_canonical_strides=True) + + print("Testing canonical stride: assert_canonical_strides = False") + canonical_result = check_canonical_stride(client, assert_canonical_strides=False) + + +print("Test memref strides") +test_memref_strides() + +# CHECK-LABEL: Test memref strides +# CHECK-NEXT: Testing non-canonical stride: assert_canonical_strides = True +# CHECK-NEXT: Received error message: InvalidArgument: InvalidArgument: +# CHECK-SAME: Given strides [1, 4] do not match canonical strides [3, 1] for shape [4, 3] +# CHECK-NEXT: Testing non-canonical stride: assert_canonical_strides = False +# CHECK-NEXT: Testing canonical stride: assert_canonical_strides = True +# CHECK-NEXT: Testing canonical stride: assert_canonical_strides = False