Skip to content

Commit b0e722f

Browse files
author
Fahad Alghanim
committed
Allow passing .zarr store roots
1 parent 3920595 commit b0e722f

3 files changed

Lines changed: 98 additions & 14 deletions

File tree

docs/source/stream.mdx

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,15 @@ Scientific datasets are often stored in formats like HDF5 or Zarr (chunked, N-di
7272

7373
Zarr support is currently experimental. Install it with `pip install "datasets[zarr]"` (or `pip install zarr`).
7474

75-
Zarr stores are directory-based. To load them, point `data_files` to the Zarr root metadata file:
75+
Zarr stores are directory-based. You can point `data_files` to either the Zarr store root directory (recommended for convenience) or the Zarr root metadata file:
7676

77+
- Zarr store root directory: `.../store.zarr` (auto-detects metadata)
7778
- Zarr v3: `.../store.zarr/zarr.json`
7879
- Zarr v2 (consolidated): `.../store.zarr/.zmetadata`
7980

8081
```py
8182
>>> from datasets import load_dataset
82-
>>> ds = load_dataset("zarr", data_files=["/path/to/store.zarr/zarr.json"], split="train", streaming=True)
83+
>>> ds = load_dataset("zarr", data_files=["/path/to/store.zarr"], split="train", streaming=True)
8384
>>> print(next(iter(ds)))
8485
{'int32': 0, 'float32': 0.0, 'matrix_2d': [[...], ...]}
8586
```

src/datasets/packaged_modules/zarr/zarr.py

Lines changed: 86 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dataclasses import dataclass
2+
import os
23
from typing import Optional
34

45
import numpy as np
@@ -23,7 +24,9 @@ class ZarrConfig(datasets.BuilderConfig):
2324
group as a column and the first dimension as the row dimension.
2425
2526
Notes:
26-
- Pass the path(s) to a Zarr root metadata file via `data_files`:
27+
- Pass either the Zarr store root directory (usually ending in `.zarr`) or the Zarr root metadata file(s)
28+
via `data_files`:
29+
- Store root directory: `.../store.zarr` (auto-detects metadata)
2730
- Zarr v2 (consolidated): `.zmetadata`
2831
- Zarr v3: `zarr.json`
2932
- Nested groups are not yet supported (only arrays at the selected group root).
@@ -69,17 +72,9 @@ def _generate_tables(self, metadata_files, storage_options):
6972

7073
for store_idx, metadata_file in enumerate(metadata_files):
7174
metadata_file = str(metadata_file)
72-
is_v2_consolidated = metadata_file.endswith("/.zmetadata") or metadata_file.endswith("\\.zmetadata")
73-
is_v3_root = metadata_file.endswith("/zarr.json") or metadata_file.endswith("\\zarr.json")
74-
if not (is_v2_consolidated or is_v3_root):
75-
raise ValueError(
76-
"Zarr packaged module expects a Zarr root metadata file. Supported values:\n"
77-
"- Zarr v2 consolidated: `.zmetadata`\n"
78-
"- Zarr v3: `zarr.json`\n"
79-
f"Got: {metadata_file}"
80-
)
81-
82-
store_root = xdirname(metadata_file)
75+
store_root, is_v2_consolidated, is_v3_root = _resolve_store_root_and_version(
76+
metadata_file, storage_options=storage_options
77+
)
8378
mapper = _get_fsspec_mapper(store_root, storage_options)
8479

8580
if is_v2_consolidated and self.config.consolidated:
@@ -128,6 +123,85 @@ def _get_fsspec_mapper(store_root: str, storage_options: dict):
128123
return fsspec.get_mapper(store_root, **(per_protocol or {}))
129124

130125

126+
def _resolve_store_root_and_version(path: str, *, storage_options: dict) -> tuple[str, bool, bool]:
127+
"""
128+
Normalize the user input into a Zarr store root and detect whether it's a v2 consolidated store.
129+
130+
Supported `path` values:
131+
- a Zarr root metadata file: `.../.zmetadata` (v2 consolidated) or `.../zarr.json` (v3)
132+
- a Zarr store root directory: `.../store.zarr` (auto-detect `zarr.json` / `.zmetadata`)
133+
"""
134+
135+
def _endswith_metadata(p: str, suffix: str) -> bool:
136+
return p.endswith("/" + suffix) or p.endswith("\\" + suffix)
137+
138+
if _endswith_metadata(path, ".zmetadata"):
139+
return xdirname(path), True, False
140+
if _endswith_metadata(path, "zarr.json"):
141+
return xdirname(path), False, True
142+
143+
protocol = path.split("://", 1)[0] if "://" in path else "file"
144+
per_protocol = storage_options.get(protocol, {}) if isinstance(storage_options, dict) else {}
145+
146+
# Treat paths ending in ".zarr" as store roots, and also accept directories for convenience.
147+
is_store_root = path.rstrip("/\\").endswith(".zarr")
148+
149+
try:
150+
from fsspec.core import url_to_fs
151+
152+
fs, fs_path = url_to_fs(path, **(per_protocol or {}))
153+
try:
154+
if fs.isdir(fs_path):
155+
is_store_root = True
156+
except Exception:
157+
pass
158+
except Exception:
159+
fs = None
160+
fs_path = None
161+
162+
if not is_store_root:
163+
raise ValueError(
164+
"Zarr packaged module expects either a Zarr store root directory (usually ending in `.zarr`) or a Zarr "
165+
"root metadata file:\n"
166+
"- Zarr store root directory: `.../store.zarr`\n"
167+
"- Zarr v2 consolidated: `.../store.zarr/.zmetadata`\n"
168+
"- Zarr v3: `.../store.zarr/zarr.json`\n"
169+
f"Got: {path}"
170+
)
171+
172+
store_root = path.rstrip("/\\")
173+
# If fsspec isn't available for some reason, fall back to local filesystem checks.
174+
if fs is None:
175+
zarr_json = os.path.join(store_root, "zarr.json")
176+
zmetadata = os.path.join(store_root, ".zmetadata")
177+
if os.path.exists(zarr_json):
178+
return store_root, False, True
179+
if os.path.exists(zmetadata):
180+
return store_root, True, False
181+
raise ValueError(
182+
f"Zarr store root directory '{store_root}' does not contain 'zarr.json' (v3) or '.zmetadata' (v2 consolidated)."
183+
)
184+
185+
# Build candidate paths in the fs namespace.
186+
fs_root = fs_path.rstrip("/") if isinstance(fs_path, str) else fs_path
187+
cand_v3 = f"{fs_root}/zarr.json"
188+
cand_v2 = f"{fs_root}/.zmetadata"
189+
190+
try:
191+
if fs.exists(cand_v3):
192+
return store_root, False, True
193+
if fs.exists(cand_v2):
194+
return store_root, True, False
195+
except Exception:
196+
pass
197+
198+
raise ValueError(
199+
"Zarr store root directory does not contain expected metadata. Looked for:\n"
200+
f"- {store_root}/zarr.json\n"
201+
f"- {store_root}/.zmetadata"
202+
)
203+
204+
131205
def _get_root_arrays(zgroup) -> dict[str, "zarr.Array"]:
132206
out = {}
133207
for name in zgroup.array_keys():

tests/packaged_modules/test_zarr.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import numpy as np
22
import pytest
33

4+
from pathlib import Path
5+
46
from datasets import Array2D, Features, Value, load_dataset
57
from datasets.builder import InvalidConfigName
68
from datasets.data_files import DataFilesList
@@ -75,6 +77,13 @@ def test_zarr_basic_loading(zarr_root_metadata_file):
7577
assert ds["int32"] == [0, 1, 2, 3, 4]
7678

7779

80+
def test_zarr_loading_from_store_root_directory(zarr_root_metadata_file):
81+
store_root = str(Path(zarr_root_metadata_file).parent)
82+
ds = load_dataset("zarr", data_files=[store_root], split="train")
83+
assert set(ds.column_names) == {"int32", "float32", "matrix_2d"}
84+
assert ds["int32"] == [0, 1, 2, 3, 4]
85+
86+
7887
def test_zarr_loading_with_features_override(zarr_root_metadata_file):
7988
features = Features(
8089
{

0 commit comments

Comments
 (0)