Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/ffi/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
/*! \brief TVM FFI minor version. */
#define TVM_FFI_VERSION_MINOR 1
/*! \brief TVM FFI patch version. */
#define TVM_FFI_VERSION_PATCH 9
#define TVM_FFI_VERSION_PATCH 8
// NOLINTEND(modernize-macro-to-enum)

#ifdef __cplusplus
Expand Down
10 changes: 8 additions & 2 deletions python/tvm_ffi/cython/tvm_ffi_python_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,10 +321,16 @@ class TVMFFIPyCallManager {
return -1;
}
}

if (ctx.dlpack_c_exchange_api != nullptr &&
prev_tensor_allocator != ctx.dlpack_c_exchange_api->managed_tensor_allocator) {
c_api_ret_code[0] =
TVMFFIEnvSetDLPackManagedTensorAllocator(prev_tensor_allocator, 0, nullptr);
// note: we cannot set the error value to c_api_ret_code[0] here because it
// will be overwritten by the error value from the function call
if (TVMFFIEnvSetDLPackManagedTensorAllocator(prev_tensor_allocator, 0, nullptr) != 0) {
PyErr_SetString(PyExc_RuntimeError, "Failed to recover DLPack managed tensor allocator");
return -1;
}
// return error after
if (c_api_ret_code[0] != 0) return 0;
}
if (optional_out_ctx_dlpack_api != nullptr && ctx.dlpack_c_exchange_api != nullptr) {
Expand Down
3 changes: 2 additions & 1 deletion src/ffi/extra/env_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ class EnvContext {
int write_to_global_context,
DLPackManagedTensorAllocator* opt_out_original_allocator) {
if (opt_out_original_allocator != nullptr) {
*opt_out_original_allocator = GetDLPackManagedTensorAllocator();
// only returns the cached local allocator and ignore global allocator
*opt_out_original_allocator = dlpack_allocator_;
}
if (write_to_global_context != 0) {
GlobalTensorAllocator() = allocator;
Expand Down
16 changes: 15 additions & 1 deletion tests/python/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

from types import ModuleType
from typing import Any, NamedTuple
from typing import Any, NamedTuple, NoReturn

import numpy.typing as npt
import pytest
Expand Down Expand Up @@ -78,6 +78,20 @@ def test_tensor_auto_dlpack() -> None:
np.testing.assert_equal(y.numpy(), x.numpy())


@pytest.mark.skipif(torch is None, reason="Fast torch dlpack importer is not enabled")
def test_tensor_auto_dlpack_with_error() -> None:
assert torch is not None
x = torch.arange(128)

def raise_torch_error(x: Any) -> NoReturn:
raise ValueError("error XYZ")

f = tvm_ffi.convert(raise_torch_error)
with pytest.raises(ValueError):
# pass in torch argment to trigger the error in set allocator path
f(x)


def test_tensor_class_override() -> None:
class MyTensor(tvm_ffi.Tensor):
pass
Expand Down