@@ -87,7 +87,7 @@ def test_check_min_max_scaling(self):
8787 X = 0.1 + 0.8 * torch .rand (4 , 2 , 3 )
8888 with warnings .catch_warnings (record = True ) as ws :
8989 check_min_max_scaling (X = X )
90- self .assertFalse (any (issubclass (w .category , InputDataWarning ) for w in ws ))
90+ self .assertFalse (any (issubclass (w .category , InputDataWarning ) for w in ws ))
9191 check_min_max_scaling (X = X , raise_on_fail = True )
9292 with self .assertWarnsRegex (
9393 expected_warning = InputDataWarning , expected_regex = "not scaled"
@@ -100,30 +100,34 @@ def test_check_min_max_scaling(self):
100100 Xstd = (X - Xmin ) / (Xmax - Xmin )
101101 with warnings .catch_warnings (record = True ) as ws :
102102 check_min_max_scaling (X = Xstd )
103- self .assertFalse (any (issubclass (w .category , InputDataWarning ) for w in ws ))
103+ self .assertFalse (any (issubclass (w .category , InputDataWarning ) for w in ws ))
104104 check_min_max_scaling (X = Xstd , raise_on_fail = True )
105105 with warnings .catch_warnings (record = True ) as ws :
106106 check_min_max_scaling (X = Xstd , strict = True )
107- self .assertFalse (any (issubclass (w .category , InputDataWarning ) for w in ws ))
107+ self .assertFalse (any (issubclass (w .category , InputDataWarning ) for w in ws ))
108108 check_min_max_scaling (X = Xstd , strict = True , raise_on_fail = True )
109109 # check violation
110110 X [0 , 0 , 0 ] = 2
111111 with warnings .catch_warnings (record = True ) as ws :
112112 check_min_max_scaling (X = X )
113- self .assertTrue (any (issubclass (w .category , InputDataWarning ) for w in ws ))
114- self .assertTrue (any ("not contained" in str (w .message ) for w in ws ))
113+ self .assertTrue (any (issubclass (w .category , InputDataWarning ) for w in ws ))
114+ self .assertTrue (any ("not contained" in str (w .message ) for w in ws ))
115115 with self .assertRaises (InputDataError ):
116116 check_min_max_scaling (X = X , raise_on_fail = True )
117117 with warnings .catch_warnings (record = True ) as ws :
118118 check_min_max_scaling (X = X , strict = True )
119- self .assertTrue (any (issubclass (w .category , InputDataWarning ) for w in ws ))
120- self .assertTrue (any ("not contained" in str (w .message ) for w in ws ))
119+ self .assertTrue (any (issubclass (w .category , InputDataWarning ) for w in ws ))
120+ self .assertTrue (any ("not contained" in str (w .message ) for w in ws ))
121121 with self .assertRaises (InputDataError ):
122122 check_min_max_scaling (X = X , strict = True , raise_on_fail = True )
123123 # check ignore_dims
124124 with warnings .catch_warnings (record = True ) as ws :
125125 check_min_max_scaling (X = X , ignore_dims = [0 ])
126- self .assertFalse (any (issubclass (w .category , InputDataWarning ) for w in ws ))
126+ self .assertFalse (any (issubclass (w .category , InputDataWarning ) for w in ws ))
127+ # all dims ignored
128+ with warnings .catch_warnings (record = True ) as ws :
129+ check_min_max_scaling (X = X , ignore_dims = [0 , 1 , 2 ])
130+ self .assertFalse (any (issubclass (w .category , InputDataWarning ) for w in ws ))
127131
128132 def test_check_standardization (self ):
129133 # Ensure that it is not filtered out.
@@ -181,6 +185,11 @@ def test_validate_input_scaling(self):
181185 # check that errors are raised when requested
182186 with self .assertRaises (InputDataError ):
183187 validate_input_scaling (train_X = train_X , train_Y = train_Y , raise_on_fail = True )
188+ # check that normalization & standardization checks & errors are skipped when
189+ # check_nans_only is True
190+ validate_input_scaling (
191+ train_X = train_X , train_Y = train_Y , raise_on_fail = True , check_nans_only = True
192+ )
184193 # check that no errors are being raised if everything is standardized
185194 train_X_min = train_X .min (dim = - 1 , keepdim = True )[0 ]
186195 train_X_max = train_X .max (dim = - 1 , keepdim = True )[0 ]
@@ -202,6 +211,11 @@ def test_validate_input_scaling(self):
202211 train_X_std [0 , 0 , 0 ] = float ("nan" )
203212 with self .assertRaises (InputDataError ):
204213 validate_input_scaling (train_X = train_X_std , train_Y = train_Y_std )
214+ # NaNs still raise errors when check_nans_only is True
215+ with self .assertRaises (InputDataError ):
216+ validate_input_scaling (
217+ train_X = train_X_std , train_Y = train_Y_std , check_nans_only = True
218+ )
205219
206220
207221class TestGPTPosteriorSettings (BotorchTestCase ):
0 commit comments