Skip to content

Commit 27e95c7

Browse files
committed
enh: enable B-Splines X5 i/o and DenseFields' is_deltas
1 parent aa973a3 commit 27e95c7

File tree

3 files changed

+81
-43
lines changed

3 files changed

+81
-43
lines changed

nitransforms/io/x5.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,11 @@ class X5Transform:
7373
For parametric models it is generally possible to obtain it analytically, so this dataset
7474
could not be as useful in that case.
7575
"""
76-
# additional_parameters: Optional[np.ndarray] = None
77-
# AdditionalParameters is empty in the draft spec - ignore for now.
78-
# Only documentation ATM is for SubType:
79-
# The SubType setting enables setting the additional parameters on a dataset called
80-
# "AdditionalParameters" that hangs directly from this transform node.
76+
additional_parameters: Optional[np.ndarray] = None
77+
"""
78+
An OPTIONAL field to store additional parameters, depending on the SubType of the
79+
transform.
80+
"""
8181
array_length: int = 1
8282
"""Undocumented field in the draft to enable a single transform group for 4D transforms."""
8383

@@ -130,11 +130,10 @@ def to_filename(fname: str | Path, x5_list: List[X5Transform]):
130130
g.create_dataset("Inverse", data=node.inverse)
131131
if node.jacobian is not None:
132132
g.create_dataset("Jacobian", data=node.jacobian)
133-
# Disabled until we need SubType and AdditionalParameters
134-
# if node.additional_parameters is not None:
135-
# g.create_dataset(
136-
# "AdditionalParameters", data=node.additional_parameters
137-
# )
133+
if node.additional_parameters is not None:
134+
g.create_dataset(
135+
"AdditionalParameters", data=node.additional_parameters
136+
)
138137
return fname
139138

140139

@@ -174,6 +173,9 @@ def _read_x5_group(node) -> X5Transform:
174173
inverse=np.asarray(node["Inverse"]) if "Inverse" in node else None,
175174
jacobian=np.asarray(node["Jacobian"]) if "Jacobian" in node else None,
176175
array_length=int(node.attrs.get("ArrayLength", 1)),
176+
additional_parameters=np.asarray(node["AdditionalParameters"])
177+
if "AdditionalParameters" in node
178+
else None,
177179
)
178180

179181
if "Domain" in node:

nitransforms/nonlinear.py

Lines changed: 58 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
from functools import partial
1313
from collections import namedtuple
1414
import numpy as np
15+
import nibabel as nb
1516

1617
from nitransforms import io
1718
from nitransforms.io.base import _ensure_image
19+
from nitransforms.io.x5 import from_filename as load_x5
1820
from nitransforms.interp.bspline import grid_bspline_weights, _cubic_bspline
1921
from nitransforms.base import (
2022
TransformBase,
@@ -34,7 +36,7 @@
3436
class DenseFieldTransform(TransformBase):
3537
"""Represents dense field (voxel-wise) transforms."""
3638

37-
__slots__ = ("_field", "_deltas")
39+
__slots__ = ("_field", "_deltas", "_is_deltas")
3840

3941
def __init__(self, field=None, is_deltas=True, reference=None):
4042
"""
@@ -68,14 +70,7 @@ def __init__(self, field=None, is_deltas=True, reference=None):
6870

6971
super().__init__()
7072

71-
if field is not None:
72-
field = _ensure_image(field)
73-
self._field = np.squeeze(
74-
np.asanyarray(field.dataobj) if hasattr(field, "dataobj") else field
75-
)
76-
else:
77-
self._field = np.zeros((*reference.shape, reference.ndim), dtype="float32")
78-
is_deltas = True
73+
self._is_deltas = is_deltas
7974

8075
try:
8176
self.reference = ImageGrid(reference if reference is not None else field)
@@ -86,22 +81,44 @@ def __init__(self, field=None, is_deltas=True, reference=None):
8681
else "Reference is not a spatial image"
8782
)
8883

84+
fieldshape = (*self.reference.shape, self.reference.ndim)
85+
if field is not None:
86+
field = _ensure_image(field)
87+
self._field = np.squeeze(
88+
np.asanyarray(field.dataobj) if hasattr(field, "dataobj") else field
89+
)
90+
if fieldshape != self._field.shape:
91+
raise TransformError(
92+
f"Shape of the field ({'x'.join(str(i) for i in self._field.shape)}) "
93+
f"doesn't match that of the reference({'x'.join(str(i) for i in fieldshape)})"
94+
)
95+
else:
96+
self._field = np.zeros(fieldshape, dtype="float32")
97+
self._is_deltas = True
98+
8999
if self._field.shape[-1] != self.ndim:
90100
raise TransformError(
91101
"The number of components of the field (%d) does not match "
92102
"the number of dimensions (%d)" % (self._field.shape[-1], self.ndim)
93103
)
94104

95-
if is_deltas:
96-
self._deltas = self._field
105+
if self._is_deltas:
106+
self._deltas = (
107+
self._field.copy()
108+
) # IMPORTANT: you don't want to update deltas
97109
# Convert from displacements (deltas) to deformations fields
98110
# (just add its origin to each delta vector)
99-
self._field += self.reference.ndcoords.T.reshape(self._field.shape)
111+
self._field += self.reference.ndcoords.T.reshape(fieldshape)
100112

101113
def __repr__(self):
102114
"""Beautify the python representation."""
103115
return f"<{self.__class__.__name__}[{self._field.shape[-1]}D] {self._field.shape[:3]}>"
104116

117+
@property
118+
def is_deltas(self):
119+
"""Check whether this is a displacements (``True``) or a deformation (``False``) field."""
120+
return self._is_deltas
121+
105122
@property
106123
def ndim(self):
107124
"""Get the dimensions of the transform."""
@@ -230,7 +247,7 @@ def __eq__(self, other):
230247
True
231248
232249
"""
233-
_eq = np.array_equal(self._field, other._field)
250+
_eq = np.allclose(self._field, other._field)
234251
if _eq and self._reference != other._reference:
235252
warnings.warn("Fields are equal, but references do not match.")
236253
return _eq
@@ -253,9 +270,9 @@ def to_x5(self, metadata=None):
253270
return io.x5.X5Transform(
254271
type="nonlinear",
255272
subtype="densefield",
256-
representation="displacements",
273+
representation="displacements" if self.is_deltas else "deformations",
257274
metadata=metadata,
258-
transform=self._deltas,
275+
transform=self._deltas if self.is_deltas else self._field,
259276
dimension_kinds=kinds,
260277
domain=domain,
261278
)
@@ -273,12 +290,15 @@ def from_filename(cls, filename, fmt="X5"):
273290
raise NotImplementedError(f"Unsupported format <{fmt}>")
274291

275292
if fmt == "X5":
276-
from .io.x5 import from_filename as load_x5
277-
278293
x5_xfm = load_x5(filename)[0]
279294
Domain = namedtuple("Domain", "affine shape")
280295
reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size)
281-
return cls(x5_xfm.transform, is_deltas=True, reference=reference)
296+
field = nb.Nifti1Image(x5_xfm.transform, reference.affine)
297+
return cls(
298+
field,
299+
is_deltas=x5_xfm.representation == "displacements",
300+
reference=reference,
301+
)
282302

283303
return cls(_factory[fmt.lower()].from_filename(filename))
284304

@@ -315,6 +335,24 @@ def ndim(self):
315335
"""Get the dimensions of the transform."""
316336
return self._coeffs.ndim - 1
317337

338+
@classmethod
339+
def from_filename(cls, filename, fmt="X5"):
340+
_factory = {
341+
"X5": None,
342+
}
343+
fmt = fmt.upper()
344+
if fmt not in {k.upper() for k in _factory}:
345+
raise NotImplementedError(f"Unsupported format <{fmt}>")
346+
347+
x5_xfm = load_x5(filename)[0]
348+
Domain = namedtuple("Domain", "affine shape")
349+
reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size)
350+
351+
coefficients = nb.Nifti1Image(x5_xfm.transform, x5_xfm.additional_parameters)
352+
return cls(coefficients, reference=reference)
353+
354+
# return cls(_factory[fmt.lower()].from_filename(filename))
355+
318356
def to_field(self, reference=None, dtype="float32"):
319357
"""Generate a displacements deformation field from this B-Spline field."""
320358
_ref = (
@@ -349,21 +387,17 @@ def to_x5(self, metadata=None):
349387
coordinates="cartesian",
350388
)
351389

352-
meta = metadata | {
353-
"KnotsAffine": self._knots.affine.tolist(),
354-
"KnotsShape": self._knots.shape,
355-
}
356-
357390
kinds = tuple("space" for _ in range(self.ndim)) + ("vector",)
358391

359392
return io.x5.X5Transform(
360393
type="nonlinear",
361394
subtype="bspline",
362395
representation="coefficients",
363-
metadata=meta,
396+
metadata=metadata,
364397
transform=self._coeffs,
365398
dimension_kinds=kinds,
366399
domain=domain,
400+
additional_parameters=self._knots.affine,
367401
)
368402

369403
def map(self, x, inverse=False):

nitransforms/tests/test_nonlinear.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -122,29 +122,29 @@ def test_bspline(tmp_path, testdata_path):
122122
)
123123

124124

125-
def test_densefield_x5_roundtrip(tmp_path):
125+
@pytest.mark.parametrize("is_deltas", [True, False])
126+
def test_densefield_x5_roundtrip(tmp_path, is_deltas):
126127
"""Ensure dense field transforms roundtrip via X5."""
127128
ref = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="uint8"), np.eye(4))
128129
disp = nb.Nifti1Image(np.random.rand(2, 2, 2, 3).astype("float32"), np.eye(4))
129130

130-
xfm = DenseFieldTransform(disp, reference=ref)
131+
xfm = DenseFieldTransform(disp, is_deltas=is_deltas, reference=ref)
131132

132133
node = xfm.to_x5(metadata={"GeneratedBy": "pytest"})
133134
assert node.type == "nonlinear"
134135
assert node.subtype == "densefield"
135-
assert node.representation == "displacements"
136+
assert node.representation == "displacements" if is_deltas else "deformations"
136137
assert node.domain.size == ref.shape
137138
assert node.metadata["GeneratedBy"] == "pytest"
138139

139140
fname = tmp_path / "test.x5"
140141
io.x5.to_filename(fname, [node])
141142

142143
xfm2 = DenseFieldTransform.from_filename(fname, fmt="X5")
143-
diff = xfm2._deltas - xfm._deltas
144-
coords = xfm.reference.ndcoords.T.reshape(xfm._deltas.shape)
145-
assert np.allclose(diff, coords)
144+
146145
assert xfm2.reference.shape == ref.shape
147146
assert np.allclose(xfm2.reference.affine, ref.affine)
147+
assert xfm == xfm2
148148

149149

150150
def test_bspline_to_x5(tmp_path):
@@ -161,6 +161,8 @@ def test_bspline_to_x5(tmp_path):
161161

162162
fname = tmp_path / "bspline.x5"
163163
io.x5.to_filename(fname, [node])
164-
node2 = io.x5.from_filename(fname)[0]
165-
assert np.allclose(node2.transform, node.transform)
166-
assert node2.metadata["tool"] == "pytest"
164+
165+
xfm2 = BSplineFieldTransform.from_filename(fname, fmt="X5")
166+
assert np.allclose(xfm._coeffs, xfm2._coeffs)
167+
assert xfm2.reference.shape == ref.shape
168+
assert np.allclose(xfm2.reference.affine, ref.affine)

0 commit comments

Comments
 (0)