From 5885ed8f5747a7b582f2cb4f4da981dfaa0fc5b2 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Sun, 20 Jul 2025 12:37:05 +0200 Subject: [PATCH 1/5] ENH: support TransformChain in X5 --- nitransforms/io/x5.py | 57 +++++++------- nitransforms/manip.py | 123 ++++++++++++++++++++++++++++++- nitransforms/tests/test_manip.py | 34 +++++++++ 3 files changed, 181 insertions(+), 33 deletions(-) diff --git a/nitransforms/io/x5.py b/nitransforms/io/x5.py index 2f86e8ab..f167a35e 100644 --- a/nitransforms/io/x5.py +++ b/nitransforms/io/x5.py @@ -105,35 +105,7 @@ def to_filename(fname: str | Path, x5_list: List[X5Transform]): tg = out_file.create_group("TransformGroup") for i, node in enumerate(x5_list): g = tg.create_group(str(i)) - g.attrs["Type"] = node.type - g.attrs["ArrayLength"] = node.array_length - if node.subtype is not None: - g.attrs["SubType"] = node.subtype - if node.representation is not None: - g.attrs["Representation"] = node.representation - if node.metadata is not None: - g.attrs["Metadata"] = json.dumps(node.metadata) - g.create_dataset("Transform", data=node.transform) - g.create_dataset( - "DimensionKinds", - data=np.asarray(node.dimension_kinds, dtype="S"), - ) - if node.domain is not None: - dgrp = g.create_group("Domain") - dgrp.create_dataset("Grid", data=np.uint8(1 if node.domain.grid else 0)) - dgrp.create_dataset("Size", data=np.asarray(node.domain.size)) - dgrp.create_dataset("Mapping", data=node.domain.mapping) - if node.domain.coordinates is not None: - dgrp.attrs["Coordinates"] = node.domain.coordinates - - if node.inverse is not None: - g.create_dataset("Inverse", data=node.inverse) - if node.jacobian is not None: - g.create_dataset("Jacobian", data=node.jacobian) - if node.additional_parameters is not None: - g.create_dataset( - "AdditionalParameters", data=node.additional_parameters - ) + _write_x5_group(g, node) return fname @@ -188,3 +160,30 @@ def _read_x5_group(node) -> X5Transform: ) return x5 + + +def _write_x5_group(g, node: X5Transform): + """Write one :class:`X5Transform` element into an opened HDF5 group.""" + g.attrs["Type"] = node.type + g.attrs["ArrayLength"] = node.array_length + if node.subtype is not None: + g.attrs["SubType"] = node.subtype + if node.representation is not None: + g.attrs["Representation"] = node.representation + if node.metadata is not None: + g.attrs["Metadata"] = json.dumps(node.metadata) + g.create_dataset("Transform", data=node.transform) + g.create_dataset("DimensionKinds", data=np.asarray(node.dimension_kinds, dtype="S")) + if node.domain is not None: + dgrp = g.create_group("Domain") + dgrp.create_dataset("Grid", data=np.uint8(1 if node.domain.grid else 0)) + dgrp.create_dataset("Size", data=np.asarray(node.domain.size)) + dgrp.create_dataset("Mapping", data=node.domain.mapping) + if node.domain.coordinates is not None: + dgrp.attrs["Coordinates"] = node.domain.coordinates + if node.inverse is not None: + g.create_dataset("Inverse", data=node.inverse) + if node.jacobian is not None: + g.create_dataset("Jacobian", data=node.jacobian) + if node.additional_parameters is not None: + g.create_dataset("AdditionalParameters", data=node.additional_parameters) diff --git a/nitransforms/manip.py b/nitransforms/manip.py index 9389197d..f2b85579 100644 --- a/nitransforms/manip.py +++ b/nitransforms/manip.py @@ -14,7 +14,7 @@ TransformBase, TransformError, ) -from .linear import Affine +from .linear import Affine, LinearTransformsMapping from .nonlinear import DenseFieldTransform @@ -190,12 +190,15 @@ def asaffine(self, indices=None): return retval @classmethod - def from_filename(cls, filename, fmt="X5", reference=None, moving=None): + def from_filename(cls, filename, fmt="X5", reference=None, moving=None, x5_chain=0): """Load a transform file.""" - from .io import itk + from .io import itk, x5 as x5io + import h5py + import nibabel as nb + from collections import namedtuple retval = [] - if str(filename).endswith(".h5"): + if str(filename).endswith(".h5") and (fmt is None or fmt.upper() != "X5"): reference = None xforms = itk.ITKCompositeH5.from_filename(filename) for xfmobj in xforms: @@ -206,8 +209,120 @@ def from_filename(cls, filename, fmt="X5", reference=None, moving=None): return TransformChain(retval) + if fmt and fmt.upper() == "X5": + with h5py.File(str(filename), "r") as f: + if f.attrs.get("Format") != "X5": + raise TypeError("Input file is not in X5 format") + + tg = [ + x5io._read_x5_group(node) + for _, node in sorted(f["TransformGroup"].items(), key=lambda kv: int(kv[0])) + ] + chain_grp = f.get("TransformChain") + if chain_grp is None: + raise TransformError("X5 file contains no TransformChain") + + chain_path = chain_grp[str(x5_chain)][()] + if isinstance(chain_path, bytes): + chain_path = chain_path.decode() + indices = [int(idx) for idx in chain_path.split("/") if idx] + + Domain = namedtuple("Domain", "affine shape") + for idx in indices: + node = tg[idx] + if node.type == "linear": + Transform = Affine if node.array_length == 1 else LinearTransformsMapping + reference = None + if node.domain is not None: + reference = Domain(node.domain.mapping, node.domain.size) + retval.append(Transform(node.transform, reference=reference)) + elif node.type == "nonlinear": + reference = Domain(node.domain.mapping, node.domain.size) + field = nb.Nifti1Image(node.transform, reference.affine) + retval.append( + DenseFieldTransform( + field, + is_deltas=node.representation == "displacements", + reference=reference, + ) + ) + else: # pragma: no cover - unsupported type + raise NotImplementedError(f"Unsupported transform type {node.type}") + + return TransformChain(retval) + raise NotImplementedError + def to_filename(self, filename, fmt="X5"): + """Store the transform chain in X5 format.""" + from .io import x5 as x5io + import os + import h5py + + if fmt.upper() != "X5": + raise NotImplementedError("Only X5 format is supported for chains") + + if os.path.exists(filename): + with h5py.File(str(filename), "r") as f: + existing = [ + x5io._read_x5_group(node) + for _, node in sorted(f["TransformGroup"].items(), key=lambda kv: int(kv[0])) + ] + else: + existing = [] + + # convert to objects for equality check + from collections import namedtuple + import nibabel as nb + + def _as_transform(x5node): + Domain = namedtuple("Domain", "affine shape") + if x5node.type == "linear": + Transform = Affine if x5node.array_length == 1 else LinearTransformsMapping + ref = None + if x5node.domain is not None: + ref = Domain(x5node.domain.mapping, x5node.domain.size) + return Transform(x5node.transform, reference=ref) + reference = Domain(x5node.domain.mapping, x5node.domain.size) + field = nb.Nifti1Image(x5node.transform, reference.affine) + return DenseFieldTransform( + field, + is_deltas=x5node.representation == "displacements", + reference=reference, + ) + + existing_objs = [_as_transform(n) for n in existing] + path_indices = [] + new_nodes = [] + for xfm in self.transforms: + # find existing + idx = None + for i, obj in enumerate(existing_objs): + if type(xfm) is type(obj) and xfm == obj: + idx = i + break + if idx is None: + idx = len(existing_objs) + new_nodes.append((idx, xfm.to_x5())) + existing_objs.append(xfm) + path_indices.append(idx) + + mode = "r+" if os.path.exists(filename) else "w" + with h5py.File(str(filename), mode) as f: + if "Format" not in f.attrs: + f.attrs["Format"] = "X5" + f.attrs["Version"] = np.uint16(1) + + tg = f.require_group("TransformGroup") + for idx, node in new_nodes: + g = tg.create_group(str(idx)) + x5io._write_x5_group(g, node) + + cg = f.require_group("TransformChain") + cg.create_dataset(str(len(cg)), data="/".join(str(i) for i in path_indices)) + + return filename + def _as_chain(x): """Convert a value into a transform chain.""" diff --git a/nitransforms/tests/test_manip.py b/nitransforms/tests/test_manip.py index b5dd5c62..3fe7b7e9 100644 --- a/nitransforms/tests/test_manip.py +++ b/nitransforms/tests/test_manip.py @@ -5,8 +5,11 @@ import pytest import numpy as np +import nibabel as nb +import h5py from ..manip import TransformChain from ..linear import Affine +from ..nonlinear import DenseFieldTransform FMT = {"lta": "fs", "tfm": "itk"} @@ -37,3 +40,34 @@ def test_collapse_affines(tmp_path, data_path, ext0, ext1, ext2): fmt=f"{FMT[ext2]}", ).matrix, ) + + +def test_transformchain_x5_roundtrip(tmp_path): + """Round-trip TransformChain with X5 storage.""" + + mat = np.eye(4) + mat[0, 3] = 1 + aff = Affine(mat) + + field = nb.Nifti1Image(np.zeros((5, 5, 5, 3), dtype="float32"), np.eye(4)) + fdata = field.get_fdata() + fdata[..., 1] = 1 + field = nb.Nifti1Image(fdata, np.eye(4)) + dfield = DenseFieldTransform(field, is_deltas=True) + + chain = TransformChain([aff, aff, aff, dfield]) + + fname = tmp_path / "chain.x5" + chain.to_filename(fname) + chain.to_filename(fname) # append again, should not duplicate transforms + + with h5py.File(fname) as f: + assert len(f["TransformGroup"]) == 2 + assert len(f["TransformChain"]) == 2 + + loaded0 = TransformChain.from_filename(fname, fmt="X5", x5_chain=0) + loaded1 = TransformChain.from_filename(fname, fmt="X5", x5_chain=1) + + assert len(loaded0) == len(chain) + assert len(loaded1) == len(chain) + assert np.allclose(chain.map([[0, 0, 0]]), loaded1.map([[0, 0, 0]])) From 23a425d1b8e45a039a98f91bca4b07e9c561f02a Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Sun, 20 Jul 2025 13:48:44 +0200 Subject: [PATCH 2/5] fix chain loading fallback --- nitransforms/manip.py | 84 +++++++++++++++++++++---------------------- 1 file changed, 41 insertions(+), 43 deletions(-) diff --git a/nitransforms/manip.py b/nitransforms/manip.py index f2b85579..d816182e 100644 --- a/nitransforms/manip.py +++ b/nitransforms/manip.py @@ -198,7 +198,47 @@ def from_filename(cls, filename, fmt="X5", reference=None, moving=None, x5_chain from collections import namedtuple retval = [] - if str(filename).endswith(".h5") and (fmt is None or fmt.upper() != "X5"): + if fmt and fmt.upper() == "X5": + with h5py.File(str(filename), "r") as f: + if f.attrs.get("Format") == "X5": + tg = [ + x5io._read_x5_group(node) + for _, node in sorted(f["TransformGroup"].items(), key=lambda kv: int(kv[0])) + ] + chain_grp = f.get("TransformChain") + if chain_grp is None: + raise TransformError("X5 file contains no TransformChain") + + chain_path = chain_grp[str(x5_chain)][()] + if isinstance(chain_path, bytes): + chain_path = chain_path.decode() + indices = [int(idx) for idx in chain_path.split("/") if idx] + + Domain = namedtuple("Domain", "affine shape") + for idx in indices: + node = tg[idx] + if node.type == "linear": + Transform = Affine if node.array_length == 1 else LinearTransformsMapping + reference = None + if node.domain is not None: + reference = Domain(node.domain.mapping, node.domain.size) + retval.append(Transform(node.transform, reference=reference)) + elif node.type == "nonlinear": + reference = Domain(node.domain.mapping, node.domain.size) + field = nb.Nifti1Image(node.transform, reference.affine) + retval.append( + DenseFieldTransform( + field, + is_deltas=node.representation == "displacements", + reference=reference, + ) + ) + else: # pragma: no cover - unsupported type + raise NotImplementedError(f"Unsupported transform type {node.type}") + + return TransformChain(retval) + + if str(filename).endswith(".h5"): reference = None xforms = itk.ITKCompositeH5.from_filename(filename) for xfmobj in xforms: @@ -209,48 +249,6 @@ def from_filename(cls, filename, fmt="X5", reference=None, moving=None, x5_chain return TransformChain(retval) - if fmt and fmt.upper() == "X5": - with h5py.File(str(filename), "r") as f: - if f.attrs.get("Format") != "X5": - raise TypeError("Input file is not in X5 format") - - tg = [ - x5io._read_x5_group(node) - for _, node in sorted(f["TransformGroup"].items(), key=lambda kv: int(kv[0])) - ] - chain_grp = f.get("TransformChain") - if chain_grp is None: - raise TransformError("X5 file contains no TransformChain") - - chain_path = chain_grp[str(x5_chain)][()] - if isinstance(chain_path, bytes): - chain_path = chain_path.decode() - indices = [int(idx) for idx in chain_path.split("/") if idx] - - Domain = namedtuple("Domain", "affine shape") - for idx in indices: - node = tg[idx] - if node.type == "linear": - Transform = Affine if node.array_length == 1 else LinearTransformsMapping - reference = None - if node.domain is not None: - reference = Domain(node.domain.mapping, node.domain.size) - retval.append(Transform(node.transform, reference=reference)) - elif node.type == "nonlinear": - reference = Domain(node.domain.mapping, node.domain.size) - field = nb.Nifti1Image(node.transform, reference.affine) - retval.append( - DenseFieldTransform( - field, - is_deltas=node.representation == "displacements", - reference=reference, - ) - ) - else: # pragma: no cover - unsupported type - raise NotImplementedError(f"Unsupported transform type {node.type}") - - return TransformChain(retval) - raise NotImplementedError def to_filename(self, filename, fmt="X5"): From 6bc2bc163173992a4abe343551bf5b238da1b838 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Sun, 20 Jul 2025 17:18:39 +0200 Subject: [PATCH 3/5] fix: general re-design of the PR --- nitransforms/io/x5.py | 2 +- nitransforms/linear.py | 36 +++++----- nitransforms/manip.py | 137 +++++++++++++------------------------- nitransforms/nonlinear.py | 47 ++++++++----- 4 files changed, 97 insertions(+), 125 deletions(-) diff --git a/nitransforms/io/x5.py b/nitransforms/io/x5.py index f167a35e..094bf114 100644 --- a/nitransforms/io/x5.py +++ b/nitransforms/io/x5.py @@ -26,7 +26,7 @@ import numpy as np -@dataclass +@dataclass(eq=True) class X5Domain: """Domain information of a transform representing reference/moving spaces.""" diff --git a/nitransforms/linear.py b/nitransforms/linear.py index 26bf3374..0fb9f507 100644 --- a/nitransforms/linear.py +++ b/nitransforms/linear.py @@ -186,22 +186,9 @@ def from_filename( """Create an affine from a transform file.""" if fmt and fmt.upper() == "X5": - x5_xfm = load_x5(filename)[x5_position] - Transform = cls if x5_xfm.array_length == 1 else LinearTransformsMapping - if ( - x5_xfm.domain - and not x5_xfm.domain.grid - and len(x5_xfm.domain.size) == 3 - ): # pragma: no cover - raise NotImplementedError( - "Only 3D regularly gridded domains are supported" - ) - elif x5_xfm.domain: - # Override reference - Domain = namedtuple("Domain", "affine shape") - reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size) - - return Transform(x5_xfm.transform, reference=reference) + return from_x5( + load_x5(filename), reference=reference, x5_position=x5_position + ) fmtlist = [fmt] if fmt is not None else ("itk", "lta", "afni", "fsl") @@ -458,3 +445,20 @@ def load(filename, fmt=None, reference=None, moving=None): xfm = xfm[0] return xfm + + +def from_x5(x5_list, reference=None, x5_position=0): + """Create an affine from a list of :class:`~nitransforms.io.x5.X5Transform` objects.""" + + x5_xfm = x5_list[x5_position] + Transform = Affine if x5_xfm.array_length == 1 else LinearTransformsMapping + if ( + x5_xfm.domain and not x5_xfm.domain.grid and len(x5_xfm.domain.size) == 3 + ): # pragma: no cover + raise NotImplementedError("Only 3D regularly gridded domains are supported") + elif x5_xfm.domain: + # Override reference + Domain = namedtuple("Domain", "affine shape") + reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size) + + return Transform(x5_xfm.transform, reference=reference) diff --git a/nitransforms/manip.py b/nitransforms/manip.py index d816182e..ee15317c 100644 --- a/nitransforms/manip.py +++ b/nitransforms/manip.py @@ -7,15 +7,26 @@ # ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## """Common interface for transforms.""" + +import os from collections.abc import Iterable import numpy as np -from .base import ( +import h5py +from nitransforms.base import ( TransformBase, TransformError, ) -from .linear import Affine, LinearTransformsMapping -from .nonlinear import DenseFieldTransform +from nitransforms.io import itk, x5 as x5io +from nitransforms.io.x5 import from_filename as load_x5 +from nitransforms.linear import ( + Affine, + from_x5 as linear_from_x5, # noqa: F401 +) +from nitransforms.nonlinear import ( + DenseFieldTransform, + from_x5 as nonlinear_from_x5, # noqa: F401 +) class TransformChain(TransformBase): @@ -183,7 +194,9 @@ def asaffine(self, indices=None): The indices of the values to extract. """ - affines = self.transforms if indices is None else np.take(self.transforms, indices) + affines = ( + self.transforms if indices is None else np.take(self.transforms, indices) + ) retval = affines[0] for xfm in affines[1:]: retval = xfm @ retval @@ -192,51 +205,28 @@ def asaffine(self, indices=None): @classmethod def from_filename(cls, filename, fmt="X5", reference=None, moving=None, x5_chain=0): """Load a transform file.""" - from .io import itk, x5 as x5io - import h5py - import nibabel as nb - from collections import namedtuple retval = [] if fmt and fmt.upper() == "X5": + xfm_list = load_x5(filename) + if not xfm_list: + raise TransformError("Empty transform group") + with h5py.File(str(filename), "r") as f: - if f.attrs.get("Format") == "X5": - tg = [ - x5io._read_x5_group(node) - for _, node in sorted(f["TransformGroup"].items(), key=lambda kv: int(kv[0])) - ] - chain_grp = f.get("TransformChain") - if chain_grp is None: - raise TransformError("X5 file contains no TransformChain") - - chain_path = chain_grp[str(x5_chain)][()] - if isinstance(chain_path, bytes): - chain_path = chain_path.decode() - indices = [int(idx) for idx in chain_path.split("/") if idx] - - Domain = namedtuple("Domain", "affine shape") - for idx in indices: - node = tg[idx] - if node.type == "linear": - Transform = Affine if node.array_length == 1 else LinearTransformsMapping - reference = None - if node.domain is not None: - reference = Domain(node.domain.mapping, node.domain.size) - retval.append(Transform(node.transform, reference=reference)) - elif node.type == "nonlinear": - reference = Domain(node.domain.mapping, node.domain.size) - field = nb.Nifti1Image(node.transform, reference.affine) - retval.append( - DenseFieldTransform( - field, - is_deltas=node.representation == "displacements", - reference=reference, - ) - ) - else: # pragma: no cover - unsupported type - raise NotImplementedError(f"Unsupported transform type {node.type}") - - return TransformChain(retval) + chain_grp = f.get("TransformChain") + if chain_grp is None: + raise TransformError("X5 file contains no TransformChain") + + chain_path = chain_grp[str(x5_chain)][()] + if isinstance(chain_path, bytes): + chain_path = chain_path.decode() + + for idx in chain_path.split("/"): + node = x5io._read_x5_group(xfm_list[int(idx)]) + from_x5 = globals()[f"{node.type}_from_x5"] + retval.append(from_x5([node])) + + return TransformChain(retval) if str(filename).endswith(".h5"): reference = None @@ -253,57 +243,24 @@ def from_filename(cls, filename, fmt="X5", reference=None, moving=None, x5_chain def to_filename(self, filename, fmt="X5"): """Store the transform chain in X5 format.""" - from .io import x5 as x5io - import os - import h5py if fmt.upper() != "X5": raise NotImplementedError("Only X5 format is supported for chains") - if os.path.exists(filename): - with h5py.File(str(filename), "r") as f: - existing = [ - x5io._read_x5_group(node) - for _, node in sorted(f["TransformGroup"].items(), key=lambda kv: int(kv[0])) - ] - else: - existing = [] - - # convert to objects for equality check - from collections import namedtuple - import nibabel as nb - - def _as_transform(x5node): - Domain = namedtuple("Domain", "affine shape") - if x5node.type == "linear": - Transform = Affine if x5node.array_length == 1 else LinearTransformsMapping - ref = None - if x5node.domain is not None: - ref = Domain(x5node.domain.mapping, x5node.domain.size) - return Transform(x5node.transform, reference=ref) - reference = Domain(x5node.domain.mapping, x5node.domain.size) - field = nb.Nifti1Image(x5node.transform, reference.affine) - return DenseFieldTransform( - field, - is_deltas=x5node.representation == "displacements", - reference=reference, - ) - - existing_objs = [_as_transform(n) for n in existing] - path_indices = [] + existing = load_x5(filename) if os.path.exists(filename) else [] + xfm_chain = [] new_nodes = [] + next_xfm_index = len(existing) for xfm in self.transforms: - # find existing - idx = None - for i, obj in enumerate(existing_objs): - if type(xfm) is type(obj) and xfm == obj: - idx = i + for eidx, existing_xfm in enumerate(existing): + if xfm == existing_xfm: + xfm_chain.append(eidx) break - if idx is None: - idx = len(existing_objs) - new_nodes.append((idx, xfm.to_x5())) - existing_objs.append(xfm) - path_indices.append(idx) + else: + xfm_chain.append(next_xfm_index) + new_nodes.append((next_xfm_index, xfm)) + existing.append(xfm) + next_xfm_index += 1 mode = "r+" if os.path.exists(filename) else "w" with h5py.File(str(filename), mode) as f: @@ -317,7 +274,7 @@ def _as_transform(x5node): x5io._write_x5_group(g, node) cg = f.require_group("TransformChain") - cg.create_dataset(str(len(cg)), data="/".join(str(i) for i in path_indices)) + cg.create_dataset(str(len(cg)), data="/".join(str(i) for i in xfm_chain)) return filename diff --git a/nitransforms/nonlinear.py b/nitransforms/nonlinear.py index 26dbe40f..9c6aadb5 100644 --- a/nitransforms/nonlinear.py +++ b/nitransforms/nonlinear.py @@ -278,7 +278,7 @@ def to_x5(self, metadata=None): ) @classmethod - def from_filename(cls, filename, fmt="X5"): + def from_filename(cls, filename, fmt="X5", x5_position=0): _factory = { "afni": io.afni.AFNIDisplacementsField, "itk": io.itk.ITKDisplacementsField, @@ -290,15 +290,7 @@ def from_filename(cls, filename, fmt="X5"): raise NotImplementedError(f"Unsupported format <{fmt}>") if fmt == "X5": - x5_xfm = load_x5(filename)[0] - Domain = namedtuple("Domain", "affine shape") - reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size) - field = nb.Nifti1Image(x5_xfm.transform, reference.affine) - return cls( - field, - is_deltas=x5_xfm.representation == "displacements", - reference=reference, - ) + return from_x5(load_x5(filename), x5_position=x5_position) return cls(_factory[fmt.lower()].from_filename(filename)) @@ -336,7 +328,7 @@ def ndim(self): return self._coeffs.ndim - 1 @classmethod - def from_filename(cls, filename, fmt="X5"): + def from_filename(cls, filename, fmt="X5", x5_position=0): _factory = { "X5": None, } @@ -344,13 +336,7 @@ def from_filename(cls, filename, fmt="X5"): if fmt not in {k.upper() for k in _factory}: raise NotImplementedError(f"Unsupported format <{fmt}>") - x5_xfm = load_x5(filename)[0] - Domain = namedtuple("Domain", "affine shape") - reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size) - - coefficients = nb.Nifti1Image(x5_xfm.transform, x5_xfm.additional_parameters) - return cls(coefficients, reference=reference) - + return from_x5(load_x5(filename), x5_position=x5_position) # return cls(_factory[fmt.lower()].from_filename(filename)) def to_field(self, reference=None, dtype="float32"): @@ -440,6 +426,31 @@ def map(self, x, inverse=False): return np.array([vfunc(_x).tolist() for _x in np.atleast_2d(x)]) +def from_x5(x5_list, x5_position=0): + """Create a transform from a list of :class:`~nitransforms.io.x5.X5Transform` objects.""" + + x5_xfm = x5_list[x5_position] + + Transform = ( + BSplineFieldTransform if x5_xfm.subtype == "bspline" else DenseFieldTransform + ) + Domain = namedtuple("Domain", "affine shape") + reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size) + xfm_params = ( + nb.Nifti1Image(x5_xfm.transform, x5_xfm.additional_parameters) + if x5_xfm.subtype == "bspline" + else x5_xfm.transform + ) + + xfm_kwargs = ( + {} + if x5_xfm.subtype == "bspline" + else {"is_deltas": x5_xfm.representation == "displacements"} + ) + + return Transform(xfm_params, reference=reference, **xfm_kwargs) + + def _map_xyz(x, reference, knots, coeffs): """Apply the transformation to just one coordinate.""" ndim = len(x) From e550a7d0c28bc76b1c0a6bf228f8a30a2346b244 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Sun, 20 Jul 2025 22:02:44 +0200 Subject: [PATCH 4/5] fix: improve loading and equalities --- nitransforms/linear.py | 3 +++ nitransforms/manip.py | 30 ++++++++++++++++----------- nitransforms/nonlinear.py | 25 ++++++++++++++++++++++ nitransforms/tests/test_manip.py | 5 ++++- nitransforms/tests/test_resampling.py | 2 +- 5 files changed, 51 insertions(+), 14 deletions(-) diff --git a/nitransforms/linear.py b/nitransforms/linear.py index 0fb9f507..22a13f07 100644 --- a/nitransforms/linear.py +++ b/nitransforms/linear.py @@ -122,6 +122,9 @@ def __eq__(self, other): True """ + if not hasattr(other, "matrix"): + return False + _eq = np.allclose(self.matrix, other.matrix, rtol=EQUALITY_TOL) if _eq and self._reference != other._reference: warnings.warn("Affines are equal, but references do not match.") diff --git a/nitransforms/manip.py b/nitransforms/manip.py index ee15317c..fdcd5305 100644 --- a/nitransforms/manip.py +++ b/nitransforms/manip.py @@ -208,10 +208,16 @@ def from_filename(cls, filename, fmt="X5", reference=None, moving=None, x5_chain retval = [] if fmt and fmt.upper() == "X5": - xfm_list = load_x5(filename) + # Get list of X5 nodes and generate transforms + xfm_list = [ + globals()[f"{node.type}_from_x5"]([node]) for node in load_x5(filename) + ] if not xfm_list: raise TransformError("Empty transform group") + if x5_chain is None: + return xfm_list + with h5py.File(str(filename), "r") as f: chain_grp = f.get("TransformChain") if chain_grp is None: @@ -221,12 +227,7 @@ def from_filename(cls, filename, fmt="X5", reference=None, moving=None, x5_chain if isinstance(chain_path, bytes): chain_path = chain_path.decode() - for idx in chain_path.split("/"): - node = x5io._read_x5_group(xfm_list[int(idx)]) - from_x5 = globals()[f"{node.type}_from_x5"] - retval.append(from_x5([node])) - - return TransformChain(retval) + return TransformChain([xfm_list[int(idx)] for idx in chain_path.split("/")]) if str(filename).endswith(".h5"): reference = None @@ -247,9 +248,14 @@ def to_filename(self, filename, fmt="X5"): if fmt.upper() != "X5": raise NotImplementedError("Only X5 format is supported for chains") - existing = load_x5(filename) if os.path.exists(filename) else [] + existing = ( + self.from_filename(filename, x5_chain=None) + if os.path.exists(filename) + else [] + ) + xfm_chain = [] - new_nodes = [] + new_xfms = [] next_xfm_index = len(existing) for xfm in self.transforms: for eidx, existing_xfm in enumerate(existing): @@ -258,7 +264,7 @@ def to_filename(self, filename, fmt="X5"): break else: xfm_chain.append(next_xfm_index) - new_nodes.append((next_xfm_index, xfm)) + new_xfms.append((next_xfm_index, xfm)) existing.append(xfm) next_xfm_index += 1 @@ -269,9 +275,9 @@ def to_filename(self, filename, fmt="X5"): f.attrs["Version"] = np.uint16(1) tg = f.require_group("TransformGroup") - for idx, node in new_nodes: + for idx, node in new_xfms: g = tg.create_group(str(idx)) - x5io._write_x5_group(g, node) + x5io._write_x5_group(g, node.to_x5()) cg = f.require_group("TransformChain") cg.create_dataset(str(len(cg)), data="/".join(str(i) for i in xfm_chain)) diff --git a/nitransforms/nonlinear.py b/nitransforms/nonlinear.py index 9c6aadb5..e3ac63ff 100644 --- a/nitransforms/nonlinear.py +++ b/nitransforms/nonlinear.py @@ -247,6 +247,9 @@ def __eq__(self, other): True """ + if not hasattr(other, "_field"): + return False + _eq = np.allclose(self._field, other._field) if _eq and self._reference != other._reference: warnings.warn("Fields are equal, but references do not match.") @@ -322,6 +325,28 @@ def __init__(self, coefficients, reference=None, order=3): "not match the number of dimensions" ) + def __eq__(self, other): + """ + Overload equals operator. + + Examples + -------- + >>> xfm1 = BSplineFieldTransform(test_dir / "someones_bspline_coefficients.nii.gz") + >>> xfm2 = BSplineFieldTransform(test_dir / "someones_bspline_coefficients.nii.gz") + >>> xfm1 == xfm2 + True + + """ + if not hasattr(other, "_coeffs"): + return False + + _eq = np.allclose(self._coeffs, other._coeffs) + _eq = _eq and self._order == other._order + + if _eq and self._reference != other._reference: + warnings.warn("Coefficients are equal, but references do not match.") + return _eq + @property def ndim(self): """Get the dimensions of the transform.""" diff --git a/nitransforms/tests/test_manip.py b/nitransforms/tests/test_manip.py index 3fe7b7e9..6e9d132f 100644 --- a/nitransforms/tests/test_manip.py +++ b/nitransforms/tests/test_manip.py @@ -59,11 +59,14 @@ def test_transformchain_x5_roundtrip(tmp_path): fname = tmp_path / "chain.x5" chain.to_filename(fname) + + with h5py.File(fname) as f: + assert len(f["TransformGroup"]) == 2 + chain.to_filename(fname) # append again, should not duplicate transforms with h5py.File(fname) as f: assert len(f["TransformGroup"]) == 2 - assert len(f["TransformChain"]) == 2 loaded0 = TransformChain.from_filename(fname, fmt="X5", x5_chain=0) loaded1 = TransformChain.from_filename(fname, fmt="X5", x5_chain=1) diff --git a/nitransforms/tests/test_resampling.py b/nitransforms/tests/test_resampling.py index 0e11df5b..b65bf579 100644 --- a/nitransforms/tests/test_resampling.py +++ b/nitransforms/tests/test_resampling.py @@ -284,7 +284,7 @@ def test_apply_transformchain(tmp_path, testdata_path): / "ds-005_sub-01_from-T1w_to-MNI152NLin2009cAsym_mode-image_xfm.h5" ) - xfm = nitm.load(xfm_fname) + xfm = nitm.load(xfm_fname, fmt="itk") assert len(xfm) == 2 From 73f3db11fac98da9f7ed1e19c06a39d1ca4af4c3 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Mon, 21 Jul 2025 00:06:34 +0200 Subject: [PATCH 5/5] enh: reduce uncovered lines --- nitransforms/linear.py | 4 ++++ nitransforms/manip.py | 13 +++++++------ nitransforms/nonlinear.py | 25 +++++++++++++++++++++---- nitransforms/tests/test_manip.py | 14 +++++++++++++- 4 files changed, 45 insertions(+), 11 deletions(-) diff --git a/nitransforms/linear.py b/nitransforms/linear.py index 22a13f07..d63c7d4f 100644 --- a/nitransforms/linear.py +++ b/nitransforms/linear.py @@ -120,6 +120,10 @@ def __eq__(self, other): >>> xfm2 = Affine(xfm1.matrix) >>> xfm1 == xfm2 True + >>> xfm1 == Affine() + False + >>> xfm1 == TransformBase() + False """ if not hasattr(other, "matrix"): diff --git a/nitransforms/manip.py b/nitransforms/manip.py index fdcd5305..fe982a58 100644 --- a/nitransforms/manip.py +++ b/nitransforms/manip.py @@ -19,13 +19,13 @@ ) from nitransforms.io import itk, x5 as x5io from nitransforms.io.x5 import from_filename as load_x5 -from nitransforms.linear import ( +from nitransforms.linear import ( # noqa: F401 Affine, - from_x5 as linear_from_x5, # noqa: F401 + from_x5 as linear_from_x5, ) -from nitransforms.nonlinear import ( +from nitransforms.nonlinear import ( # noqa: F401 DenseFieldTransform, - from_x5 as nonlinear_from_x5, # noqa: F401 + from_x5 as nonlinear_from_x5, ) @@ -224,8 +224,9 @@ def from_filename(cls, filename, fmt="X5", reference=None, moving=None, x5_chain raise TransformError("X5 file contains no TransformChain") chain_path = chain_grp[str(x5_chain)][()] - if isinstance(chain_path, bytes): - chain_path = chain_path.decode() + chain_path = ( + chain_path.decode() if isinstance(chain_path, bytes) else chain_path + ) return TransformChain([xfm_list[int(idx)] for idx in chain_path.split("/")]) diff --git a/nitransforms/nonlinear.py b/nitransforms/nonlinear.py index e3ac63ff..0869b9af 100644 --- a/nitransforms/nonlinear.py +++ b/nitransforms/nonlinear.py @@ -245,9 +245,13 @@ def __eq__(self, other): >>> xfm2 = DenseFieldTransform(test_dir / "someones_displacement_field.nii.gz") >>> xfm1 == xfm2 True + >>> xfm1 == TransformBase() + False + >>> xfm1 == BSplineFieldTransform(test_dir / "someones_bspline_coefficients.nii.gz") + False """ - if not hasattr(other, "_field"): + if not hasattr(other, "_field") or self._field.shape != other._field.shape: return False _eq = np.allclose(self._field, other._field) @@ -335,13 +339,26 @@ def __eq__(self, other): >>> xfm2 = BSplineFieldTransform(test_dir / "someones_bspline_coefficients.nii.gz") >>> xfm1 == xfm2 True + >>> xfm2._coeffs[:, :, :] = 0 # Let's zero all coefficients + >>> xfm1 == xfm2 + False + >>> xfm2 = BSplineFieldTransform( + ... test_dir / "someones_bspline_coefficients.nii.gz", + ... order=4, + ... ) + >>> xfm1 == xfm2 + False + >>> xfm1 == TransformBase() + False + >>> xfm1 == DenseFieldTransform(test_dir / "someones_displacement_field.nii.gz") + False """ - if not hasattr(other, "_coeffs"): + if not hasattr(other, "_coeffs") or self._coeffs.shape != other._coeffs.shape: return False - _eq = np.allclose(self._coeffs, other._coeffs) - _eq = _eq and self._order == other._order + _eq = self._order == other._order + _eq = _eq and np.allclose(self._coeffs, other._coeffs) if _eq and self._reference != other._reference: warnings.warn("Coefficients are equal, but references do not match.") diff --git a/nitransforms/tests/test_manip.py b/nitransforms/tests/test_manip.py index 6e9d132f..4d7b3a5b 100644 --- a/nitransforms/tests/test_manip.py +++ b/nitransforms/tests/test_manip.py @@ -7,9 +7,11 @@ import numpy as np import nibabel as nb import h5py +from ..base import TransformError from ..manip import TransformChain from ..linear import Affine from ..nonlinear import DenseFieldTransform +from ..io import x5 FMT = {"lta": "fs", "tfm": "itk"} @@ -45,18 +47,28 @@ def test_collapse_affines(tmp_path, data_path, ext0, ext1, ext2): def test_transformchain_x5_roundtrip(tmp_path): """Round-trip TransformChain with X5 storage.""" + # Test empty transform file + x5.to_filename(tmp_path / "empty.x5", []) + with pytest.raises(TransformError): + TransformChain.from_filename(tmp_path / "empty.x5") + mat = np.eye(4) mat[0, 3] = 1 aff = Affine(mat) + # Test loading X5 with no transforms chains + x5.to_filename(tmp_path / "nochain.x5", [aff.to_x5()]) + with pytest.raises(TransformError): + TransformChain.from_filename(tmp_path / "nochain.x5") + field = nb.Nifti1Image(np.zeros((5, 5, 5, 3), dtype="float32"), np.eye(4)) fdata = field.get_fdata() fdata[..., 1] = 1 field = nb.Nifti1Image(fdata, np.eye(4)) dfield = DenseFieldTransform(field, is_deltas=True) + # Create a chain chain = TransformChain([aff, aff, aff, dfield]) - fname = tmp_path / "chain.x5" chain.to_filename(fname)