diff --git a/geoapps_utils/utils/numerical.py b/geoapps_utils/utils/numerical.py index 6ccd792d..1b2bd04a 100644 --- a/geoapps_utils/utils/numerical.py +++ b/geoapps_utils/utils/numerical.py @@ -78,10 +78,10 @@ def traveling_salesman(locs: np.ndarray) -> np.ndarray: return np.asarray(order) -def weighted_average( # pylint: disable=too-many-locals +def weighted_average( xyz_in: np.ndarray, xyz_out: np.ndarray, - values: list, + values: list[np.ndarray], *, max_distance: float = np.inf, n: int = 8, @@ -104,34 +104,43 @@ def weighted_average( # pylint: disable=too-many-locals :return avg_values: List of values averaged to the output coordinates """ - n = np.min([xyz_in.shape[0], n]) - assert isinstance(values, list), "Input 'values' must be a list of numpy.ndarrays" + if ( + not isinstance(xyz_in, np.ndarray) + or not isinstance(xyz_out, np.ndarray) + or not isinstance(values, list) + or not all(isinstance(val, np.ndarray) for val in values) + ): + raise TypeError( + "Inputs 'xyz_in' and 'xyz_out' must be numpy.ndarrays " + "and 'values' must be a list of numpy.ndarrays." + f"Got {type(xyz_in)} and {type(xyz_out)} for 'xyz_in' and 'xyz_out' " + f"respectively, and {type(values)} for 'values'." + ) - assert all(vals.shape[0] == xyz_in.shape[0] for vals in values), ( - "Input 'values' must have the same shape as input 'locations'" - ) + if not all(vals.shape[0] == xyz_in.shape[0] for vals in values): + raise ValueError( + "Input 'values' must have the same number of rows as input 'xyz_in'. " + f"Got {xyz_in.shape[0]} for 'xyz_in' and " + f"{[val.shape[0] for val in values]} for 'values'." + ) + + n = np.min([xyz_in.shape[0], n]) avg_values = [] for value in values: sub = ~np.isnan(value) - tree = cKDTree(xyz_in[sub, :]) - rad, ind = tree.query(xyz_out, n) - ind = np.c_[ind] - rad = np.c_[rad] - rad[rad > max_distance] = np.nan - - values_interp = np.zeros(xyz_out.shape[0]) - weight = np.zeros(xyz_out.shape[0]) - - for i in range(n): - v = value[sub][ind[:, i]] / (rad[:, i] + threshold) - values_interp = np.nansum([values_interp, v], axis=0) - w = 1.0 / (rad[:, i] + threshold) - weight = np.nansum([weight, w], axis=0) - - values_interp[weight > 0] = values_interp[weight > 0] / weight[weight > 0] - values_interp[weight == 0] = np.nan - avg_values += [values_interp] + rad, ind = cKDTree(xyz_in[sub]).query(xyz_out, n) + + if n == 1: + ind = ind[:, np.newaxis] + rad = rad[:, np.newaxis] + + rad = np.where(rad > max_distance, np.nan, rad) + threshold + + values_interp = np.nansum(value[sub][ind] / rad, axis=1) + weight = np.nansum(1.0 / rad, axis=1) + + avg_values.append(values_interp / weight) if return_indices: return avg_values, ind diff --git a/tests/numerical_test.py b/tests/numerical_test.py index 3afa359f..a350c08b 100644 --- a/tests/numerical_test.py +++ b/tests/numerical_test.py @@ -11,6 +11,7 @@ from __future__ import annotations import numpy as np +import pytest from numpy import random from geoapps_utils.utils.numerical import ( @@ -145,6 +146,28 @@ def test_weighted_average_return_indices(): assert ind[0][0] == 1 +def test_weighted_average_errors(): + expected_results = [ + np.array([[2, 0, 0], [0, 1, 0], [0, 0, 2]]), + np.array([[0, 0, 0]]), + [np.array([1])], + ] + + for idx in range(len(expected_results)): + temp_results = expected_results.copy() + temp_results[idx] = "not expected type" + + with pytest.raises(TypeError, match="Inputs 'xyz_in'"): + _ = weighted_average(*temp_results) # type: ignore + + with pytest.raises(ValueError, match="Input 'values'"): + _ = weighted_average( + np.array([[2, 0, 0], [0, 1, 0], [0, 0, 2]]), + np.array([[0, 0, 0]]), + [np.array([1]), np.array([2])], + ) + + def test_weighted_average_threshold(): # threshold >> r -> arithmetic mean xyz_in = np.array([[1, 0, 0], [0, 100, 0], [0, 0, 1000]])