Skip to content

Commit

Permalink
refactor: simplify caching implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist authored and cpcloud committed Aug 27, 2024
1 parent 335a538 commit afba988
Show file tree
Hide file tree
Showing 11 changed files with 93 additions and 152 deletions.
119 changes: 72 additions & 47 deletions ibis/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@
import importlib.metadata
import keyword
import re
import sys
import urllib.parse
import weakref
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar
from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple

import ibis
import ibis.common.exceptions as exc
import ibis.config
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis import util
from ibis.common.caching import RefCountedCache

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping, MutableMapping
Expand Down Expand Up @@ -777,7 +778,74 @@ def drop_schema(
self.drop_database(name=name, catalog=database, force=force)


class BaseBackend(abc.ABC, _FileIOHandler):
class CacheEntry(NamedTuple):
orig_op: ops.Relation
cached_op_ref: weakref.ref[ops.Relation]
finalizer: weakref.finalize


class CacheHandler:
"""A mixin for handling `.cache()`/`CachedTable` operations."""

def __init__(self):
self._cache_name_to_entry = {}
self._cache_op_to_entry = {}

def _cached_table(self, table: ir.Table) -> ir.CachedTable:
"""Convert a Table to a CachedTable.
Parameters
----------
table
Table expression to cache
Returns
-------
Table
Cached table
"""
entry = self._cache_op_to_entry.get(table.op())
if entry is None or (cached_op := entry.cached_op_ref()) is None:
cached_op = self._create_cached_table(util.gen_name("cached"), table).op()
entry = CacheEntry(
table.op(),
weakref.ref(cached_op),
weakref.finalize(
cached_op, self._finalize_cached_table, cached_op.name
),
)
self._cache_op_to_entry[table.op()] = entry
self._cache_name_to_entry[cached_op.name] = entry
return ir.CachedTable(cached_op)

def _finalize_cached_table(self, name: str) -> None:
"""Release a cached table given its name.
This is a no-op if the cached table is already released.
Parameters
----------
name
The name of the cached table.
"""
if (entry := self._cache_name_to_entry.pop(name, None)) is not None:
self._cache_op_to_entry.pop(entry.orig_op)
entry.finalizer.detach()
try:
self._drop_cached_table(name)
except Exception:
# suppress exceptions during interpreter shutdown
if not sys.is_finalizing():
raise

def _create_cached_table(self, name: str, expr: ir.Table) -> ir.Table:
return self.create_table(name, expr, temp=True)

def _drop_cached_table(self, name: str) -> None:
self.drop_table(name, force=True)


class BaseBackend(abc.ABC, _FileIOHandler, CacheHandler):
"""Base backend class.
All Ibis backends must subclass this class and implement all the
Expand All @@ -794,12 +862,7 @@ def __init__(self, *args, **kwargs):
self._con_args: tuple[Any] = args
self._con_kwargs: dict[str, Any] = kwargs
self._can_reconnect: bool = True
# expression cache
self._query_cache = RefCountedCache(
populate=self._load_into_cache,
lookup=lambda name: self.table(name).op(),
finalize=self._clean_up_cached_table,
)
super().__init__()

@property
@abc.abstractmethod
Expand Down Expand Up @@ -1225,44 +1288,6 @@ def has_operation(cls, operation: type[ops.Value]) -> bool:
f"{cls.name} backend has not implemented `has_operation` API"
)

def _cached(self, expr: ir.Table):
"""Cache the provided expression.
All subsequent operations on the returned expression will be performed on the cached data.
Parameters
----------
expr
Table expression to cache
Returns
-------
Expr
Cached table
"""
op = expr.op()
if (result := self._query_cache.get(op)) is None:
result = self._query_cache.store(expr)
return ir.CachedTable(result)

def _release_cached(self, expr: ir.CachedTable) -> None:
"""Releases the provided cached expression.
Parameters
----------
expr
Cached expression to release
"""
self._query_cache.release(expr.op().name)

def _load_into_cache(self, name, expr):
raise NotImplementedError(self.name)

def _clean_up_cached_table(self, name):
raise NotImplementedError(self.name)

def _transpile_sql(self, query: str, *, dialect: str | None = None) -> str:
# only transpile if dialect was passed
if dialect is None:
Expand Down
9 changes: 1 addition & 8 deletions ibis/backends/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,6 @@ class Backend(SQLBackend, CanCreateDatabase, CanCreateSchema):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.__session_dataset: bq.DatasetReference | None = None
self._query_cache.lookup = lambda name: self.table(
name,
database=(self._session_dataset.project, self._session_dataset.dataset_id),
).op()

@property
def _session_dataset(self):
Expand Down Expand Up @@ -1137,10 +1133,7 @@ def drop_view(
)
self.raw_sql(stmt.sql(self.name))

def _load_into_cache(self, name, expr):
self.create_table(name, expr, schema=expr.schema(), temp=True)

def _clean_up_cached_table(self, name):
def _drop_cached_table(self, name):
self.drop_table(
name,
database=(self._session_dataset.project, self._session_dataset.dataset_id),
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/dask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,5 +179,5 @@ def _convert_object(self, obj) -> dd.DataFrame:
pandas_df = super()._convert_object(obj)
return dd.from_pandas(pandas_df, npartitions=1)

def _load_into_cache(self, name, expr):
self.create_table(name, self.compile(expr).persist())
def _create_cached_table(self, name, expr):
return self.create_table(name, self.compile(expr).persist())
2 changes: 1 addition & 1 deletion ibis/backends/oracle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,5 +647,5 @@ def _clean_up_tmp_table(self, name: str) -> None:
with contextlib.suppress(oracledb.DatabaseError):
bind.execute(f'DROP TABLE "{name}"')

def _clean_up_cached_table(self, name):
def _drop_cached_table(self, name):
self._clean_up_tmp_table(name)
6 changes: 3 additions & 3 deletions ibis/backends/pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def _get_operations(cls):
def has_operation(cls, operation: type[ops.Value]) -> bool:
return operation in cls._get_operations()

def _clean_up_cached_table(self, name):
def _drop_cached_table(self, name):
del self.dictionary[name]

def to_pyarrow(
Expand Down Expand Up @@ -328,8 +328,8 @@ def execute(self, query, params=None, limit="default", **kwargs):

return PandasExecutor.execute(query.op(), backend=self, params=params)

def _load_into_cache(self, name, expr):
self.create_table(name, expr.execute())
def _create_cached_table(self, name, expr):
return self.create_table(name, expr.execute())


@lazy_singledispatch
Expand Down
6 changes: 3 additions & 3 deletions ibis/backends/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,10 +552,10 @@ def to_pyarrow_batches(
table = self._to_pyarrow_table(expr, params=params, limit=limit, **kwargs)
return table.to_reader(chunk_size)

def _load_into_cache(self, name, expr):
self.create_table(name, self.compile(expr).cache())
def _create_cached_table(self, name, expr):
return self.create_table(name, self.compile(expr).cache())

def _clean_up_cached_table(self, name):
def _drop_cached_table(self, name):
self.drop_table(name, force=True)


Expand Down
5 changes: 3 additions & 2 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,16 +704,17 @@ def compute_stats(
)
return self.raw_sql(f"ANALYZE TABLE {table} COMPUTE STATISTICS{maybe_noscan}")

def _load_into_cache(self, name, expr):
def _create_cached_table(self, name, expr):
query = self.compile(expr)
t = self._session.sql(query).cache()
assert t.is_cached
t.createOrReplaceTempView(name)
# store the underlying spark dataframe so we can release memory when
# asked to, instead of when the session ends
self._cached_dataframes[name] = t
return self.table(name)

def _clean_up_cached_table(self, name):
def _drop_cached_table(self, name):
self._session.catalog.dropTempView(name)
t = self._cached_dataframes.pop(name)
assert t.is_cached
Expand Down
6 changes: 0 additions & 6 deletions ibis/backends/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,6 @@ def drop_view(
with self._safe_raw_sql(src):
pass

def _load_into_cache(self, name, expr):
self.create_table(name, expr, schema=expr.schema(), temp=True)

def _clean_up_cached_table(self, name):
self.drop_table(name, force=True)

def execute(
self,
expr: ir.Expr,
Expand Down
10 changes: 5 additions & 5 deletions ibis/backends/tests/test_expr_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_persist_expression_contextmanager(backend, con, alltypes):
backend.assert_frame_equal(
non_cached_table.to_pandas(), cached_table.to_pandas()
)
assert non_cached_table.op() not in con._query_cache.cache
assert non_cached_table.op() not in con._cache_op_to_entry


@mark.notimpl(["datafusion", "flink", "impala", "trino", "druid"])
Expand Down Expand Up @@ -89,15 +89,15 @@ def test_persist_expression_multiple_refs(backend, con, alltypes):
# cached tables are identical and reusing the same op
assert cached_table.op() is nested_cached_table.op()
# table is cached
assert op in con._query_cache.cache
assert op in con._cache_op_to_entry

# deleting the first reference, leaves table in cache
del nested_cached_table
assert op in con._query_cache.cache
assert op in con._cache_op_to_entry

# deleting the last reference, releases table from cache
del cached_table
assert op not in con._query_cache.cache
assert op not in con._cache_op_to_entry

# assert that table has been dropped
assert name not in con.list_tables()
Expand Down Expand Up @@ -143,7 +143,7 @@ def test_persist_expression_release(con, alltypes):
cached_table = non_cached_table.cache()
cached_table.release()

assert non_cached_table.op() not in con._query_cache.cache
assert non_cached_table.op() not in con._cache_op_to_entry

# a second release does not hurt
cached_table.release()
Expand Down
74 changes: 1 addition & 73 deletions ibis/common/caching.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from __future__ import annotations

import functools
import sys
from collections import namedtuple
from typing import TYPE_CHECKING, Any
from weakref import finalize, ref
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from collections.abc import Callable
Expand All @@ -25,72 +22,3 @@ def wrapper(*args, **kwargs):
return result

return wrapper


CacheEntry = namedtuple("CacheEntry", ["name", "ref", "finalizer"])


class RefCountedCache:
"""A cache with implicitly reference-counted values.
We could implement `MutableMapping`, but the `__setitem__` implementation
doesn't make sense and the `len` and `__iter__` methods aren't used.
We can implement that interface if and when we need to.
"""

def __init__(
self,
*,
populate: Callable[[str, Any], None],
lookup: Callable[[str], Any],
finalize: Callable[[Any], None],
) -> None:
self.populate = populate
self.lookup = lookup
self.finalize = finalize

self.cache: dict[Any, CacheEntry] = dict()

def get(self, key, default=None):
if (entry := self.cache.get(key)) is not None:
op = entry.ref()
return op if op is not None else default
return default

def __getitem__(self, key):
op = self.cache[key].ref()
if op is None:
raise KeyError(key)
return op

def store(self, input):
"""Compute and store a reference to `key`."""
from ibis.util import gen_name

key = input.op()
name = gen_name("cache")
self.populate(name, input)
cached = self.lookup(name)
finalizer = finalize(cached, self._release, key)

self.cache[key] = CacheEntry(name, ref(cached), finalizer)

return cached

def release(self, name: str) -> None:
# Could be sped up with an inverse dictionary
for key, entry in self.cache.items():
if entry.name == name:
self._release(key)
return

def _release(self, key) -> None:
entry = self.cache.pop(key)
try:
self.finalize(entry.name)
except Exception:
# suppress exceptions during interpreter shutdown
if not sys.is_finalizing():
raise
entry.finalizer.detach()
Loading

0 comments on commit afba988

Please sign in to comment.