1+ import xarray as xr
2+ import datafusion as dfn
3+ from xarray_sql .reader import read_xarray_table
4+
5+
6+ def group_vars_by_dims (ds ):
7+ """
8+ Group variables in the dataset based on shared dims
9+
10+ ("time", "lat", "lon"): ["temperature_2m", "wind_speed"],
11+ ("time", "lat", "lon", "level"): ["pressure", "humidity"]
12+ """
13+ groups = {}
14+
15+ for var_name , var in ds .data_vars .items ():
16+ dims = var .dims
17+
18+ if dims not in groups :
19+ groups [dims ] = []
20+
21+ groups [dims ].append (var_name )
22+
23+ return groups
24+
25+
26+ def dims_to_table_name (dims ):
27+ """
28+ "time", "lat", "lon" -> "time_lat_lon"
29+ """
30+ return "_" .join (dims )
31+
32+
33+ class XarraySchemaProvider (dfn .catalog .SchemaProvider ):
34+ """
35+ Custom datafusion schema that holds the tables
36+ """
37+
38+ def __init__ (self , ds , groups , chunks ):
39+ # dictionary to store the tables
40+ self .tables = {}
41+
42+ # create a table for for each group of vars
43+ for dims , var_names in groups .items ():
44+ table_name = dims_to_table_name (dims )
45+ subset = ds [var_names ]
46+ self .tables [table_name ] = read_xarray_table (subset , chunks )
47+
48+ def table_names (self ):
49+ return set (self .tables .keys ())
50+
51+ def table (self , name ):
52+ return self .tables .get (name )
53+
54+ def table_exist (self , name ):
55+ return name in self .tables
56+
57+ def register_table (self , name , table ):
58+ self .tables [name ] = table
59+
60+ def deregister_table (self , name , cascade = True ):
61+ del self .tables [name ]
62+
63+
64+ class XarrayCatalogProvider (dfn .catalog .CatalogProvider ):
65+ """
66+ Custom datafusion catalog that holds the schemas
67+ """
68+ #Constructor
69+ def __init__ (self , ds , schema_name , chunks ):
70+ groups = group_vars_by_dims (ds )
71+
72+ # dictionary to store schemas using previous schema class
73+ """
74+ "data": {
75+ "time_lat_lon": [temperature_2m, wind_speed],
76+ "time_lat_lon_level": [pressure, humidity]
77+ }
78+ """
79+ self .schemas = {
80+ schema_name : XarraySchemaProvider (ds , groups , chunks )
81+ }
82+
83+ """
84+ Other methods from test_catalog.py
85+ """
86+ def schema_names (self ):
87+ return set (self .schemas .keys ())
88+
89+ def schema (self , name ):
90+ return self .schemas .get (name )
91+
92+ def register_schema (self , name , schema ):
93+ self .schemas [name ] = schema
94+
95+ def deregister_schema (self , name , cascade = True ):
96+ del self .schemas [name ]
97+
98+
99+ def register_catalog_from_dataset (ctx , ds , catalog_name = "xarray" , schema_name = "data" , chunks = None ):
100+ """
101+ Main function. Takes an xarray dataset and registers it
102+ with DataFusion so you can query it with SQL.
103+ """
104+ catalog = XarrayCatalogProvider (ds , schema_name , chunks )
105+ ctx .register_catalog_provider (catalog_name , catalog )
0 commit comments