@@ -57,120 +57,123 @@ def deserialize(index_set: bytes) -> "IndexSet":
5757
5858
5959class SplitV2 (BaseModel ):
60+ """
61+ A single train-test split pair containing training and test index sets.
62+
63+ This represents one train-test split with training and test sets.
64+ Multiple SplitV2 instances can be used together for cross-validation scenarios.
65+ """
66+
6067 training : IndexSet
6168 test : IndexSet
6269
6370 @field_validator ("training" , "test" , mode = "before" )
6471 @classmethod
65- def _parse_index_sets (cls , v : bytes | IndexSet ) -> bytes | IndexSet :
66- """
67- Accepted a binary serialized IndexSet
68- """
72+ def _parse_index_set (cls , v : bytes | IndexSet ) -> IndexSet :
73+ """Accept a binary serialized IndexSet"""
6974 if isinstance (v , bytes ):
7075 return IndexSet .deserialize (v )
7176 return v
7277
7378 @field_validator ("training" )
7479 @classmethod
7580 def _validate_training_set (cls , v : IndexSet ) -> IndexSet :
76- """
77- Training index set can be empty (zero-shot)
78- """
81+ """Training index set can be empty (zero-shot)"""
7982 if v .datapoints == 0 :
80- logger .info (
81- "This benchmark only specifies a test set. It will return an empty train set in `get_train_test_split()`"
83+ logger .debug (
84+ "This train-test split only specifies a test set. It will return an empty train set in `get_train_test_split()`"
8285 )
8386 return v
8487
8588 @field_validator ("test" )
8689 @classmethod
8790 def _validate_test_set (cls , v : IndexSet ) -> IndexSet :
88- """
89- Test index set cannot be empty
90- """
91+ """Test index set cannot be empty"""
9192 if v .datapoints == 0 :
92- raise InvalidBenchmarkError ("The predefined split contains empty test partitions " )
93+ raise InvalidBenchmarkError ("Test set cannot be empty" )
9394 return v
9495
9596 @model_validator (mode = "after" )
9697 def validate_set_overlap (self ) -> Self :
97- """
98- The training and test index sets do not overlap
99- """
98+ """The training and test index sets do not overlap"""
10099 if self .training .intersect (self .test ):
101100 raise InvalidBenchmarkError ("The predefined split specifies overlapping train and test sets" )
102101 return self
103102
104103 @property
105104 def n_train_datapoints (self ) -> int :
106- """
107- The size of the train set.
108- """
105+ """The size of the train set."""
109106 return self .training .datapoints
110107
111108 @property
112- def n_test_sets (self ) -> int :
113- """
114- The number of test sets
115- """
116- # TODO: Until we support multi-test benchmarks
117- return 1
118-
119- @property
120- def n_test_datapoints (self ) -> dict [str , int ]:
121- """
122- The size of (each of) the test set(s).
123- """
124- # TODO: Until we support multi-test benchmarks
125- return {"test" : self .test .datapoints }
109+ def n_test_datapoints (self ) -> int :
110+ """The size of the test set."""
111+ return self .test .datapoints
126112
127113 @property
128114 def max_index (self ) -> int :
129- # TODO: Until we support multi- test benchmarks (need)
130- return max ( self . training . indices . max (), self . test . indices . max ())
115+ """Maximum index across train and test sets"""
116+ max_indices = []
131117
132- def test_items (self ) -> Generator [tuple [str , IndexSet ], None , None ]:
133- # TODO: Until we support multi-test benchmarks
134- yield "test" , self .test
118+ # Only add max if the bitmap is not empty
119+ if len (self .training .indices ) > 0 :
120+ max_indices .append (self .training .indices .max ())
121+ max_indices .append (self .test .indices .max ())
122+
123+ return max (max_indices )
135124
136125
137126class SplitSpecificationV2Mixin (BaseModel ):
138127 """
139- Mixin class to add a split field to a benchmark. This is the V2 implementation.
128+ Mixin class to add splits field to a benchmark. This is the V2 implementation.
140129
141- The internal representation for the split is a roaring bitmap ,
130+ The internal representation for the splits uses roaring bitmaps ,
142131 which drastically improves scalability over the V1 implementation.
143132
144133 Attributes:
145- split : The predefined train-test split to use for evaluation.
134+ splits : The predefined train-test splits to use for evaluation.
146135 """
147136
148- split : SplitV2
137+ splits : dict [str , SplitV2 ]
138+
139+ @model_validator (mode = "after" )
140+ def validate_splits_not_empty (self ) -> Self :
141+ """Ensure at least one split is provided"""
142+ if not self .splits :
143+ raise InvalidBenchmarkError ("At least one split must be specified" )
144+ return self
149145
150146 @computed_field
151147 @property
152- def n_train_datapoints (self ) -> int :
153- """The size of the train set. """
154- return self .split . n_train_datapoints
148+ def n_splits (self ) -> int :
149+ """The number of splits """
150+ return len ( self .splits )
155151
156152 @computed_field
157153 @property
158- def n_test_sets (self ) -> int :
159- """The number of test sets """
160- return self .split . n_test_sets
154+ def split_labels (self ) -> list [ str ] :
155+ """Labels of all splits """
156+ return list ( self .splits . keys ())
161157
162158 @computed_field
163159 @property
164- def n_test_datapoints (self ) -> dict [str , int ]:
165- """The size of (each of) the test set(s) ."""
166- return self . split . n_test_datapoints
160+ def n_train_datapoints (self ) -> dict [str , int ]:
161+ """The size of the train set for each split ."""
162+ return { label : split . n_train_datapoints for label , split in self . splits . items ()}
167163
168164 @computed_field
169165 @property
170- def test_set_sizes (self ) -> dict [str , int ]:
171- return {label : index_set .datapoints for label , index_set in self .split .test_items ()}
166+ def n_test_datapoints (self ) -> dict [str , int ]:
167+ """The size of the test set for each split."""
168+ return {label : split .n_test_datapoints for label , split in self .splits .items ()}
172169
173170 @computed_field
174171 @property
175- def test_set_labels (self ) -> list [str ]:
176- return list (label for label , _ in self .split .test_items ())
172+ def max_index (self ) -> int :
173+ """Maximum index across all splits"""
174+ return max (split .max_index for split in self .splits .values ())
175+
176+ def split_items (self ) -> Generator [tuple [str , SplitV2 ], None , None ]:
177+ """Yield all splits with their labels"""
178+ for label , split in self .splits .items ():
179+ yield label , split
0 commit comments