|
2 | 2 |
|
3 | 3 | import os |
4 | 4 | from datetime import datetime |
| 5 | +from typing import Any, Dict, Optional |
| 6 | +from urllib.parse import urlparse |
5 | 7 |
|
6 | 8 | import numpy |
7 | 9 | import pytest |
@@ -223,6 +225,69 @@ def test_reader(protocol, filename): |
223 | 225 | assert src.tile(0, 0, 0) |
224 | 226 |
|
225 | 227 |
|
| 228 | +def test_opener(): |
| 229 | + """test custom opener""" |
| 230 | + src_path = "file://" + os.path.join("file://", prefix, "dataset_2d.nc") |
| 231 | + |
| 232 | + def custom_netcdf_opener( # noqa: C901 |
| 233 | + src_path: str, |
| 234 | + special_arg: bool, |
| 235 | + group: Optional[str] = None, |
| 236 | + decode_times: bool = True, |
| 237 | + ) -> xarray.Dataset: |
| 238 | + """Open Xarray dataset with fsspec. |
| 239 | +
|
| 240 | + Args: |
| 241 | + src_path (str): dataset path. |
| 242 | + group (Optional, str): path to the netCDF/Zarr group in the given file to open given as a str. |
| 243 | + decode_times (bool): If True, decode times encoded in the standard NetCDF datetime format into datetime objects. Otherwise, leave them encoded as numbers. |
| 244 | +
|
| 245 | + Returns: |
| 246 | + xarray.Dataset |
| 247 | +
|
| 248 | + """ |
| 249 | + import fsspec # noqa |
| 250 | + |
| 251 | + parsed = urlparse(src_path) |
| 252 | + protocol = parsed.scheme or "file" |
| 253 | + |
| 254 | + if not special_arg: |
| 255 | + raise ValueError("you forgot the special_arg :(") |
| 256 | + |
| 257 | + xr_open_args: Dict[str, Any] = { |
| 258 | + "decode_coords": "all", |
| 259 | + "decode_times": decode_times, |
| 260 | + "engine": "h5netcdf", |
| 261 | + "lock": False, |
| 262 | + } |
| 263 | + |
| 264 | + # Argument if we're opening a datatree |
| 265 | + if group is not None: |
| 266 | + xr_open_args["group"] = group |
| 267 | + |
| 268 | + fs = fsspec.filesystem(protocol) |
| 269 | + ds = xarray.open_dataset(fs.open(src_path), **xr_open_args) |
| 270 | + |
| 271 | + return ds |
| 272 | + |
| 273 | + with Reader( |
| 274 | + src_path=src_path, |
| 275 | + opener=custom_netcdf_opener, |
| 276 | + opener_options={"special_arg": True}, |
| 277 | + variable="dataset", |
| 278 | + ) as src: |
| 279 | + assert src.info() |
| 280 | + |
| 281 | + with pytest.raises(ValueError): |
| 282 | + with Reader( |
| 283 | + src_path=src_path, |
| 284 | + opener=custom_netcdf_opener, |
| 285 | + opener_options={"special_arg": False}, |
| 286 | + variable="dataset", |
| 287 | + ) as src: |
| 288 | + pass |
| 289 | + |
| 290 | + |
226 | 291 | @pytest.mark.parametrize( |
227 | 292 | "group", |
228 | 293 | [0, 1, 2], |
|
0 commit comments