diff --git a/include/tvm/ffi/c_api.h b/include/tvm/ffi/c_api.h index a5a581ae..4da0d4d0 100644 --- a/include/tvm/ffi/c_api.h +++ b/include/tvm/ffi/c_api.h @@ -62,7 +62,7 @@ /*! \brief TVM FFI minor version. */ #define TVM_FFI_VERSION_MINOR 1 /*! \brief TVM FFI patch version. */ -#define TVM_FFI_VERSION_PATCH 9 +#define TVM_FFI_VERSION_PATCH 8 // NOLINTEND(modernize-macro-to-enum) #ifdef __cplusplus diff --git a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h index 93d6540a..666bb6e9 100644 --- a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h +++ b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h @@ -321,10 +321,16 @@ class TVMFFIPyCallManager { return -1; } } + if (ctx.dlpack_c_exchange_api != nullptr && prev_tensor_allocator != ctx.dlpack_c_exchange_api->managed_tensor_allocator) { - c_api_ret_code[0] = - TVMFFIEnvSetDLPackManagedTensorAllocator(prev_tensor_allocator, 0, nullptr); + // note: we cannot set the error value to c_api_ret_code[0] here because it + // will be overwritten by the error value from the function call + if (TVMFFIEnvSetDLPackManagedTensorAllocator(prev_tensor_allocator, 0, nullptr) != 0) { + PyErr_SetString(PyExc_RuntimeError, "Failed to recover DLPack managed tensor allocator"); + return -1; + } + // return error after if (c_api_ret_code[0] != 0) return 0; } if (optional_out_ctx_dlpack_api != nullptr && ctx.dlpack_c_exchange_api != nullptr) { diff --git a/src/ffi/extra/env_context.cc b/src/ffi/extra/env_context.cc index 9b2fb252..95045d41 100644 --- a/src/ffi/extra/env_context.cc +++ b/src/ffi/extra/env_context.cc @@ -66,7 +66,8 @@ class EnvContext { int write_to_global_context, DLPackManagedTensorAllocator* opt_out_original_allocator) { if (opt_out_original_allocator != nullptr) { - *opt_out_original_allocator = GetDLPackManagedTensorAllocator(); + // only returns the cached local allocator and ignore global allocator + *opt_out_original_allocator = dlpack_allocator_; } if (write_to_global_context != 0) { GlobalTensorAllocator() = allocator; diff --git a/tests/python/test_tensor.py b/tests/python/test_tensor.py index d091d859..0c45654c 100644 --- a/tests/python/test_tensor.py +++ b/tests/python/test_tensor.py @@ -18,7 +18,7 @@ from __future__ import annotations from types import ModuleType -from typing import Any, NamedTuple +from typing import Any, NamedTuple, NoReturn import numpy.typing as npt import pytest @@ -78,6 +78,20 @@ def test_tensor_auto_dlpack() -> None: np.testing.assert_equal(y.numpy(), x.numpy()) +@pytest.mark.skipif(torch is None, reason="Fast torch dlpack importer is not enabled") +def test_tensor_auto_dlpack_with_error() -> None: + assert torch is not None + x = torch.arange(128) + + def raise_torch_error(x: Any) -> NoReturn: + raise ValueError("error XYZ") + + f = tvm_ffi.convert(raise_torch_error) + with pytest.raises(ValueError): + # pass in torch argment to trigger the error in set allocator path + f(x) + + def test_tensor_class_override() -> None: class MyTensor(tvm_ffi.Tensor): pass