Skip to content

Commit

Permalink
fix(mssql): support .cache() for caching tables
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist authored and cpcloud committed Aug 27, 2024
1 parent afba988 commit 1de2f45
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 24 deletions.
2 changes: 1 addition & 1 deletion ibis/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ def _finalize_cached_table(self, name: str) -> None:
raise

def _create_cached_table(self, name: str, expr: ir.Table) -> ir.Table:
return self.create_table(name, expr, temp=True)
return self.create_table(name, expr, schema=expr.schema(), temp=True)

def _drop_cached_table(self, name: str) -> None:
self.drop_table(name, force=True)
Expand Down
11 changes: 7 additions & 4 deletions ibis/backends/mssql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,15 +692,18 @@ def create_table(
new = raw_this.sql(self.dialect)
cur.execute(f"EXEC sp_rename '{old}', '{new}'")

if temp:
# If a temporary table, amend the output name/catalog/db accordingly
name = "##" + name
catalog = "tempdb"
db = "dbo"

if schema is None:
# Clean up temporary memtable if we've created one
# for in-memory reads
if temp_memtable_view is not None:
self.drop_table(temp_memtable_view)
return self.table(
"##" * temp + name,
database=("tempdb" * temp or catalog, "dbo" * temp or db),
)
return self.table(name, database=(catalog, db))

# preserve the input schema if it was provided
return ops.DatabaseTable(
Expand Down
13 changes: 13 additions & 0 deletions ibis/backends/mssql/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,19 @@ def test_create_temp_table_from_obj(con):
con.drop_table("fuhreal")


@pytest.mark.parametrize("explicit_schema", [False, True])
def test_create_temp_table_from_expression(con, explicit_schema, temp_table):
t = ibis.memtable(
{"x": [1, 2, 3], "y": ["a", "b", "c"]}, schema={"x": "int64", "y": "str"}
)
t2 = con.create_table(
temp_table, t, temp=True, schema=t.schema() if explicit_schema else None
)
res = con.to_pandas(t.order_by("y"))
sol = con.to_pandas(t2.order_by("y"))
assert res.equals(sol)


def test_from_url():
user = MSSQL_USER
password = MSSQL_PASS
Expand Down
27 changes: 8 additions & 19 deletions ibis/backends/tests/test_expr_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@


@mark.notimpl(["datafusion", "flink", "impala", "trino", "druid"])
@mark.never(
["mssql"],
reason="mssql supports support temporary tables through naming conventions",
)
@mark.notimpl(["exasol"], reason="Exasol does not support temporary tables")
@pytest.mark.never(
["risingwave"],
Expand All @@ -27,15 +23,12 @@ def test_persist_expression(backend, alltypes):
)
persisted_table = non_persisted_table.cache()
backend.assert_frame_equal(
non_persisted_table.to_pandas(), persisted_table.to_pandas()
non_persisted_table.order_by("id").to_pandas(),
persisted_table.order_by("id").to_pandas(),
)


@mark.notimpl(["datafusion", "flink", "impala", "trino", "druid"])
@mark.never(
["mssql"],
reason="mssql supports support temporary tables through naming conventions",
)
@mark.notimpl(["exasol"], reason="Exasol does not support temporary tables")
@pytest.mark.never(
["risingwave"],
Expand All @@ -48,16 +41,13 @@ def test_persist_expression_contextmanager(backend, con, alltypes):
)
with non_cached_table.cache() as cached_table:
backend.assert_frame_equal(
non_cached_table.to_pandas(), cached_table.to_pandas()
non_cached_table.order_by("id").to_pandas(),
cached_table.order_by("id").to_pandas(),
)
assert non_cached_table.op() not in con._cache_op_to_entry


@mark.notimpl(["datafusion", "flink", "impala", "trino", "druid"])
@mark.never(
["mssql"],
reason="mssql supports support temporary tables through naming conventions",
)
@pytest.mark.never(
["risingwave"],
raises=com.UnsupportedOperationError,
Expand All @@ -81,7 +71,10 @@ def test_persist_expression_multiple_refs(backend, con, alltypes):
op = non_cached_table.op()
cached_table = non_cached_table.cache()

backend.assert_frame_equal(non_cached_table.to_pandas(), cached_table.to_pandas())
backend.assert_frame_equal(
non_cached_table.order_by("id").to_pandas(),
cached_table.order_by("id").to_pandas(),
)

name = cached_table.op().name
nested_cached_table = non_cached_table.cache()
Expand All @@ -104,10 +97,6 @@ def test_persist_expression_multiple_refs(backend, con, alltypes):


@mark.notimpl(["datafusion", "flink", "impala", "trino", "druid"])
@mark.never(
["mssql"],
reason="mssql supports support temporary tables through naming conventions",
)
@mark.notimpl(["exasol"], reason="Exasol does not support temporary tables")
@pytest.mark.never(
["risingwave"],
Expand Down

0 comments on commit 1de2f45

Please sign in to comment.