@@ -125,16 +125,14 @@ def output_var(self):
125125 return "output"
126126
127127 def _data_setter (self , X : pd .Series | np .ndarray , y : pd .Series | np .ndarray = None ):
128-
129128 with self .model :
130-
131129 X = X .values if isinstance (X , pd .Series ) else X .ravel ()
132-
130+
133131 pm .set_data ({"x" : X })
134-
132+
135133 if y is not None :
136134 y = y .values if isinstance (y , pd .Series ) else y .ravel ()
137-
135+
138136 pm .set_data ({"y_data" : y })
139137
140138 @property
@@ -263,12 +261,16 @@ def test_sample_xxx_extend_idata_param(fitted_model_instance, group, extend_idat
263261
264262 prediction_data = pd .DataFrame ({"input" : x_pred })
265263 if group == "prior_predictive" :
266- pred = fitted_model_instance .sample_prior_predictive (prediction_data ["input" ], combined = False , extend_idata = extend_idata )
264+ pred = fitted_model_instance .sample_prior_predictive (
265+ prediction_data ["input" ], combined = False , extend_idata = extend_idata
266+ )
267267 else : # group == "posterior_predictive":
268- pred = fitted_model_instance .sample_posterior_predictive (prediction_data ["input" ], combined = False , predictions = False , extend_idata = extend_idata )
268+ pred = fitted_model_instance .sample_posterior_predictive (
269+ prediction_data ["input" ], combined = False , predictions = False , extend_idata = extend_idata
270+ )
269271
270272 pred_unstacked = pred [output_var ].values
271-
273+
272274 idata_now = fitted_model_instance .idata [group ][output_var ].values
273275
274276 if extend_idata :
@@ -314,7 +316,9 @@ def test_id():
314316
315317@pytest .mark .parametrize ("predictions" , [True , False ])
316318@pytest .mark .parametrize ("predict_method" , ["predict" , "predict_posterior" ])
317- def test_predict_method_respects_predictions_flag (fitted_model_instance , predictions , predict_method ):
319+ def test_predict_method_respects_predictions_flag (
320+ fitted_model_instance , predictions , predict_method
321+ ):
318322 x_pred = np .random .uniform (0 , 1 , 100 )
319323 prediction_data = pd .DataFrame ({"input" : x_pred })
320324 output_var = fitted_model_instance .output_var
@@ -332,7 +336,7 @@ def test_predict_method_respects_predictions_flag(fitted_model_instance, predict
332336 extend_idata = True ,
333337 predictions = predictions ,
334338 )
335- else :# predict_method == "predict_posterior":
339+ else : # predict_method == "predict_posterior":
336340 fitted_model_instance .predict_posterior (
337341 X_pred = prediction_data [["input" ]],
338342 extend_idata = True ,
@@ -350,4 +354,3 @@ def test_predict_method_respects_predictions_flag(fitted_model_instance, predict
350354 assert "predictions" not in fitted_model_instance .idata .groups ()
351355 # Posterior predictive should be updated
352356 assert not np .array_equal (pp_before , pp_after )
353-
0 commit comments