Skip to content

Commit 4dd22f1

Browse files
authored
Feat: Add Support for multiple test-set benchmarks (#296)
* updated splits to support multiple test sets
1 parent 52d3f8d commit 4dd22f1

File tree

6 files changed

+346
-127
lines changed

6 files changed

+346
-127
lines changed

polaris/benchmark/_benchmark_v2.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class BenchmarkV2Specification(
3030
3131
Attributes:
3232
dataset: The dataset the benchmark specification is based on.
33+
splits: The predefined train-test splits to use for evaluation.
3334
n_classes: The number of classes for each of the target columns.
3435
readme: Markdown text that can be used to provide a formatted description of the benchmark.
3536
artifact_version: The version of the benchmark.
@@ -85,7 +86,7 @@ def _validate_split_in_dataset(self) -> Self:
8586
- All indices are valid given the dataset
8687
"""
8788
dataset_length = len(self.dataset)
88-
if self.split.max_index >= dataset_length:
89+
if self.max_index >= dataset_length:
8990
raise InvalidBenchmarkError("The predefined split contains invalid indices")
9091

9192
return self
@@ -102,17 +103,24 @@ def _validate_cols_in_dataset(self) -> Self:
102103

103104
return self
104105

105-
def _get_test_sets(
106+
def _get_splits(
106107
self, hide_targets=True, featurization_fn: Callable | None = None
107-
) -> dict[str, Subset]:
108+
) -> dict[str, tuple[Subset, Subset]]:
108109
"""
109-
Construct the test set(s), given the split in the benchmark specification. Used
110-
internally to construct the test set for client use and evaluation.
110+
Construct all train-test split pairs, given the splits in the benchmark specification.
111+
Used internally to construct the splits for client use and evaluation.
111112
"""
112113
# TODO: We need a subset class that can handle very large index sets without copying or materializing all of them
113114
return {
114-
label: self._get_subset(index_set.indices, hide_targets, featurization_fn)
115-
for label, index_set in self.split.test_items()
115+
label: (
116+
self._get_subset(
117+
split.training.indices, hide_targets=False, featurization_fn=featurization_fn
118+
),
119+
self._get_subset(
120+
split.test.indices, hide_targets=hide_targets, featurization_fn=featurization_fn
121+
),
122+
)
123+
for label, split in self.split_items()
116124
}
117125

118126
def _get_subset(self, indices, hide_targets=True, featurization_fn=None) -> Subset:
@@ -129,8 +137,8 @@ def _get_subset(self, indices, hide_targets=True, featurization_fn=None) -> Subs
129137

130138
def get_train_test_split(
131139
self, featurization_fn: Callable | None = None
132-
) -> tuple[Subset, dict[str, Subset]]:
133-
"""Construct the train and test sets, given the split in the benchmark specification.
140+
) -> dict[str, tuple[Subset, Subset]]:
141+
"""Construct the train and test sets for all splits, given the splits in the benchmark specification.
134142
135143
Returns [`Subset`][polaris.dataset.Subset] objects, which offer several ways of accessing the data
136144
and can thus easily serve as a basis to build framework-specific (e.g. PyTorch, Tensorflow)
@@ -141,15 +149,10 @@ def get_train_test_split(
141149
expects an input in the format specified by the `input_format` parameter.
142150
143151
Returns:
144-
A tuple with the train `Subset` and test `Subset` objects.
145-
If there are multiple test sets, these are returned in a dictionary and each test set has
146-
an associated name. The targets of the test set can not be accessed.
152+
A dictionary mapping split labels to (train, test) tuples of `Subset` objects.
153+
The targets of the test sets cannot be accessed.
147154
"""
148-
train = self._get_subset(
149-
self.split.training.indices, hide_targets=False, featurization_fn=featurization_fn
150-
)
151-
test = self._get_test_sets(hide_targets=True, featurization_fn=featurization_fn)
152-
return train, test
155+
return self._get_splits(hide_targets=True, featurization_fn=featurization_fn)
153156

154157
def upload_to_hub(
155158
self,
@@ -208,8 +211,8 @@ def submit_predictions(
208211
benchmark_artifact_id=self.artifact_id,
209212
predictions=predictions,
210213
target_labels=list(self.target_cols),
211-
test_set_labels=self.test_set_labels,
212-
test_set_sizes=self.test_set_sizes,
214+
test_set_labels=self.split_labels,
215+
test_set_sizes=self.n_test_datapoints,
213216
contributors=contributors or [],
214217
model=model,
215218
description=description,

polaris/benchmark/_split_v2.py

Lines changed: 58 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -57,120 +57,123 @@ def deserialize(index_set: bytes) -> "IndexSet":
5757

5858

5959
class SplitV2(BaseModel):
60+
"""
61+
A single train-test split pair containing training and test index sets.
62+
63+
This represents one train-test split with training and test sets.
64+
Multiple SplitV2 instances can be used together for cross-validation scenarios.
65+
"""
66+
6067
training: IndexSet
6168
test: IndexSet
6269

6370
@field_validator("training", "test", mode="before")
6471
@classmethod
65-
def _parse_index_sets(cls, v: bytes | IndexSet) -> bytes | IndexSet:
66-
"""
67-
Accepted a binary serialized IndexSet
68-
"""
72+
def _parse_index_set(cls, v: bytes | IndexSet) -> IndexSet:
73+
"""Accept a binary serialized IndexSet"""
6974
if isinstance(v, bytes):
7075
return IndexSet.deserialize(v)
7176
return v
7277

7378
@field_validator("training")
7479
@classmethod
7580
def _validate_training_set(cls, v: IndexSet) -> IndexSet:
76-
"""
77-
Training index set can be empty (zero-shot)
78-
"""
81+
"""Training index set can be empty (zero-shot)"""
7982
if v.datapoints == 0:
80-
logger.info(
81-
"This benchmark only specifies a test set. It will return an empty train set in `get_train_test_split()`"
83+
logger.debug(
84+
"This train-test split only specifies a test set. It will return an empty train set in `get_train_test_split()`"
8285
)
8386
return v
8487

8588
@field_validator("test")
8689
@classmethod
8790
def _validate_test_set(cls, v: IndexSet) -> IndexSet:
88-
"""
89-
Test index set cannot be empty
90-
"""
91+
"""Test index set cannot be empty"""
9192
if v.datapoints == 0:
92-
raise InvalidBenchmarkError("The predefined split contains empty test partitions")
93+
raise InvalidBenchmarkError("Test set cannot be empty")
9394
return v
9495

9596
@model_validator(mode="after")
9697
def validate_set_overlap(self) -> Self:
97-
"""
98-
The training and test index sets do not overlap
99-
"""
98+
"""The training and test index sets do not overlap"""
10099
if self.training.intersect(self.test):
101100
raise InvalidBenchmarkError("The predefined split specifies overlapping train and test sets")
102101
return self
103102

104103
@property
105104
def n_train_datapoints(self) -> int:
106-
"""
107-
The size of the train set.
108-
"""
105+
"""The size of the train set."""
109106
return self.training.datapoints
110107

111108
@property
112-
def n_test_sets(self) -> int:
113-
"""
114-
The number of test sets
115-
"""
116-
# TODO: Until we support multi-test benchmarks
117-
return 1
118-
119-
@property
120-
def n_test_datapoints(self) -> dict[str, int]:
121-
"""
122-
The size of (each of) the test set(s).
123-
"""
124-
# TODO: Until we support multi-test benchmarks
125-
return {"test": self.test.datapoints}
109+
def n_test_datapoints(self) -> int:
110+
"""The size of the test set."""
111+
return self.test.datapoints
126112

127113
@property
128114
def max_index(self) -> int:
129-
# TODO: Until we support multi-test benchmarks (need)
130-
return max(self.training.indices.max(), self.test.indices.max())
115+
"""Maximum index across train and test sets"""
116+
max_indices = []
131117

132-
def test_items(self) -> Generator[tuple[str, IndexSet], None, None]:
133-
# TODO: Until we support multi-test benchmarks
134-
yield "test", self.test
118+
# Only add max if the bitmap is not empty
119+
if len(self.training.indices) > 0:
120+
max_indices.append(self.training.indices.max())
121+
max_indices.append(self.test.indices.max())
122+
123+
return max(max_indices)
135124

136125

137126
class SplitSpecificationV2Mixin(BaseModel):
138127
"""
139-
Mixin class to add a split field to a benchmark. This is the V2 implementation.
128+
Mixin class to add splits field to a benchmark. This is the V2 implementation.
140129
141-
The internal representation for the split is a roaring bitmap,
130+
The internal representation for the splits uses roaring bitmaps,
142131
which drastically improves scalability over the V1 implementation.
143132
144133
Attributes:
145-
split: The predefined train-test split to use for evaluation.
134+
splits: The predefined train-test splits to use for evaluation.
146135
"""
147136

148-
split: SplitV2
137+
splits: dict[str, SplitV2]
138+
139+
@model_validator(mode="after")
140+
def validate_splits_not_empty(self) -> Self:
141+
"""Ensure at least one split is provided"""
142+
if not self.splits:
143+
raise InvalidBenchmarkError("At least one split must be specified")
144+
return self
149145

150146
@computed_field
151147
@property
152-
def n_train_datapoints(self) -> int:
153-
"""The size of the train set."""
154-
return self.split.n_train_datapoints
148+
def n_splits(self) -> int:
149+
"""The number of splits"""
150+
return len(self.splits)
155151

156152
@computed_field
157153
@property
158-
def n_test_sets(self) -> int:
159-
"""The number of test sets"""
160-
return self.split.n_test_sets
154+
def split_labels(self) -> list[str]:
155+
"""Labels of all splits"""
156+
return list(self.splits.keys())
161157

162158
@computed_field
163159
@property
164-
def n_test_datapoints(self) -> dict[str, int]:
165-
"""The size of (each of) the test set(s)."""
166-
return self.split.n_test_datapoints
160+
def n_train_datapoints(self) -> dict[str, int]:
161+
"""The size of the train set for each split."""
162+
return {label: split.n_train_datapoints for label, split in self.splits.items()}
167163

168164
@computed_field
169165
@property
170-
def test_set_sizes(self) -> dict[str, int]:
171-
return {label: index_set.datapoints for label, index_set in self.split.test_items()}
166+
def n_test_datapoints(self) -> dict[str, int]:
167+
"""The size of the test set for each split."""
168+
return {label: split.n_test_datapoints for label, split in self.splits.items()}
172169

173170
@computed_field
174171
@property
175-
def test_set_labels(self) -> list[str]:
176-
return list(label for label, _ in self.split.test_items())
172+
def max_index(self) -> int:
173+
"""Maximum index across all splits"""
174+
return max(split.max_index for split in self.splits.values())
175+
176+
def split_items(self) -> Generator[tuple[str, SplitV2], None, None]:
177+
"""Yield all splits with their labels"""
178+
for label, split in self.splits.items():
179+
yield label, split

polaris/hub/client.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -484,12 +484,33 @@ def _get_v2_benchmark(self, owner: str | HubOwner, slug: str) -> BenchmarkV2Spec
484484

485485
response_data["dataset"] = self.get_dataset(*response_data["dataset"]["artifactId"].split("/"))
486486

487-
split = {}
488-
for label, url in response_data.get("split", {}).items():
489-
with fsspec.open(url, mode="rb") as f:
490-
split[label] = f.read()
491-
492-
return BenchmarkV2Specification(**{**response_data, "split": split})
487+
# Handle split data - each split contains training and test data
488+
split_data = response_data["split"]
489+
splits = {}
490+
491+
# Import SplitV2 and IndexSet for creating proper split objects
492+
from polaris.benchmark._split_v2 import SplitV2, IndexSet
493+
494+
for split_label, split_urls in split_data.items():
495+
# Each split should have 'training' and 'test' objects with filePath, datapoints, md5Checksum
496+
split_indices = {}
497+
for data_type, url_info in split_urls.items():
498+
# Extract the actual URL from the filePath field
499+
url = url_info["filePath"]
500+
with fsspec.open(url, mode="rb") as f:
501+
# Deserialize the roaring bitmap data into an IndexSet
502+
roaring_data = f.read()
503+
index_set = IndexSet.deserialize(roaring_data)
504+
split_indices[data_type] = index_set
505+
506+
# Create a SplitV2 object from the training and test IndexSets
507+
splits[split_label] = SplitV2(training=split_indices["training"], test=split_indices["test"])
508+
509+
# Remove the original 'split' field and add 'splits' field
510+
response_data.pop("split", None)
511+
response_data["splits"] = splits
512+
513+
return BenchmarkV2Specification(**response_data)
493514

494515
def upload_results(
495516
self,

polaris/hub/oauth.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,7 @@ class DatasetV2Paths(ArtifactPaths):
9595

9696
class BenchmarkV2Paths(ArtifactPaths):
9797
training: AnyUrlString = Field(json_schema_extra={"file": True})
98-
test: AnyUrlString = Field(json_schema_extra={"file": True})
99-
test_2: int = 0
98+
test_sets: dict[str, AnyUrlString] = Field(json_schema_extra={"file": True})
10099

101100

102101
class PredictionPaths(ArtifactPaths):

tests/conftest.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -403,15 +403,36 @@ def test_benchmark_v2(test_dataset_v2, test_org_owner):
403403
name="v2-benchmark-float-dtype",
404404
owner=test_org_owner,
405405
dataset=test_dataset_v2,
406-
split=split,
406+
splits={"default": split},
407407
target_cols=["A"],
408408
input_cols=["B"],
409409
)
410410
return benchmark
411411

412412

413+
@pytest.fixture(scope="function")
414+
def test_benchmark_v2_multiple_test_sets(test_dataset_v2, test_org_owner):
415+
benchmark = BenchmarkV2Specification(
416+
name="v2-benchmark-multiple-test-sets",
417+
owner=test_org_owner,
418+
dataset=test_dataset_v2,
419+
splits={
420+
"split_1": SplitV2(training=IndexSet(indices=[0, 1, 2, 3, 4]), test=IndexSet(indices=[5, 6, 7])),
421+
"split_2": SplitV2(training=IndexSet(indices=[0, 1, 2, 3, 5, 6]), test=IndexSet(indices=[4, 7])),
422+
"split_3": SplitV2(
423+
training=IndexSet(indices=[0, 1, 2, 4, 7]), test=IndexSet(indices=[3, 5, 6, 8])
424+
),
425+
},
426+
target_cols=["A"],
427+
input_cols=["B"],
428+
)
429+
430+
return benchmark
431+
432+
413433
@pytest.fixture(scope="function")
414434
def v2_benchmark_with_rdkit_object_dtype(tmp_path, test_org_owner):
435+
"""Fixture for a benchmark with RDKit object dtype"""
415436
from polaris.utils.zarr.codecs import RDKitMolCodec
416437

417438
zarr_path = tmp_path / "test_rdkit_object_dtype.zarr"
@@ -442,7 +463,7 @@ def v2_benchmark_with_rdkit_object_dtype(tmp_path, test_org_owner):
442463
name="v2-benchmark-rdkit-object-dtype",
443464
owner=test_org_owner,
444465
dataset=dataset,
445-
split=split,
466+
splits={"test": split},
446467
target_cols=["expt"],
447468
input_cols=["smiles"],
448469
)
@@ -451,6 +472,7 @@ def v2_benchmark_with_rdkit_object_dtype(tmp_path, test_org_owner):
451472

452473
@pytest.fixture(scope="function")
453474
def v2_benchmark_with_atomarray_object_dtype(tmp_path, test_org_owner):
475+
"""Fixture for a benchmark with AtomArray object dtype"""
454476
from polaris.utils.zarr.codecs import AtomArrayCodec
455477

456478
zarr_path = tmp_path / "test_atomarray_object_dtype.zarr"
@@ -481,7 +503,7 @@ def v2_benchmark_with_atomarray_object_dtype(tmp_path, test_org_owner):
481503
name="v2-benchmark-atomarray-object-dtype",
482504
owner=test_org_owner,
483505
dataset=dataset,
484-
split=split,
506+
splits={"test": split},
485507
target_cols=["expt"],
486508
input_cols=["smiles"],
487509
)

0 commit comments

Comments
 (0)