diff --git a/docs/user-guide/amor/amor-reduction.ipynb b/docs/user-guide/amor/amor-reduction.ipynb index 2c9e82e7..c0c08ec0 100644 --- a/docs/user-guide/amor/amor-reduction.ipynb +++ b/docs/user-guide/amor/amor-reduction.ipynb @@ -30,11 +30,12 @@ "from ess.amor import data # noqa: F401\n", "from ess.reflectometry.types import *\n", "from ess.amor.types import *\n", + "from ess.reflectometry import batch_processor\n", "\n", "# The files used in this tutorial have some issues that makes scippnexus\n", "# raise warnings when loading them. To avoid noise in the notebook the warnings are silenced.\n", "warnings.filterwarnings('ignore', 'Failed to convert .* into a transformation')\n", - "warnings.filterwarnings('ignore', 'Invalid transformation, missing attribute')" + "warnings.filterwarnings('ignore', 'Invalid transformation')" ] }, { @@ -124,19 +125,17 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "## Computing sample reflectivity\n", + "## Computing sample reflectivity from batch reduction\n", "\n", "We now compute the sample reflectivity from 4 runs that used different sample rotation angles.\n", - "The measurements at different rotation angles cover different ranges of $Q$." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ + "The measurements at different rotation angles cover different ranges of $Q$.\n", + "\n", + "We set up a batch reduction helper (using the `batch_processor` function) which makes it easy to process multiple runs at once.\n", + "\n", "In this tutorial we use some Amor data files we have received.\n", "The file paths to the tutorial files are obtained by calling:" ] @@ -184,15 +183,22 @@ " },\n", "}\n", "\n", - "\n", - "reflectivity = {}\n", - "for run_number, params in runs.items():\n", - " wf = workflow.copy()\n", - " for key, value in params.items():\n", - " wf[key] = value\n", - " reflectivity[run_number] = wf.compute(ReflectivityOverQ).hist()\n", - "\n", - "sc.plot(reflectivity, norm='log', vmin=1e-4)" + "batch = batch_processor(workflow, runs)\n", + "batch.param_table" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Compute R(Q) for all runs\n", + "reflectivity = batch.compute(ReflectivityOverQ)\n", + "sc.plot(\n", + " {key: r.hist() for key, r in reflectivity.items()},\n", + " norm='log', vmin=1e-4\n", + ")" ] }, { @@ -212,13 +218,16 @@ "source": [ "from ess.reflectometry.tools import scale_reflectivity_curves_to_overlap\n", "\n", - "scaled_reflectivity_curves, scale_factors = scale_reflectivity_curves_to_overlap(\n", - " reflectivity.values(),\n", - " # Optionally specify a Q-interval where the reflectivity is known to be 1.0\n", + "# Pass the batch workflow collection and get a new workflow collection as output,\n", + "# with the correct scaling factors applied.\n", + "scaled_wf = scale_reflectivity_curves_to_overlap(\n", + " batch,\n", " critical_edge_interval=(sc.scalar(0.01, unit='1/angstrom'), sc.scalar(0.014, unit='1/angstrom'))\n", ")\n", "\n", - "sc.plot(dict(zip(reflectivity.keys(), scaled_reflectivity_curves, strict=True)), norm='log', vmin=1e-5)" + "scaled_r = {key: r.hist() for key, r in scaled_wf.compute(ReflectivityOverQ).items()}\n", + "\n", + "sc.plot(scaled_r, norm='log', vmin=1e-5)" ] }, { @@ -235,7 +244,7 @@ "outputs": [], "source": [ "from ess.reflectometry.tools import combine_curves\n", - "combined = combine_curves(scaled_reflectivity_curves, workflow.compute(QBins))\n", + "combined = combine_curves(scaled_r.values(), workflow.compute(QBins))\n", "combined.plot(norm='log')" ] }, @@ -265,26 +274,8 @@ "metadata": {}, "outputs": [], "source": [ - "# Start by computing the `ReflectivityData` for each of the files\n", - "diagnostics = {}\n", - "for run_number, params in runs.items():\n", - " wf = workflow.copy()\n", - " for key, value in params.items():\n", - " wf[key] = value\n", - " diagnostics[run_number] = wf.compute((ReflectivityOverZW, ThetaBins[SampleRun]))\n", - "\n", - "# Scale the results using the scale factors computed earlier\n", - "for run_number, scale_factor in zip(reflectivity.keys(), scale_factors, strict=True):\n", - " diagnostics[run_number][ReflectivityOverZW] *= scale_factor" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "diagnostics['608'][ReflectivityOverZW].hist().flatten(('blade', 'wire'), to='z').plot(norm='log')" + "diagnostics = scaled_wf.compute(ReflectivityOverZW)\n", + "diagnostics['608'].hist().flatten(('blade', 'wire'), to='z').plot(norm='log')" ] }, { @@ -304,8 +295,8 @@ "from ess.reflectometry.figures import wavelength_theta_figure\n", "\n", "wavelength_theta_figure(\n", - " [result[ReflectivityOverZW] for result in diagnostics.values()],\n", - " theta_bins=[result[ThetaBins[SampleRun]] for result in diagnostics.values()],\n", + " diagnostics.values,\n", + " theta_bins=scaled_wf.compute(ThetaBins[SampleRun]).values,\n", " q_edges_to_display=(sc.scalar(0.018, unit='1/angstrom'), sc.scalar(0.113, unit='1/angstrom'))\n", ")" ] @@ -334,8 +325,8 @@ "from ess.reflectometry.figures import q_theta_figure\n", "\n", "q_theta_figure(\n", - " [res[ReflectivityOverZW] for res in diagnostics.values()],\n", - " theta_bins=[res[ThetaBins[SampleRun]] for res in diagnostics.values()],\n", + " diagnostics.values,\n", + " theta_bins=scaled_wf.compute(ThetaBins[SampleRun]).values,\n", " q_bins=workflow.compute(QBins)\n", ")" ] @@ -380,8 +371,7 @@ "We can save the computed $I(Q)$ to an [ORSO](https://www.reflectometry.org) [.ort](https://github.com/reflectivity/file_format/blob/master/specification.md) file using the [orsopy](https://orsopy.readthedocs.io/en/latest/index.html) package.\n", "\n", "First, we need to collect the metadata for that file.\n", - "To this end, we build a pipeline with additional providers.\n", - "We also insert a parameter to indicate the creator of the processed data." + "To this end, we insert a parameter to indicate the creator of the processed data." ] }, { @@ -400,7 +390,7 @@ "metadata": {}, "outputs": [], "source": [ - "workflow[orso.OrsoCreator] = orso.OrsoCreator(\n", + "scaled_wf[orso.OrsoCreator] = orso.OrsoCreator(\n", " fileio.base.Person(\n", " name='Max Mustermann',\n", " affiliation='European Spallation Source ERIC',\n", @@ -409,20 +399,11 @@ ")" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "workflow.visualize(orso.OrsoIofQDataset, graph_attr={'rankdir': 'LR'})" - ] - }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We build our ORSO dataset from the computed $I(Q)$ and the ORSO metadata:" + "We can visualize the workflow for the `OrsoIofQDataset`:" ] }, { @@ -431,15 +412,14 @@ "metadata": {}, "outputs": [], "source": [ - "iofq_dataset = workflow.compute(orso.OrsoIofQDataset)\n", - "iofq_dataset" + "scaled_wf.visualize(orso.OrsoIofQDataset, graph_attr={'rankdir': 'LR'})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We also add the URL of this notebook to make it easier to reproduce the data:" + "We build our ORSO dataset from the computed $I(Q)$ and the ORSO metadata:" ] }, { @@ -448,17 +428,14 @@ "metadata": {}, "outputs": [], "source": [ - "iofq_dataset.info.reduction.script = (\n", - " 'https://scipp.github.io/essreflectometry/examples/amor.html'\n", - ")" + "iofq_datasets = scaled_wf.compute(orso.OrsoIofQDataset)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Now let's repeat this for all the sample measurements!\n", - "To do that we can use an utility in `ess.reflectometry.tools`:" + "We also add the URL of this notebook to make it easier to reproduce the data:" ] }, { @@ -467,14 +444,10 @@ "metadata": {}, "outputs": [], "source": [ - "from ess.reflectometry.tools import orso_datasets_from_measurements\n", - "\n", - "datasets = orso_datasets_from_measurements(\n", - " workflow,\n", - " runs.values(),\n", - " # Optionally scale the curves to overlap using `scale_reflectivity_curves_to_overlap`\n", - " scale_to_overlap=True\n", - ")" + "for ds in iofq_datasets.values:\n", + " ds.info.reduction.script = (\n", + " 'https://scipp.github.io/essreflectometry/user-guide/amor/amor-reduction.html'\n", + " )" ] }, { @@ -482,7 +455,7 @@ "metadata": {}, "source": [ "Finally, we can save the data to a file.\n", - "Note that `iofq_dataset` is an [orsopy.fileio.orso.OrsoDataset](https://orsopy.readthedocs.io/en/latest/orsopy.fileio.orso.html#orsopy.fileio.orso.OrsoDataset)." + "Note that `iofq_datasets` contains [orsopy.fileio.orso.OrsoDataset](https://orsopy.readthedocs.io/en/latest/orsopy.fileio.orso.html#orsopy.fileio.orso.OrsoDataset)s." ] }, { @@ -491,7 +464,7 @@ "metadata": {}, "outputs": [], "source": [ - "fileio.orso.save_orso(datasets=datasets, fname='amor_reduced_iofq.ort')" + "fileio.orso.save_orso(datasets=iofq_datasets.values, fname='amor_reduced_iofq.ort')" ] }, { @@ -527,7 +500,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.12.7" } }, "nbformat": 4, diff --git a/src/ess/amor/__init__.py b/src/ess/amor/__init__.py index 4148b94b..5c67959f 100644 --- a/src/ess/amor/__init__.py +++ b/src/ess/amor/__init__.py @@ -16,6 +16,7 @@ Position, RunType, SampleRotationOffset, + ScalingFactorForOverlap, ) from . import ( conversions, @@ -74,6 +75,7 @@ def default_parameters() -> dict: ), GravityToggle: True, SampleRotationOffset[RunType]: sc.scalar(0.0, unit='deg'), + ScalingFactorForOverlap[RunType]: 1.0, } diff --git a/src/ess/amor/load.py b/src/ess/amor/load.py index 8ad1272d..029f7908 100644 --- a/src/ess/amor/load.py +++ b/src/ess/amor/load.py @@ -17,6 +17,7 @@ NeXusComponent, NeXusDetectorName, ProtonCurrent, + RawChopper, RawSampleRotation, RunType, SampleRotation, @@ -29,7 +30,6 @@ ChopperFrequency, ChopperPhase, ChopperSeparation, - RawChopper, ) diff --git a/src/ess/amor/normalization.py b/src/ess/amor/normalization.py index eb97d188..d4c450b0 100644 --- a/src/ess/amor/normalization.py +++ b/src/ess/amor/normalization.py @@ -88,6 +88,7 @@ def evaluate_reference_at_sample_coords( ref = ref.transform_coords( ( "Q", + "theta", "wavelength_resolution", "sample_size_resolution", "angular_resolution", diff --git a/src/ess/amor/types.py b/src/ess/amor/types.py index 92eb704e..bb801398 100644 --- a/src/ess/amor/types.py +++ b/src/ess/amor/types.py @@ -27,8 +27,4 @@ class ChopperSeparation(sciline.Scope[RunType, sc.Variable], sc.Variable): """Distance between the two choppers.""" -class RawChopper(sciline.Scope[RunType, sc.DataGroup], sc.DataGroup): - """Chopper data loaded from nexus file.""" - - GravityToggle = NewType("GravityToggle", bool) diff --git a/src/ess/amor/workflow.py b/src/ess/amor/workflow.py index 828abe4f..68dc0644 100644 --- a/src/ess/amor/workflow.py +++ b/src/ess/amor/workflow.py @@ -12,6 +12,8 @@ ProtonCurrent, ReducibleData, RunType, + ScalingFactorForOverlap, + UnscaledReducibleData, WavelengthBins, YIndexLimits, ZIndexLimits, @@ -27,7 +29,7 @@ def add_coords_masks_and_apply_corrections( wbins: WavelengthBins, proton_current: ProtonCurrent[RunType], graph: CoordTransformationGraph, -) -> ReducibleData[RunType]: +) -> UnscaledReducibleData[RunType]: """ Computes coordinates, masks and corrections that are the same for the sample measurement and the reference measurement. @@ -43,7 +45,17 @@ def add_coords_masks_and_apply_corrections( da = add_proton_current_mask(da) da = correct_by_proton_current(da) - return ReducibleData[RunType](da) + return UnscaledReducibleData[RunType](da) + + +def scale_raw_reducible_data( + da: UnscaledReducibleData[RunType], + scale: ScalingFactorForOverlap[RunType], +) -> ReducibleData[RunType]: + """ + Scales the raw data by a given factor. + """ + return ReducibleData[RunType](da * scale) -providers = (add_coords_masks_and_apply_corrections,) +providers = (add_coords_masks_and_apply_corrections, scale_raw_reducible_data) diff --git a/src/ess/reflectometry/__init__.py b/src/ess/reflectometry/__init__.py index 7dd54dce..3b6d18b2 100644 --- a/src/ess/reflectometry/__init__.py +++ b/src/ess/reflectometry/__init__.py @@ -12,6 +12,7 @@ from . import conversions, corrections, figures, normalization, orso from .load import load_reference, save_reference +from .tools import batch_processor providers = ( *corrections.providers, @@ -31,9 +32,15 @@ del importlib - __all__ = [ + "__version__", + "batch_processor", + "conversions", + "corrections", "figures", "load_reference", + "normalization", + "orso", + "providers", "save_reference", ] diff --git a/src/ess/reflectometry/tools.py b/src/ess/reflectometry/tools.py index 3409789d..b0808e65 100644 --- a/src/ess/reflectometry/tools.py +++ b/src/ess/reflectometry/tools.py @@ -1,17 +1,23 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +import uuid from collections.abc import Mapping, Sequence from itertools import chain from typing import Any import numpy as np -import sciline +import sciline as sl import scipp as sc import scipy.optimize as opt -from orsopy.fileio.orso import OrsoDataset -from ess.reflectometry import orso -from ess.reflectometry.types import ReflectivityOverQ +from ess.reflectometry.types import ( + QBins, + ReferenceRun, + ReflectivityOverQ, + SampleRun, + ScalingFactorForOverlap, + UnscaledReducibleData, +) _STD_TO_FWHM = sc.scalar(2.0) * sc.sqrt(sc.scalar(2.0) * sc.log(sc.scalar(2.0))) @@ -100,6 +106,59 @@ def linlogspace( return sc.concat(grids, dim) +class WorkflowCollection: + """ + A collection of sciline workflows that can be used to compute multiple + targets from mapping a workflow over a parameter table. + It can also be used to set parameters for all mapped nodes in a single shot. + """ + + def __init__(self, workflow: sl.Pipeline, param_table): + self._original_workflow = workflow.copy() + self.param_table = param_table + self._mapped_workflow = self._original_workflow.map(self.param_table) + + def __setitem__(self, key, value): + if key in self.param_table: + ind = list(self.param_table.keys()).index(key) + self.param_table.iloc[:, ind] = value + self._mapped_workflow = self._original_workflow.map(self.param_table) + else: + self.param_table.insert(len(self.param_table.columns), key, value) + self._original_workflow[key] = None + self._mapped_workflow = self._original_workflow.map(self.param_table) + + def compute(self, keys: type | Sequence[type], **kwargs) -> Mapping[str, Any]: + from sciline.pipeline import _is_multiple_keys + + out = {} + if not _is_multiple_keys(keys): + keys = [keys] + for key in keys: + out[key] = {} + if sl.is_mapped_node(self._mapped_workflow, key): + targets = sl.get_mapped_node_names(self._mapped_workflow, key) + results = self._mapped_workflow.compute(targets, **kwargs) + for node, v in results.items(): + out[key][node.index.values[0]] = v + else: + out[key] = self._mapped_workflow.compute(key, **kwargs) + return next(iter(out.values())) if len(out) == 1 else out + + # TODO: implement get() + + # TODO: implement the group() method to group by params in the parameter table + + def visualize(self, targets, **kwargs): + targets = sl.get_mapped_node_names(self._mapped_workflow, targets) + return self._mapped_workflow.visualize(targets, **kwargs) + + def copy(self) -> 'WorkflowCollection': + return self.__class__( + workflow=self._original_workflow, param_table=self.param_table + ) + + def _sort_by(a, by): return [x for x, _ in sorted(zip(a, by, strict=True), key=lambda x: x[1])] @@ -160,53 +219,99 @@ def _interpolate_on_qgrid(curves, grid): def scale_reflectivity_curves_to_overlap( - curves: Sequence[sc.DataArray], + workflow: WorkflowCollection | sl.Pipeline, critical_edge_interval: tuple[sc.Variable, sc.Variable] | None = None, + cache_intermediate_results: bool = True, ) -> tuple[list[sc.DataArray], list[sc.Variable]]: - '''Make the curves overlap by scaling all except the first by a factor. + ''' + Set the ``ScalingFactorForOverlap`` parameter on the provided workflows + in a way that would makes the 1D reflectivity curves overlap. + One can supply either a collection of workflows or a single workflow. + + If :code:`critical_edge_interval` is not provided, all workflows are scaled except + the data with the lowest Q-range, which is considered to be the reference curve. The scaling factors are determined by a maximum likelihood estimate (assuming the errors are normal distributed). - If :code:`critical_edge_interval` is provided then all curves are scaled. + If :code:`critical_edge_interval` is provided then all data are scaled. - All curves must be have the same unit for data and the Q-coordinate. + All reflectivity curves must be have the same unit for data and the Q-coordinate. Parameters --------- - curves: - the reflectivity curves that should be scaled together + workflows: + The workflow or collection of workflows that can compute ``ReflectivityOverQ``. critical_edge_interval: - a tuple denoting an interval that is known to belong + A tuple denoting an interval that is known to belong to the critical edge, i.e. where the reflectivity is known to be 1. + cache_intermediate_results: + If ``True`` the intermediate results ``UnscaledReducibleData`` will be cached + (this is the base for all types that are downstream of the scaling factor). Returns --------- : A list of scaled reflectivity curves and a list of the scaling factors. ''' - if critical_edge_interval is not None: - q = next(iter(curves)).coords['Q'] - N = ( - ((q >= critical_edge_interval[0]) & (q < critical_edge_interval[1])) - .sum() - .value + not_collection = isinstance(workflow, sl.Pipeline) + + wfc = workflow.copy() + if cache_intermediate_results: + try: + wfc[UnscaledReducibleData[SampleRun]] = wfc.compute( + UnscaledReducibleData[SampleRun] + ) + except sl.UnsatisfiedRequirement: + pass + try: + wfc[UnscaledReducibleData[ReferenceRun]] = wfc.compute( + UnscaledReducibleData[ReferenceRun] + ) + except sl.UnsatisfiedRequirement: + pass + + reflectivities = wfc.compute(ReflectivityOverQ) + if not_collection: + reflectivities = {"": reflectivities} + + # First sort the dict of reflectivities by the Q min value + curves = { + k: v.hist() if v.bins is not None else v + for k, v in sorted( + reflectivities.items(), key=lambda item: item[1].coords['Q'].min().value ) + } + + critical_edge_key = uuid.uuid4().hex + if critical_edge_interval is not None: + q = wfc.compute(QBins) + if hasattr(q, "items"): + # If QBins is a mapping, find the one with the lowest Q start + # Note the conversion to a dict, because if pandas is used for the mapping, + # it will return a Series, whose `.values` attribute is not callable. + q = min(dict(q).values(), key=lambda q_: q_.min()) + + # TODO: This is slightly different from before: it extracts the bins from the + # QBins variable that cover the critical edge interval. This means that the + # resulting curve will not necessarily begin and end exactly at the values + # specified, but rather at the closest bin edges. edge = sc.DataArray( - data=sc.ones(dims=('Q',), shape=(N,), with_variances=True), - coords={'Q': sc.linspace('Q', *critical_edge_interval, N + 1)}, - ) - curves, factors = scale_reflectivity_curves_to_overlap([edge, *curves]) - return curves[1:], factors[1:] - if len({c.data.unit for c in curves}) != 1: + data=sc.ones(sizes={q.dim: q.sizes[q.dim] - 1}, with_variances=True), + coords={q.dim: q}, + )[q.dim, critical_edge_interval[0] : critical_edge_interval[1]] + # Now place the critical edge at the beginning + curves = {critical_edge_key: edge} | curves + + if len({c.data.unit for c in curves.values()}) != 1: raise ValueError('The reflectivity curves must have the same unit') - if len({c.coords['Q'].unit for c in curves}) != 1: + if len({c.coords['Q'].unit for c in curves.values()}) != 1: raise ValueError('The Q-coordinates must have the same unit for each curve') - qgrid = _create_qgrid_where_overlapping([c.coords['Q'] for c in curves]) + qgrid = _create_qgrid_where_overlapping([c.coords['Q'] for c in curves.values()]) - r = _interpolate_on_qgrid(map(sc.values, curves), qgrid).values - v = _interpolate_on_qgrid(map(sc.variances, curves), qgrid).values + r = _interpolate_on_qgrid(map(sc.values, curves.values()), qgrid).values + v = _interpolate_on_qgrid(map(sc.variances, curves.values()), qgrid).values def cost(scaling_factors): scaling_factors = np.concatenate([[1.0], scaling_factors])[:, None] @@ -221,10 +326,17 @@ def cost(scaling_factors): sol = opt.minimize(cost, [1.0] * (len(curves) - 1)) scaling_factors = (1.0, *map(float, sol.x)) - return [ - scaling_factor * curve - for scaling_factor, curve in zip(scaling_factors, curves, strict=True) - ], scaling_factors + + results = { + k: v + for k, v in zip(curves.keys(), scaling_factors, strict=True) + if k != critical_edge_key + } + if not_collection: + results = results[""] + wfc[ScalingFactorForOverlap[SampleRun]] = results + + return wfc def combine_curves( @@ -279,58 +391,64 @@ def combine_curves( ) -def orso_datasets_from_measurements( - workflow: sciline.Pipeline, - runs: Sequence[Mapping[type, Any]], - *, - scale_to_overlap: bool = True, -) -> list[OrsoDataset]: - '''Produces a list of ORSO datasets containing one - reflectivity curve for each of the provided runs. - Each entry of :code:`runs` is a mapping of parameters and - values needed to produce the dataset. - - Optionally, the reflectivity curves can be scaled to overlap in - the regions where they have the same Q-value. +def batch_processor( + workflow: sl.Pipeline, params: Mapping[Any, Mapping[type, Any]] +) -> WorkflowCollection: + """ + Maps the provided workflow over the provided params. + + Example: + + ``` + from ess.reflectometry import amor, tools + + workflow = amor.AmorWorkflow() + + runs = { + '608': { + SampleRotationOffset[SampleRun]: sc.scalar(0.05, unit='deg'), + Filename[SampleRun]: amor.data.amor_run(608), + }, + '609': { + SampleRotationOffset[SampleRun]: sc.scalar(0.06, unit='deg'), + Filename[SampleRun]: amor.data.amor_run(609), + }, + '610': { + SampleRotationOffset[SampleRun]: sc.scalar(0.05, unit='deg'), + Filename[SampleRun]: amor.data.amor_run(610), + }, + '611': { + SampleRotationOffset[SampleRun]: sc.scalar(0.07, unit='deg'), + Filename[SampleRun]: amor.data.amor_run(611), + }, + } + + batch = tools.batch_processor(workflow, runs) + + results = batch.compute(ReflectivityOverQ) + ``` Parameters - ----------- + ---------- workflow: - The sciline workflow used to compute `ReflectivityOverQ` for each of the runs. - - runs: - The sciline parameters to be used for each run + The sciline workflow used to compute the targets for each of the runs. + params: + The sciline parameters to be used for each run. + Should be a mapping where the keys are the names of the runs + and the values are mappings of type to value pairs. + """ + import pandas as pd - scale_to_overlap: - If True the curves will be scaled to overlap. - Note that the curve of the first run is unscaled and - the rest are scaled to match it. + all_types = {t for v in params.values() for t in v.keys()} + data = {t: [] for t in all_types} + for param in params.values(): + for t in all_types: + if t in param: + data[t].append(param[t]) + else: + # Set the default value + data[t].append(workflow.compute(t)) - Returns - --------- - list of the computed ORSO datasets, containing one reflectivity curve each - ''' - reflectivity_curves = [] - for parameters in runs: - wf = workflow.copy() - for name, value in parameters.items(): - wf[name] = value - reflectivity_curves.append(wf.compute(ReflectivityOverQ)) - - scale_factors = ( - scale_reflectivity_curves_to_overlap([r.hist() for r in reflectivity_curves])[1] - if scale_to_overlap - else (1,) * len(runs) - ) + param_table = pd.DataFrame(data, index=params.keys()).rename_axis(index='run_id') - datasets = [] - for parameters, curve, scale_factor in zip( - runs, reflectivity_curves, scale_factors, strict=True - ): - wf = workflow.copy() - for name, value in parameters.items(): - wf[name] = value - wf[ReflectivityOverQ] = scale_factor * curve - dataset = wf.compute(orso.OrsoIofQDataset) - datasets.append(dataset) - return datasets + return WorkflowCollection(workflow, param_table) diff --git a/src/ess/reflectometry/types.py b/src/ess/reflectometry/types.py index b14b0243..6959534b 100644 --- a/src/ess/reflectometry/types.py +++ b/src/ess/reflectometry/types.py @@ -24,10 +24,22 @@ CoordTransformationGraph = NewType("CoordTransformationGraph", dict) +class RawChopper(sciline.Scope[RunType, sc.DataGroup], sc.DataGroup): + """Chopper data loaded from nexus file.""" + + +class UnscaledReducibleData(sciline.Scope[RunType, sc.DataArray], sc.DataArray): + """""" + + class ReducibleData(sciline.Scope[RunType, sc.DataArray], sc.DataArray): """Event data with common coordinates added""" +class ScalingFactorForOverlap(sciline.Scope[RunType, float], float): + """""" + + ReducedReference = NewType("ReducedReference", sc.DataArray) """Intensity distribution on the detector for a sample with :math`R(Q) = 1`""" diff --git a/src/ess/reflectometry/workflow.py b/src/ess/reflectometry/workflow.py index 2c0fa9be..8a203eb5 100644 --- a/src/ess/reflectometry/workflow.py +++ b/src/ess/reflectometry/workflow.py @@ -6,7 +6,6 @@ import sciline import scipp as sc -from ess.amor.types import RawChopper from ess.reflectometry.orso import ( OrsoExperiment, OrsoOwner, @@ -14,11 +13,13 @@ OrsoSampleFilenames, ) from ess.reflectometry.types import ( + DetectorRotation, Filename, - ReducibleData, + RawChopper, RunType, SampleRotation, SampleRun, + UnscaledReducibleData, ) @@ -62,26 +63,54 @@ def with_filenames( mapped = wf.map(df) - wf[ReducibleData[runtype]] = mapped[ReducibleData[runtype]].reduce( - index=axis_name, func=_concatenate_event_lists - ) - wf[RawChopper[runtype]] = mapped[RawChopper[runtype]].reduce( - index=axis_name, func=_any_value - ) - wf[SampleRotation[runtype]] = mapped[SampleRotation[runtype]].reduce( - index=axis_name, func=_any_value - ) - - if runtype is SampleRun: - wf[OrsoSample] = mapped[OrsoSample].reduce(index=axis_name, func=_any_value) - wf[OrsoExperiment] = mapped[OrsoExperiment].reduce( + try: + wf[UnscaledReducibleData[runtype]] = mapped[ + UnscaledReducibleData[runtype] + ].reduce(index=axis_name, func=_concatenate_event_lists) + except (ValueError, KeyError): + # UnscaledReducibleData[runtype] is independent of Filename[runtype] or is not + # present in the workflow. + pass + try: + wf[RawChopper[runtype]] = mapped[RawChopper[runtype]].reduce( index=axis_name, func=_any_value ) - wf[OrsoOwner] = mapped[OrsoOwner].reduce(index=axis_name, func=lambda x, *_: x) - wf[OrsoSampleFilenames] = mapped[OrsoSampleFilenames].reduce( + except (ValueError, KeyError): + # RawChopper[runtype] is independent of Filename[runtype] or is not + # present in the workflow. + pass + try: + wf[SampleRotation[runtype]] = mapped[SampleRotation[runtype]].reduce( + index=axis_name, func=_any_value + ) + except (ValueError, KeyError): + # SampleRotation[runtype] is independent of Filename[runtype] or is not + # present in the workflow. + pass + try: + wf[DetectorRotation[runtype]] = mapped[DetectorRotation[runtype]].reduce( + index=axis_name, func=_any_value + ) + except (ValueError, KeyError): + # DetectorRotation[runtype] is independent of Filename[runtype] or is not + # present in the workflow. + pass + + if runtype is SampleRun: + if OrsoSample in wf.underlying_graph: + wf[OrsoSample] = mapped[OrsoSample].reduce(index=axis_name, func=_any_value) + if OrsoExperiment in wf.underlying_graph: + wf[OrsoExperiment] = mapped[OrsoExperiment].reduce( + index=axis_name, func=_any_value + ) + if OrsoOwner in wf.underlying_graph: + wf[OrsoOwner] = mapped[OrsoOwner].reduce( + index=axis_name, func=lambda x, *_: x + ) + if OrsoSampleFilenames in wf.underlying_graph: # When we don't map over filenames # each OrsoSampleFilenames is a list with a single entry. - index=axis_name, - func=_concatenate_lists, - ) + wf[OrsoSampleFilenames] = mapped[OrsoSampleFilenames].reduce( + index=axis_name, func=_concatenate_lists + ) return wf diff --git a/tests/amor/pipeline_test.py b/tests/amor/pipeline_test.py index b6ca1504..9b5f01b3 100644 --- a/tests/amor/pipeline_test.py +++ b/tests/amor/pipeline_test.py @@ -127,16 +127,15 @@ def test_save_reduced_orso_file(output_folder: Path): ) wf[Filename[ReferenceRun]] = data.amor_run(4152) wf[QBins] = sc.geomspace(dim="Q", start=0.01, stop=0.06, num=201, unit="1/angstrom") - r = wf.compute(ReflectivityOverQ) - _, (s,) = scale_reflectivity_curves_to_overlap( - [r.hist()], + + scaled_wf = scale_reflectivity_curves_to_overlap( + wf, critical_edge_interval=( sc.scalar(0.01, unit='1/angstrom'), sc.scalar(0.014, unit='1/angstrom'), ), ) - wf[ReflectivityOverQ] = s * r - wf[orso.OrsoCreator] = orso.OrsoCreator( + scaled_wf[orso.OrsoCreator] = orso.OrsoCreator( fileio.base.Person( name="Max Mustermann", affiliation="European Spallation Source ERIC", @@ -144,7 +143,7 @@ def test_save_reduced_orso_file(output_folder: Path): ) ) fileio.orso.save_orso( - datasets=[wf.compute(orso.OrsoIofQDataset)], + datasets=[scaled_wf.compute(orso.OrsoIofQDataset)], fname=output_folder / 'amor_reduced_iofq.ort', ) diff --git a/tests/corrections_test.py b/tests/reflectometry/corrections_test.py similarity index 100% rename from tests/corrections_test.py rename to tests/reflectometry/corrections_test.py diff --git a/tests/orso_test.py b/tests/reflectometry/orso_test.py similarity index 100% rename from tests/orso_test.py rename to tests/reflectometry/orso_test.py diff --git a/tests/reflectometry/tools_test.py b/tests/reflectometry/tools_test.py new file mode 100644 index 00000000..70e8838e --- /dev/null +++ b/tests/reflectometry/tools_test.py @@ -0,0 +1,445 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) + +import numpy as np +import pytest +import sciline as sl +import scipp as sc +from orsopy.fileio import Orso, OrsoDataset +from scipp.testing import assert_allclose + +from ess.reflectometry.tools import ( + batch_processor, + combine_curves, + linlogspace, + scale_reflectivity_curves_to_overlap, +) +from ess.reflectometry.types import ( + Filename, + QBins, + ReducibleData, + ReferenceRun, + ReflectivityOverQ, + RunType, + SampleRun, + ScalingFactorForOverlap, + UnscaledReducibleData, +) + + +def make_sample_events(scale, qmin, qmax): + n1 = 10 + n2 = 15 + qbins = sc.linspace('Q', qmin, qmax, n1 + n2 + 1) + data = sc.DataArray( + data=sc.concat( + ( + sc.ones(dims=['Q'], shape=[10], with_variances=True), + 0.5 * sc.ones(dims=['Q'], shape=[15], with_variances=True), + ), + dim='Q', + ) + * scale, + coords={'Q': sc.midpoints(qbins, 'Q')}, + ) + data.variances[:] = 0.1 + data.unit = 'counts' + return data.bin(Q=qbins) + + +def make_reference_events(qmin, qmax): + n = 25 + qbins = sc.linspace('Q', qmin, qmax, n + 1) + data = sc.DataArray( + data=sc.ones(dims=['Q'], shape=[n], with_variances=True), + coords={'Q': sc.midpoints(qbins, 'Q')}, + ) + data.variances[:] = 0.1 + data.unit = 'counts' + return data.bin(Q=qbins) + + +def make_workflow(): + def sample_data_from_filename( + filename: Filename[SampleRun], + ) -> UnscaledReducibleData[SampleRun]: + return UnscaledReducibleData[SampleRun]( + make_sample_events(*(float(x) for x in filename.split('_'))) + ) + + def reference_data_from_filename( + filename: Filename[ReferenceRun], + ) -> UnscaledReducibleData[ReferenceRun]: + return UnscaledReducibleData[ReferenceRun]( + make_reference_events(*(float(x) for x in filename.split('_'))) + ) + + def apply_scaling( + da: UnscaledReducibleData[RunType], + scale: ScalingFactorForOverlap[RunType], + ) -> ReducibleData[RunType]: + """ + Scales the raw data by a given factor. + """ + return ReducibleData[RunType](da * scale) + + def reflectivity( + sample: ReducibleData[SampleRun], + reference: ReducibleData[ReferenceRun], + qbins: QBins, + ) -> ReflectivityOverQ: + return ReflectivityOverQ(sample.hist(Q=qbins) / reference.hist(Q=qbins)) + + return sl.Pipeline( + [ + sample_data_from_filename, + reference_data_from_filename, + apply_scaling, + reflectivity, + ] + ) + + +def test_reflectivity_curve_scaling(): + wf = make_workflow() + wf[ScalingFactorForOverlap[SampleRun]] = 1.0 + wf[ScalingFactorForOverlap[ReferenceRun]] = 1.0 + params = {'a': (1.0, 0, 0.3), 'b': (0.8, 0.2, 0.7), 'c': (0.1, 0.6, 1.0)} + table = { + k: { + Filename[SampleRun]: "_".join(map(str, v)), + Filename[ReferenceRun]: "_".join(map(str, v[1:])), + QBins: make_reference_events(*v[1:]).coords['Q'], + } + for k, v in params.items() + } + wfc = batch_processor(wf, table) + + scaled_wf = scale_reflectivity_curves_to_overlap(wfc) + + factors = scaled_wf.compute(ScalingFactorForOverlap[SampleRun]) + + assert np.isclose(factors['a'], 1.0) + assert np.isclose(factors['b'], 0.5 / 0.8) + assert np.isclose(factors['c'], 0.25 / 0.1) + + +def test_reflectivity_curve_scaling_with_critical_edge(): + wf = make_workflow() + wf[ScalingFactorForOverlap[SampleRun]] = 1.0 + wf[ScalingFactorForOverlap[ReferenceRun]] = 1.0 + params = {'a': (2, 0, 0.3), 'b': (0.8, 0.2, 0.7), 'c': (0.1, 0.6, 1.0)} + table = { + k: { + Filename[SampleRun]: "_".join(map(str, v)), + Filename[ReferenceRun]: "_".join(map(str, v[1:])), + QBins: make_reference_events(*v[1:]).coords['Q'], + } + for k, v in params.items() + } + wfc = batch_processor(wf, table) + + scaled_wf = scale_reflectivity_curves_to_overlap( + wfc, critical_edge_interval=(sc.scalar(0.01), sc.scalar(0.05)) + ) + + factors = scaled_wf.compute(ScalingFactorForOverlap[SampleRun]) + + assert np.isclose(factors['a'], 0.5) + assert np.isclose(factors['b'], 0.5 / 0.8) + assert np.isclose(factors['c'], 0.25 / 0.1) + + +def test_reflectivity_curve_scaling_works_with_single_workflow_and_critical_edge(): + wf = make_workflow() + wf[ScalingFactorForOverlap[SampleRun]] = 1.0 + wf[ScalingFactorForOverlap[ReferenceRun]] = 1.0 + wf[Filename[SampleRun]] = '2.5_0.4_0.8' + wf[Filename[ReferenceRun]] = '0.4_0.8' + wf[QBins] = make_reference_events(0.4, 0.8).coords['Q'] + + scaled_wf = scale_reflectivity_curves_to_overlap( + wf, critical_edge_interval=(sc.scalar(0.0), sc.scalar(0.5)) + ) + + factor = scaled_wf.compute(ScalingFactorForOverlap[SampleRun]) + + assert np.isclose(factor, 0.4) + + +def test_reflectivity_curve_scaling_caches_intermediate_results(): + sample_count = 0 + reference_count = 0 + + def sample_data_from_filename( + filename: Filename[SampleRun], + ) -> UnscaledReducibleData[SampleRun]: + nonlocal sample_count + sample_count += 1 + return UnscaledReducibleData[SampleRun]( + make_sample_events(*(float(x) for x in filename.split('_'))) + ) + + def reference_data_from_filename( + filename: Filename[ReferenceRun], + ) -> UnscaledReducibleData[ReferenceRun]: + nonlocal reference_count + reference_count += 1 + return UnscaledReducibleData[ReferenceRun]( + make_reference_events(*(float(x) for x in filename.split('_'))) + ) + + def apply_scaling( + da: UnscaledReducibleData[RunType], + scale: ScalingFactorForOverlap[RunType], + ) -> ReducibleData[RunType]: + """ + Scales the raw data by a given factor. + """ + return ReducibleData[RunType](da * scale) + + def reflectivity( + sample: ReducibleData[SampleRun], + reference: ReducibleData[ReferenceRun], + qbins: QBins, + ) -> ReflectivityOverQ: + return ReflectivityOverQ(sample.hist(Q=qbins) / reference.hist(Q=qbins)) + + wf = sl.Pipeline( + [ + sample_data_from_filename, + reference_data_from_filename, + apply_scaling, + reflectivity, + ] + ) + wf[ScalingFactorForOverlap[SampleRun]] = 1.0 + wf[ScalingFactorForOverlap[ReferenceRun]] = 1.0 + params = {'a': (1.0, 0, 0.3), 'b': (0.8, 0.2, 0.7), 'c': (0.1, 0.6, 1.0)} + table = { + k: { + Filename[SampleRun]: "_".join(map(str, v)), + Filename[ReferenceRun]: "_".join(map(str, v[1:])), + QBins: make_reference_events(*v[1:]).coords['Q'], + } + for k, v in params.items() + } + wfc = batch_processor(wf, table) + + scaled_wf = scale_reflectivity_curves_to_overlap( + wfc, cache_intermediate_results=False + ) + scaled_wf.compute(ReflectivityOverQ) + # We expect 6 counts: 3 for each of the 3 runs * 2 for computing ReflectivityOverQ + # inside the scaling function and one more time for the final computation just above + assert sample_count == 6 + assert reference_count == 6 + + sample_count = 0 + reference_count = 0 + + scaled_wf = scale_reflectivity_curves_to_overlap( + wfc, cache_intermediate_results=True + ) + scaled_wf.compute(ReflectivityOverQ) + # We expect 3 counts: 1 for each of the 3 runs * 1 for computing ReflectivityOverQ + assert sample_count == 3 + assert reference_count == 3 + + +def test_combined_curves(): + qgrid = sc.linspace('Q', 0, 1, 26) + curves = ( + make_sample_events(1.0, 0, 0.3).hist(), + 0.5 * make_sample_events(1.0, 0.2, 0.7).hist(), + 0.25 * make_sample_events(1.0, 0.6, 1.0).hist(), + ) + + combined = combine_curves(curves, qgrid) + assert_allclose( + combined.data, + sc.array( + dims='Q', + values=[ + 1.0, + 1, + 1, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.125, + 0.125, + 0.125, + 0.125, + 0.125, + 0.125, + ], + variances=[ + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.02, + 0.02, + 0.025, + 0.025, + 0.025, + 0.025, + 0.025, + 0.025, + 0.025, + 0.025, + 0.005, + 0.005, + 0.00625, + 0.00625, + 0.00625, + 0.00625, + 0.00625, + 0.00625, + 0.00625, + 0.00625, + ], + unit='counts', + ), + ) + + +def test_linlogspace_linear(): + q_lin = linlogspace( + dim='qz', edges=[0.008, 0.08], scale='linear', num=50, unit='1/angstrom' + ) + expected = sc.linspace(dim='qz', start=0.008, stop=0.08, num=50, unit='1/angstrom') + assert sc.allclose(q_lin, expected) + + +def test_linlogspace_linear_list_input(): + q_lin = linlogspace( + dim='qz', edges=[0.008, 0.08], unit='1/angstrom', scale=['linear'], num=[50] + ) + expected = sc.linspace(dim='qz', start=0.008, stop=0.08, num=50, unit='1/angstrom') + assert sc.allclose(q_lin, expected) + + +def test_linlogspace_log(): + q_log = linlogspace( + dim='qz', edges=[0.008, 0.08], unit='1/angstrom', scale='log', num=50 + ) + expected = sc.geomspace(dim='qz', start=0.008, stop=0.08, num=50, unit='1/angstrom') + assert sc.allclose(q_log, expected) + + +def test_linlogspace_linear_log(): + q_linlog = linlogspace( + dim='qz', + edges=[0.008, 0.03, 0.08], + unit='1/angstrom', + scale=['linear', 'log'], + num=[16, 20], + ) + exp_lin = sc.linspace(dim='qz', start=0.008, stop=0.03, num=16, unit='1/angstrom') + exp_log = sc.geomspace(dim='qz', start=0.03, stop=0.08, num=21, unit='1/angstrom') + expected = sc.concat([exp_lin, exp_log['qz', 1:]], 'qz') + assert sc.allclose(q_linlog, expected) + + +def test_linlogspace_log_linear(): + q_loglin = linlogspace( + dim='qz', + edges=[0.008, 0.03, 0.08], + unit='1/angstrom', + scale=['log', 'linear'], + num=[16, 20], + ) + exp_log = sc.geomspace(dim='qz', start=0.008, stop=0.03, num=16, unit='1/angstrom') + exp_lin = sc.linspace(dim='qz', start=0.03, stop=0.08, num=21, unit='1/angstrom') + expected = sc.concat([exp_log, exp_lin['qz', 1:]], 'qz') + assert sc.allclose(q_loglin, expected) + + +def test_linlogspace_linear_log_linear(): + q_linloglin = linlogspace( + dim='qz', + edges=[0.008, 0.03, 0.08, 0.12], + unit='1/angstrom', + scale=['linear', 'log', 'linear'], + num=[16, 20, 10], + ) + exp_lin = sc.linspace(dim='qz', start=0.008, stop=0.03, num=16, unit='1/angstrom') + exp_log = sc.geomspace(dim='qz', start=0.03, stop=0.08, num=21, unit='1/angstrom') + exp_lin2 = sc.linspace(dim='qz', start=0.08, stop=0.12, num=11, unit='1/angstrom') + expected = sc.concat([exp_lin, exp_log['qz', 1:], exp_lin2['qz', 1:]], 'qz') + assert sc.allclose(q_linloglin, expected) + + +def test_linlogspace_bad_input(): + with pytest.raises(ValueError, match="Sizes do not match"): + _ = linlogspace( + dim='qz', + edges=[0.008, 0.03, 0.08, 0.12], + unit='1/angstrom', + scale=['linear', 'log'], + num=[16, 20], + ) + + +@pytest.mark.filterwarnings("ignore:No suitable") +def test_batch_processor_tool_uses_expected_parameters_from_each_run(): + def normalized_ioq(filename: Filename[SampleRun]) -> ReflectivityOverQ: + return filename + + def orso_dataset(filename: Filename[SampleRun]) -> OrsoDataset: + class Reduction: + corrections = [] # noqa: RUF012 + + return OrsoDataset( + Orso({}, Reduction, [], name=f'{filename}.orso'), np.ones((0, 0)) + ) + + workflow = sl.Pipeline( + [normalized_ioq, orso_dataset], params={Filename[SampleRun]: 'default'} + ) + + batch = batch_processor(workflow, {'a': {}, 'b': {Filename[SampleRun]: 'special'}}) + + results = batch.compute(OrsoDataset) + assert len(results) == 2 + assert results['a'].info.name == 'default.orso' + assert results['b'].info.name == 'special.orso' + + +# TODO: need to implement groupby in the mapping +# def test_batch_processor_tool_merges_event_lists(): +# wf = make_workflow() +# wf[ScalingFactorForOverlap[SampleRun]] = 1.0 +# wf[ScalingFactorForOverlap[ReferenceRun]] = 1.0 + +# runs = { +# 'a': {Filename[SampleRun]: ('1.0_0.0_0.3', '1.5_0.0_0.3')}, +# 'b': {Filename[SampleRun]: '0.8_0.2_0.7'}, +# 'c': {Filename[SampleRun]: ('0.1_0.6_1.0', '0.2_0.6_1.0')}, +# } +# batch = batch_processor(wf, runs) + +# results = batch.compute(UnscaledReducibleData[SampleRun]) + +# assert_almost_equal(results['a'].sum().value, +# 10 + 15 * 0.5 + (10 + 15 * 0.5) * 1.5) +# assert_almost_equal(results['b'].sum().value, 10 * 0.8 + 15 * 0.5 * 0.8) +# assert_almost_equal( +# results['c'].sum().value, (10 + 15 * 0.5) * 0.1 + (10 + 15 * 0.5) * 0.2 +# ) diff --git a/tests/reflectometry/workflow_collection_test.py b/tests/reflectometry/workflow_collection_test.py new file mode 100644 index 00000000..a7232a72 --- /dev/null +++ b/tests/reflectometry/workflow_collection_test.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) + +import pandas as pd +import sciline as sl + +from ess.reflectometry.tools import WorkflowCollection + + +def int_to_float(x: int) -> float: + return 0.5 * x + + +def int_float_to_str(x: int, y: float) -> str: + return f"{x};{y}" + + +def make_param_table(params: dict) -> pd.DataFrame: + all_types = {t for v in params.values() for t in v.keys()} + data = {t: [] for t in all_types} + for param in params.values(): + for t in all_types: + data[t].append(param[t]) + return pd.DataFrame(data, index=params.keys()).rename_axis(index='run_id') + + +def test_compute() -> None: + wf = sl.Pipeline([int_to_float, int_float_to_str]) + + coll = WorkflowCollection(wf, make_param_table({'a': {int: 3}, 'b': {int: 4}})) + + assert dict(coll.compute(float)) == {'a': 1.5, 'b': 2.0} + assert dict(coll.compute(str)) == {'a': '3;1.5', 'b': '4;2.0'} + + +def test_compute_multiple() -> None: + wf = sl.Pipeline([int_to_float, int_float_to_str]) + + coll = WorkflowCollection(wf, make_param_table({'a': {int: 3}, 'b': {int: 4}})) + + # wfa = wf.copy() + # wfa[int] = 3 + # wfb = wf.copy() + # wfb[int] = 4 + # coll = WorkflowCollection({'a': wfa, 'b': wfb}) + + result = coll.compute([float, str]) + + assert result['a'] == {float: 1.5, str: '3;1.5'} + assert result['b'] == {float: 2.0, str: '4;2.0'} + + +def test_setitem_mapping() -> None: + wf = sl.Pipeline([int_to_float, int_float_to_str]) + wfa = wf.copy() + wfa[int] = 3 + wfb = wf.copy() + wfb[int] = 4 + coll = WorkflowCollection({'a': wfa, 'b': wfb}) + + coll[int] = {'a': 7, 'b': 8} + + assert coll.compute(float) == {'a': 3.5, 'b': 4.0} + assert coll.compute(str) == {'a': '7;3.5', 'b': '8;4.0'} + + +def test_setitem_single_value() -> None: + wf = sl.Pipeline([int_to_float, int_float_to_str]) + wfa = wf.copy() + wfa[int] = 3 + wfb = wf.copy() + wfb[int] = 4 + coll = WorkflowCollection({'a': wfa, 'b': wfb}) + + coll[int] = 5 + + assert coll.compute(float) == {'a': 2.5, 'b': 2.5} + assert coll.compute(str) == {'a': '5;2.5', 'b': '5;2.5'} + + +def test_copy() -> None: + wf = sl.Pipeline([int_to_float, int_float_to_str]) + wfa = wf.copy() + wfa[int] = 3 + wfb = wf.copy() + wfb[int] = 4 + coll = WorkflowCollection({'a': wfa, 'b': wfb}) + + coll_copy = coll.copy() + + assert coll_copy.compute(float) == {'a': 1.5, 'b': 2.0} + assert coll_copy.compute(str) == {'a': '3;1.5', 'b': '4;2.0'} + + coll_copy[int] = {'a': 7, 'b': 8} + assert coll.compute(float) == {'a': 1.5, 'b': 2.0} + assert coll.compute(str) == {'a': '3;1.5', 'b': '4;2.0'} + assert coll_copy.compute(float) == {'a': 3.5, 'b': 4.0} + assert coll_copy.compute(str) == {'a': '7;3.5', 'b': '8;4.0'} + + +def test_add_workflow() -> None: + wf = sl.Pipeline([int_to_float, int_float_to_str]) + wfa = wf.copy() + wfa[int] = 3 + wfb = wf.copy() + wfb[int] = 4 + coll = WorkflowCollection({'a': wfa, 'b': wfb}) + + wfc = wf.copy() + wfc[int] = 5 + coll.add('c', wfc) + + assert coll.compute(float) == {'a': 1.5, 'b': 2.0, 'c': 2.5} + assert coll.compute(str) == {'a': '3;1.5', 'b': '4;2.0', 'c': '5;2.5'} + + +def test_add_workflow_with_existing_key() -> None: + wf = sl.Pipeline([int_to_float, int_float_to_str]) + wfa = wf.copy() + wfa[int] = 3 + wfb = wf.copy() + wfb[int] = 4 + coll = WorkflowCollection({'a': wfa, 'b': wfb}) + + wfc = wf.copy() + wfc[int] = 5 + coll.add('a', wfc) + + assert coll.compute(float) == {'a': 2.5, 'b': 2.0} + assert coll.compute(str) == {'a': '5;2.5', 'b': '4;2.0'} + assert 'c' not in coll.keys() # 'c' should not exist + + +def test_remove_workflow() -> None: + wf = sl.Pipeline([int_to_float, int_float_to_str]) + wfa = wf.copy() + wfa[int] = 3 + wfb = wf.copy() + wfb[int] = 4 + coll = WorkflowCollection({'a': wfa, 'b': wfb}) + + coll.remove('b') + + assert 'b' not in coll.keys() + assert coll.compute(float) == {'a': 1.5} + assert coll.compute(str) == {'a': '3;1.5'} diff --git a/tests/tools_test.py b/tests/tools_test.py deleted file mode 100644 index 03c07aaf..00000000 --- a/tests/tools_test.py +++ /dev/null @@ -1,249 +0,0 @@ -# SPDX-License-Identifier: BSD-3-Clause -# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -import numpy as np -import pytest -import sciline as sl -import scipp as sc -from numpy.testing import assert_allclose as np_assert_allclose -from orsopy.fileio import Orso, OrsoDataset -from scipp.testing import assert_allclose - -from ess.reflectometry.orso import OrsoIofQDataset -from ess.reflectometry.tools import ( - combine_curves, - linlogspace, - orso_datasets_from_measurements, - scale_reflectivity_curves_to_overlap, -) -from ess.reflectometry.types import Filename, ReflectivityOverQ, SampleRun - - -def curve(d, qmin, qmax): - return sc.DataArray(data=d, coords={'Q': sc.linspace('Q', qmin, qmax, len(d) + 1)}) - - -def test_reflectivity_curve_scaling(): - data = sc.concat( - ( - sc.ones(dims=['Q'], shape=[10], with_variances=True), - 0.5 * sc.ones(dims=['Q'], shape=[15], with_variances=True), - ), - dim='Q', - ) - data.variances[:] = 0.1 - - curves, factors = scale_reflectivity_curves_to_overlap( - (curve(data, 0, 0.3), curve(0.8 * data, 0.2, 0.7), curve(0.1 * data, 0.6, 1.0)), - ) - - assert_allclose(curves[0].data, data, rtol=sc.scalar(1e-5)) - assert_allclose(curves[1].data, 0.5 * data, rtol=sc.scalar(1e-5)) - assert_allclose(curves[2].data, 0.25 * data, rtol=sc.scalar(1e-5)) - np_assert_allclose((1, 0.5 / 0.8, 0.25 / 0.1), factors, 1e-4) - - -def test_reflectivity_curve_scaling_with_critical_edge(): - data = sc.concat( - ( - sc.ones(dims=['Q'], shape=[10], with_variances=True), - 0.5 * sc.ones(dims=['Q'], shape=[15], with_variances=True), - ), - dim='Q', - ) - data.variances[:] = 0.1 - - curves, factors = scale_reflectivity_curves_to_overlap( - ( - 2 * curve(data, 0, 0.3), - curve(0.8 * data, 0.2, 0.7), - curve(0.1 * data, 0.6, 1.0), - ), - critical_edge_interval=(sc.scalar(0.01), sc.scalar(0.05)), - ) - - assert_allclose(curves[0].data, data, rtol=sc.scalar(1e-5)) - assert_allclose(curves[1].data, 0.5 * data, rtol=sc.scalar(1e-5)) - assert_allclose(curves[2].data, 0.25 * data, rtol=sc.scalar(1e-5)) - np_assert_allclose((0.5, 0.5 / 0.8, 0.25 / 0.1), factors, 1e-4) - - -def test_combined_curves(): - qgrid = sc.linspace('Q', 0, 1, 26) - data = sc.concat( - ( - sc.ones(dims=['Q'], shape=[10], with_variances=True), - 0.5 * sc.ones(dims=['Q'], shape=[15], with_variances=True), - ), - dim='Q', - ) - data.variances[:] = 0.1 - curves = ( - curve(data, 0, 0.3), - curve(0.5 * data, 0.2, 0.7), - curve(0.25 * data, 0.6, 1.0), - ) - - combined = combine_curves(curves, qgrid) - assert_allclose( - combined.data, - sc.array( - dims='Q', - values=[ - 1.0, - 1, - 1, - 0.5, - 0.5, - 0.5, - 0.5, - 0.5, - 0.5, - 0.5, - 0.25, - 0.25, - 0.25, - 0.25, - 0.25, - 0.25, - 0.25, - 0.25, - 0.25, - 0.125, - 0.125, - 0.125, - 0.125, - 0.125, - 0.125, - ], - variances=[ - 0.1, - 0.1, - 0.1, - 0.1, - 0.1, - 0.02, - 0.02, - 0.025, - 0.025, - 0.025, - 0.025, - 0.025, - 0.025, - 0.025, - 0.025, - 0.005, - 0.005, - 0.00625, - 0.00625, - 0.00625, - 0.00625, - 0.00625, - 0.00625, - 0.00625, - 0.00625, - ], - ), - ) - - -def test_linlogspace_linear(): - q_lin = linlogspace( - dim='qz', edges=[0.008, 0.08], scale='linear', num=50, unit='1/angstrom' - ) - expected = sc.linspace(dim='qz', start=0.008, stop=0.08, num=50, unit='1/angstrom') - assert sc.allclose(q_lin, expected) - - -def test_linlogspace_linear_list_input(): - q_lin = linlogspace( - dim='qz', edges=[0.008, 0.08], unit='1/angstrom', scale=['linear'], num=[50] - ) - expected = sc.linspace(dim='qz', start=0.008, stop=0.08, num=50, unit='1/angstrom') - assert sc.allclose(q_lin, expected) - - -def test_linlogspace_log(): - q_log = linlogspace( - dim='qz', edges=[0.008, 0.08], unit='1/angstrom', scale='log', num=50 - ) - expected = sc.geomspace(dim='qz', start=0.008, stop=0.08, num=50, unit='1/angstrom') - assert sc.allclose(q_log, expected) - - -def test_linlogspace_linear_log(): - q_linlog = linlogspace( - dim='qz', - edges=[0.008, 0.03, 0.08], - unit='1/angstrom', - scale=['linear', 'log'], - num=[16, 20], - ) - exp_lin = sc.linspace(dim='qz', start=0.008, stop=0.03, num=16, unit='1/angstrom') - exp_log = sc.geomspace(dim='qz', start=0.03, stop=0.08, num=21, unit='1/angstrom') - expected = sc.concat([exp_lin, exp_log['qz', 1:]], 'qz') - assert sc.allclose(q_linlog, expected) - - -def test_linlogspace_log_linear(): - q_loglin = linlogspace( - dim='qz', - edges=[0.008, 0.03, 0.08], - unit='1/angstrom', - scale=['log', 'linear'], - num=[16, 20], - ) - exp_log = sc.geomspace(dim='qz', start=0.008, stop=0.03, num=16, unit='1/angstrom') - exp_lin = sc.linspace(dim='qz', start=0.03, stop=0.08, num=21, unit='1/angstrom') - expected = sc.concat([exp_log, exp_lin['qz', 1:]], 'qz') - assert sc.allclose(q_loglin, expected) - - -def test_linlogspace_linear_log_linear(): - q_linloglin = linlogspace( - dim='qz', - edges=[0.008, 0.03, 0.08, 0.12], - unit='1/angstrom', - scale=['linear', 'log', 'linear'], - num=[16, 20, 10], - ) - exp_lin = sc.linspace(dim='qz', start=0.008, stop=0.03, num=16, unit='1/angstrom') - exp_log = sc.geomspace(dim='qz', start=0.03, stop=0.08, num=21, unit='1/angstrom') - exp_lin2 = sc.linspace(dim='qz', start=0.08, stop=0.12, num=11, unit='1/angstrom') - expected = sc.concat([exp_lin, exp_log['qz', 1:], exp_lin2['qz', 1:]], 'qz') - assert sc.allclose(q_linloglin, expected) - - -def test_linlogspace_bad_input(): - with pytest.raises(ValueError, match="Sizes do not match"): - _ = linlogspace( - dim='qz', - edges=[0.008, 0.03, 0.08, 0.12], - unit='1/angstrom', - scale=['linear', 'log'], - num=[16, 20], - ) - - -@pytest.mark.filterwarnings("ignore:No suitable") -def test_orso_datasets_tool(): - def normalized_ioq(filename: Filename[SampleRun]) -> ReflectivityOverQ: - return filename - - def orso_dataset(filename: Filename[SampleRun]) -> OrsoIofQDataset: - class Reduction: - corrections = [] # noqa: RUF012 - - return OrsoDataset( - Orso({}, Reduction, [], name=f'{filename}.orso'), np.ones((0, 0)) - ) - - workflow = sl.Pipeline( - [normalized_ioq, orso_dataset], params={Filename[SampleRun]: 'default'} - ) - datasets = orso_datasets_from_measurements( - workflow, - [{}, {Filename[SampleRun]: 'special'}], - scale_to_overlap=False, - ) - assert len(datasets) == 2 - assert tuple(d.info.name for d in datasets) == ('default.orso', 'special.orso')