Skip to content

Commit 92dde5b

Browse files
committed
feat: enable automatic registration of Python objects in SQL queries and add corresponding tests
1 parent 1f36102 commit 92dde5b

File tree

3 files changed

+163
-183
lines changed

3 files changed

+163
-183
lines changed

docs/source/user-guide/dataframe/index.rst

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,31 @@ Core Classes
228228
* :py:meth:`~datafusion.SessionContext.from_pandas` - Create from Pandas DataFrame
229229
* :py:meth:`~datafusion.SessionContext.from_arrow` - Create from Arrow data
230230

231+
``SessionContext`` automatically resolves SQL table names that match
232+
in-scope Python data objects. When ``auto_register_python_objects`` is
233+
enabled (the default), a query such as ``ctx.sql("SELECT * FROM pdf")``
234+
will register a pandas or PyArrow object named ``pdf`` without calling
235+
:py:meth:`~datafusion.SessionContext.from_pandas` or
236+
:py:meth:`~datafusion.SessionContext.from_arrow` explicitly. This requires
237+
the corresponding library (``pandas`` for pandas objects, ``pyarrow`` for
238+
Arrow objects) to be installed.
239+
240+
.. code-block:: python
241+
242+
import pandas as pd
243+
from datafusion import SessionContext
244+
245+
ctx = SessionContext()
246+
pdf = pd.DataFrame({"value": [1, 2, 3]})
247+
248+
df = ctx.sql("SELECT SUM(value) AS total FROM pdf")
249+
print(df.to_pandas()) # automatically registers `pdf`
250+
251+
To opt out, either pass ``auto_register_python_objects=False`` when
252+
constructing the session, or call
253+
:py:meth:`~datafusion.SessionContext.set_python_table_lookup` with
254+
``False`` to require explicit registration.
255+
231256
See: :py:class:`datafusion.SessionContext`
232257

233258
Expression Classes

python/datafusion/context.py

Lines changed: 100 additions & 181 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@
4545
from ._internal import SQLOptions as SQLOptionsInternal
4646
from ._internal import expr as expr_internal
4747

48-
_MISSING_TABLE_PATTERN = re.compile(r"(?i)(?:table|view) '([^']+)' not found")
49-
5048
if TYPE_CHECKING:
5149
import pathlib
5250
from collections.abc import Sequence
@@ -505,7 +503,7 @@ def __init__(
505503
config: SessionConfig | None = None,
506504
runtime: RuntimeEnvBuilder | None = None,
507505
*,
508-
auto_register_python_variables: bool = False,
506+
auto_register_python_objects: bool = True,
509507
) -> None:
510508
"""Main interface for executing queries with DataFusion.
511509
@@ -516,9 +514,9 @@ def __init__(
516514
Args:
517515
config: Session configuration options.
518516
runtime: Runtime configuration options.
519-
auto_register_python_variables: Automatically register Arrow-like
520-
Python objects referenced in SQL queries when they are available
521-
in the caller's scope.
517+
auto_register_python_objects: Automatically register referenced
518+
Python objects (such as pandas or PyArrow data) when ``sql``
519+
queries reference them by name.
522520
523521
Example usage:
524522
@@ -530,17 +528,11 @@ def __init__(
530528
ctx = SessionContext()
531529
df = ctx.read_csv("data.csv")
532530
"""
533-
python_table_lookup = auto_register_python_variables # Use parameter as default
534-
if config is not None:
535-
python_table_lookup = config._python_table_lookup
536-
config_internal = config.config_internal
537-
else:
538-
config_internal = None
539-
540-
runtime_internal = runtime.config_internal if runtime is not None else None
541-
542-
self.ctx = SessionContextInternal(config_internal, runtime_internal)
543-
self._python_table_lookup = python_table_lookup
531+
self.ctx = SessionContextInternal(
532+
config.config_internal if config is not None else None,
533+
runtime.config_internal if runtime is not None else None,
534+
)
535+
self._auto_python_table_lookup = auto_register_python_objects
544536

545537
def __repr__(self) -> str:
546538
"""Print a string representation of the Session Context."""
@@ -567,27 +559,25 @@ def enable_url_table(self) -> SessionContext:
567559
klass = self.__class__
568560
obj = klass.__new__(klass)
569561
obj.ctx = self.ctx.enable_url_table()
570-
obj._python_table_lookup = self._python_table_lookup
562+
obj._auto_python_table_lookup = getattr(
563+
self, "_auto_python_table_lookup", True
564+
)
571565
return obj
572566

573-
def set_python_table_lookup(self, enabled: bool) -> None:
574-
"""Enable or disable implicit table lookup for Python objects."""
575-
self._python_table_lookup = enabled
567+
def set_python_table_lookup(self, enabled: bool = True) -> SessionContext:
568+
"""Enable or disable automatic registration of Python objects in SQL.
576569
577-
# Backward compatibility properties
578-
@property
579-
def auto_register_python_variables(self) -> bool:
580-
"""Toggle automatic registration of Python variables in SQL queries."""
581-
return self._python_table_lookup
582-
583-
@auto_register_python_variables.setter
584-
def auto_register_python_variables(self, enabled: bool) -> None:
585-
self._python_table_lookup = bool(enabled)
570+
Args:
571+
enabled: When ``True`` (default), SQL queries automatically attempt
572+
to resolve missing table names by looking up Python objects in
573+
the caller's scope. When ``False``, missing tables will raise an
574+
error unless they have been explicitly registered.
586575
587-
def _extract_missing_table_names(self, error: Exception) -> set[str]:
588-
"""Extract missing table names from error (backward compatibility)."""
589-
missing_table = self._extract_missing_table_name(error)
590-
return {missing_table} if missing_table else set()
576+
Returns:
577+
The current :py:class:`SessionContext` instance for chaining.
578+
"""
579+
self._auto_python_table_lookup = enabled
580+
return self
591581

592582
def register_object_store(
593583
self, schema: str, store: Any, host: str | None = None
@@ -653,29 +643,28 @@ def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame:
653643
Returns:
654644
DataFrame representation of the SQL query.
655645
"""
656-
attempted_missing_tables: set[str] = set()
646+
def _execute_sql() -> DataFrame:
647+
if options is None:
648+
return DataFrame(self.ctx.sql(query))
649+
return DataFrame(
650+
self.ctx.sql_with_options(query, options.options_internal)
651+
)
657652

658-
while True:
659-
try:
660-
if options is None:
661-
result = self.ctx.sql(query)
662-
else:
663-
result = self.ctx.sql_with_options(query, options.options_internal)
664-
except Exception as exc:
665-
missing_table = self._extract_missing_table_name(exc)
666-
if (
667-
missing_table is None
668-
or missing_table in attempted_missing_tables
669-
or not self._python_table_lookup
670-
):
671-
raise
672-
673-
attempted_missing_tables.add(missing_table)
674-
if not self._register_missing_table_from_callers(missing_table):
675-
raise
676-
continue
653+
try:
654+
return _execute_sql()
655+
except Exception as err:
656+
if not getattr(self, "_auto_python_table_lookup", True):
657+
raise
658+
659+
missing_tables = self._extract_missing_table_names(err)
660+
if not missing_tables:
661+
raise
677662

678-
return DataFrame(result)
663+
registered = self._register_python_tables(missing_tables)
664+
if not registered:
665+
raise
666+
667+
return _execute_sql()
679668

680669
def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame:
681670
"""Create a :py:class:`~datafusion.dataframe.DataFrame` from SQL query text.
@@ -692,144 +681,74 @@ def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame:
692681
"""
693682
return self.sql(query, options)
694683

695-
def _extract_missing_table_name(self, error: Exception) -> str | None:
696-
"""Return the missing table name if the exception represents that error."""
697-
message = str(error)
698-
699-
# Try the global pattern first (supports both table and view, case-insensitive)
700-
match = _MISSING_TABLE_PATTERN.search(message)
701-
if match:
702-
table_name = match.group(1)
703-
# Handle dotted table names by extracting the last part
704-
if "." in table_name:
705-
table_name = table_name.rsplit(".", 1)[-1]
706-
return table_name
707-
708-
# Fallback to additional patterns for broader compatibility
709-
patterns = (
710-
r"Table not found: ['\"]?([^\s'\"]+)['\"]?",
711-
r"Table or CTE with name ['\"]?([^\s'\"]+)['\"]? not found",
712-
r"Invalid reference to table ['\"]?([^\s'\"]+)['\"]?",
713-
)
714-
for pattern in patterns:
715-
if match := re.search(pattern, message):
716-
table_name = match.group(1)
717-
if "." in table_name:
718-
table_name = table_name.rsplit(".", 1)[-1]
719-
return table_name
720-
return None
684+
@staticmethod
685+
def _extract_missing_table_names(err: Exception) -> list[str]:
686+
message = str(err)
687+
matches = set()
688+
for pattern in (r"table '([^']+)' not found", r"No table named '([^']+)'"):
689+
matches.update(re.findall(pattern, message))
690+
691+
tables: list[str] = []
692+
for raw_name in matches:
693+
if not raw_name:
694+
continue
695+
tables.append(raw_name.rsplit(".", 1)[-1])
696+
return tables
721697

722-
def _register_missing_table_from_callers(self, table_name: str) -> bool:
723-
"""Register a supported local object from caller stack frames."""
724-
frame = inspect.currentframe()
725-
if frame is None:
726-
return False
698+
def _register_python_tables(self, tables: list[str]) -> bool:
699+
registered_any = False
700+
for table_name in tables:
701+
if not table_name or self.table_exist(table_name):
702+
continue
703+
704+
python_obj = self._lookup_python_object(table_name)
705+
if python_obj is None:
706+
continue
727707

708+
if self._register_python_object(table_name, python_obj):
709+
registered_any = True
710+
711+
return registered_any
712+
713+
@staticmethod
714+
def _lookup_python_object(name: str) -> Any | None:
715+
frame = inspect.currentframe()
728716
try:
729-
frame = frame.f_back
730-
if frame is None:
731-
return False
732-
frame = frame.f_back
717+
if frame is not None:
718+
frame = frame.f_back
733719
while frame is not None:
734-
if self._register_from_namespace(table_name, frame.f_locals):
735-
return True
736-
if self._register_from_namespace(table_name, frame.f_globals):
737-
return True
720+
locals_dict = frame.f_locals
721+
if name in locals_dict:
722+
return locals_dict[name]
723+
globals_dict = frame.f_globals
724+
if name in globals_dict:
725+
return globals_dict[name]
738726
frame = frame.f_back
739727
finally:
740728
del frame
741-
return False
729+
return None
742730

743-
def _register_from_namespace(
744-
self, table_name: str, namespace: dict[str, Any]
745-
) -> bool:
746-
"""Register a table from a namespace if the table name exists."""
747-
if table_name not in namespace:
748-
return False
749-
value = namespace[table_name]
750-
return self._register_python_value(table_name, value)
751-
752-
def _register_python_value(self, table_name: str, value: Any) -> bool:
753-
"""Register a Python object as a table if it's a supported type."""
754-
pandas = _load_optional_module("pandas")
755-
polars = _load_optional_module("polars")
756-
polars_df = getattr(polars, "DataFrame", None) if polars is not None else None
757-
758-
handlers = (
759-
(isinstance(value, DataFrame), self._register_datafusion_dataframe),
760-
(
761-
isinstance(value, (pa.Table, pa.RecordBatch, pa.RecordBatchReader)),
762-
self._register_arrow_object,
763-
),
764-
(
765-
pandas is not None and isinstance(value, pandas.DataFrame),
766-
self._register_pandas_dataframe,
767-
),
768-
(
769-
polars_df is not None and isinstance(value, polars_df),
770-
self._register_polars_dataframe,
771-
),
772-
# Support objects with Arrow C Stream interface
773-
(
774-
hasattr(value, "__arrow_c_stream__") or hasattr(value, "__arrow_c_array__"),
775-
self._register_arrow_object,
776-
),
777-
)
731+
def _register_python_object(self, name: str, obj: Any) -> bool:
732+
if isinstance(obj, DataFrame):
733+
self.register_view(name, obj)
734+
return True
778735

779-
for matches, handler in handlers:
780-
if matches:
781-
return handler(table_name, value)
736+
if (
737+
obj.__class__.__module__.startswith("pandas.")
738+
and obj.__class__.__name__ == "DataFrame"
739+
):
740+
self.from_pandas(obj, name=name)
741+
return True
782742

783-
return False
743+
if isinstance(obj, (pa.Table, pa.RecordBatch, pa.RecordBatchReader)):
744+
self.from_arrow(obj, name=name)
745+
return True
784746

785-
def _register_datafusion_dataframe(self, table_name: str, value: DataFrame) -> bool:
786-
"""Register a DataFusion DataFrame as a view."""
787-
try:
788-
self.register_view(table_name, value)
789-
except Exception as exc: # noqa: BLE001
790-
warnings.warn(
791-
f"Failed to register DataFusion DataFrame for table '{table_name}': {exc}",
792-
stacklevel=4,
793-
)
794-
return False
795-
return True
796-
797-
def _register_arrow_object(self, table_name: str, value: Any) -> bool:
798-
"""Register an Arrow object (Table, RecordBatch, RecordBatchReader, or stream)."""
799-
try:
800-
self.from_arrow(value, table_name)
801-
except Exception as exc: # noqa: BLE001
802-
warnings.warn(
803-
f"Failed to register Arrow data for table '{table_name}': {exc}",
804-
stacklevel=4,
805-
)
806-
return False
807-
return True
808-
809-
def _register_pandas_dataframe(self, table_name: str, value: Any) -> bool:
810-
"""Register a pandas DataFrame."""
811-
try:
812-
self.from_pandas(value, table_name)
813-
except Exception as exc: # noqa: BLE001
814-
warnings.warn(
815-
f"Failed to register pandas DataFrame for table '{table_name}': {exc}",
816-
stacklevel=4,
817-
)
818-
return False
819-
return True
820-
821-
def _register_polars_dataframe(self, table_name: str, value: Any) -> bool:
822-
"""Register a polars DataFrame."""
823-
try:
824-
self.from_polars(value, table_name)
825-
except Exception as exc: # noqa: BLE001
826-
warnings.warn(
827-
f"Failed to register polars DataFrame for table '{table_name}': {exc}",
828-
stacklevel=4,
829-
)
830-
return False
831-
return True
747+
if hasattr(obj, "__arrow_c_stream__") or hasattr(obj, "__arrow_c_array__"):
748+
self.from_arrow(obj, name=name)
749+
return True
832750

751+
return False
833752
def create_dataframe(
834753
self,
835754
partitions: list[list[pa.RecordBatch]],

0 commit comments

Comments
 (0)