Skip to content

Commit 4bace27

Browse files
Updating split validator to allow duplicates across test sets (#158)
* updating split validator to allow duplicates across test sets * Using only unique indices in test set to check for out of bound indices * removing print statement * fixing spacing * Update polaris/benchmark/_base.py Co-authored-by: Honoré Hounwanou <[email protected]> * adding cleaner way to combine all test set indices --------- Co-authored-by: Honoré Hounwanou <[email protected]>
1 parent e71bd79 commit 4bace27

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

polaris/benchmark/_base.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -182,30 +182,40 @@ def _validate_split(cls, v, info: ValidationInfo):
182182
raise InvalidBenchmarkError("The predefined split contains empty test partitions")
183183

184184
train_idx_list = v[0]
185-
test_idx_list = list(i for part in v[1].values() for i in part) if isinstance(v[1], dict) else v[1]
185+
full_test_idx_list = list(chain.from_iterable(v[1].values())) if isinstance(v[1], dict) else v[1]
186186

187187
if len(train_idx_list) == 0:
188188
logger.info(
189189
"This benchmark only specifies a test set. It will return an empty train set in `get_train_test_split()`"
190190
)
191191

192192
train_idx_set = set(train_idx_list)
193-
test_idx_set = set(test_idx_list)
193+
full_test_idx_set = set(full_test_idx_list)
194194

195195
# The train and test indices do not overlap
196-
if len(train_idx_set & test_idx_set) > 0:
196+
if len(train_idx_set & full_test_idx_set) > 0:
197197
raise InvalidBenchmarkError("The predefined split specifies overlapping train and test sets")
198198

199-
# Duplicate indices
199+
# Check for duplicate indices within the train set
200200
if len(train_idx_set) != len(train_idx_list):
201201
raise InvalidBenchmarkError("The training set contains duplicate indices")
202-
if len(test_idx_set) != len(test_idx_list):
202+
203+
# Check for duplicate indices within a given test set. Because a user can specify
204+
# multiple test sets for a given benchmark and it is acceptable for indices to be shared
205+
# across test sets, we check for duplicates in each test set independently.
206+
if isinstance(v[1], dict):
207+
for test_set_name, test_set_idx_list in v[1].items():
208+
if len(test_set_idx_list) != len(set(test_set_idx_list)):
209+
raise InvalidBenchmarkError(
210+
f'Test set with name "{test_set_name}" contains duplicate indices'
211+
)
212+
elif len(full_test_idx_set) != len(full_test_idx_list):
203213
raise InvalidBenchmarkError("The test set contains duplicate indices")
204214

205215
# All indices are valid given the dataset
206216
if info.data["dataset"] is not None:
207217
max_i = len(info.data["dataset"])
208-
if any(i < 0 or i >= max_i for i in chain(train_idx_list, test_idx_list)):
218+
if any(i < 0 or i >= max_i for i in chain(train_idx_list, full_test_idx_set)):
209219
raise InvalidBenchmarkError("The predefined split contains invalid indices")
210220

211221
return v

tests/test_benchmark.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ def test_split_verification(is_single_task, test_single_task_benchmark, test_mul
5353
cls(split=(train_split + train_split[:1], test_split), **default_kwargs)
5454
with pytest.raises(ValidationError):
5555
cls(split=(train_split, test_split + test_split[:1]), **default_kwargs)
56+
with pytest.raises(ValidationError):
57+
cls(
58+
split=(train_split, {"test1": test_split, "test2": test_split + test_split[:1]}), **default_kwargs
59+
)
60+
61+
# It should _not_ fail with duplicate indices across test partitions
62+
cls(split=(train_split, {"test1": test_split, "test2": test_split}), **default_kwargs)
5663
# It should _not_ fail with missing indices
5764
cls(split=(train_split[:-1], test_split), **default_kwargs)
5865
# It should _not_ fail with an empty train set

0 commit comments

Comments
 (0)