Skip to content

Commit 3e09d11

Browse files
authored
Merge branch 'master' into fix/do-not-overwrite-deltas
2 parents 6ea6ba5 + 2f66df4 commit 3e09d11

File tree

3 files changed

+178
-24
lines changed

3 files changed

+178
-24
lines changed

nitransforms/io/x5.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,11 @@ class X5Transform:
7373
For parametric models it is generally possible to obtain it analytically, so this dataset
7474
could not be as useful in that case.
7575
"""
76-
# additional_parameters: Optional[np.ndarray] = None
77-
# AdditionalParameters is empty in the draft spec - ignore for now.
78-
# Only documentation ATM is for SubType:
79-
# The SubType setting enables setting the additional parameters on a dataset called
80-
# "AdditionalParameters" that hangs directly from this transform node.
76+
additional_parameters: Optional[np.ndarray] = None
77+
"""
78+
An OPTIONAL field to store additional parameters, depending on the SubType of the
79+
transform.
80+
"""
8181
array_length: int = 1
8282
"""Undocumented field in the draft to enable a single transform group for 4D transforms."""
8383

@@ -130,11 +130,10 @@ def to_filename(fname: str | Path, x5_list: List[X5Transform]):
130130
g.create_dataset("Inverse", data=node.inverse)
131131
if node.jacobian is not None:
132132
g.create_dataset("Jacobian", data=node.jacobian)
133-
# Disabled until we need SubType and AdditionalParameters
134-
# if node.additional_parameters is not None:
135-
# g.create_dataset(
136-
# "AdditionalParameters", data=node.additional_parameters
137-
# )
133+
if node.additional_parameters is not None:
134+
g.create_dataset(
135+
"AdditionalParameters", data=node.additional_parameters
136+
)
138137
return fname
139138

140139

@@ -174,6 +173,9 @@ def _read_x5_group(node) -> X5Transform:
174173
inverse=np.asarray(node["Inverse"]) if "Inverse" in node else None,
175174
jacobian=np.asarray(node["Jacobian"]) if "Jacobian" in node else None,
176175
array_length=int(node.attrs.get("ArrayLength", 1)),
176+
additional_parameters=np.asarray(node["AdditionalParameters"])
177+
if "AdditionalParameters" in node
178+
else None,
177179
)
178180

179181
if "Domain" in node:

nitransforms/nonlinear.py

Lines changed: 119 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,16 @@
77
#
88
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
99
"""Nonlinear transforms."""
10+
1011
import warnings
1112
from functools import partial
13+
from collections import namedtuple
1214
import numpy as np
15+
import nibabel as nb
1316

1417
from nitransforms import io
1518
from nitransforms.io.base import _ensure_image
19+
from nitransforms.io.x5 import from_filename as load_x5
1620
from nitransforms.interp.bspline import grid_bspline_weights, _cubic_bspline
1721
from nitransforms.base import (
1822
TransformBase,
@@ -22,11 +26,17 @@
2226
)
2327
from scipy.ndimage import map_coordinates
2428

29+
# Avoids circular imports
30+
try:
31+
from nitransforms._version import __version__
32+
except ModuleNotFoundError: # pragma: no cover
33+
__version__ = "0+unknown"
34+
2535

2636
class DenseFieldTransform(TransformBase):
2737
"""Represents dense field (voxel-wise) transforms."""
2838

29-
__slots__ = ("_field", "_deltas")
39+
__slots__ = ("_field", "_deltas", "_is_deltas")
3040

3141
def __init__(self, field=None, is_deltas=True, reference=None):
3242
"""
@@ -60,14 +70,7 @@ def __init__(self, field=None, is_deltas=True, reference=None):
6070

6171
super().__init__()
6272

63-
if field is not None:
64-
field = _ensure_image(field)
65-
self._field = np.squeeze(
66-
np.asanyarray(field.dataobj) if hasattr(field, "dataobj") else field
67-
)
68-
else:
69-
self._field = np.zeros((*reference.shape, reference.ndim), dtype="float32")
70-
is_deltas = True
73+
self._is_deltas = is_deltas
7174

7275
try:
7376
self.reference = ImageGrid(reference if reference is not None else field)
@@ -78,24 +81,44 @@ def __init__(self, field=None, is_deltas=True, reference=None):
7881
else "Reference is not a spatial image"
7982
)
8083

84+
fieldshape = (*self.reference.shape, self.reference.ndim)
85+
if field is not None:
86+
field = _ensure_image(field)
87+
self._field = np.squeeze(
88+
np.asanyarray(field.dataobj) if hasattr(field, "dataobj") else field
89+
)
90+
if fieldshape != self._field.shape:
91+
raise TransformError(
92+
f"Shape of the field ({'x'.join(str(i) for i in self._field.shape)}) "
93+
f"doesn't match that of the reference({'x'.join(str(i) for i in fieldshape)})"
94+
)
95+
else:
96+
self._field = np.zeros(fieldshape, dtype="float32")
97+
self._is_deltas = True
98+
8199
if self._field.shape[-1] != self.ndim:
82100
raise TransformError(
83101
"The number of components of the field (%d) does not match "
84102
"the number of dimensions (%d)" % (self._field.shape[-1], self.ndim)
85103
)
86104

87-
if is_deltas:
105+
if self._is_deltas:
88106
self._deltas = (
89107
self._field.copy()
90108
) # IMPORTANT: you don't want to update deltas
91109
# Convert from displacements (deltas) to deformations fields
92110
# (just add its origin to each delta vector)
93-
self._field += self.reference.ndcoords.T.reshape(self._field.shape)
111+
self._field += self.reference.ndcoords.T.reshape(fieldshape)
94112

95113
def __repr__(self):
96114
"""Beautify the python representation."""
97115
return f"<{self.__class__.__name__}[{self._field.shape[-1]}D] {self._field.shape[:3]}>"
98116

117+
@property
118+
def is_deltas(self):
119+
"""Check whether this is a displacements (``True``) or a deformation (``False``) field."""
120+
return self._is_deltas
121+
99122
@property
100123
def ndim(self):
101124
"""Get the dimensions of the transform."""
@@ -224,22 +247,60 @@ def __eq__(self, other):
224247
True
225248
226249
"""
227-
_eq = np.array_equal(self._field, other._field)
250+
_eq = np.allclose(self._field, other._field)
228251
if _eq and self._reference != other._reference:
229252
warnings.warn("Fields are equal, but references do not match.")
230253
return _eq
231254

255+
def to_x5(self, metadata=None):
256+
"""Return an :class:`~nitransforms.io.x5.X5Transform` representation."""
257+
metadata = {"WrittenBy": f"NiTransforms {__version__}"} | (metadata or {})
258+
259+
domain = None
260+
if (reference := self.reference) is not None:
261+
domain = io.x5.X5Domain(
262+
grid=True,
263+
size=getattr(reference, "shape", (0, 0, 0)),
264+
mapping=reference.affine,
265+
coordinates="cartesian",
266+
)
267+
268+
kinds = tuple("space" for _ in range(self.ndim)) + ("vector",)
269+
270+
return io.x5.X5Transform(
271+
type="nonlinear",
272+
subtype="densefield",
273+
representation="displacements" if self.is_deltas else "deformations",
274+
metadata=metadata,
275+
transform=self._deltas if self.is_deltas else self._field,
276+
dimension_kinds=kinds,
277+
domain=domain,
278+
)
279+
232280
@classmethod
233281
def from_filename(cls, filename, fmt="X5"):
234282
_factory = {
235283
"afni": io.afni.AFNIDisplacementsField,
236284
"itk": io.itk.ITKDisplacementsField,
237285
"fsl": io.fsl.FSLDisplacementsField,
286+
"X5": None,
238287
}
239-
if fmt not in _factory:
288+
fmt = fmt.upper()
289+
if fmt not in {k.upper() for k in _factory}:
240290
raise NotImplementedError(f"Unsupported format <{fmt}>")
241291

242-
return cls(_factory[fmt].from_filename(filename))
292+
if fmt == "X5":
293+
x5_xfm = load_x5(filename)[0]
294+
Domain = namedtuple("Domain", "affine shape")
295+
reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size)
296+
field = nb.Nifti1Image(x5_xfm.transform, reference.affine)
297+
return cls(
298+
field,
299+
is_deltas=x5_xfm.representation == "displacements",
300+
reference=reference,
301+
)
302+
303+
return cls(_factory[fmt.lower()].from_filename(filename))
243304

244305

245306
load = DenseFieldTransform.from_filename
@@ -274,6 +335,24 @@ def ndim(self):
274335
"""Get the dimensions of the transform."""
275336
return self._coeffs.ndim - 1
276337

338+
@classmethod
339+
def from_filename(cls, filename, fmt="X5"):
340+
_factory = {
341+
"X5": None,
342+
}
343+
fmt = fmt.upper()
344+
if fmt not in {k.upper() for k in _factory}:
345+
raise NotImplementedError(f"Unsupported format <{fmt}>")
346+
347+
x5_xfm = load_x5(filename)[0]
348+
Domain = namedtuple("Domain", "affine shape")
349+
reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size)
350+
351+
coefficients = nb.Nifti1Image(x5_xfm.transform, x5_xfm.additional_parameters)
352+
return cls(coefficients, reference=reference)
353+
354+
# return cls(_factory[fmt.lower()].from_filename(filename))
355+
277356
def to_field(self, reference=None, dtype="float32"):
278357
"""Generate a displacements deformation field from this B-Spline field."""
279358
_ref = (
@@ -295,6 +374,32 @@ def to_field(self, reference=None, dtype="float32"):
295374
field.astype(dtype).reshape(*_ref.shape, -1), reference=_ref
296375
)
297376

377+
def to_x5(self, metadata=None):
378+
"""Return an :class:`~nitransforms.io.x5.X5Transform` representation."""
379+
metadata = {"WrittenBy": f"NiTransforms {__version__}"} | (metadata or {})
380+
381+
domain = None
382+
if (reference := self.reference) is not None:
383+
domain = io.x5.X5Domain(
384+
grid=True,
385+
size=getattr(reference, "shape", (0, 0, 0)),
386+
mapping=reference.affine,
387+
coordinates="cartesian",
388+
)
389+
390+
kinds = tuple("space" for _ in range(self.ndim)) + ("vector",)
391+
392+
return io.x5.X5Transform(
393+
type="nonlinear",
394+
subtype="bspline",
395+
representation="coefficients",
396+
metadata=metadata,
397+
transform=self._coeffs,
398+
dimension_kinds=kinds,
399+
domain=domain,
400+
additional_parameters=self._knots.affine,
401+
)
402+
298403
def map(self, x, inverse=False):
299404
r"""
300405
Apply the transformation to a list of physical coordinate points.

nitransforms/tests/test_nonlinear.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
BSplineFieldTransform,
1515
DenseFieldTransform,
1616
)
17+
from nitransforms import io
1718
from ..io.itk import ITKDisplacementsField
1819

1920

@@ -119,3 +120,49 @@ def test_bspline(tmp_path, testdata_path):
119120
).mean()
120121
< 0.2
121122
)
123+
124+
125+
@pytest.mark.parametrize("is_deltas", [True, False])
126+
def test_densefield_x5_roundtrip(tmp_path, is_deltas):
127+
"""Ensure dense field transforms roundtrip via X5."""
128+
ref = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="uint8"), np.eye(4))
129+
disp = nb.Nifti1Image(np.random.rand(2, 2, 2, 3).astype("float32"), np.eye(4))
130+
131+
xfm = DenseFieldTransform(disp, is_deltas=is_deltas, reference=ref)
132+
133+
node = xfm.to_x5(metadata={"GeneratedBy": "pytest"})
134+
assert node.type == "nonlinear"
135+
assert node.subtype == "densefield"
136+
assert node.representation == "displacements" if is_deltas else "deformations"
137+
assert node.domain.size == ref.shape
138+
assert node.metadata["GeneratedBy"] == "pytest"
139+
140+
fname = tmp_path / "test.x5"
141+
io.x5.to_filename(fname, [node])
142+
143+
xfm2 = DenseFieldTransform.from_filename(fname, fmt="X5")
144+
145+
assert xfm2.reference.shape == ref.shape
146+
assert np.allclose(xfm2.reference.affine, ref.affine)
147+
assert xfm == xfm2
148+
149+
150+
def test_bspline_to_x5(tmp_path):
151+
"""Check BSpline transforms export to X5."""
152+
coeff = nb.Nifti1Image(np.zeros((2, 2, 2, 3), dtype="float32"), np.eye(4))
153+
ref = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="uint8"), np.eye(4))
154+
155+
xfm = BSplineFieldTransform(coeff, reference=ref)
156+
node = xfm.to_x5(metadata={"tool": "pytest"})
157+
assert node.type == "nonlinear"
158+
assert node.subtype == "bspline"
159+
assert node.representation == "coefficients"
160+
assert node.metadata["tool"] == "pytest"
161+
162+
fname = tmp_path / "bspline.x5"
163+
io.x5.to_filename(fname, [node])
164+
165+
xfm2 = BSplineFieldTransform.from_filename(fname, fmt="X5")
166+
assert np.allclose(xfm._coeffs, xfm2._coeffs)
167+
assert xfm2.reference.shape == ref.shape
168+
assert np.allclose(xfm2.reference.affine, ref.affine)

0 commit comments

Comments
 (0)