diff --git a/src/ess/reflectometry/tools.py b/src/ess/reflectometry/tools.py index 3409789d..529e4060 100644 --- a/src/ess/reflectometry/tools.py +++ b/src/ess/reflectometry/tools.py @@ -229,7 +229,7 @@ def cost(scaling_factors): def combine_curves( curves: Sequence[sc.DataArray], - q_bin_edges: sc.Variable | None = None, + q_bin_edges: sc.Variable, ) -> sc.DataArray: '''Combines the given curves by interpolating them on a 1d grid defined by :code:`q_bin_edges` and averaging @@ -261,22 +261,42 @@ def combine_curves( if len({c.coords['Q'].unit for c in curves}) != 1: raise ValueError('The Q-coordinates must have the same unit for each curve') - r = _interpolate_on_qgrid(map(sc.values, curves), q_bin_edges).values - v = _interpolate_on_qgrid(map(sc.variances, curves), q_bin_edges).values + r = _interpolate_on_qgrid(map(sc.values, curves), q_bin_edges) + v = _interpolate_on_qgrid(map(sc.variances, curves), q_bin_edges) - v[v == 0] = np.nan + v = sc.where(v == 0, sc.scalar(np.nan, unit=v.unit), v) inv_v = 1.0 / v - r_avg = np.nansum(r * inv_v, axis=0) / np.nansum(inv_v, axis=0) - v_avg = 1 / np.nansum(inv_v, axis=0) - return sc.DataArray( + r_avg = sc.nansum(r * inv_v, dim='curves') / sc.nansum(inv_v, dim='curves') + v_avg = 1 / sc.nansum(inv_v, dim='curves') + + out = sc.DataArray( data=sc.array( dims='Q', - values=r_avg, - variances=v_avg, + values=r_avg.values, + variances=v_avg.values, unit=next(iter(curves)).data.unit, ), coords={'Q': q_bin_edges}, ) + if any('Q_resolution' in c.coords for c in curves): + # This might need to be revisited. The question about how to combine curves + # with different Q-resolution is not completely resolved. + # However, in practice the difference in Q-resolution between different curves + # is small so it's not likely to make a big difference. + q_res = ( + sc.DataArray( + data=c.coords.get( + 'Q_resolution', sc.full_like(c.coords['Q'], value=np.nan) + ), + coords={'Q': c.coords['Q']}, + ) + for c in curves + ) + qs = _interpolate_on_qgrid(q_res, q_bin_edges) + out.coords['Q_resolution'] = sc.nansum(qs * inv_v, dim='curves') / sc.nansum( + sc.where(sc.isnan(qs), sc.scalar(0.0, unit=inv_v.unit), inv_v), dim='curves' + ) + return out def orso_datasets_from_measurements( diff --git a/tests/tools_test.py b/tests/tools_test.py index 03c07aaf..345d233c 100644 --- a/tests/tools_test.py +++ b/tests/tools_test.py @@ -146,6 +146,29 @@ def test_combined_curves(): ) +@pytest.mark.filterwarnings("ignore:invalid value encountered in divide") +def test_combined_curves_resolution(): + qgrid = sc.linspace('Q', 0, 1, 26) + data = sc.concat( + ( + sc.ones(dims=['Q'], shape=[10], with_variances=True), + 0.5 * sc.ones(dims=['Q'], shape=[15], with_variances=True), + ), + dim='Q', + ) + data.variances[:] = 0.1 + curves = ( + curve(data, 0, 0.3), + curve(0.5 * data, 0.2, 0.7), + curve(0.25 * data, 0.6, 1.0), + ) + curves[0].coords['Q_resolution'] = sc.midpoints(curves[0].coords['Q']) / 5 + combined = combine_curves(curves, qgrid) + assert 'Q_resolution' in combined.coords + assert combined.coords['Q_resolution'][0] == curves[0].coords['Q_resolution'][1] + assert sc.isnan(combined.coords['Q_resolution'][-1]) + + def test_linlogspace_linear(): q_lin = linlogspace( dim='qz', edges=[0.008, 0.08], scale='linear', num=50, unit='1/angstrom'