Skip to content

Commit 8de112c

Browse files
cache re-used test functions on demand
1 parent e592926 commit 8de112c

File tree

1 file changed

+46
-55
lines changed

1 file changed

+46
-55
lines changed

tests/statespace/filters/test_kalman_filter.py

Lines changed: 46 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from collections.abc import Callable
2+
from functools import cache
3+
14
import numpy as np
25
import pytensor
36
import pytensor.tensor as pt
@@ -31,28 +34,24 @@
3134
RTOL = 1e-6 if floatX.endswith("64") else 1e-3
3235

3336

34-
@pytest.fixture(scope="session")
35-
def f_standard():
36-
standard_inout = initialize_filter(StandardFilter())
37-
f_standard = pytensor.function(*standard_inout, on_unused_input="ignore")
38-
return f_standard
39-
40-
41-
@pytest.fixture(scope="session")
42-
def f_cholesky():
43-
cholesky_inout = initialize_filter(SquareRootFilter())
44-
f_cholesky = pytensor.function(*cholesky_inout, on_unused_input="ignore")
45-
return f_cholesky
37+
@cache
38+
def get_filter_function(filter_name: str) -> Callable:
39+
"""
40+
Compile and return a filter function given its name, caching the result to make tests as fast as possible
41+
"""
42+
match filter_name:
43+
case "StandardFilter":
44+
filter_inout = initialize_filter(StandardFilter())
45+
case "CholeskyFilter":
46+
filter_inout = initialize_filter(SquareRootFilter())
47+
case "UnivariateFilter":
48+
filter_inout = initialize_filter(UnivariateFilter())
49+
case _:
50+
raise ValueError(f"Unknown filter name: {filter_name}")
4651

52+
filter_func = pytensor.function(*filter_inout, on_unused_input="ignore")
53+
return filter_func
4754

48-
@pytest.fixture(scope="session")
49-
def f_univariate():
50-
univariate_inout = initialize_filter(UnivariateFilter())
51-
f_univariate = pytensor.function(*univariate_inout, on_unused_input="ignore")
52-
return f_univariate
53-
54-
55-
filter_funcs = [f_standard, f_cholesky, f_univariate]
5655

5756
filter_names = [
5857
"StandardFilter",
@@ -79,11 +78,11 @@ def test_base_class_update_raises():
7978
filter.update(*inputs)
8079

8180

82-
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
83-
def test_output_shapes_one_state_one_observed(filter_func, rng):
81+
@pytest.mark.parametrize("filter_name", filter_names)
82+
def test_output_shapes_one_state_one_observed(filter_name, rng):
8483
p, m, r, n = 1, 1, 1, 10
8584
inputs = make_test_inputs(p, m, r, n, rng)
86-
outputs = filter_func(*inputs)
85+
outputs = get_filter_function(filter_name)(*inputs)
8786

8887
for output_idx, name in enumerate(output_names):
8988
expected_output = get_expected_shape(name, p, m, r, n)
@@ -92,25 +91,25 @@ def test_output_shapes_one_state_one_observed(filter_func, rng):
9291
), f"Shape of {name} does not match expected"
9392

9493

95-
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
96-
def test_output_shapes_when_all_states_are_stochastic(filter_func, rng):
94+
@pytest.mark.parametrize("filter_name", filter_names)
95+
def test_output_shapes_when_all_states_are_stochastic(filter_name, rng):
9796
p, m, r, n = 1, 2, 2, 10
9897
inputs = make_test_inputs(p, m, r, n, rng)
9998

100-
outputs = filter_func(*inputs)
99+
outputs = get_filter_function(filter_name)(*inputs)
101100
for output_idx, name in enumerate(output_names):
102101
expected_output = get_expected_shape(name, p, m, r, n)
103102
assert (
104103
outputs[output_idx].shape == expected_output
105104
), f"Shape of {name} does not match expected"
106105

107106

108-
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
109-
def test_output_shapes_when_some_states_are_deterministic(filter_func, rng):
107+
@pytest.mark.parametrize("filter_name", filter_names)
108+
def test_output_shapes_when_some_states_are_deterministic(filter_name, rng):
110109
p, m, r, n = 1, 5, 2, 10
111110
inputs = make_test_inputs(p, m, r, n, rng)
112111

113-
outputs = filter_func(*inputs)
112+
outputs = get_filter_function(filter_name)(*inputs)
114113
for output_idx, name in enumerate(output_names):
115114
expected_output = get_expected_shape(name, p, m, r, n)
116115
assert (
@@ -180,12 +179,12 @@ def test_output_shapes_with_time_varying_matrices(f_standard_nd, rng):
180179
), f"Shape of {name} does not match expected"
181180

182181

183-
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
184-
def test_output_with_deterministic_observation_equation(filter_func, rng):
182+
@pytest.mark.parametrize("filter_name", filter_names)
183+
def test_output_with_deterministic_observation_equation(filter_name, rng):
185184
p, m, r, n = 1, 5, 1, 10
186185
inputs = make_test_inputs(p, m, r, n, rng)
187186

188-
outputs = filter_func(*inputs)
187+
outputs = get_filter_function(filter_name)(*inputs)
189188

190189
for output_idx, name in enumerate(output_names):
191190
expected_output = get_expected_shape(name, p, m, r, n)
@@ -194,61 +193,55 @@ def test_output_with_deterministic_observation_equation(filter_func, rng):
194193
), f"Shape of {name} does not match expected"
195194

196195

197-
@pytest.mark.parametrize(
198-
("filter_func", "filter_name"), zip(filter_funcs, filter_names), ids=filter_names
199-
)
200-
def test_output_with_multiple_observed(filter_func, filter_name, rng):
196+
@pytest.mark.parametrize("filter_name", filter_names)
197+
def test_output_with_multiple_observed(filter_name, rng):
201198
p, m, r, n = 5, 5, 1, 10
202199
inputs = make_test_inputs(p, m, r, n, rng)
203200

204-
outputs = filter_func(*inputs)
201+
outputs = get_filter_function(filter_name)(*inputs)
205202
for output_idx, name in enumerate(output_names):
206203
expected_output = get_expected_shape(name, p, m, r, n)
207204
assert (
208205
outputs[output_idx].shape == expected_output
209206
), f"Shape of {name} does not match expected"
210207

211208

212-
@pytest.mark.parametrize(
213-
("filter_func", "filter_name"), zip(filter_funcs, filter_names), ids=filter_names
214-
)
209+
@pytest.mark.parametrize("filter_name", filter_names)
215210
@pytest.mark.parametrize("p", [1, 5], ids=["univariate (p=1)", "multivariate (p=5)"])
216-
def test_missing_data(filter_func, filter_name, p, rng):
211+
def test_missing_data(filter_name, p, rng):
217212
m, r, n = 5, 1, 10
218213
inputs = make_test_inputs(p, m, r, n, rng, missing_data=1)
219214

220-
outputs = filter_func(*inputs)
215+
outputs = get_filter_function(filter_name)(*inputs)
221216
for output_idx, name in enumerate(output_names):
222217
expected_output = get_expected_shape(name, p, m, r, n)
223218
assert (
224219
outputs[output_idx].shape == expected_output
225220
), f"Shape of {name} does not match expected"
226221

227222

228-
@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names)
223+
@pytest.mark.parametrize("filter_name", filter_names)
229224
@pytest.mark.parametrize("output_idx", [(0, 2), (3, 5)], ids=["smoothed_states", "smoothed_covs"])
230-
def test_last_smoother_is_last_filtered(filter_func, output_idx, rng):
225+
def test_last_smoother_is_last_filtered(filter_name, output_idx, rng):
231226
p, m, r, n = 1, 5, 1, 10
232227
inputs = make_test_inputs(p, m, r, n, rng)
233-
outputs = filter_func(*inputs)
228+
outputs = get_filter_function(filter_name)(*inputs)
234229

235230
filtered = outputs[output_idx[0]]
236231
smoothed = outputs[output_idx[1]]
237232

238233
assert_allclose(filtered[-1], smoothed[-1])
239234

240235

241-
@pytest.mark.parametrize(
242-
"filter_func, filter_name", zip(filter_funcs, filter_names), ids=filter_names
243-
)
236+
@pytest.mark.parametrize("filter_name", filter_names)
244237
@pytest.mark.parametrize("n_missing", [0, 5], ids=["n_missing=0", "n_missing=5"])
245238
@pytest.mark.skipif(floatX == "float32", reason="Tests are too sensitive for float32")
246-
def test_filters_match_statsmodel_output(filter_func, filter_name, n_missing, rng):
239+
def test_filters_match_statsmodel_output(filter_name, n_missing, rng):
247240
fit_sm_mod, [data, a0, P0, c, d, T, Z, R, H, Q] = nile_test_test_helper(rng, n_missing)
248241
if filter_name == "CholeskyFilter":
249242
P0 = np.linalg.cholesky(P0)
250243
inputs = [data, a0, P0, c, d, T, Z, R, H, Q]
251-
outputs = filter_func(*inputs)
244+
outputs = get_filter_function(filter_name)(*inputs)
252245

253246
for output_idx, name in enumerate(output_names):
254247
ref_val = get_sm_state_from_output_name(fit_sm_mod, name)
@@ -283,12 +276,10 @@ def test_filters_match_statsmodel_output(filter_func, filter_name, n_missing, rn
283276
)
284277

285278

286-
@pytest.mark.parametrize(
287-
"filter_func, filter_name", zip(filter_funcs[:-1], filter_names[:-1]), ids=filter_names[:-1]
288-
)
279+
@pytest.mark.parametrize("filter_name", filter_names)
289280
@pytest.mark.parametrize("n_missing", [0, 5], ids=["n_missing=0", "n_missing=5"])
290281
@pytest.mark.parametrize("obs_noise", [True, False])
291-
def test_all_covariance_matrices_are_PSD(filter_func, filter_name, n_missing, obs_noise, rng):
282+
def test_all_covariance_matrices_are_PSD(filter_name, n_missing, obs_noise, rng):
292283
if (floatX == "float32") & (filter_name == "UnivariateFilter"):
293284
# TODO: These tests all pass locally for me with float32 but they fail on the CI, so i'm just disabling them.
294285
pytest.skip("Univariate filter not stable at half precision without measurement error")
@@ -299,7 +290,7 @@ def test_all_covariance_matrices_are_PSD(filter_func, filter_name, n_missing, ob
299290

300291
H *= int(obs_noise)
301292
inputs = [data, a0, P0, c, d, T, Z, R, H, Q]
302-
outputs = filter_func(*inputs)
293+
outputs = get_filter_function(filter_name)(*inputs)
303294

304295
for output_idx, name in zip([3, 4, 5], output_names[3:-2]):
305296
cov_stack = outputs[output_idx]

0 commit comments

Comments
 (0)