diff --git a/gstaichi/program/program.h b/gstaichi/program/program.h index 69916c3d7f..7579ff4d0b 100644 --- a/gstaichi/program/program.h +++ b/gstaichi/program/program.h @@ -302,6 +302,10 @@ class TI_DLL_EXPORT Program { return program_impl_.get(); } + size_t get_num_ndarrays() const { + return ndarrays_.size(); + } + // TODO(zhanlue): Move these members and corresponding interfaces to // ProgramImpl Ideally, Program should serve as a pure interface class and all // the implementations should fall inside ProgramImpl diff --git a/gstaichi/python/export_lang.cpp b/gstaichi/python/export_lang.cpp index f8cdc6653f..f22ce7d277 100644 --- a/gstaichi/python/export_lang.cpp +++ b/gstaichi/python/export_lang.cpp @@ -370,6 +370,7 @@ void export_lang(py::module &m) { [](Program *program, SNode *snode, int element_ndim, int n, int m) { return field_to_dlpack(program, snode, element_ndim, n, m); }) + .def("_get_num_ndarrays", &Program::get_num_ndarrays) .def("config", &Program::compile_config, py::return_value_policy::reference) .def("sync_kernel_profiler", diff --git a/python/gstaichi/lang/_ndarray.py b/python/gstaichi/lang/_ndarray.py index 202988fe72..fb67581912 100644 --- a/python/gstaichi/lang/_ndarray.py +++ b/python/gstaichi/lang/_ndarray.py @@ -38,6 +38,14 @@ def __init__(self): # we register with runtime, in order to enable reset to work later impl.get_runtime().ndarrays.add(self) + def __del__(self): + if impl is not None and impl.get_runtime is not None and impl.get_runtime() is not None: + arr = getattr(self, "arr") + if arr is not None: + prog = impl.get_runtime()._prog + if prog is not None: + prog.delete_ndarray(arr) + def to_dlpack(self): return impl.get_runtime().prog.ndarray_to_dlpack(self, self.arr) @@ -270,12 +278,6 @@ def __init__(self, dtype, arr_shape): self.shape = tuple(self.arr.shape) self.element_type = dtype - def __del__(self): - if impl is not None and impl.get_runtime is not None and impl.get_runtime() is not None: - prog = impl.get_runtime()._prog - if prog is not None: - prog.delete_ndarray(self.arr) - @property def element_shape(self): return () diff --git a/tests/python/test_ndarray.py b/tests/python/test_ndarray.py index f1ea03cfcb..3e7ea672a5 100644 --- a/tests/python/test_ndarray.py +++ b/tests/python/test_ndarray.py @@ -1199,3 +1199,14 @@ def my_kernel({args}) -> None: my_kernel(*arg_objs_l) for i in range(num_args): assert arg_objs_l[i][0] == i + 1 + + +@pytest.mark.parametrize("dtype", [ti.i32, ti.types.vector(3, ti.f32), ti.types.matrix(2, 2, ti.f32)]) +@test_utils.test() +def test_ndarray_del(dtype) -> None: + def foo(): + nd = ti.ndarray(dtype, (1000,)) + assert ti.lang.impl.get_runtime().prog._get_num_ndarrays() == 1 + + foo() + assert ti.lang.impl.get_runtime().prog._get_num_ndarrays() == 0