Skip to content

Commit bab0632

Browse files
[python] Link lifetimes of SOMAArray and ManagedQuery (#3516) (#3522)
Co-authored-by: nguyenv <[email protected]>
1 parent c757020 commit bab0632

File tree

4 files changed

+90
-55
lines changed

4 files changed

+90
-55
lines changed

apis/python/src/tiledbsoma/_dense_nd_array.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ._arrow_types import pyarrow_to_carrow_type
2323
from ._common_nd_array import NDArray
2424
from ._exception import SOMAError, map_exception_for_create
25-
from ._read_iters import TableReadIter
25+
from ._read_iters import ManagedQuery, TableReadIter
2626
from ._tdb_handles import DenseNDArrayWrapper
2727
from ._types import OpenTimestamp, Slice
2828
from ._util import dense_indices_to_shape
@@ -313,11 +313,11 @@ def write(
313313
input = np.ascontiguousarray(input)
314314
order = clib.ResultOrder.rowmajor
315315

316-
mq = clib.ManagedQuery(clib_handle, clib_handle.context())
317-
mq.set_layout(order)
318-
_util._set_coords(mq, clib_handle, new_coords)
319-
mq.set_soma_data(input)
320-
mq.submit_write()
316+
mq = ManagedQuery(self, platform_config)
317+
mq._handle.set_layout(order)
318+
_util._set_coords(mq, new_coords)
319+
mq._handle.set_soma_data(input)
320+
mq._handle.submit_write()
321321

322322
tiledb_write_options = TileDBWriteOptions.from_platform_config(platform_config)
323323
if tiledb_write_options.consolidate_and_vacuum:

apis/python/src/tiledbsoma/_read_iters.py

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
cast,
2424
)
2525

26+
import attrs
2627
import numpy as np
2728
import numpy.typing as npt
2829
import pyarrow as pa
@@ -470,6 +471,27 @@ def _cs_reader(
470471
yield sp, indices
471472

472473

474+
@attrs.define(frozen=True)
475+
class ManagedQuery:
476+
"""Keep the lifetime of the SOMAArray tethered to ManagedQuery."""
477+
478+
_array: SOMAArray
479+
_platform_config: options.PlatformConfig | None
480+
_handle: clib.ManagedQuery = attrs.field(init=False)
481+
482+
def __attrs_post_init__(self) -> None:
483+
clib_handle = self._array._handle._handle
484+
485+
if self._platform_config is not None:
486+
cfg = clib_handle.context().config()
487+
cfg.update(self._platform_config)
488+
ctx = clib.SOMAContext(cfg)
489+
else:
490+
ctx = clib_handle.context()
491+
492+
object.__setattr__(self, "_handle", clib.ManagedQuery(clib_handle, ctx))
493+
494+
473495
class SparseTensorReadIterBase(somacore.ReadIter[_RT], metaclass=abc.ABCMeta):
474496
"""Private implementation class"""
475497

@@ -487,27 +509,18 @@ def __init__(
487509
self.result_order = result_order
488510
self.platform_config = platform_config
489511

490-
clib_handle = array._handle._handle
512+
self.mq = ManagedQuery(array, platform_config)
491513

492-
if platform_config is not None:
493-
cfg = clib_handle.context().config()
494-
cfg.update(platform_config)
495-
ctx = clib.SOMAContext(cfg)
496-
else:
497-
ctx = clib_handle.context()
514+
self.mq._handle.set_layout(result_order)
498515

499-
self.mq = clib.ManagedQuery(clib_handle, ctx)
500-
501-
self.mq.set_layout(result_order)
502-
503-
_util._set_coords(self.mq, clib_handle, coords)
516+
_util._set_coords(self.mq, coords)
504517

505518
@abc.abstractmethod
506519
def _from_table(self, arrow_table: pa.Table) -> _RT:
507520
raise NotImplementedError()
508521

509522
def __next__(self) -> _RT:
510-
return self._from_table(self.mq.next())
523+
return self._from_table(self.mq._handle.next())
511524

512525
def concat(self) -> _RT:
513526
"""Returns all the requested data in a single operation.
@@ -556,27 +569,22 @@ def __init__(
556569
):
557570
clib_handle = array._handle._handle
558571

559-
if platform_config is not None:
560-
cfg = clib_handle.context().config()
561-
cfg.update(platform_config)
562-
ctx = clib.SOMAContext(cfg)
563-
else:
564-
ctx = clib_handle.context()
572+
self.mq = ManagedQuery(array, platform_config)
565573

566-
self.mq = clib.ManagedQuery(clib_handle, ctx)
567-
568-
self.mq.set_layout(result_order)
574+
self.mq._handle.set_layout(result_order)
569575

570576
if column_names is not None:
571-
self.mq.select_columns(list(column_names))
577+
self.mq._handle.select_columns(list(column_names))
572578

573579
if value_filter is not None:
574-
self.mq.set_condition(QueryCondition(value_filter), clib_handle.schema)
580+
self.mq._handle.set_condition(
581+
QueryCondition(value_filter), clib_handle.schema
582+
)
575583

576-
_util._set_coords(self.mq, clib_handle, coords)
584+
_util._set_coords(self.mq, coords)
577585

578586
def __next__(self) -> pa.Table:
579-
return self.mq.next()
587+
return self.mq._handle.next()
580588

581589

582590
def _coords_strider(

apis/python/src/tiledbsoma/_util.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from concurrent.futures import Future
1414
from itertools import zip_longest
1515
from typing import (
16+
TYPE_CHECKING,
1617
Any,
1718
Dict,
1819
List,
@@ -38,6 +39,9 @@
3839
_DictFilterSpec,
3940
)
4041

42+
if TYPE_CHECKING:
43+
from ._read_iters import ManagedQuery
44+
4145
_JSONFilter = Union[str, Dict[str, Union[str, Union[int, float]]]]
4246
_JSONFilterList = Union[str, List[_JSONFilter]]
4347

@@ -465,48 +469,44 @@ def _cast_domainish(domainish: List[Any]) -> Tuple[Tuple[object, object], ...]:
465469
return tuple(result)
466470

467471

468-
def _set_coords(
469-
mq: clib.ManagedQuery, sarr: clib.SOMAArray, coords: options.SparseNDCoords
470-
) -> None:
472+
def _set_coords(mq: ManagedQuery, coords: options.SparseNDCoords) -> None:
471473
if not is_nonstringy_sequence(coords):
472474
raise TypeError(
473475
f"coords type {type(coords)} must be a regular sequence,"
474476
" not str or bytes"
475477
)
476478

477-
if len(coords) > len(sarr.dimension_names):
479+
if len(coords) > len(mq._array._handle._handle.dimension_names):
478480
raise ValueError(
479481
f"coords ({len(coords)} elements) must be shorter than ndim"
480-
f" ({len(sarr.dimension_names)})"
482+
f" ({len(mq._array._handle._handle.dimension_names)})"
481483
)
482484

483485
for i, coord in enumerate(coords):
484-
_set_coord(i, mq, sarr, coord)
486+
_set_coord(i, mq, coord)
485487

486488

487-
def _set_coord(
488-
dim_idx: int, mq: clib.ManagedQuery, sarr: clib.SOMAArray, coord: object
489-
) -> None:
489+
def _set_coord(dim_idx: int, mq: ManagedQuery, coord: object) -> None:
490490
if coord is None:
491491
return
492492

493-
dim = sarr.schema.field(dim_idx)
494-
dom = _cast_domainish(sarr.domain())[dim_idx]
493+
dim = mq._array._handle._handle.schema.field(dim_idx)
494+
dom = _cast_domainish(mq._array._handle._handle.domain())[dim_idx]
495495

496496
if isinstance(coord, (str, bytes)):
497-
mq.set_dim_points_string_or_bytes(dim.name, [coord])
497+
mq._handle.set_dim_points_string_or_bytes(dim.name, [coord])
498498
return
499499

500500
if isinstance(coord, (pa.Array, pa.ChunkedArray)):
501-
mq.set_dim_points_arrow(dim.name, coord)
501+
mq._handle.set_dim_points_arrow(dim.name, coord)
502502
return
503503

504504
if isinstance(coord, (Sequence, np.ndarray)):
505505
_set_coord_by_py_seq_or_np_array(mq, dim, coord)
506506
return
507507

508508
if isinstance(coord, int):
509-
mq.set_dim_points_int64(dim.name, [coord])
509+
mq._handle.set_dim_points_int64(dim.name, [coord])
510510
return
511511

512512
# Note: slice(None, None) matches the is_slice_of part, unless we also check
@@ -521,11 +521,11 @@ def _set_coord(
521521
if coord.stop is None:
522522
# There's no way to specify "to infinity" for strings.
523523
# We have to get the nonempty domain and use that as the end.\
524-
ned = _cast_domainish(sarr.non_empty_domain())
524+
ned = _cast_domainish(mq._array._handle._handle.non_empty_domain())
525525
_, stop = ned[dim_idx]
526526
else:
527527
stop = coord.stop
528-
mq.set_dim_ranges_string_or_bytes(dim.name, [(start, stop)])
528+
mq._handle.set_dim_ranges_string_or_bytes(dim.name, [(start, stop)])
529529
return
530530

531531
# Note: slice(None, None) matches the is_slice_of part, unless we also check
@@ -548,7 +548,7 @@ def _set_coord(
548548
else:
549549
istop = ts_dom[1].as_py()
550550

551-
mq.set_dim_ranges_int64(dim.name, [(istart, istop)])
551+
mq._handle.set_dim_ranges_int64(dim.name, [(istart, istop)])
552552
return
553553

554554
if isinstance(coord, slice):
@@ -562,7 +562,7 @@ def _set_coord(
562562

563563

564564
def _set_coord_by_py_seq_or_np_array(
565-
mq: clib.ManagedQuery, dim: pa.Field, coord: object
565+
mq: ManagedQuery, dim: pa.Field, coord: object
566566
) -> None:
567567
if isinstance(coord, np.ndarray):
568568
if coord.ndim != 1:
@@ -571,7 +571,7 @@ def _set_coord_by_py_seq_or_np_array(
571571
)
572572

573573
try:
574-
set_dim_points = getattr(mq, f"set_dim_points_{dim.type}")
574+
set_dim_points = getattr(mq._handle, f"set_dim_points_{dim.type}")
575575
except AttributeError:
576576
# We have to handle this type specially below
577577
pass
@@ -580,7 +580,7 @@ def _set_coord_by_py_seq_or_np_array(
580580
return
581581

582582
if pa_types_is_string_or_bytes(dim.type):
583-
mq.set_dim_points_string_or_bytes(dim.name, coord)
583+
mq._handle.set_dim_points_string_or_bytes(dim.name, coord)
584584
return
585585

586586
if pa.types.is_timestamp(dim.type):
@@ -591,14 +591,14 @@ def _set_coord_by_py_seq_or_np_array(
591591
icoord = [
592592
int(e.astype("int64")) if isinstance(e, np.datetime64) else e for e in coord
593593
]
594-
mq.set_dim_points_int64(dim.name, icoord)
594+
mq._handle.set_dim_points_int64(dim.name, icoord)
595595
return
596596

597597
raise ValueError(f"unhandled type {dim.type} for index column named {dim.name}")
598598

599599

600600
def _set_coord_by_numeric_slice(
601-
mq: clib.ManagedQuery, dim: pa.Field, dom: Tuple[object, object], coord: Slice[Any]
601+
mq: ManagedQuery, dim: pa.Field, dom: Tuple[object, object], coord: Slice[Any]
602602
) -> None:
603603
try:
604604
lo_hi = slice_to_numeric_range(coord, dom)
@@ -609,7 +609,7 @@ def _set_coord_by_numeric_slice(
609609
return
610610

611611
try:
612-
set_dim_range = getattr(mq, f"set_dim_ranges_{dim.type}")
612+
set_dim_range = getattr(mq._handle, f"set_dim_ranges_{dim.type}")
613613
set_dim_range(dim.name, [lo_hi])
614614
return
615615
except AttributeError:

apis/python/tests/test_sparse_nd_array.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1952,3 +1952,30 @@ def test_pass_configs(tmp_path):
19521952
}
19531953
).tables()
19541954
)
1955+
1956+
1957+
def test_iter(tmp_path: pathlib.Path):
1958+
arrow_tensor = create_random_tensor("table", (1,), np.float32(), density=1)
1959+
1960+
with soma.SparseNDArray.create(
1961+
tmp_path.as_uri(), type=pa.float64(), shape=(1,)
1962+
) as write_arr:
1963+
write_arr.write(arrow_tensor)
1964+
1965+
# Verify that the SOMAArray stays open as long as the ManagedQuery
1966+
# (i.e., `next`) is still active
1967+
a = soma.open(tmp_path.as_uri(), mode="r").read().tables()
1968+
assert next(a)
1969+
with pytest.raises(StopIteration):
1970+
next(a)
1971+
1972+
# Open two instances of the same array. Iterating through one should not
1973+
# affect the other
1974+
a = soma.open(tmp_path.as_uri(), mode="r").read().tables()
1975+
b = soma.open(tmp_path.as_uri(), mode="r").read().tables()
1976+
assert next(a)
1977+
assert next(b)
1978+
with pytest.raises(StopIteration):
1979+
next(a)
1980+
with pytest.raises(StopIteration):
1981+
next(b)

0 commit comments

Comments
 (0)