diff --git a/nitransforms/io/x5.py b/nitransforms/io/x5.py index 2f86e8a..094bf11 100644 --- a/nitransforms/io/x5.py +++ b/nitransforms/io/x5.py @@ -26,7 +26,7 @@ import numpy as np -@dataclass +@dataclass(eq=True) class X5Domain: """Domain information of a transform representing reference/moving spaces.""" @@ -105,35 +105,7 @@ def to_filename(fname: str | Path, x5_list: List[X5Transform]): tg = out_file.create_group("TransformGroup") for i, node in enumerate(x5_list): g = tg.create_group(str(i)) - g.attrs["Type"] = node.type - g.attrs["ArrayLength"] = node.array_length - if node.subtype is not None: - g.attrs["SubType"] = node.subtype - if node.representation is not None: - g.attrs["Representation"] = node.representation - if node.metadata is not None: - g.attrs["Metadata"] = json.dumps(node.metadata) - g.create_dataset("Transform", data=node.transform) - g.create_dataset( - "DimensionKinds", - data=np.asarray(node.dimension_kinds, dtype="S"), - ) - if node.domain is not None: - dgrp = g.create_group("Domain") - dgrp.create_dataset("Grid", data=np.uint8(1 if node.domain.grid else 0)) - dgrp.create_dataset("Size", data=np.asarray(node.domain.size)) - dgrp.create_dataset("Mapping", data=node.domain.mapping) - if node.domain.coordinates is not None: - dgrp.attrs["Coordinates"] = node.domain.coordinates - - if node.inverse is not None: - g.create_dataset("Inverse", data=node.inverse) - if node.jacobian is not None: - g.create_dataset("Jacobian", data=node.jacobian) - if node.additional_parameters is not None: - g.create_dataset( - "AdditionalParameters", data=node.additional_parameters - ) + _write_x5_group(g, node) return fname @@ -188,3 +160,30 @@ def _read_x5_group(node) -> X5Transform: ) return x5 + + +def _write_x5_group(g, node: X5Transform): + """Write one :class:`X5Transform` element into an opened HDF5 group.""" + g.attrs["Type"] = node.type + g.attrs["ArrayLength"] = node.array_length + if node.subtype is not None: + g.attrs["SubType"] = node.subtype + if node.representation is not None: + g.attrs["Representation"] = node.representation + if node.metadata is not None: + g.attrs["Metadata"] = json.dumps(node.metadata) + g.create_dataset("Transform", data=node.transform) + g.create_dataset("DimensionKinds", data=np.asarray(node.dimension_kinds, dtype="S")) + if node.domain is not None: + dgrp = g.create_group("Domain") + dgrp.create_dataset("Grid", data=np.uint8(1 if node.domain.grid else 0)) + dgrp.create_dataset("Size", data=np.asarray(node.domain.size)) + dgrp.create_dataset("Mapping", data=node.domain.mapping) + if node.domain.coordinates is not None: + dgrp.attrs["Coordinates"] = node.domain.coordinates + if node.inverse is not None: + g.create_dataset("Inverse", data=node.inverse) + if node.jacobian is not None: + g.create_dataset("Jacobian", data=node.jacobian) + if node.additional_parameters is not None: + g.create_dataset("AdditionalParameters", data=node.additional_parameters) diff --git a/nitransforms/linear.py b/nitransforms/linear.py index 26bf337..d63c7d4 100644 --- a/nitransforms/linear.py +++ b/nitransforms/linear.py @@ -120,8 +120,15 @@ def __eq__(self, other): >>> xfm2 = Affine(xfm1.matrix) >>> xfm1 == xfm2 True + >>> xfm1 == Affine() + False + >>> xfm1 == TransformBase() + False """ + if not hasattr(other, "matrix"): + return False + _eq = np.allclose(self.matrix, other.matrix, rtol=EQUALITY_TOL) if _eq and self._reference != other._reference: warnings.warn("Affines are equal, but references do not match.") @@ -186,22 +193,9 @@ def from_filename( """Create an affine from a transform file.""" if fmt and fmt.upper() == "X5": - x5_xfm = load_x5(filename)[x5_position] - Transform = cls if x5_xfm.array_length == 1 else LinearTransformsMapping - if ( - x5_xfm.domain - and not x5_xfm.domain.grid - and len(x5_xfm.domain.size) == 3 - ): # pragma: no cover - raise NotImplementedError( - "Only 3D regularly gridded domains are supported" - ) - elif x5_xfm.domain: - # Override reference - Domain = namedtuple("Domain", "affine shape") - reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size) - - return Transform(x5_xfm.transform, reference=reference) + return from_x5( + load_x5(filename), reference=reference, x5_position=x5_position + ) fmtlist = [fmt] if fmt is not None else ("itk", "lta", "afni", "fsl") @@ -458,3 +452,20 @@ def load(filename, fmt=None, reference=None, moving=None): xfm = xfm[0] return xfm + + +def from_x5(x5_list, reference=None, x5_position=0): + """Create an affine from a list of :class:`~nitransforms.io.x5.X5Transform` objects.""" + + x5_xfm = x5_list[x5_position] + Transform = Affine if x5_xfm.array_length == 1 else LinearTransformsMapping + if ( + x5_xfm.domain and not x5_xfm.domain.grid and len(x5_xfm.domain.size) == 3 + ): # pragma: no cover + raise NotImplementedError("Only 3D regularly gridded domains are supported") + elif x5_xfm.domain: + # Override reference + Domain = namedtuple("Domain", "affine shape") + reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size) + + return Transform(x5_xfm.transform, reference=reference) diff --git a/nitransforms/manip.py b/nitransforms/manip.py index 9389197..fe982a5 100644 --- a/nitransforms/manip.py +++ b/nitransforms/manip.py @@ -7,15 +7,26 @@ # ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## """Common interface for transforms.""" + +import os from collections.abc import Iterable import numpy as np -from .base import ( +import h5py +from nitransforms.base import ( TransformBase, TransformError, ) -from .linear import Affine -from .nonlinear import DenseFieldTransform +from nitransforms.io import itk, x5 as x5io +from nitransforms.io.x5 import from_filename as load_x5 +from nitransforms.linear import ( # noqa: F401 + Affine, + from_x5 as linear_from_x5, +) +from nitransforms.nonlinear import ( # noqa: F401 + DenseFieldTransform, + from_x5 as nonlinear_from_x5, +) class TransformChain(TransformBase): @@ -183,18 +194,42 @@ def asaffine(self, indices=None): The indices of the values to extract. """ - affines = self.transforms if indices is None else np.take(self.transforms, indices) + affines = ( + self.transforms if indices is None else np.take(self.transforms, indices) + ) retval = affines[0] for xfm in affines[1:]: retval = xfm @ retval return retval @classmethod - def from_filename(cls, filename, fmt="X5", reference=None, moving=None): + def from_filename(cls, filename, fmt="X5", reference=None, moving=None, x5_chain=0): """Load a transform file.""" - from .io import itk retval = [] + if fmt and fmt.upper() == "X5": + # Get list of X5 nodes and generate transforms + xfm_list = [ + globals()[f"{node.type}_from_x5"]([node]) for node in load_x5(filename) + ] + if not xfm_list: + raise TransformError("Empty transform group") + + if x5_chain is None: + return xfm_list + + with h5py.File(str(filename), "r") as f: + chain_grp = f.get("TransformChain") + if chain_grp is None: + raise TransformError("X5 file contains no TransformChain") + + chain_path = chain_grp[str(x5_chain)][()] + chain_path = ( + chain_path.decode() if isinstance(chain_path, bytes) else chain_path + ) + + return TransformChain([xfm_list[int(idx)] for idx in chain_path.split("/")]) + if str(filename).endswith(".h5"): reference = None xforms = itk.ITKCompositeH5.from_filename(filename) @@ -208,6 +243,48 @@ def from_filename(cls, filename, fmt="X5", reference=None, moving=None): raise NotImplementedError + def to_filename(self, filename, fmt="X5"): + """Store the transform chain in X5 format.""" + + if fmt.upper() != "X5": + raise NotImplementedError("Only X5 format is supported for chains") + + existing = ( + self.from_filename(filename, x5_chain=None) + if os.path.exists(filename) + else [] + ) + + xfm_chain = [] + new_xfms = [] + next_xfm_index = len(existing) + for xfm in self.transforms: + for eidx, existing_xfm in enumerate(existing): + if xfm == existing_xfm: + xfm_chain.append(eidx) + break + else: + xfm_chain.append(next_xfm_index) + new_xfms.append((next_xfm_index, xfm)) + existing.append(xfm) + next_xfm_index += 1 + + mode = "r+" if os.path.exists(filename) else "w" + with h5py.File(str(filename), mode) as f: + if "Format" not in f.attrs: + f.attrs["Format"] = "X5" + f.attrs["Version"] = np.uint16(1) + + tg = f.require_group("TransformGroup") + for idx, node in new_xfms: + g = tg.create_group(str(idx)) + x5io._write_x5_group(g, node.to_x5()) + + cg = f.require_group("TransformChain") + cg.create_dataset(str(len(cg)), data="/".join(str(i) for i in xfm_chain)) + + return filename + def _as_chain(x): """Convert a value into a transform chain.""" diff --git a/nitransforms/nonlinear.py b/nitransforms/nonlinear.py index 26dbe40..0869b9a 100644 --- a/nitransforms/nonlinear.py +++ b/nitransforms/nonlinear.py @@ -245,8 +245,15 @@ def __eq__(self, other): >>> xfm2 = DenseFieldTransform(test_dir / "someones_displacement_field.nii.gz") >>> xfm1 == xfm2 True + >>> xfm1 == TransformBase() + False + >>> xfm1 == BSplineFieldTransform(test_dir / "someones_bspline_coefficients.nii.gz") + False """ + if not hasattr(other, "_field") or self._field.shape != other._field.shape: + return False + _eq = np.allclose(self._field, other._field) if _eq and self._reference != other._reference: warnings.warn("Fields are equal, but references do not match.") @@ -278,7 +285,7 @@ def to_x5(self, metadata=None): ) @classmethod - def from_filename(cls, filename, fmt="X5"): + def from_filename(cls, filename, fmt="X5", x5_position=0): _factory = { "afni": io.afni.AFNIDisplacementsField, "itk": io.itk.ITKDisplacementsField, @@ -290,15 +297,7 @@ def from_filename(cls, filename, fmt="X5"): raise NotImplementedError(f"Unsupported format <{fmt}>") if fmt == "X5": - x5_xfm = load_x5(filename)[0] - Domain = namedtuple("Domain", "affine shape") - reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size) - field = nb.Nifti1Image(x5_xfm.transform, reference.affine) - return cls( - field, - is_deltas=x5_xfm.representation == "displacements", - reference=reference, - ) + return from_x5(load_x5(filename), x5_position=x5_position) return cls(_factory[fmt.lower()].from_filename(filename)) @@ -330,13 +329,48 @@ def __init__(self, coefficients, reference=None, order=3): "not match the number of dimensions" ) + def __eq__(self, other): + """ + Overload equals operator. + + Examples + -------- + >>> xfm1 = BSplineFieldTransform(test_dir / "someones_bspline_coefficients.nii.gz") + >>> xfm2 = BSplineFieldTransform(test_dir / "someones_bspline_coefficients.nii.gz") + >>> xfm1 == xfm2 + True + >>> xfm2._coeffs[:, :, :] = 0 # Let's zero all coefficients + >>> xfm1 == xfm2 + False + >>> xfm2 = BSplineFieldTransform( + ... test_dir / "someones_bspline_coefficients.nii.gz", + ... order=4, + ... ) + >>> xfm1 == xfm2 + False + >>> xfm1 == TransformBase() + False + >>> xfm1 == DenseFieldTransform(test_dir / "someones_displacement_field.nii.gz") + False + + """ + if not hasattr(other, "_coeffs") or self._coeffs.shape != other._coeffs.shape: + return False + + _eq = self._order == other._order + _eq = _eq and np.allclose(self._coeffs, other._coeffs) + + if _eq and self._reference != other._reference: + warnings.warn("Coefficients are equal, but references do not match.") + return _eq + @property def ndim(self): """Get the dimensions of the transform.""" return self._coeffs.ndim - 1 @classmethod - def from_filename(cls, filename, fmt="X5"): + def from_filename(cls, filename, fmt="X5", x5_position=0): _factory = { "X5": None, } @@ -344,13 +378,7 @@ def from_filename(cls, filename, fmt="X5"): if fmt not in {k.upper() for k in _factory}: raise NotImplementedError(f"Unsupported format <{fmt}>") - x5_xfm = load_x5(filename)[0] - Domain = namedtuple("Domain", "affine shape") - reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size) - - coefficients = nb.Nifti1Image(x5_xfm.transform, x5_xfm.additional_parameters) - return cls(coefficients, reference=reference) - + return from_x5(load_x5(filename), x5_position=x5_position) # return cls(_factory[fmt.lower()].from_filename(filename)) def to_field(self, reference=None, dtype="float32"): @@ -440,6 +468,31 @@ def map(self, x, inverse=False): return np.array([vfunc(_x).tolist() for _x in np.atleast_2d(x)]) +def from_x5(x5_list, x5_position=0): + """Create a transform from a list of :class:`~nitransforms.io.x5.X5Transform` objects.""" + + x5_xfm = x5_list[x5_position] + + Transform = ( + BSplineFieldTransform if x5_xfm.subtype == "bspline" else DenseFieldTransform + ) + Domain = namedtuple("Domain", "affine shape") + reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size) + xfm_params = ( + nb.Nifti1Image(x5_xfm.transform, x5_xfm.additional_parameters) + if x5_xfm.subtype == "bspline" + else x5_xfm.transform + ) + + xfm_kwargs = ( + {} + if x5_xfm.subtype == "bspline" + else {"is_deltas": x5_xfm.representation == "displacements"} + ) + + return Transform(xfm_params, reference=reference, **xfm_kwargs) + + def _map_xyz(x, reference, knots, coeffs): """Apply the transformation to just one coordinate.""" ndim = len(x) diff --git a/nitransforms/tests/test_manip.py b/nitransforms/tests/test_manip.py index b5dd5c6..4d7b3a5 100644 --- a/nitransforms/tests/test_manip.py +++ b/nitransforms/tests/test_manip.py @@ -5,8 +5,13 @@ import pytest import numpy as np +import nibabel as nb +import h5py +from ..base import TransformError from ..manip import TransformChain from ..linear import Affine +from ..nonlinear import DenseFieldTransform +from ..io import x5 FMT = {"lta": "fs", "tfm": "itk"} @@ -37,3 +42,47 @@ def test_collapse_affines(tmp_path, data_path, ext0, ext1, ext2): fmt=f"{FMT[ext2]}", ).matrix, ) + + +def test_transformchain_x5_roundtrip(tmp_path): + """Round-trip TransformChain with X5 storage.""" + + # Test empty transform file + x5.to_filename(tmp_path / "empty.x5", []) + with pytest.raises(TransformError): + TransformChain.from_filename(tmp_path / "empty.x5") + + mat = np.eye(4) + mat[0, 3] = 1 + aff = Affine(mat) + + # Test loading X5 with no transforms chains + x5.to_filename(tmp_path / "nochain.x5", [aff.to_x5()]) + with pytest.raises(TransformError): + TransformChain.from_filename(tmp_path / "nochain.x5") + + field = nb.Nifti1Image(np.zeros((5, 5, 5, 3), dtype="float32"), np.eye(4)) + fdata = field.get_fdata() + fdata[..., 1] = 1 + field = nb.Nifti1Image(fdata, np.eye(4)) + dfield = DenseFieldTransform(field, is_deltas=True) + + # Create a chain + chain = TransformChain([aff, aff, aff, dfield]) + fname = tmp_path / "chain.x5" + chain.to_filename(fname) + + with h5py.File(fname) as f: + assert len(f["TransformGroup"]) == 2 + + chain.to_filename(fname) # append again, should not duplicate transforms + + with h5py.File(fname) as f: + assert len(f["TransformGroup"]) == 2 + + loaded0 = TransformChain.from_filename(fname, fmt="X5", x5_chain=0) + loaded1 = TransformChain.from_filename(fname, fmt="X5", x5_chain=1) + + assert len(loaded0) == len(chain) + assert len(loaded1) == len(chain) + assert np.allclose(chain.map([[0, 0, 0]]), loaded1.map([[0, 0, 0]])) diff --git a/nitransforms/tests/test_resampling.py b/nitransforms/tests/test_resampling.py index 0e11df5..b65bf57 100644 --- a/nitransforms/tests/test_resampling.py +++ b/nitransforms/tests/test_resampling.py @@ -284,7 +284,7 @@ def test_apply_transformchain(tmp_path, testdata_path): / "ds-005_sub-01_from-T1w_to-MNI152NLin2009cAsym_mode-image_xfm.h5" ) - xfm = nitm.load(xfm_fname) + xfm = nitm.load(xfm_fname, fmt="itk") assert len(xfm) == 2