Skip to content

Commit cc2f1f0

Browse files
committed
Refactor abs_error_stats, fix zero-index bug
pred_steps and ref_cut were always equal at every call site, making using two parameters redundant. ref_cut was also a misleading name. Collapse both into a single steps parameter and replace negative-index slicing with (err[n-steps:], err[:n-steps]), which correctly handles steps=0 (full ref, empty pred) — fixing the err[-0:] bug. Change comment wording.
1 parent ec90c2b commit cc2f1f0

File tree

3 files changed

+44
-17
lines changed

3 files changed

+44
-17
lines changed

Notebooks/Jumper.ipynb

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -600,8 +600,7 @@
600600
"for spec in TERM_SPECS:\n",
601601
" error_stats[spec.key] = abs_error_stats(\n",
602602
" errs[spec.key],\n",
603-
" pred_steps=steps,\n",
604-
" ref_cut=steps, # baseline: all but the forecast window\n",
603+
" steps=steps,\n",
605604
" axes=spec.err_axes,\n",
606605
" )\n",
607606
"\n",

src/nemo_spinup_forecast/pipeline_utils.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ def compute_rmse_for_terms(
133133
sims : Mapping[str, Simulation]
134134
Prepared and decomposed simulations keyed by :attr:`TermSpec.key`.
135135
n_components : int or None, default=None
136-
Number of components to use for reconstruction. When ``None``,
137-
all fitted components (``len(s.pca.components_)``) are used.
136+
Number of components to use for reconstruction.
137+
None uses all fitted components (``len(s.pca.components_)``).
138138
139139
Returns
140140
-------
@@ -291,8 +291,7 @@ def forecast_all(
291291
def abs_error_stats(
292292
err: np.ndarray,
293293
*,
294-
pred_steps: int,
295-
ref_cut: int,
294+
steps: int,
296295
axes: tuple[int, ...],
297296
) -> dict[str, Any]:
298297
"""Compute absolute-error summary statistics for prediction and reference windows.
@@ -301,11 +300,9 @@ def abs_error_stats(
301300
----------
302301
err : np.ndarray
303302
Absolute-error array, typically ``abs(reference - prediction)``.
304-
pred_steps : int
305-
Number of trailing time steps considered as the prediction window.
306-
ref_cut : int
307-
Number of trailing time steps excluded from the reference window.
308-
If set to ``0``, the full ``err`` array is used as the reference.
303+
steps : int
304+
Number of time steps at the end of ``err`` used as the forecast window;
305+
the rest of ``err``is the reference.
309306
axes : tuple[int, ...]
310307
Axes reduced with ``nanmean`` and ``nanstd``.
311308
@@ -314,13 +311,10 @@ def abs_error_stats(
314311
dict[str, Any]
315312
Dictionary with keys ``pred_mean``, ``pred_std``, ``ref_mean``,
316313
and ``ref_std``.
317-
318-
Notes
319-
-----
320-
This function prints the prediction and reference window shapes.
321314
"""
322-
pred = err[-pred_steps:]
323-
ref = err[:-ref_cut] if ref_cut else err
315+
n = len(err)
316+
pred = err[n - steps :]
317+
ref = err[: n - steps]
324318
print("pred shape:", pred.shape)
325319
print("ref shape:", ref.shape)
326320
return {

tests/test_pipeline_utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import numpy as np
2+
import pytest
3+
4+
from nemo_spinup_forecast.pipeline_utils import abs_error_stats
5+
6+
7+
def test_3d_spatial_reduction_shape_and_values():
8+
"""Reducing over spatial axes (y, x) leaves a time-indexed output."""
9+
err = np.zeros((8, 3, 4)) # (time, y, x)
10+
err[-3:] = 10.0 # pred window: last 3 time steps
11+
err[:-3] = 1.0 # ref window: first 5 time steps
12+
result = abs_error_stats(err, steps=3, axes=(1, 2))
13+
# pred = err[5:] → shape (3, 3, 4), reduce (y, x) → shape (3,)
14+
# ref = err[:5] → shape (5, 3, 4), reduce (y, x) → shape (5,)
15+
assert result["pred_mean"].shape == (3,)
16+
assert result["ref_mean"].shape == (5,)
17+
np.testing.assert_allclose(result["pred_mean"], 10.0)
18+
np.testing.assert_allclose(result["ref_mean"], 1.0)
19+
20+
21+
def test_steps_zero_gives_full_ref():
22+
"""steps=0 → pred is empty (nan), ref is the full array (baseline case)."""
23+
err = np.ones((10, 4, 5))
24+
err[7:] = 3.0
25+
result = abs_error_stats(err, steps=0, axes=(0, 1, 2))
26+
assert np.isnan(result["pred_mean"])
27+
assert result["ref_mean"] == pytest.approx((7 * 1.0 + 3 * 3.0) / 10)
28+
29+
30+
def test_steps_equals_full_length():
31+
"""steps=len(err) → pred covers the entire time axis."""
32+
err = np.arange(6.0).reshape(6, 1, 1)
33+
result = abs_error_stats(err, steps=6, axes=(0, 1, 2))
34+
assert result["pred_mean"] == pytest.approx(2.5) # mean of 0..5

0 commit comments

Comments
 (0)