Skip to content

Commit aa93ee5

Browse files
committed
Claude address non-ambiguous items of feedback.
1 parent 1485a27 commit aa93ee5

File tree

2 files changed

+67
-47
lines changed

2 files changed

+67
-47
lines changed

xarray_sql/cft.py

Lines changed: 65 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,17 @@
3030
# ---------------------------------------------------------------------------
3131

3232
#: Calendars close enough to proleptic Gregorian for ``pa.timestamp('us')``.
33-
GREGORIAN_LIKE_CALENDARS: frozenset[str] = frozenset({
34-
'standard', 'gregorian', 'proleptic_gregorian',
35-
'noleap', '365_day',
36-
'all_leap', '366_day',
37-
})
33+
GREGORIAN_LIKE_CALENDARS: frozenset[str] = frozenset(
34+
{
35+
'standard',
36+
'gregorian',
37+
'proleptic_gregorian',
38+
'noleap',
39+
'365_day',
40+
'all_leap',
41+
'366_day',
42+
}
43+
)
3844

3945
#: Default CF-convention units when no encoding is available on the coordinate.
4046
#: Microseconds give sub-second precision and fit int64 for ±292 k years.
@@ -50,10 +56,12 @@ def is_gregorian_like(calendar: str) -> bool:
5056
# Detection helpers (avoid materializing Dask/Zarr data where possible)
5157
# ---------------------------------------------------------------------------
5258

53-
def is_cftime(values) -> bool:
59+
60+
def is_cftime(values: np.ndarray) -> bool:
5461
"""Check if a numpy array contains cftime datetime objects."""
5562
try:
5663
import cftime
64+
5765
if values.dtype == np.dtype('O') and len(values) > 0:
5866
sample = values.ravel()[0]
5967
return isinstance(sample, cftime.datetime)
@@ -68,8 +76,9 @@ def is_cftime_index(ds: xr.Dataset, coord_name: str) -> bool:
6876
idx = ds.indexes.get(coord_name)
6977
if idx is not None:
7078
from xarray import CFTimeIndex
79+
7180
return isinstance(idx, CFTimeIndex)
72-
except Exception:
81+
except (ImportError, AttributeError):
7382
pass
7483
return False
7584

@@ -84,15 +93,16 @@ def calendar(ds: xr.Dataset, coord_name: str) -> str | None:
8493
idx = ds.indexes.get(coord_name)
8594
if idx is not None:
8695
from xarray import CFTimeIndex
96+
8797
if isinstance(idx, CFTimeIndex):
88-
return idx.calendar # type: ignore[attr-defined]
89-
except Exception:
98+
return str(idx.calendar) # type: ignore[attr-defined]
99+
except (ImportError, AttributeError):
90100
pass
91101
try:
92102
values = ds.coords[coord_name].values
93103
if is_cftime(values):
94-
return values.ravel()[0].calendar
95-
except Exception:
104+
return str(values.ravel()[0].calendar)
105+
except (AttributeError, KeyError):
96106
pass
97107
return None
98108

@@ -113,13 +123,15 @@ def encoding(ds: xr.Dataset, coord_name: str) -> tuple[str, str]:
113123
# Numeric conversion
114124
# ---------------------------------------------------------------------------
115125

126+
116127
def to_microseconds(values) -> np.ndarray:
117128
"""Convert cftime objects to int64 microseconds since Unix epoch.
118129
119130
Used for Gregorian-like calendars. Vectorised via ``cftime.date2num``
120131
(implemented in C).
121132
"""
122133
import cftime as _cftime
134+
123135
us = _cftime.date2num(
124136
values.ravel(),
125137
units=DEFAULT_UNITS,
@@ -134,6 +146,7 @@ def to_offsets(values, units: str, cal: str) -> np.ndarray:
134146
Used for non-Gregorian calendars where data is stored as ``pa.int64()``.
135147
"""
136148
import cftime as _cftime
149+
137150
raw = _cftime.date2num(values.ravel(), units=units, calendar=cal)
138151
return np.asarray(raw, dtype=np.float64).astype(np.int64)
139152

@@ -156,6 +169,7 @@ def convert_for_field(values, field: pa.Field) -> np.ndarray:
156169
# Partition pruning helpers
157170
# ---------------------------------------------------------------------------
158171

172+
159173
def partition_bounds(
160174
values,
161175
) -> tuple[int, int, str]:
@@ -178,6 +192,7 @@ def partition_bounds(
178192
# Arrow schema helpers
179193
# ---------------------------------------------------------------------------
180194

195+
181196
def arrow_field(name: str, units: str, cal: str) -> pa.Field:
182197
"""Build a ``pa.Field`` for a cftime coordinate.
183198
@@ -192,3 +207,42 @@ def arrow_field(name: str, units: str, cal: str) -> pa.Field:
192207
if is_gregorian_like(cal):
193208
return pa.field(name, pa.timestamp('us'), metadata=meta)
194209
return pa.field(name, pa.int64(), metadata=meta)
210+
211+
212+
# ---------------------------------------------------------------------------
213+
# DataFusion UDF
214+
# ---------------------------------------------------------------------------
215+
216+
217+
def make_cftime_udf(units: str, calendar: str):
218+
"""Create a DataFusion scalar UDF that converts date strings to int64 offsets.
219+
220+
This enables ergonomic SQL filtering on non-Gregorian cftime columns::
221+
222+
SELECT * FROM ds360 WHERE time > cftime('0500-01-01')
223+
224+
The UDF parses the input string as a cftime datetime in the given
225+
calendar system and returns the corresponding int64 offset in the
226+
specified units.
227+
"""
228+
import cftime as _cftime
229+
from datafusion import udf
230+
231+
def _cftime_scalar(date_strings: pa.Array) -> pa.Array:
232+
results: list[int | None] = []
233+
for s in date_strings.to_pylist():
234+
if s is None:
235+
results.append(None)
236+
continue
237+
dt = _cftime.datetime.strptime(s, '%Y-%m-%d', calendar=calendar)
238+
val = _cftime.date2num(dt, units=units, calendar=calendar)
239+
results.append(int(val))
240+
return pa.array(results, type=pa.int64())
241+
242+
return udf(
243+
_cftime_scalar,
244+
[pa.utf8()],
245+
pa.int64(),
246+
'immutable',
247+
'cftime',
248+
)

xarray_sql/sql.py

Lines changed: 2 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,11 @@
1-
import pyarrow as pa
21
import xarray as xr
3-
from datafusion import SessionContext, udf
2+
from datafusion import SessionContext
43

54
from . import cft
65
from .df import Chunks
76
from .reader import read_xarray_table
87

98

10-
def _make_cftime_udf(units: str, calendar: str):
11-
"""Create a DataFusion scalar UDF that converts date strings to int64 offsets.
12-
13-
This enables ergonomic SQL filtering on non-Gregorian cftime columns::
14-
15-
SELECT * FROM ds360 WHERE time > cftime('0500-01-01')
16-
17-
The UDF parses the input string as a cftime datetime in the given
18-
calendar system and returns the corresponding int64 offset in the
19-
specified units.
20-
"""
21-
import cftime as _cftime
22-
23-
def _cftime_scalar(date_strings: pa.Array) -> pa.Array:
24-
results = []
25-
for s in date_strings.to_pylist():
26-
if s is None:
27-
results.append(None)
28-
continue
29-
dt = _cftime.datetime.strptime(s, '%Y-%m-%d', calendar=calendar)
30-
val = _cftime.date2num(dt, units=units, calendar=calendar)
31-
results.append(int(val))
32-
return pa.array(results, type=pa.int64())
33-
34-
return udf(
35-
_cftime_scalar,
36-
[pa.utf8()],
37-
pa.int64(),
38-
'immutable',
39-
'cftime',
40-
)
41-
42-
439
class XarrayContext(SessionContext):
4410
"""A datafusion `SessionContext` that also supports `xarray.Dataset`s."""
4511

@@ -58,7 +24,7 @@ def from_dataset(
5824
if cft.is_cftime_index(input_table, coord_name):
5925
units, cal = cft.encoding(input_table, coord_name)
6026
if not cft.is_gregorian_like(cal):
61-
self.register_udf(_make_cftime_udf(units, cal))
27+
self.register_udf(cft.make_cftime_udf(units, cal))
6228
break # One UDF per context is enough.
6329

6430
return self

0 commit comments

Comments
 (0)