Skip to content

Commit ea6870a

Browse files
committed
Add test to create test datasets for rust tests.
1 parent 986b498 commit ea6870a

File tree

2 files changed

+313
-0
lines changed

2 files changed

+313
-0
lines changed

.github/workflows/ci-rust.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ jobs:
1818
- uses: actions/checkout@v4
1919
with:
2020
submodules: "recursive"
21+
- name: Install uv
22+
uses: astral-sh/setup-uv@v5
2123
- name: Install Rust
2224
uses: dtolnay/rust-toolchain@stable
2325
with:
@@ -27,6 +29,8 @@ jobs:
2729
run: cargo clippy --all-features --tests -- -D warnings
2830
- name: Check
2931
run: cargo check --all-features
32+
- name: create test data
33+
run: uv run examples/create_test_datasets.py
3034
- name: Test
3135
run: cargo test --all-features
3236

examples/create_test_datasets.py

Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
#!/usr/bin/env python3
2+
# /// script
3+
# requires-python = ">=3.11"
4+
# dependencies = [
5+
# "xarray",
6+
# "zarr",
7+
# ]
8+
# ///
9+
"""
10+
Create diverse test Zarr datasets for comprehensive SQL testing.
11+
"""
12+
13+
import numpy as np
14+
import xarray as xr
15+
import shutil
16+
from pathlib import Path
17+
18+
19+
def get_project_root():
20+
"""Get the project root directory."""
21+
# Assume script is in examples/ and project root is parent
22+
return Path(__file__).parent.parent
23+
24+
25+
def create_weather_dataset():
26+
"""Create a weather dataset with temperature, pressure, humidity."""
27+
print('Creating weather dataset...')
28+
29+
# 4D dataset: time (5), lat (3), lon (4), altitude (2)
30+
time = np.arange(0, 5) # 5 time points
31+
lat = np.array([30.0, 35.0, 40.0]) # 3 latitudes
32+
lon = np.array([-120.0, -115.0, -110.0, -105.0]) # 4 longitudes
33+
altitude = np.array([0, 1000]) # 2 altitude levels (0m, 1000m)
34+
35+
# Create 4D data arrays
36+
shape = (5, 3, 4, 2) # 120 total points
37+
38+
# Temperature: varies by lat, decreases with altitude
39+
temperature_data = np.random.normal(20, 5, shape)
40+
for alt_idx in range(2):
41+
for lat_idx in range(3):
42+
temperature_data[:, lat_idx, :, alt_idx] += (
43+
lat[lat_idx] - 35
44+
) * 0.5 - alt_idx * 10
45+
46+
# Pressure: decreases with altitude, varies by location
47+
pressure_data = np.random.normal(1013, 20, shape)
48+
for alt_idx in range(2):
49+
pressure_data[:, :, :, alt_idx] -= alt_idx * 100
50+
51+
# Humidity: random but realistic
52+
humidity_data = np.random.uniform(30, 90, shape)
53+
54+
ds = xr.Dataset(
55+
{
56+
'temperature': (['time', 'lat', 'lon', 'altitude'], temperature_data),
57+
'pressure': (['time', 'lat', 'lon', 'altitude'], pressure_data),
58+
'humidity': (['time', 'lat', 'lon', 'altitude'], humidity_data),
59+
},
60+
coords={
61+
'time': time,
62+
'lat': lat,
63+
'lon': lon,
64+
'altitude': altitude,
65+
},
66+
)
67+
68+
project_root = get_project_root()
69+
zarr_path = project_root / 'test_data' / 'weather.zarr'
70+
if zarr_path.exists():
71+
shutil.rmtree(zarr_path)
72+
ds.to_zarr(str(zarr_path))
73+
74+
print(f'✅ Created weather dataset: {zarr_path}')
75+
print(f' Shape: {shape} = {np.prod(shape)} rows')
76+
print(f' Variables: temperature, pressure, humidity')
77+
return zarr_path
78+
79+
80+
def create_ocean_dataset():
81+
"""Create an ocean dataset with different dimensions (3D)."""
82+
print('\nCreating ocean dataset...')
83+
84+
# 3D dataset: depth (4), lat (5), lon (6)
85+
depth = np.array([0, 10, 50, 100]) # 4 depth levels
86+
lat = np.array([25.0, 30.0, 35.0, 40.0, 45.0]) # 5 latitudes
87+
lon = np.array(
88+
[-130.0, -125.0, -120.0, -115.0, -110.0, -105.0]
89+
) # 6 longitudes
90+
91+
shape = (4, 5, 6) # 120 total points
92+
93+
# Sea temperature: decreases with depth and varies by latitude
94+
sea_temp_data = np.zeros(shape)
95+
for depth_idx in range(4):
96+
for lat_idx in range(5):
97+
sea_temp_data[depth_idx, lat_idx, :] = (
98+
25 + (lat[lat_idx] - 35) * 0.3 - depth[depth_idx] * 0.1
99+
)
100+
101+
# Salinity: varies by location and depth
102+
salinity_data = np.random.normal(35, 1, shape)
103+
for depth_idx in range(4):
104+
salinity_data[depth_idx, :, :] += depth_idx * 0.2
105+
106+
ds = xr.Dataset(
107+
{
108+
'sea_temperature': (['depth', 'lat', 'lon'], sea_temp_data),
109+
'salinity': (['depth', 'lat', 'lon'], salinity_data),
110+
},
111+
coords={
112+
'depth': depth,
113+
'lat': lat, # Same lat coordinates as weather for potential joins
114+
'lon': lon,
115+
},
116+
)
117+
118+
project_root = get_project_root()
119+
zarr_path = project_root / 'test_data' / 'ocean.zarr'
120+
if zarr_path.exists():
121+
shutil.rmtree(zarr_path)
122+
ds.to_zarr(str(zarr_path))
123+
124+
print(f'✅ Created ocean dataset: {zarr_path}')
125+
print(f' Shape: {shape} = {np.prod(shape)} rows')
126+
print(f' Variables: sea_temperature, salinity')
127+
return zarr_path
128+
129+
130+
def create_simple_timeseries():
131+
"""Create a simple 2D time series for basic testing."""
132+
print('\nCreating simple timeseries dataset...')
133+
134+
# 2D dataset: time (10), station (3)
135+
time = np.arange(0, 10) # 10 time points
136+
station = np.array([1, 2, 3]) # 3 stations
137+
138+
shape = (10, 3) # 30 total points
139+
140+
# Simple metrics
141+
value_data = np.random.normal(100, 10, shape)
142+
count_data = np.random.poisson(5, shape)
143+
144+
ds = xr.Dataset(
145+
{
146+
'value': (['time', 'station'], value_data),
147+
'count': (['time', 'station'], count_data.astype(float)),
148+
},
149+
coords={
150+
'time': time,
151+
'station': station,
152+
},
153+
)
154+
155+
project_root = get_project_root()
156+
zarr_path = project_root / 'test_data' / 'timeseries.zarr'
157+
if zarr_path.exists():
158+
shutil.rmtree(zarr_path)
159+
ds.to_zarr(str(zarr_path))
160+
161+
print(f'✅ Created timeseries dataset: {zarr_path}')
162+
print(f' Shape: {shape} = {np.prod(shape)} rows')
163+
print(f' Variables: value, count')
164+
return zarr_path
165+
166+
167+
def create_single_dimension_dataset():
168+
"""Create a 1D dataset for testing edge cases."""
169+
print('\nCreating single dimension dataset...')
170+
171+
# 1D dataset: just index (8)
172+
index = np.arange(0, 8)
173+
174+
shape = (8,) # 8 total points
175+
176+
# Single variable
177+
measurement_data = np.array([10.5, 15.2, 20.1, 18.7, 12.3, 8.9, 14.6, 22.1])
178+
179+
ds = xr.Dataset(
180+
{
181+
'measurement': (['index'], measurement_data),
182+
},
183+
coords={
184+
'index': index,
185+
},
186+
)
187+
188+
project_root = get_project_root()
189+
zarr_path = project_root / 'test_data' / 'single_dim.zarr'
190+
if zarr_path.exists():
191+
shutil.rmtree(zarr_path)
192+
ds.to_zarr(str(zarr_path))
193+
194+
print(f'✅ Created single dimension dataset: {zarr_path}')
195+
print(f' Shape: {shape} = {np.prod(shape)} rows')
196+
print(f' Variables: measurement')
197+
return zarr_path
198+
199+
200+
def create_large_sparse_dataset():
201+
"""Create a larger dataset with some interesting patterns for aggregation testing."""
202+
print('\nCreating large sparse dataset...')
203+
204+
# 3D dataset: category (4), region (6), period (8)
205+
category = np.array([0, 1, 2, 3]) # 4 categories
206+
region = np.arange(0, 6) # 6 regions
207+
period = np.arange(0, 8) # 8 periods
208+
209+
shape = (4, 6, 8) # 192 total points
210+
211+
# Create pattern: some categories are more active in certain regions/periods
212+
activity_data = np.zeros(shape)
213+
revenue_data = np.zeros(shape)
214+
215+
for cat in range(4):
216+
for reg in range(6):
217+
for per in range(8):
218+
# Category patterns
219+
if cat == 0: # Category 0 active in first half
220+
activity_data[cat, reg, per] = max(
221+
0, 100 - per * 10 + np.random.normal(0, 5)
222+
)
223+
elif cat == 1: # Category 1 active in certain regions
224+
activity_data[cat, reg, per] = max(
225+
0, reg * 15 + np.random.normal(0, 8)
226+
)
227+
elif cat == 2: # Category 2 has seasonal pattern
228+
activity_data[cat, reg, per] = max(
229+
0, 50 + 30 * np.sin(per * np.pi / 4) + np.random.normal(0, 10)
230+
)
231+
else: # Category 3 is sparse
232+
activity_data[cat, reg, per] = max(
233+
0, np.random.exponential(5) if np.random.random() > 0.6 else 0
234+
)
235+
236+
# Revenue correlated with activity
237+
revenue_data[cat, reg, per] = activity_data[cat, reg, per] * (
238+
2 + np.random.normal(0, 0.5)
239+
)
240+
241+
ds = xr.Dataset(
242+
{
243+
'activity': (['category', 'region', 'period'], activity_data),
244+
'revenue': (['category', 'region', 'period'], revenue_data),
245+
},
246+
coords={
247+
'category': category,
248+
'region': region,
249+
'period': period,
250+
},
251+
)
252+
253+
project_root = get_project_root()
254+
zarr_path = project_root / 'test_data' / 'business.zarr'
255+
if zarr_path.exists():
256+
shutil.rmtree(zarr_path)
257+
ds.to_zarr(str(zarr_path))
258+
259+
print(f'✅ Created business dataset: {zarr_path}')
260+
print(f' Shape: {shape} = {np.prod(shape)} rows')
261+
print(f' Variables: activity, revenue')
262+
return zarr_path
263+
264+
265+
if __name__ == '__main__':
266+
try:
267+
# Create test data directory
268+
project_root = get_project_root()
269+
test_data_dir = project_root / 'test_data'
270+
test_data_dir.mkdir(exist_ok=True)
271+
272+
print('🏗️ Creating diverse test datasets for SQL integration tests...\n')
273+
274+
# Create all test datasets
275+
datasets = []
276+
datasets.append(create_weather_dataset())
277+
datasets.append(create_ocean_dataset())
278+
datasets.append(create_simple_timeseries())
279+
datasets.append(create_single_dimension_dataset())
280+
datasets.append(create_large_sparse_dataset())
281+
282+
print(f'\n🎉 Successfully created {len(datasets)} test datasets!')
283+
print('\n📊 Dataset Summary:')
284+
print(
285+
' 1. weather.zarr - 4D (time×lat×lon×altitude) - temperature, pressure, humidity'
286+
)
287+
print(
288+
' 2. ocean.zarr - 3D (depth×lat×lon) - sea_temperature, salinity'
289+
)
290+
print(' 3. timeseries.zarr - 2D (time×station) - value, count')
291+
print(' 4. single_dim.zarr - 1D (index) - measurement')
292+
print(
293+
' 5. business.zarr - 3D (category×region×period) - activity, revenue'
294+
)
295+
296+
print('\n🔗 Join Testing Opportunities:')
297+
print(' • Weather ⋈ Ocean: matching lat coordinates')
298+
print(' • Different dimensionalities: 4D ⋈ 3D ⋈ 2D ⋈ 1D')
299+
print(' • Time-based joins: weather.time ⋈ timeseries.time')
300+
print(' • Categorical joins: various coordinate-based relationships')
301+
302+
print(f'\n💡 Ready for SQL integration tests!')
303+
print(' Run: cargo run --example sql_integration_tests')
304+
305+
except ImportError as e:
306+
print(f'❌ Missing dependencies: {e}')
307+
print('💡 Install with: pip install xarray numpy')
308+
except Exception as e:
309+
print(f'❌ Error creating test datasets: {e}')

0 commit comments

Comments
 (0)