Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions python/xorq/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,85 @@
import functools
import importlib.metadata
from abc import ABC
from typing import Any, Mapping

from xorq.vendor.ibis import BaseBackend
from xorq.vendor.ibis.expr import types as ir


class ExecutionBackend(BaseBackend, ABC):
def _pandas_execute(self, expr: ir.Expr, **kwargs):
from xorq.expr.api import _transform_expr
from xorq.expr.relations import FlightExpr, FlightUDXF

node = expr.op()
if isinstance(node, (FlightExpr, FlightUDXF)):
df = node.to_rbr().read_pandas(timestamp_as_object=True)
return expr.__pandas_result__(df)
(expr, created) = _transform_expr(expr)

return super().execute(expr, **kwargs)

def execute(self, expr, **kwargs) -> Any:
if self.name == "pandas":
return self._pandas_execute(expr, **kwargs)

batch_reader = self.to_pyarrow_batches(expr, **kwargs)
df = batch_reader.read_pandas(timestamp_as_object=True)

return expr.__pandas_result__(df)

def to_pyarrow_batches(
self,
expr: ir.Expr,
*,
chunk_size: int = 1_000_000,
**kwargs: Any,
):
from xorq.common.utils.defer_utils import rbr_wrapper
from xorq.expr.api import _transform_expr
from xorq.expr.relations import FlightExpr, FlightUDXF

if isinstance(expr.op(), (FlightExpr, FlightUDXF)):
return expr.op().to_rbr()
(expr, created) = _transform_expr(expr)
reader = super().to_pyarrow_batches(expr, chunk_size=chunk_size, **kwargs)

def clean_up():
for table_name, conn in created.items():
try:
conn.drop_table(table_name, force=True)
except Exception:
conn.drop_view(table_name)

return rbr_wrapper(reader, clean_up)

def _pandas_to_pyarrow(self, expr, **kwargs):
from xorq.expr.api import _transform_expr
from xorq.expr.relations import FlightExpr, FlightUDXF

node = expr.op()
if isinstance(node, (FlightExpr, FlightUDXF)):
df = node.to_rbr().read_pandas(timestamp_as_object=True)
return expr.__pyarrow_result__(df)
(expr, created) = _transform_expr(expr)

return super().to_pyarrow(expr, **kwargs)

def to_pyarrow(
self,
expr: ir.Expr,
*,
params: Mapping[ir.Scalar, Any] | None = None,
limit: int | str | None = None,
**kwargs: Any,
):
if self.name == "pandas":
return self._pandas_to_pyarrow(expr, **kwargs)

batch_reader = self.to_pyarrow_batches(expr, **kwargs)
arrow_table = batch_reader.read_all()
return expr.__pyarrow_result__(arrow_table)


@functools.cache
Expand Down
3 changes: 2 additions & 1 deletion python/xorq/backends/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import xorq.vendor.ibis.expr.operations as ops
import xorq.vendor.ibis.expr.schema as sch
import xorq.vendor.ibis.expr.types as ir
from xorq.backends import ExecutionBackend
from xorq.vendor import ibis
from xorq.vendor.ibis.backends.datafusion import Backend as IbisDatafusionBackend
from xorq.vendor.ibis.common.dispatch import lazy_singledispatch
Expand All @@ -21,7 +22,7 @@
import pandas as pd


class Backend(IbisDatafusionBackend):
class Backend(ExecutionBackend, IbisDatafusionBackend):
def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
self.con.from_arrow(op.data.to_pyarrow(op.schema), op.name)

Expand Down
7 changes: 6 additions & 1 deletion python/xorq/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import pyarrow as pa

from xorq.backends import ExecutionBackend
from xorq.vendor.ibis.backends.duckdb import Backend as IbisDuckDBBackend
from xorq.vendor.ibis.expr import types as ir
from xorq.vendor.ibis.util import gen_name


class Backend(IbisDuckDBBackend):
class BaseExecutionBackend(IbisDuckDBBackend):
def execute(
self,
expr: ir.Expr,
Expand Down Expand Up @@ -37,3 +38,7 @@ def to_pyarrow_batches(
return self._to_duckdb_relation(
expr, params=params, limit=limit
).fetch_arrow_reader(chunk_size)


class Backend(ExecutionBackend, BaseExecutionBackend):
pass
39 changes: 2 additions & 37 deletions python/xorq/backends/let/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pyarrow_hotfix # noqa: F401
from sqlglot import exp, parse_one

from xorq.backends import ExecutionBackend
from xorq.backends.let.datafusion import Backend as DataFusionBackend
from xorq.common.collections import SourceDict
from xorq.internal import SessionConfig, WindowUDF
Expand Down Expand Up @@ -35,7 +36,7 @@ def _get_datafusion_dataframe(con, expr, **kwargs):
return con.con.sql(raw_sql)


class Backend(DataFusionBackend):
class Backend(ExecutionBackend, DataFusionBackend):
name = "let"

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -158,42 +159,6 @@ def create_table(
self._sources[registered_table.op()] = registered_table.op()
return registered_table

def execute(self, expr: ir.Expr, **kwargs: Any):
batch_reader = self.to_pyarrow_batches(expr, **kwargs)
return expr.__pandas_result__(
batch_reader.read_pandas(timestamp_as_object=True)
)

def to_pyarrow(self, expr: ir.Expr, **kwargs: Any) -> pa.Table:
batch_reader = self.to_pyarrow_batches(expr, **kwargs)
arrow_table = batch_reader.read_all()
return expr.__pyarrow_result__(arrow_table)

def to_pyarrow_batches(
self,
expr: ir.Expr,
*,
chunk_size: int = 1_000_000,
**kwargs: Any,
) -> pa.ipc.RecordBatchReader:
return super().to_pyarrow_batches(expr, chunk_size=chunk_size, **kwargs)

def do_connect(self, config: SessionConfig | None = None) -> None:
"""Creates a connection.

Parameters
----------
config
Mapping of table names to files.

Examples
--------
>>> import xorq.api as xo
>>> con = xo.connect()

"""
super().do_connect(config=config)

def _to_sqlglot(
self, expr: ir.Expr, *, limit: str | None = None, params=None, **_: Any
):
Expand Down
1 change: 1 addition & 0 deletions python/xorq/backends/let/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,7 @@ def to_pyarrow(self, expr: ir.Expr, **kwargs: Any) -> pa.Table:
return expr.__pyarrow_result__(arrow_table)

def execute(self, expr: ir.Expr, **kwargs: Any):
breakpoint()
batch_reader = self._to_pyarrow_batches(expr, **kwargs)
return expr.__pandas_result__(
batch_reader.read_pandas(timestamp_as_object=True)
Expand Down
7 changes: 6 additions & 1 deletion python/xorq/backends/pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import xorq.vendor.ibis.expr.operations as ops
import xorq.vendor.ibis.expr.schema as sch
import xorq.vendor.ibis.expr.types as ir
from xorq.backends import ExecutionBackend
from xorq.vendor import ibis
from xorq.vendor.ibis import util
from xorq.vendor.ibis.backends import BaseBackend, NoUrl
Expand Down Expand Up @@ -302,7 +303,7 @@ def to_pyarrow_batches(
)


class Backend(BasePandasBackend):
class BaseExecutionBackend(BasePandasBackend):
name = "pandas"

def execute(self, query, params=None, limit="default", **kwargs):
Expand Down Expand Up @@ -354,6 +355,10 @@ def read_record_batches(
return self.table(table_name)


class Backend(ExecutionBackend, BaseExecutionBackend):
name = "pandas"


@lazy_singledispatch
def _convert_object(obj: Any, _conn):
raise com.BackendConversionError(
Expand Down
3 changes: 2 additions & 1 deletion python/xorq/backends/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import toolz

import xorq.vendor.ibis.expr.schema as sch
from xorq.backends import ExecutionBackend
from xorq.backends.postgres.compiler import compiler
from xorq.common.utils.defer_utils import (
read_csv_rbr,
Expand All @@ -21,7 +22,7 @@
)


class Backend(IbisPostgresBackend):
class Backend(ExecutionBackend, IbisPostgresBackend):
_top_level_methods = ("connect_examples", "connect_env")
compiler = compiler

Expand Down
7 changes: 6 additions & 1 deletion python/xorq/backends/pyiceberg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pyiceberg.table import Table as IcebergTable

import xorq.vendor.ibis.expr.operations as ops
from xorq.backends import ExecutionBackend
from xorq.backends.postgres.compiler import compiler as postgres_compiler
from xorq.backends.pyiceberg.compiler import PyIceberg, translate
from xorq.backends.pyiceberg.relations import PyIcebergTable
Expand Down Expand Up @@ -50,7 +51,7 @@ def _overwrite_table_data(iceberg_table: IcebergTable, data: pa.Table):
tx.commit_transaction()


class Backend(SQLBackend):
class BaseExecutionBackend(SQLBackend):
name = "pyiceberg"
dialect = PyIceberg
compiler = postgres_compiler
Expand Down Expand Up @@ -317,3 +318,7 @@ def list_snapshots(self, database=None) -> dict[str, int]:
)

return snapshots


class Backend(ExecutionBackend, BaseExecutionBackend):
pass
3 changes: 2 additions & 1 deletion python/xorq/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import xorq.vendor.ibis.expr.api as api
import xorq.vendor.ibis.expr.schema as sch
import xorq.vendor.ibis.expr.types as ir
from xorq.backends import ExecutionBackend
from xorq.common.utils.logging_utils import get_logger
from xorq.expr.relations import replace_cache_table
from xorq.vendor.ibis.backends.snowflake import _SNOWFLAKE_MAP_UDFS
Expand All @@ -23,7 +24,7 @@
logger = get_logger(__name__)


class Backend(IbisSnowflakeBackend):
class Backend(ExecutionBackend, IbisSnowflakeBackend):
_top_level_methods = ("connect_env",)

@classmethod
Expand Down
3 changes: 2 additions & 1 deletion python/xorq/backends/sqlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sqlglot as sg
import sqlglot.expressions as sge

from xorq.backends import ExecutionBackend
from xorq.expr.api import read_csv, read_parquet
from xorq.vendor.ibis import Schema, util
from xorq.vendor.ibis.backends.sqlite import Backend as IbisSQLiteBackend
Expand All @@ -18,7 +19,7 @@
import pyarrow as pa


class Backend(IbisSQLiteBackend):
class Backend(ExecutionBackend, IbisSQLiteBackend):
def read_record_batches(
self,
record_batches: pa.RecordBatchReader,
Expand Down
3 changes: 2 additions & 1 deletion python/xorq/backends/trino/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from xorq.backends import ExecutionBackend
from xorq.vendor.ibis.backends.trino import Backend as IbisTrinoBackend


class Backend(IbisTrinoBackend):
class Backend(ExecutionBackend, IbisTrinoBackend):
pass
2 changes: 1 addition & 1 deletion python/xorq/caching/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,6 @@ def maybe_prevent_cross_source_caching(expr, storage):
into_backend,
)

if storage.storage.source != expr._find_backend():
if storage.storage.source is not expr._find_backend():
expr = into_backend(expr, storage.storage.source)
return expr
3 changes: 3 additions & 0 deletions python/xorq/common/utils/defer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ def deferred_read_parquet(
The name to give to the resulting table in the backend. If not provided,
a unique name will be generated automatically.

normalize_method : Callable, optional
The method that returns the values to be used in the hashing of the Read operation.

**kwargs : dict
Additional keyword arguments passed to the backend's read_parquet method.

Expand Down
7 changes: 2 additions & 5 deletions python/xorq/expr/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,8 @@ def replace_read(node, _kwargs):
),
},
)
if node.source.name == "pandas":
# FIXME: pandas read is not lazy, leave it to the pandas executor to do
node = dt_to_read[node] = node.make_dt()
else:
node = dt_to_read[node] = node.make_dt()
# FIXME: pandas read is not lazy, leave it to the pandas executor to do
node = dt_to_read[node] = node.make_dt()
else:
if _kwargs:
node = node.__recreate__(_kwargs)
Expand Down
2 changes: 1 addition & 1 deletion python/xorq/expr/ml/tests/test_split_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,6 @@ def test_calc_split_column(connect_method, n, name):
.value_counts()
.order_by(xo.asc(name))
)
df = xo.execute(expr)
df = expr.execute()
assert tuple(df[name].values) == tuple(range(n))
assert df[f"{name}_count"].sum() == N
17 changes: 5 additions & 12 deletions python/xorq/expr/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,26 +561,19 @@ def flight_udxf(
)


class Read(ops.Relation):
method_name: str
name: str
schema: Schema
source: Any
read_kwargs: Any
normalize_method: Any
values = FrozenDict()
class Read(ops.DatabaseTable):
method_name: str = None
read_kwargs: Any = None
normalize_method: Any = None

def make_dt(self):
method = getattr(self.source, self.method_name)
dt = method(**dict(self.read_kwargs)).op()
return dt

def make_unbound_dt(self):
import dask

name = f"{self.name}-{dask.base.tokenize(self)}"
return ops.UnboundTable(
name=name,
name=self.name,
schema=self.schema,
)

Expand Down
Loading