Skip to content
Merged
14 changes: 14 additions & 0 deletions python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,8 @@ class SessionContext:
See :ref:`user_guide_concepts` in the online documentation for more information.
"""

_global_instance = None

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why move this from the rust code to the python code?

def __init__(
self,
config: SessionConfig | None = None,
Expand Down Expand Up @@ -498,6 +500,18 @@ def __init__(

self.ctx = SessionContextInternal(config, runtime)

@classmethod
def global_ctx(cls) -> "SessionContextInternal":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want to return the wrapper SessionContext since end users may be getting this and would want to use the associated methods from the wrapper classes. It would mean updating your methods in read accordingly.

"""Retrieve the global context.

Returns:
A `SessionContextInternal` object that corresponds to the global context
"""
if cls._global_instance is None:
internal_ctx = SessionContextInternal.global_ctx()
cls._global_instance = internal_ctx
return cls._global_instance

def enable_url_table(self) -> "SessionContext":
"""Control if local files can be queried as tables.

Expand Down
10 changes: 5 additions & 5 deletions python/datafusion/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from datafusion.dataframe import DataFrame
from datafusion.expr import Expr

from ._internal import SessionContext as SessionContextInternal
from datafusion.context import SessionContext


def read_parquet(
Expand Down Expand Up @@ -65,7 +65,7 @@ def read_parquet(
if table_partition_cols is None:
table_partition_cols = []
return DataFrame(
SessionContextInternal._global_ctx().read_parquet(
SessionContext.global_ctx().read_parquet(
str(path),
table_partition_cols,
parquet_pruning,
Expand Down Expand Up @@ -107,7 +107,7 @@ def read_json(
if table_partition_cols is None:
table_partition_cols = []
return DataFrame(
SessionContextInternal._global_ctx().read_json(
SessionContext.global_ctx().read_json(
str(path),
schema,
schema_infer_max_records,
Expand Down Expand Up @@ -158,7 +158,7 @@ def read_csv(
path = [str(p) for p in path] if isinstance(path, list) else str(path)

return DataFrame(
SessionContextInternal._global_ctx().read_csv(
SessionContext.global_ctx().read_csv(
path,
schema,
has_header,
Expand Down Expand Up @@ -195,7 +195,7 @@ def read_avro(
if file_partition_cols is None:
file_partition_cols = []
return DataFrame(
SessionContextInternal._global_ctx().read_avro(
SessionContext.global_ctx().read_avro(
str(path), schema, file_partition_cols, file_extension
)
)
31 changes: 31 additions & 0 deletions python/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
literal,
)

from datafusion._internal import SessionContext as SessionContextInternal


def test_create_context_no_args():
SessionContext()
Expand Down Expand Up @@ -629,3 +631,32 @@ def test_sql_with_options_no_statements(ctx):
options = SQLOptions().with_allow_statements(False)
with pytest.raises(Exception, match="SetVariable"):
ctx.sql_with_options(sql, options=options)


def test_global_context_type():
ctx = SessionContext.global_ctx()
assert isinstance(ctx, SessionContextInternal)


def test_global_context_is_singleton():
ctx1 = SessionContext.global_ctx()
ctx2 = SessionContext.global_ctx()
assert ctx1 is ctx2


@pytest.fixture
def batch():
return pa.RecordBatch.from_arrays(
[pa.array([4, 5, 6])],
names=["a"],
)


def test_create_dataframe_with_global_ctx(batch):
ctx = SessionContext.global_ctx()

df = ctx.create_dataframe([[batch]])

result = df.collect()[0].column(0)

assert result == pa.array([4, 5, 6])
2 changes: 1 addition & 1 deletion src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ impl PySessionContext {

#[classmethod]
#[pyo3(signature = ())]
fn _global_ctx(_cls: &Bound<'_, PyType>) -> PyResult<Self> {
fn global_ctx(_cls: &Bound<'_, PyType>) -> PyResult<Self> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you have it here where you moved the single entry over to the python side, this method goes unused. I would recommend you leave this line as is, but up in the python code you call this method instead of creating _global_instance

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comments, just to summarize whats needed here:

  1. Expose the global context (_global_ctx -> global_ctx), which I've currently done.
  2. A python wrapper should be created for the global context (in the SessionContext class) which calls the above function and wraps it in SessionContext so that users can still use the other associated methods in this class, but with the global context. This should be a class method so that users dont have to instantiate SessionContext first.
  3. The read_* functions (read_parquet, etc) should use the global context from this python wrapper instead of using the one from the internal implementation.

Am I interpreting this correctly? Sorry if I'm overthinking this 😅. I've updated the PR, currently the test_read_csv and test_read_csv_list tests fail so I'm looking into that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this description looks correct.

Ok(Self {
ctx: get_global_ctx().clone(),
})
Expand Down