Skip to content

Commit 5885ed8

Browse files
committed
ENH: support TransformChain in X5
1 parent fc6fbbd commit 5885ed8

File tree

3 files changed

+181
-33
lines changed

3 files changed

+181
-33
lines changed

nitransforms/io/x5.py

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -105,35 +105,7 @@ def to_filename(fname: str | Path, x5_list: List[X5Transform]):
105105
tg = out_file.create_group("TransformGroup")
106106
for i, node in enumerate(x5_list):
107107
g = tg.create_group(str(i))
108-
g.attrs["Type"] = node.type
109-
g.attrs["ArrayLength"] = node.array_length
110-
if node.subtype is not None:
111-
g.attrs["SubType"] = node.subtype
112-
if node.representation is not None:
113-
g.attrs["Representation"] = node.representation
114-
if node.metadata is not None:
115-
g.attrs["Metadata"] = json.dumps(node.metadata)
116-
g.create_dataset("Transform", data=node.transform)
117-
g.create_dataset(
118-
"DimensionKinds",
119-
data=np.asarray(node.dimension_kinds, dtype="S"),
120-
)
121-
if node.domain is not None:
122-
dgrp = g.create_group("Domain")
123-
dgrp.create_dataset("Grid", data=np.uint8(1 if node.domain.grid else 0))
124-
dgrp.create_dataset("Size", data=np.asarray(node.domain.size))
125-
dgrp.create_dataset("Mapping", data=node.domain.mapping)
126-
if node.domain.coordinates is not None:
127-
dgrp.attrs["Coordinates"] = node.domain.coordinates
128-
129-
if node.inverse is not None:
130-
g.create_dataset("Inverse", data=node.inverse)
131-
if node.jacobian is not None:
132-
g.create_dataset("Jacobian", data=node.jacobian)
133-
if node.additional_parameters is not None:
134-
g.create_dataset(
135-
"AdditionalParameters", data=node.additional_parameters
136-
)
108+
_write_x5_group(g, node)
137109
return fname
138110

139111

@@ -188,3 +160,30 @@ def _read_x5_group(node) -> X5Transform:
188160
)
189161

190162
return x5
163+
164+
165+
def _write_x5_group(g, node: X5Transform):
166+
"""Write one :class:`X5Transform` element into an opened HDF5 group."""
167+
g.attrs["Type"] = node.type
168+
g.attrs["ArrayLength"] = node.array_length
169+
if node.subtype is not None:
170+
g.attrs["SubType"] = node.subtype
171+
if node.representation is not None:
172+
g.attrs["Representation"] = node.representation
173+
if node.metadata is not None:
174+
g.attrs["Metadata"] = json.dumps(node.metadata)
175+
g.create_dataset("Transform", data=node.transform)
176+
g.create_dataset("DimensionKinds", data=np.asarray(node.dimension_kinds, dtype="S"))
177+
if node.domain is not None:
178+
dgrp = g.create_group("Domain")
179+
dgrp.create_dataset("Grid", data=np.uint8(1 if node.domain.grid else 0))
180+
dgrp.create_dataset("Size", data=np.asarray(node.domain.size))
181+
dgrp.create_dataset("Mapping", data=node.domain.mapping)
182+
if node.domain.coordinates is not None:
183+
dgrp.attrs["Coordinates"] = node.domain.coordinates
184+
if node.inverse is not None:
185+
g.create_dataset("Inverse", data=node.inverse)
186+
if node.jacobian is not None:
187+
g.create_dataset("Jacobian", data=node.jacobian)
188+
if node.additional_parameters is not None:
189+
g.create_dataset("AdditionalParameters", data=node.additional_parameters)

nitransforms/manip.py

Lines changed: 119 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
TransformBase,
1515
TransformError,
1616
)
17-
from .linear import Affine
17+
from .linear import Affine, LinearTransformsMapping
1818
from .nonlinear import DenseFieldTransform
1919

2020

@@ -190,12 +190,15 @@ def asaffine(self, indices=None):
190190
return retval
191191

192192
@classmethod
193-
def from_filename(cls, filename, fmt="X5", reference=None, moving=None):
193+
def from_filename(cls, filename, fmt="X5", reference=None, moving=None, x5_chain=0):
194194
"""Load a transform file."""
195-
from .io import itk
195+
from .io import itk, x5 as x5io
196+
import h5py
197+
import nibabel as nb
198+
from collections import namedtuple
196199

197200
retval = []
198-
if str(filename).endswith(".h5"):
201+
if str(filename).endswith(".h5") and (fmt is None or fmt.upper() != "X5"):
199202
reference = None
200203
xforms = itk.ITKCompositeH5.from_filename(filename)
201204
for xfmobj in xforms:
@@ -206,8 +209,120 @@ def from_filename(cls, filename, fmt="X5", reference=None, moving=None):
206209

207210
return TransformChain(retval)
208211

212+
if fmt and fmt.upper() == "X5":
213+
with h5py.File(str(filename), "r") as f:
214+
if f.attrs.get("Format") != "X5":
215+
raise TypeError("Input file is not in X5 format")
216+
217+
tg = [
218+
x5io._read_x5_group(node)
219+
for _, node in sorted(f["TransformGroup"].items(), key=lambda kv: int(kv[0]))
220+
]
221+
chain_grp = f.get("TransformChain")
222+
if chain_grp is None:
223+
raise TransformError("X5 file contains no TransformChain")
224+
225+
chain_path = chain_grp[str(x5_chain)][()]
226+
if isinstance(chain_path, bytes):
227+
chain_path = chain_path.decode()
228+
indices = [int(idx) for idx in chain_path.split("/") if idx]
229+
230+
Domain = namedtuple("Domain", "affine shape")
231+
for idx in indices:
232+
node = tg[idx]
233+
if node.type == "linear":
234+
Transform = Affine if node.array_length == 1 else LinearTransformsMapping
235+
reference = None
236+
if node.domain is not None:
237+
reference = Domain(node.domain.mapping, node.domain.size)
238+
retval.append(Transform(node.transform, reference=reference))
239+
elif node.type == "nonlinear":
240+
reference = Domain(node.domain.mapping, node.domain.size)
241+
field = nb.Nifti1Image(node.transform, reference.affine)
242+
retval.append(
243+
DenseFieldTransform(
244+
field,
245+
is_deltas=node.representation == "displacements",
246+
reference=reference,
247+
)
248+
)
249+
else: # pragma: no cover - unsupported type
250+
raise NotImplementedError(f"Unsupported transform type {node.type}")
251+
252+
return TransformChain(retval)
253+
209254
raise NotImplementedError
210255

256+
def to_filename(self, filename, fmt="X5"):
257+
"""Store the transform chain in X5 format."""
258+
from .io import x5 as x5io
259+
import os
260+
import h5py
261+
262+
if fmt.upper() != "X5":
263+
raise NotImplementedError("Only X5 format is supported for chains")
264+
265+
if os.path.exists(filename):
266+
with h5py.File(str(filename), "r") as f:
267+
existing = [
268+
x5io._read_x5_group(node)
269+
for _, node in sorted(f["TransformGroup"].items(), key=lambda kv: int(kv[0]))
270+
]
271+
else:
272+
existing = []
273+
274+
# convert to objects for equality check
275+
from collections import namedtuple
276+
import nibabel as nb
277+
278+
def _as_transform(x5node):
279+
Domain = namedtuple("Domain", "affine shape")
280+
if x5node.type == "linear":
281+
Transform = Affine if x5node.array_length == 1 else LinearTransformsMapping
282+
ref = None
283+
if x5node.domain is not None:
284+
ref = Domain(x5node.domain.mapping, x5node.domain.size)
285+
return Transform(x5node.transform, reference=ref)
286+
reference = Domain(x5node.domain.mapping, x5node.domain.size)
287+
field = nb.Nifti1Image(x5node.transform, reference.affine)
288+
return DenseFieldTransform(
289+
field,
290+
is_deltas=x5node.representation == "displacements",
291+
reference=reference,
292+
)
293+
294+
existing_objs = [_as_transform(n) for n in existing]
295+
path_indices = []
296+
new_nodes = []
297+
for xfm in self.transforms:
298+
# find existing
299+
idx = None
300+
for i, obj in enumerate(existing_objs):
301+
if type(xfm) is type(obj) and xfm == obj:
302+
idx = i
303+
break
304+
if idx is None:
305+
idx = len(existing_objs)
306+
new_nodes.append((idx, xfm.to_x5()))
307+
existing_objs.append(xfm)
308+
path_indices.append(idx)
309+
310+
mode = "r+" if os.path.exists(filename) else "w"
311+
with h5py.File(str(filename), mode) as f:
312+
if "Format" not in f.attrs:
313+
f.attrs["Format"] = "X5"
314+
f.attrs["Version"] = np.uint16(1)
315+
316+
tg = f.require_group("TransformGroup")
317+
for idx, node in new_nodes:
318+
g = tg.create_group(str(idx))
319+
x5io._write_x5_group(g, node)
320+
321+
cg = f.require_group("TransformChain")
322+
cg.create_dataset(str(len(cg)), data="/".join(str(i) for i in path_indices))
323+
324+
return filename
325+
211326

212327
def _as_chain(x):
213328
"""Convert a value into a transform chain."""

nitransforms/tests/test_manip.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55
import pytest
66

77
import numpy as np
8+
import nibabel as nb
9+
import h5py
810
from ..manip import TransformChain
911
from ..linear import Affine
12+
from ..nonlinear import DenseFieldTransform
1013

1114
FMT = {"lta": "fs", "tfm": "itk"}
1215

@@ -37,3 +40,34 @@ def test_collapse_affines(tmp_path, data_path, ext0, ext1, ext2):
3740
fmt=f"{FMT[ext2]}",
3841
).matrix,
3942
)
43+
44+
45+
def test_transformchain_x5_roundtrip(tmp_path):
46+
"""Round-trip TransformChain with X5 storage."""
47+
48+
mat = np.eye(4)
49+
mat[0, 3] = 1
50+
aff = Affine(mat)
51+
52+
field = nb.Nifti1Image(np.zeros((5, 5, 5, 3), dtype="float32"), np.eye(4))
53+
fdata = field.get_fdata()
54+
fdata[..., 1] = 1
55+
field = nb.Nifti1Image(fdata, np.eye(4))
56+
dfield = DenseFieldTransform(field, is_deltas=True)
57+
58+
chain = TransformChain([aff, aff, aff, dfield])
59+
60+
fname = tmp_path / "chain.x5"
61+
chain.to_filename(fname)
62+
chain.to_filename(fname) # append again, should not duplicate transforms
63+
64+
with h5py.File(fname) as f:
65+
assert len(f["TransformGroup"]) == 2
66+
assert len(f["TransformChain"]) == 2
67+
68+
loaded0 = TransformChain.from_filename(fname, fmt="X5", x5_chain=0)
69+
loaded1 = TransformChain.from_filename(fname, fmt="X5", x5_chain=1)
70+
71+
assert len(loaded0) == len(chain)
72+
assert len(loaded1) == len(chain)
73+
assert np.allclose(chain.map([[0, 0, 0]]), loaded1.map([[0, 0, 0]]))

0 commit comments

Comments
 (0)