Skip to content

Commit 1f36102

Browse files
committed
feat: enhance table name extraction and add tests for local Arrow, Pandas, and Polars dataframes
1 parent 53a62f7 commit 1f36102

File tree

2 files changed

+78
-15
lines changed

2 files changed

+78
-15
lines changed

python/datafusion/context.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -692,21 +692,35 @@ def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame:
692692
"""
693693
return self.sql(query, options)
694694

695-
@staticmethod
696-
def _extract_missing_table_name(error: Exception) -> str | None:
695+
def _extract_missing_table_name(self, error: Exception) -> str | None:
696+
"""Return the missing table name if the exception represents that error."""
697697
message = str(error)
698+
699+
# Try the global pattern first (supports both table and view, case-insensitive)
700+
match = _MISSING_TABLE_PATTERN.search(message)
701+
if match:
702+
table_name = match.group(1)
703+
# Handle dotted table names by extracting the last part
704+
if "." in table_name:
705+
table_name = table_name.rsplit(".", 1)[-1]
706+
return table_name
707+
708+
# Fallback to additional patterns for broader compatibility
698709
patterns = (
699-
r"table '([^']+)' not found",
700710
r"Table not found: ['\"]?([^\s'\"]+)['\"]?",
701711
r"Table or CTE with name ['\"]?([^\s'\"]+)['\"]? not found",
702712
r"Invalid reference to table ['\"]?([^\s'\"]+)['\"]?",
703713
)
704714
for pattern in patterns:
705715
if match := re.search(pattern, message):
706-
return match.group(1)
716+
table_name = match.group(1)
717+
if "." in table_name:
718+
table_name = table_name.rsplit(".", 1)[-1]
719+
return table_name
707720
return None
708721

709722
def _register_missing_table_from_callers(self, table_name: str) -> bool:
723+
"""Register a supported local object from caller stack frames."""
710724
frame = inspect.currentframe()
711725
if frame is None:
712726
return False
@@ -729,12 +743,14 @@ def _register_missing_table_from_callers(self, table_name: str) -> bool:
729743
def _register_from_namespace(
730744
self, table_name: str, namespace: dict[str, Any]
731745
) -> bool:
746+
"""Register a table from a namespace if the table name exists."""
732747
if table_name not in namespace:
733748
return False
734749
value = namespace[table_name]
735750
return self._register_python_value(table_name, value)
736751

737752
def _register_python_value(self, table_name: str, value: Any) -> bool:
753+
"""Register a Python object as a table if it's a supported type."""
738754
pandas = _load_optional_module("pandas")
739755
polars = _load_optional_module("polars")
740756
polars_df = getattr(polars, "DataFrame", None) if polars is not None else None
@@ -753,6 +769,11 @@ def _register_python_value(self, table_name: str, value: Any) -> bool:
753769
polars_df is not None and isinstance(value, polars_df),
754770
self._register_polars_dataframe,
755771
),
772+
# Support objects with Arrow C Stream interface
773+
(
774+
hasattr(value, "__arrow_c_stream__") or hasattr(value, "__arrow_c_array__"),
775+
self._register_arrow_object,
776+
),
756777
)
757778

758779
for matches, handler in handlers:
@@ -762,48 +783,48 @@ def _register_python_value(self, table_name: str, value: Any) -> bool:
762783
return False
763784

764785
def _register_datafusion_dataframe(self, table_name: str, value: DataFrame) -> bool:
786+
"""Register a DataFusion DataFrame as a view."""
765787
try:
766788
self.register_view(table_name, value)
767789
except Exception as exc: # noqa: BLE001
768790
warnings.warn(
769-
"Failed to register DataFusion DataFrame for table "
770-
f"'{table_name}': {exc}",
791+
f"Failed to register DataFusion DataFrame for table '{table_name}': {exc}",
771792
stacklevel=4,
772793
)
773794
return False
774795
return True
775796

776797
def _register_arrow_object(self, table_name: str, value: Any) -> bool:
798+
"""Register an Arrow object (Table, RecordBatch, RecordBatchReader, or stream)."""
777799
try:
778800
self.from_arrow(value, table_name)
779801
except Exception as exc: # noqa: BLE001
780802
warnings.warn(
781-
"Failed to register Arrow data for table "
782-
f"'{table_name}': {exc}",
803+
f"Failed to register Arrow data for table '{table_name}': {exc}",
783804
stacklevel=4,
784805
)
785806
return False
786807
return True
787808

788809
def _register_pandas_dataframe(self, table_name: str, value: Any) -> bool:
810+
"""Register a pandas DataFrame."""
789811
try:
790812
self.from_pandas(value, table_name)
791813
except Exception as exc: # noqa: BLE001
792814
warnings.warn(
793-
"Failed to register pandas DataFrame for table "
794-
f"'{table_name}': {exc}",
815+
f"Failed to register pandas DataFrame for table '{table_name}': {exc}",
795816
stacklevel=4,
796817
)
797818
return False
798819
return True
799820

800821
def _register_polars_dataframe(self, table_name: str, value: Any) -> bool:
822+
"""Register a polars DataFrame."""
801823
try:
802824
self.from_polars(value, table_name)
803825
except Exception as exc: # noqa: BLE001
804826
warnings.warn(
805-
"Failed to register polars DataFrame for table "
806-
f"'{table_name}': {exc}",
827+
f"Failed to register polars DataFrame for table '{table_name}': {exc}",
807828
stacklevel=4,
808829
)
809830
return False

python/tests/test_context.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,9 +261,9 @@ def test_sql_missing_table_without_auto_register(ctx):
261261
with pytest.raises(Exception, match="not found") as excinfo:
262262
ctx.sql("SELECT * FROM arrow_table").collect()
263263

264-
missing = getattr(excinfo.value, "missing_table_names", None)
265-
assert missing is not None
266-
assert "arrow_table" in set(ctx._extract_missing_table_names(excinfo.value))
264+
# Test that our extraction method works correctly
265+
missing_tables = ctx._extract_missing_table_names(excinfo.value)
266+
assert "arrow_table" in missing_tables
267267

268268

269269
def test_sql_auto_register_arrow_table():
@@ -348,6 +348,48 @@ def test_from_pandas(ctx):
348348
assert df.collect()[0].num_rows == 3
349349

350350

351+
def test_sql_from_local_arrow_table(ctx):
352+
ctx.set_python_table_lookup(True) # Enable implicit table lookup
353+
arrow_table = pa.Table.from_pydict({"a": [1, 2], "b": ["x", "y"]})
354+
355+
result = ctx.sql("SELECT * FROM arrow_table ORDER BY a").collect()
356+
actual = pa.Table.from_batches(result)
357+
expected = pa.Table.from_pydict({"a": [1, 2], "b": ["x", "y"]})
358+
359+
assert actual.equals(expected)
360+
361+
362+
def test_sql_from_local_pandas_dataframe(ctx):
363+
ctx.set_python_table_lookup(True) # Enable implicit table lookup
364+
pd = pytest.importorskip("pandas")
365+
pandas_df = pd.DataFrame({"a": [3, 1], "b": ["z", "y"]})
366+
367+
result = ctx.sql("SELECT * FROM pandas_df ORDER BY a").collect()
368+
actual = pa.Table.from_batches(result)
369+
expected = pa.Table.from_pydict({"a": [1, 3], "b": ["y", "z"]})
370+
371+
assert actual.equals(expected)
372+
373+
374+
def test_sql_from_local_polars_dataframe(ctx):
375+
ctx.set_python_table_lookup(True) # Enable implicit table lookup
376+
pl = pytest.importorskip("polars")
377+
polars_df = pl.DataFrame({"a": [2, 1], "b": ["beta", "alpha"]})
378+
379+
result = ctx.sql("SELECT * FROM polars_df ORDER BY a").collect()
380+
actual = pa.Table.from_batches(result)
381+
expected = pa.Table.from_pydict({"a": [1, 2], "b": ["alpha", "beta"]})
382+
383+
assert actual.equals(expected)
384+
385+
386+
def test_sql_from_local_unsupported_object(ctx):
387+
unsupported = object()
388+
389+
with pytest.raises(Exception, match="table 'unsupported' not found"):
390+
ctx.sql("SELECT * FROM unsupported").collect()
391+
392+
351393
def test_from_polars(ctx):
352394
# create a dataframe from Polars dataframe
353395
pd = pytest.importorskip("polars")

0 commit comments

Comments
 (0)