Skip to content

Commit 78c26cc

Browse files
committed
refactor: improve auto-registration logic for Arrow and DataFrame objects in SessionContext
1 parent dc06874 commit 78c26cc

File tree

2 files changed

+33
-16
lines changed

2 files changed

+33
-16
lines changed

python/datafusion/context.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -797,23 +797,29 @@ def _register_python_object(self, name: str, obj: Any) -> bool:
797797
if isinstance(obj, DataFrame):
798798
self.register_view(name, obj)
799799
registered = True
800-
elif (
801-
obj.__class__.__module__.startswith("polars.")
802-
and obj.__class__.__name__ == "DataFrame"
803-
):
804-
self.from_polars(obj, name=name)
805-
registered = True
806-
elif (
807-
obj.__class__.__module__.startswith("pandas.")
808-
and obj.__class__.__name__ == "DataFrame"
809-
):
810-
self.from_pandas(obj, name=name)
811-
registered = True
812-
elif isinstance(obj, (pa.Table, pa.RecordBatch, pa.RecordBatchReader)) or (
813-
hasattr(obj, "__arrow_c_stream__") or hasattr(obj, "__arrow_c_array__")
814-
):
800+
elif isinstance(obj, (pa.Table, pa.RecordBatch, pa.RecordBatchReader)):
815801
self.from_arrow(obj, name=name)
816802
registered = True
803+
else:
804+
exports_arrow_capsule = hasattr(obj, "__arrow_c_stream__") or hasattr(
805+
obj, "__arrow_c_array__"
806+
)
807+
808+
if exports_arrow_capsule:
809+
self.from_arrow(obj, name=name)
810+
registered = True
811+
elif (
812+
obj.__class__.__module__.startswith("polars.")
813+
and obj.__class__.__name__ == "DataFrame"
814+
):
815+
self.from_polars(obj, name=name)
816+
registered = True
817+
elif (
818+
obj.__class__.__module__.startswith("pandas.")
819+
and obj.__class__.__name__ == "DataFrame"
820+
):
821+
self.from_pandas(obj, name=name)
822+
registered = True
817823

818824
if registered:
819825
try:

python/tests/test_context.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,12 +372,23 @@ def test_sql_auto_register_case_insensitive_lookup():
372372
assert batches[0].column(0).to_pylist()[0] == 5
373373

374374

375-
def test_sql_auto_register_pandas_dataframe():
375+
def test_sql_auto_register_pandas_dataframe(monkeypatch):
376376
pd = pytest.importorskip("pandas")
377377

378378
ctx = SessionContext(auto_register_python_objects=True)
379379
pandas_df = pd.DataFrame({"value": [1, 2, 3, 4]}) # noqa: F841
380380

381+
if not (
382+
hasattr(pandas_df, "__arrow_c_stream__")
383+
or hasattr(pandas_df, "__arrow_c_array__")
384+
):
385+
pytest.skip("pandas does not expose Arrow capsule export")
386+
387+
def fail_from_pandas(*args, **kwargs): # noqa: ANN002, ANN003
388+
raise AssertionError("from_pandas should not be called during auto-registration")
389+
390+
monkeypatch.setattr(SessionContext, "from_pandas", fail_from_pandas)
391+
381392
result = ctx.sql(
382393
"SELECT AVG(value) AS avg_value FROM pandas_df",
383394
).collect()

0 commit comments

Comments
 (0)