Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions verde/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Utility functions for building gridders and checking arguments.
"""
import numpy as np
from sklearn.base import BaseEstimator
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.metrics import check_scoring


Expand Down Expand Up @@ -51,7 +51,7 @@ def score_estimator(scoring, estimator, coordinates, data, weights=None):
coordinates, data, weights, unpack=False
)
predicted = check_data(estimator.predict(coordinates))
scorer = check_scoring(DummyEstimator, scoring=scoring)
scorer = check_scoring(DummyEstimator(np.array([0])), scoring=scoring)
result = np.mean(
[
scorer(
Expand All @@ -66,7 +66,7 @@ def score_estimator(scoring, estimator, coordinates, data, weights=None):
return result


class DummyEstimator(BaseEstimator):
class DummyEstimator(RegressorMixin, BaseEstimator):
"""
Dummy estimator that does nothing but pass along the predicted data.
Used to fool the scikit-learn scorer functions to fit our API
Expand Down
2 changes: 1 addition & 1 deletion verde/blockreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def _blocked_mean_variance(self, table, ncomps):
The variance will be the unweighted variance of the blocks.
"""
reduction = {
"data{}".format(i): (("mean", self.reduction), ("variance", np.var))
"data{}".format(i): (("mean", "mean"), ("variance", "var"))
for i in range(ncomps)
}
blocked = table.groupby("block").aggregate(reduction)
Expand Down
9 changes: 9 additions & 0 deletions verde/tests/test_model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,15 @@ def test_cross_val_score_vector(trend, metric, expected):
npt.assert_allclose(scores, expected, atol=1e-10)


def test_gridder_score(trend):
"Check that gridder.score works with the default R² scorer."
coords, data = trend[:2]
model = Trend(degree=1).fit(coords, data)
with pytest.warns(FutureWarning):
score = model.score(coords, data)
npt.assert_allclose(score, 1, atol=1e-10)


def test_cross_val_score_client(trend):
"Test the deprecated dask Client interface"
coords, data = trend[:2]
Expand Down
17 changes: 17 additions & 0 deletions verde/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import numpy as np
import numpy.testing as npt
import pandas as pd
import pytest
import xarray as xr
from scipy.spatial import cKDTree
Expand All @@ -28,6 +29,7 @@
meshgrid_to_1d,
parse_engine,
partition_by_sum,
variance_to_weights,
)


Expand Down Expand Up @@ -108,6 +110,21 @@ def test_partition_by_sum_fails_no_partitions():
assert "Could not find partition points" in str(error)


def test_variance_to_weights_pandas_series():
"A pandas Series input should work even if its backing array is read-only."
variance = pd.Series([0, 2, 0.2, np.nan])
weights = variance_to_weights(variance)
npt.assert_allclose(weights, [1, 0.1, 1, 1])


def test_variance_to_weights_readonly_array():
"A read-only NumPy array input should still produce normalized weights."
variance = np.array([0, 2, 0.2, np.nan])
variance.flags.writeable = False
weights = variance_to_weights(variance)
npt.assert_allclose(weights, [1, 0.1, 1, 1])


def test_make_xarray_grid():
"""
Check if xarray.Dataset is correctly created
Expand Down
4 changes: 3 additions & 1 deletion verde/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ def variance_to_weights(variance, tol=1e-15, dtype="float64"):
variance = check_data(variance)
weights = []
for var in variance:
var = np.nan_to_num(np.atleast_1d(var), copy=False)
# Pandas 3 can expose Series-backed arrays as read-only, so normalize
# NaNs on a private writeable copy.
var = np.nan_to_num(np.atleast_1d(var), copy=True)
w = np.ones_like(var, dtype=dtype)
nonzero = var > tol
if np.any(nonzero):
Expand Down
Loading