Skip to content

ENH: X5 read/write support of TransformChain #253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 29 additions & 30 deletions nitransforms/io/x5.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import numpy as np


@dataclass
@dataclass(eq=True)
class X5Domain:
"""Domain information of a transform representing reference/moving spaces."""

Expand Down Expand Up @@ -105,35 +105,7 @@ def to_filename(fname: str | Path, x5_list: List[X5Transform]):
tg = out_file.create_group("TransformGroup")
for i, node in enumerate(x5_list):
g = tg.create_group(str(i))
g.attrs["Type"] = node.type
g.attrs["ArrayLength"] = node.array_length
if node.subtype is not None:
g.attrs["SubType"] = node.subtype
if node.representation is not None:
g.attrs["Representation"] = node.representation
if node.metadata is not None:
g.attrs["Metadata"] = json.dumps(node.metadata)
g.create_dataset("Transform", data=node.transform)
g.create_dataset(
"DimensionKinds",
data=np.asarray(node.dimension_kinds, dtype="S"),
)
if node.domain is not None:
dgrp = g.create_group("Domain")
dgrp.create_dataset("Grid", data=np.uint8(1 if node.domain.grid else 0))
dgrp.create_dataset("Size", data=np.asarray(node.domain.size))
dgrp.create_dataset("Mapping", data=node.domain.mapping)
if node.domain.coordinates is not None:
dgrp.attrs["Coordinates"] = node.domain.coordinates

if node.inverse is not None:
g.create_dataset("Inverse", data=node.inverse)
if node.jacobian is not None:
g.create_dataset("Jacobian", data=node.jacobian)
if node.additional_parameters is not None:
g.create_dataset(
"AdditionalParameters", data=node.additional_parameters
)
_write_x5_group(g, node)
return fname


Expand Down Expand Up @@ -188,3 +160,30 @@ def _read_x5_group(node) -> X5Transform:
)

return x5


def _write_x5_group(g, node: X5Transform):
"""Write one :class:`X5Transform` element into an opened HDF5 group."""
g.attrs["Type"] = node.type
g.attrs["ArrayLength"] = node.array_length
if node.subtype is not None:
g.attrs["SubType"] = node.subtype
if node.representation is not None:
g.attrs["Representation"] = node.representation
if node.metadata is not None:
g.attrs["Metadata"] = json.dumps(node.metadata)
g.create_dataset("Transform", data=node.transform)
g.create_dataset("DimensionKinds", data=np.asarray(node.dimension_kinds, dtype="S"))
if node.domain is not None:
dgrp = g.create_group("Domain")
dgrp.create_dataset("Grid", data=np.uint8(1 if node.domain.grid else 0))
dgrp.create_dataset("Size", data=np.asarray(node.domain.size))
dgrp.create_dataset("Mapping", data=node.domain.mapping)
if node.domain.coordinates is not None:
dgrp.attrs["Coordinates"] = node.domain.coordinates
if node.inverse is not None:
g.create_dataset("Inverse", data=node.inverse)
if node.jacobian is not None:
g.create_dataset("Jacobian", data=node.jacobian)
if node.additional_parameters is not None:
g.create_dataset("AdditionalParameters", data=node.additional_parameters)
43 changes: 27 additions & 16 deletions nitransforms/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,15 @@ def __eq__(self, other):
>>> xfm2 = Affine(xfm1.matrix)
>>> xfm1 == xfm2
True
>>> xfm1 == Affine()
False
>>> xfm1 == TransformBase()
False

"""
if not hasattr(other, "matrix"):
return False

_eq = np.allclose(self.matrix, other.matrix, rtol=EQUALITY_TOL)
if _eq and self._reference != other._reference:
warnings.warn("Affines are equal, but references do not match.")
Expand Down Expand Up @@ -186,22 +193,9 @@ def from_filename(
"""Create an affine from a transform file."""

if fmt and fmt.upper() == "X5":
x5_xfm = load_x5(filename)[x5_position]
Transform = cls if x5_xfm.array_length == 1 else LinearTransformsMapping
if (
x5_xfm.domain
and not x5_xfm.domain.grid
and len(x5_xfm.domain.size) == 3
): # pragma: no cover
raise NotImplementedError(
"Only 3D regularly gridded domains are supported"
)
elif x5_xfm.domain:
# Override reference
Domain = namedtuple("Domain", "affine shape")
reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size)

return Transform(x5_xfm.transform, reference=reference)
return from_x5(
load_x5(filename), reference=reference, x5_position=x5_position
)

fmtlist = [fmt] if fmt is not None else ("itk", "lta", "afni", "fsl")

Expand Down Expand Up @@ -458,3 +452,20 @@ def load(filename, fmt=None, reference=None, moving=None):
xfm = xfm[0]

return xfm


def from_x5(x5_list, reference=None, x5_position=0):
"""Create an affine from a list of :class:`~nitransforms.io.x5.X5Transform` objects."""

x5_xfm = x5_list[x5_position]
Transform = Affine if x5_xfm.array_length == 1 else LinearTransformsMapping
if (
x5_xfm.domain and not x5_xfm.domain.grid and len(x5_xfm.domain.size) == 3
): # pragma: no cover
raise NotImplementedError("Only 3D regularly gridded domains are supported")
elif x5_xfm.domain:
# Override reference
Domain = namedtuple("Domain", "affine shape")
reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size)

return Transform(x5_xfm.transform, reference=reference)
89 changes: 83 additions & 6 deletions nitransforms/manip.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,26 @@
#
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
"""Common interface for transforms."""

import os
from collections.abc import Iterable
import numpy as np

from .base import (
import h5py
from nitransforms.base import (
TransformBase,
TransformError,
)
from .linear import Affine
from .nonlinear import DenseFieldTransform
from nitransforms.io import itk, x5 as x5io
from nitransforms.io.x5 import from_filename as load_x5
from nitransforms.linear import ( # noqa: F401
Affine,
from_x5 as linear_from_x5,
)
from nitransforms.nonlinear import ( # noqa: F401
DenseFieldTransform,
from_x5 as nonlinear_from_x5,
)


class TransformChain(TransformBase):
Expand Down Expand Up @@ -183,18 +194,42 @@ def asaffine(self, indices=None):
The indices of the values to extract.

"""
affines = self.transforms if indices is None else np.take(self.transforms, indices)
affines = (
self.transforms if indices is None else np.take(self.transforms, indices)
)
retval = affines[0]
for xfm in affines[1:]:
retval = xfm @ retval
return retval

@classmethod
def from_filename(cls, filename, fmt="X5", reference=None, moving=None):
def from_filename(cls, filename, fmt="X5", reference=None, moving=None, x5_chain=0):
"""Load a transform file."""
from .io import itk

retval = []
if fmt and fmt.upper() == "X5":
# Get list of X5 nodes and generate transforms
xfm_list = [
globals()[f"{node.type}_from_x5"]([node]) for node in load_x5(filename)
]
if not xfm_list:
raise TransformError("Empty transform group")

if x5_chain is None:
return xfm_list

with h5py.File(str(filename), "r") as f:
chain_grp = f.get("TransformChain")
if chain_grp is None:
raise TransformError("X5 file contains no TransformChain")

chain_path = chain_grp[str(x5_chain)][()]
chain_path = (
chain_path.decode() if isinstance(chain_path, bytes) else chain_path
)

return TransformChain([xfm_list[int(idx)] for idx in chain_path.split("/")])

if str(filename).endswith(".h5"):
reference = None
xforms = itk.ITKCompositeH5.from_filename(filename)
Expand All @@ -208,6 +243,48 @@ def from_filename(cls, filename, fmt="X5", reference=None, moving=None):

raise NotImplementedError

def to_filename(self, filename, fmt="X5"):
"""Store the transform chain in X5 format."""

if fmt.upper() != "X5":
raise NotImplementedError("Only X5 format is supported for chains")

existing = (
self.from_filename(filename, x5_chain=None)
if os.path.exists(filename)
else []
)

xfm_chain = []
new_xfms = []
next_xfm_index = len(existing)
for xfm in self.transforms:
for eidx, existing_xfm in enumerate(existing):
if xfm == existing_xfm:
xfm_chain.append(eidx)
break
else:
xfm_chain.append(next_xfm_index)
new_xfms.append((next_xfm_index, xfm))
existing.append(xfm)
next_xfm_index += 1

mode = "r+" if os.path.exists(filename) else "w"
with h5py.File(str(filename), mode) as f:
if "Format" not in f.attrs:
f.attrs["Format"] = "X5"
f.attrs["Version"] = np.uint16(1)

tg = f.require_group("TransformGroup")
for idx, node in new_xfms:
g = tg.create_group(str(idx))
x5io._write_x5_group(g, node.to_x5())

cg = f.require_group("TransformChain")
cg.create_dataset(str(len(cg)), data="/".join(str(i) for i in xfm_chain))

return filename


def _as_chain(x):
"""Convert a value into a transform chain."""
Expand Down
Loading
Loading