Skip to content

Commit

Permalink
fix(mssql): support .cache() for caching tables
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist committed Aug 27, 2024
1 parent 466b9c5 commit 5ff9536
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 30 deletions.
13 changes: 7 additions & 6 deletions ibis/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,16 +807,17 @@ def _cached_table(self, table: ir.Table) -> ir.CachedTable:
"""
entry = self._cache_op_to_entry.get(table.op())
if entry is None or (cached_op := entry.cached_op_ref()) is None:
name = util.gen_name("cached")
cached_op = self._create_cached_table(name, table).op()
cached_op = self._create_cached_table(util.gen_name("cached"), table).op()
entry = _CacheEntry(
name,
cached_op.name,
table.op(),
weakref.ref(cached_op),
weakref.finalize(cached_op, self._finalize_cached_table, name),
weakref.finalize(
cached_op, self._finalize_cached_table, cached_op.name
),
)
self._cache_op_to_entry[table.op()] = entry
self._cache_name_to_entry[name] = entry
self._cache_name_to_entry[cached_op.name] = entry
return ir.CachedTable(cached_op)

def _finalize_cached_table(self, name: str) -> None:
Expand All @@ -840,7 +841,7 @@ def _finalize_cached_table(self, name: str) -> None:
raise

Check warning on line 841 in ibis/backends/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/__init__.py#L841

Added line #L841 was not covered by tests

def _create_cached_table(self, name: str, expr: ir.Table) -> ir.Table:
return self.create_table(name, expr, temp=True)
return self.create_table(name, expr, schema=expr.schema(), temp=True)

def _drop_cached_table(self, name: str) -> None:
self.drop_table(name, force=True)
Expand Down
11 changes: 7 additions & 4 deletions ibis/backends/mssql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,15 +692,18 @@ def create_table(
new = raw_this.sql(self.dialect)
cur.execute(f"EXEC sp_rename '{old}', '{new}'")

if temp:
# If a temporary table, amend the output name/catalog/db accordingly
name = "##" + name
catalog = "tempdb"
db = "dbo"

Check warning on line 699 in ibis/backends/mssql/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/mssql/__init__.py#L697-L699

Added lines #L697 - L699 were not covered by tests

if schema is None:
# Clean up temporary memtable if we've created one
# for in-memory reads
if temp_memtable_view is not None:
self.drop_table(temp_memtable_view)
return self.table(
"##" * temp + name,
database=("tempdb" * temp or catalog, "dbo" * temp or db),
)
return self.table(name, database=(catalog, db))

Check warning on line 706 in ibis/backends/mssql/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/mssql/__init__.py#L706

Added line #L706 was not covered by tests

# preserve the input schema if it was provided
return ops.DatabaseTable(
Expand Down
13 changes: 13 additions & 0 deletions ibis/backends/mssql/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,19 @@ def test_create_temp_table_from_obj(con):
con.drop_table("fuhreal")


@pytest.mark.parametrize("explicit_schema", [False, True])

Check warning on line 219 in ibis/backends/mssql/tests/test_client.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/mssql/tests/test_client.py#L219

Added line #L219 was not covered by tests
def test_create_temp_table_from_expression(con, explicit_schema, temp_table):
t = ibis.memtable(

Check warning on line 221 in ibis/backends/mssql/tests/test_client.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/mssql/tests/test_client.py#L221

Added line #L221 was not covered by tests
{"x": [1, 2, 3], "y": ["a", "b", "c"]}, schema={"x": "int64", "y": "str"}
)
t2 = con.create_table(

Check warning on line 224 in ibis/backends/mssql/tests/test_client.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/mssql/tests/test_client.py#L224

Added line #L224 was not covered by tests
temp_table, t, temp=True, schema=t.schema() if explicit_schema else None
)
res = t.order_by("y").to_pandas()
sol = t2.order_by("y").to_pandas()
assert res.equals(sol)

Check warning on line 229 in ibis/backends/mssql/tests/test_client.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/mssql/tests/test_client.py#L227-L229

Added lines #L227 - L229 were not covered by tests


def test_from_url():
user = MSSQL_USER
password = MSSQL_PASS
Expand Down
29 changes: 9 additions & 20 deletions ibis/backends/tests/test_expr_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,24 @@


@mark.notimpl(["datafusion", "flink", "impala", "trino", "druid"])
@mark.never(
["mssql"],
reason="mssql supports support temporary tables through naming conventions",
)
@mark.notimpl(["exasol"], reason="Exasol does not support temporary tables")
@pytest.mark.never(
["risingwave"],
raises=com.UnsupportedOperationError,
reason="Feature is not yet implemented: CREATE TEMPORARY TABLE",
)
def test_persist_expression(backend, alltypes):
def test_persist_expression_foo(backend, alltypes):
non_persisted_table = alltypes.mutate(
test_column=ibis.literal("calculation"), other_calc=ibis.literal("xyz")
)
persisted_table = non_persisted_table.cache()
backend.assert_frame_equal(
non_persisted_table.to_pandas(), persisted_table.to_pandas()
non_persisted_table.order_by("id").to_pandas(),
persisted_table.order_by("id").to_pandas(),
)


@mark.notimpl(["datafusion", "flink", "impala", "trino", "druid"])
@mark.never(
["mssql"],
reason="mssql supports support temporary tables through naming conventions",
)
@mark.notimpl(["exasol"], reason="Exasol does not support temporary tables")
@pytest.mark.never(
["risingwave"],
Expand All @@ -48,16 +41,13 @@ def test_persist_expression_contextmanager(backend, con, alltypes):
)
with non_cached_table.cache() as cached_table:
backend.assert_frame_equal(
non_cached_table.to_pandas(), cached_table.to_pandas()
non_cached_table.order_by("id").to_pandas(),
cached_table.order_by("id").to_pandas(),
)
assert non_cached_table.op() not in con._cache_op_to_entry


@mark.notimpl(["datafusion", "flink", "impala", "trino", "druid"])
@mark.never(
["mssql"],
reason="mssql supports support temporary tables through naming conventions",
)
@pytest.mark.never(
["risingwave"],
raises=com.UnsupportedOperationError,
Expand All @@ -81,7 +71,10 @@ def test_persist_expression_multiple_refs(backend, con, alltypes):
op = non_cached_table.op()
cached_table = non_cached_table.cache()

backend.assert_frame_equal(non_cached_table.to_pandas(), cached_table.to_pandas())
backend.assert_frame_equal(
non_cached_table.order_by("id").to_pandas(),
cached_table.order_by("id").to_pandas(),
)

name = cached_table.op().name
nested_cached_table = non_cached_table.cache()
Expand All @@ -104,10 +97,6 @@ def test_persist_expression_multiple_refs(backend, con, alltypes):


@mark.notimpl(["datafusion", "flink", "impala", "trino", "druid"])
@mark.never(
["mssql"],
reason="mssql supports support temporary tables through naming conventions",
)
@mark.notimpl(["exasol"], reason="Exasol does not support temporary tables")
@pytest.mark.never(
["risingwave"],
Expand Down

0 comments on commit 5ff9536

Please sign in to comment.