Skip to content

Commit a6bfde3

Browse files
committed
add catalog provider for review
1 parent f685712 commit a6bfde3

File tree

3 files changed

+181
-0
lines changed

3 files changed

+181
-0
lines changed

catalog.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import xarray as xr
2+
import datafusion as dfn
3+
from xarray_sql.reader import read_xarray_table
4+
5+
6+
def group_vars_by_dims(ds):
7+
"""
8+
Group variables in the dataset based on shared dims
9+
10+
("time", "lat", "lon"): ["temperature_2m", "wind_speed"],
11+
("time", "lat", "lon", "level"): ["pressure", "humidity"]
12+
"""
13+
groups = {}
14+
15+
for var_name, var in ds.data_vars.items():
16+
dims = var.dims
17+
18+
if dims not in groups:
19+
groups[dims] = []
20+
21+
groups[dims].append(var_name)
22+
23+
return groups
24+
25+
26+
def dims_to_table_name(dims):
27+
"""
28+
"time", "lat", "lon" -> "time_lat_lon"
29+
"""
30+
return "_".join(dims)
31+
32+
33+
class XarraySchemaProvider(dfn.catalog.SchemaProvider):
34+
"""
35+
Custom datafusion schema that holds the tables
36+
"""
37+
38+
def __init__(self, ds, groups, chunks):
39+
# dictionary to store the tables
40+
self.tables = {}
41+
42+
# create a table for for each group of vars
43+
for dims, var_names in groups.items():
44+
table_name = dims_to_table_name(dims)
45+
subset = ds[var_names]
46+
self.tables[table_name] = read_xarray_table(subset, chunks)
47+
48+
def table_names(self):
49+
return set(self.tables.keys())
50+
51+
def table(self, name):
52+
return self.tables.get(name)
53+
54+
def table_exist(self, name):
55+
return name in self.tables
56+
57+
def register_table(self, name, table):
58+
self.tables[name] = table
59+
60+
def deregister_table(self, name, cascade=True):
61+
del self.tables[name]
62+
63+
64+
class XarrayCatalogProvider(dfn.catalog.CatalogProvider):
65+
"""
66+
Custom datafusion catalog that holds the schemas
67+
"""
68+
#Constructor
69+
def __init__(self, ds, schema_name, chunks):
70+
groups = group_vars_by_dims(ds)
71+
72+
# dictionary to store schemas using previous schema class
73+
"""
74+
"data": {
75+
"time_lat_lon": [temperature_2m, wind_speed],
76+
"time_lat_lon_level": [pressure, humidity]
77+
}
78+
"""
79+
self.schemas = {
80+
schema_name: XarraySchemaProvider(ds, groups, chunks)
81+
}
82+
83+
"""
84+
Other methods from test_catalog.py
85+
"""
86+
def schema_names(self):
87+
return set(self.schemas.keys())
88+
89+
def schema(self, name):
90+
return self.schemas.get(name)
91+
92+
def register_schema(self, name, schema):
93+
self.schemas[name] = schema
94+
95+
def deregister_schema(self, name, cascade=True):
96+
del self.schemas[name]
97+
98+
99+
def register_catalog_from_dataset(ctx, ds, catalog_name="xarray", schema_name="data", chunks=None):
100+
"""
101+
Main function. Takes an xarray dataset and registers it
102+
with DataFusion so you can query it with SQL.
103+
"""
104+
catalog = XarrayCatalogProvider(ds, schema_name, chunks)
105+
ctx.register_catalog_provider(catalog_name, catalog)

catalog_test.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import numpy as np
2+
import xarray as xr
3+
from context import XarrayContext
4+
5+
6+
# create a fake era5 dataset for testing
7+
times = np.array(["2020-01-01", "2020-01-02", "2020-01-03"], dtype="datetime64")
8+
lats = np.array([0.0, 1.0, 2.0])
9+
lons = np.array([0.0, 1.0, 2.0])
10+
levels = np.array([500, 850])
11+
12+
ds = xr.Dataset(
13+
{
14+
"temperature_2m": (["time", "lat", "lon"], np.random.rand(3, 3, 3)),
15+
"wind_speed": (["time", "lat", "lon"], np.random.rand(3, 3, 3)),
16+
"pressure": (["time", "lat", "lon", "level"], np.random.rand(3, 3, 3, 2)),
17+
"humidity": (["time", "lat", "lon", "level"], np.random.rand(3, 3, 3, 2)),
18+
},
19+
coords={
20+
"time": times,
21+
"lat": lats,
22+
"lon": lons,
23+
"level": levels,
24+
}
25+
).chunk({"time": 1})
26+
27+
print("Variables:", list(ds.data_vars))
28+
print("Dimensions:", list(ds.dims))
29+
30+
ctx = XarrayContext()
31+
ctx.register_catalog_from_dataset(ds)
32+
33+
print("\nCatalogs:", ctx.catalog_names())
34+
print("Schemas:", ctx.catalog("xarray").schema_names())
35+
print("Tables:", ctx.catalog("xarray").schema("data").table_names())
36+
37+
print("\n--- Surface variables (time, lat, lon) ---")
38+
result = ctx.sql("SELECT * FROM xarray.data.time_lat_lon LIMIT 5").collect()
39+
for batch in result:
40+
print(batch.to_pandas())
41+
42+
43+
print("\n--- Atmospheric variables (time, lat, lon, level) ---")
44+
result = ctx.sql("SELECT * FROM xarray.data.time_lat_lon_level LIMIT 5").collect()
45+
for batch in result:
46+
print(batch.to_pandas())
47+
48+
print("\n--- Joined surface + atmospheric on shared dims ---")
49+
result = ctx.sql("""
50+
SELECT
51+
s.time, s.lat, s.lon,
52+
s.temperature_2m,
53+
a.level,
54+
a.pressure
55+
FROM xarray.data.time_lat_lon s
56+
JOIN xarray.data.time_lat_lon_level a
57+
ON s.time = a.time
58+
AND s.lat = a.lat
59+
AND s.lon = a.lon
60+
LIMIT 10
61+
""").collect()
62+
for batch in result:
63+
print(batch.to_pandas())

xarray_context.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import xarray as xr
2+
from datafusion import SessionContext
3+
from catalog import register_catalog_from_dataset
4+
5+
6+
class XarrayContext(SessionContext):
7+
"""
8+
A regular DataFusion SessionContext but with an extra method
9+
for registering xarray datasets.
10+
"""
11+
12+
def register_catalog_from_dataset(self, ds, catalog_name="xarray", schema_name="data", chunks=None):
13+
register_catalog_from_dataset(self, ds, catalog_name, schema_name, chunks)

0 commit comments

Comments
 (0)