Skip to content

Commit 50b2b0a

Browse files
committed
fix: general re-design of the PR
1 parent 23a425d commit 50b2b0a

File tree

4 files changed

+95
-125
lines changed

4 files changed

+95
-125
lines changed

nitransforms/io/x5.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import numpy as np
2727

2828

29-
@dataclass
29+
@dataclass(eq=True)
3030
class X5Domain:
3131
"""Domain information of a transform representing reference/moving spaces."""
3232

nitransforms/linear.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -186,22 +186,7 @@ def from_filename(
186186
"""Create an affine from a transform file."""
187187

188188
if fmt and fmt.upper() == "X5":
189-
x5_xfm = load_x5(filename)[x5_position]
190-
Transform = cls if x5_xfm.array_length == 1 else LinearTransformsMapping
191-
if (
192-
x5_xfm.domain
193-
and not x5_xfm.domain.grid
194-
and len(x5_xfm.domain.size) == 3
195-
): # pragma: no cover
196-
raise NotImplementedError(
197-
"Only 3D regularly gridded domains are supported"
198-
)
199-
elif x5_xfm.domain:
200-
# Override reference
201-
Domain = namedtuple("Domain", "affine shape")
202-
reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size)
203-
204-
return Transform(x5_xfm.transform, reference=reference)
189+
return from_x5(load_x5(filename), x5_position)
205190

206191
fmtlist = [fmt] if fmt is not None else ("itk", "lta", "afni", "fsl")
207192

@@ -458,3 +443,20 @@ def load(filename, fmt=None, reference=None, moving=None):
458443
xfm = xfm[0]
459444

460445
return xfm
446+
447+
448+
def from_x5(x5_list, x5_position=0):
449+
"""Create an affine from a list of :class:`~nitransforms.io.x5.X5Transform` objects."""
450+
451+
x5_xfm = x5_list[x5_position]
452+
Transform = Affine if x5_xfm.array_length == 1 else LinearTransformsMapping
453+
if (
454+
x5_xfm.domain and not x5_xfm.domain.grid and len(x5_xfm.domain.size) == 3
455+
): # pragma: no cover
456+
raise NotImplementedError("Only 3D regularly gridded domains are supported")
457+
elif x5_xfm.domain:
458+
# Override reference
459+
Domain = namedtuple("Domain", "affine shape")
460+
reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size)
461+
462+
return Transform(x5_xfm.transform, reference=reference)

nitransforms/manip.py

Lines changed: 47 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,26 @@
77
#
88
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
99
"""Common interface for transforms."""
10+
11+
import os
1012
from collections.abc import Iterable
1113
import numpy as np
1214

13-
from .base import (
15+
import h5py
16+
from nitransforms.base import (
1417
TransformBase,
1518
TransformError,
1619
)
17-
from .linear import Affine, LinearTransformsMapping
18-
from .nonlinear import DenseFieldTransform
20+
from nitransforms.io import itk, x5 as x5io
21+
from nitransforms.io.x5 import from_filename as load_x5
22+
from nitransforms.linear import (
23+
Affine,
24+
from_x5 as linear_from_x5, # noqa: F401
25+
)
26+
from nitransforms.nonlinear import (
27+
DenseFieldTransform,
28+
from_x5 as nonlinear_from_x5, # noqa: F401
29+
)
1930

2031

2132
class TransformChain(TransformBase):
@@ -183,7 +194,9 @@ def asaffine(self, indices=None):
183194
The indices of the values to extract.
184195
185196
"""
186-
affines = self.transforms if indices is None else np.take(self.transforms, indices)
197+
affines = (
198+
self.transforms if indices is None else np.take(self.transforms, indices)
199+
)
187200
retval = affines[0]
188201
for xfm in affines[1:]:
189202
retval = xfm @ retval
@@ -192,51 +205,28 @@ def asaffine(self, indices=None):
192205
@classmethod
193206
def from_filename(cls, filename, fmt="X5", reference=None, moving=None, x5_chain=0):
194207
"""Load a transform file."""
195-
from .io import itk, x5 as x5io
196-
import h5py
197-
import nibabel as nb
198-
from collections import namedtuple
199208

200209
retval = []
201210
if fmt and fmt.upper() == "X5":
211+
xfm_list = load_x5(filename)
212+
if not xfm_list:
213+
raise TransformError("Empty transform group")
214+
202215
with h5py.File(str(filename), "r") as f:
203-
if f.attrs.get("Format") == "X5":
204-
tg = [
205-
x5io._read_x5_group(node)
206-
for _, node in sorted(f["TransformGroup"].items(), key=lambda kv: int(kv[0]))
207-
]
208-
chain_grp = f.get("TransformChain")
209-
if chain_grp is None:
210-
raise TransformError("X5 file contains no TransformChain")
211-
212-
chain_path = chain_grp[str(x5_chain)][()]
213-
if isinstance(chain_path, bytes):
214-
chain_path = chain_path.decode()
215-
indices = [int(idx) for idx in chain_path.split("/") if idx]
216-
217-
Domain = namedtuple("Domain", "affine shape")
218-
for idx in indices:
219-
node = tg[idx]
220-
if node.type == "linear":
221-
Transform = Affine if node.array_length == 1 else LinearTransformsMapping
222-
reference = None
223-
if node.domain is not None:
224-
reference = Domain(node.domain.mapping, node.domain.size)
225-
retval.append(Transform(node.transform, reference=reference))
226-
elif node.type == "nonlinear":
227-
reference = Domain(node.domain.mapping, node.domain.size)
228-
field = nb.Nifti1Image(node.transform, reference.affine)
229-
retval.append(
230-
DenseFieldTransform(
231-
field,
232-
is_deltas=node.representation == "displacements",
233-
reference=reference,
234-
)
235-
)
236-
else: # pragma: no cover - unsupported type
237-
raise NotImplementedError(f"Unsupported transform type {node.type}")
238-
239-
return TransformChain(retval)
216+
chain_grp = f.get("TransformChain")
217+
if chain_grp is None:
218+
raise TransformError("X5 file contains no TransformChain")
219+
220+
chain_path = chain_grp[str(x5_chain)][()]
221+
if isinstance(chain_path, bytes):
222+
chain_path = chain_path.decode()
223+
224+
for idx in chain_path.split("/"):
225+
node = x5io._read_x5_group(xfm_list[int(idx)])
226+
from_x5 = globals()[f"{node.type}_from_x5"]
227+
retval.append(from_x5([node]))
228+
229+
return TransformChain(retval)
240230

241231
if str(filename).endswith(".h5"):
242232
reference = None
@@ -253,57 +243,24 @@ def from_filename(cls, filename, fmt="X5", reference=None, moving=None, x5_chain
253243

254244
def to_filename(self, filename, fmt="X5"):
255245
"""Store the transform chain in X5 format."""
256-
from .io import x5 as x5io
257-
import os
258-
import h5py
259246

260247
if fmt.upper() != "X5":
261248
raise NotImplementedError("Only X5 format is supported for chains")
262249

263-
if os.path.exists(filename):
264-
with h5py.File(str(filename), "r") as f:
265-
existing = [
266-
x5io._read_x5_group(node)
267-
for _, node in sorted(f["TransformGroup"].items(), key=lambda kv: int(kv[0]))
268-
]
269-
else:
270-
existing = []
271-
272-
# convert to objects for equality check
273-
from collections import namedtuple
274-
import nibabel as nb
275-
276-
def _as_transform(x5node):
277-
Domain = namedtuple("Domain", "affine shape")
278-
if x5node.type == "linear":
279-
Transform = Affine if x5node.array_length == 1 else LinearTransformsMapping
280-
ref = None
281-
if x5node.domain is not None:
282-
ref = Domain(x5node.domain.mapping, x5node.domain.size)
283-
return Transform(x5node.transform, reference=ref)
284-
reference = Domain(x5node.domain.mapping, x5node.domain.size)
285-
field = nb.Nifti1Image(x5node.transform, reference.affine)
286-
return DenseFieldTransform(
287-
field,
288-
is_deltas=x5node.representation == "displacements",
289-
reference=reference,
290-
)
291-
292-
existing_objs = [_as_transform(n) for n in existing]
293-
path_indices = []
250+
existing = load_x5(filename) if os.path.exists(filename) else []
251+
xfm_chain = []
294252
new_nodes = []
253+
next_xfm_index = len(existing)
295254
for xfm in self.transforms:
296-
# find existing
297-
idx = None
298-
for i, obj in enumerate(existing_objs):
299-
if type(xfm) is type(obj) and xfm == obj:
300-
idx = i
255+
for eidx, existing_xfm in enumerate(existing):
256+
if xfm == existing_xfm:
257+
xfm_chain.append(eidx)
301258
break
302-
if idx is None:
303-
idx = len(existing_objs)
304-
new_nodes.append((idx, xfm.to_x5()))
305-
existing_objs.append(xfm)
306-
path_indices.append(idx)
259+
else:
260+
xfm_chain.append(next_xfm_index)
261+
new_nodes.append((next_xfm_index, xfm))
262+
existing.append(xfm)
263+
next_xfm_index += 1
307264

308265
mode = "r+" if os.path.exists(filename) else "w"
309266
with h5py.File(str(filename), mode) as f:
@@ -317,7 +274,7 @@ def _as_transform(x5node):
317274
x5io._write_x5_group(g, node)
318275

319276
cg = f.require_group("TransformChain")
320-
cg.create_dataset(str(len(cg)), data="/".join(str(i) for i in path_indices))
277+
cg.create_dataset(str(len(cg)), data="/".join(str(i) for i in xfm_chain))
321278

322279
return filename
323280

nitransforms/nonlinear.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def to_x5(self, metadata=None):
278278
)
279279

280280
@classmethod
281-
def from_filename(cls, filename, fmt="X5"):
281+
def from_filename(cls, filename, fmt="X5", x5_position=0):
282282
_factory = {
283283
"afni": io.afni.AFNIDisplacementsField,
284284
"itk": io.itk.ITKDisplacementsField,
@@ -290,15 +290,7 @@ def from_filename(cls, filename, fmt="X5"):
290290
raise NotImplementedError(f"Unsupported format <{fmt}>")
291291

292292
if fmt == "X5":
293-
x5_xfm = load_x5(filename)[0]
294-
Domain = namedtuple("Domain", "affine shape")
295-
reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size)
296-
field = nb.Nifti1Image(x5_xfm.transform, reference.affine)
297-
return cls(
298-
field,
299-
is_deltas=x5_xfm.representation == "displacements",
300-
reference=reference,
301-
)
293+
return from_x5(load_x5(filename), x5_position=x5_position)
302294

303295
return cls(_factory[fmt.lower()].from_filename(filename))
304296

@@ -336,21 +328,15 @@ def ndim(self):
336328
return self._coeffs.ndim - 1
337329

338330
@classmethod
339-
def from_filename(cls, filename, fmt="X5"):
331+
def from_filename(cls, filename, fmt="X5", x5_position=0):
340332
_factory = {
341333
"X5": None,
342334
}
343335
fmt = fmt.upper()
344336
if fmt not in {k.upper() for k in _factory}:
345337
raise NotImplementedError(f"Unsupported format <{fmt}>")
346338

347-
x5_xfm = load_x5(filename)[0]
348-
Domain = namedtuple("Domain", "affine shape")
349-
reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size)
350-
351-
coefficients = nb.Nifti1Image(x5_xfm.transform, x5_xfm.additional_parameters)
352-
return cls(coefficients, reference=reference)
353-
339+
return from_x5(load_x5(filename), x5_position=x5_position)
354340
# return cls(_factory[fmt.lower()].from_filename(filename))
355341

356342
def to_field(self, reference=None, dtype="float32"):
@@ -440,6 +426,31 @@ def map(self, x, inverse=False):
440426
return np.array([vfunc(_x).tolist() for _x in np.atleast_2d(x)])
441427

442428

429+
def from_x5(x5_list, x5_position=0):
430+
"""Create a transform from a list of :class:`~nitransforms.io.x5.X5Transform` objects."""
431+
432+
x5_xfm = x5_list[x5_position]
433+
434+
Transform = (
435+
BSplineFieldTransform if x5_xfm.subtype == "bspline" else DenseFieldTransform
436+
)
437+
Domain = namedtuple("Domain", "affine shape")
438+
reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size)
439+
xfm_params = (
440+
nb.Nifti1Image(x5_xfm.transform, x5_xfm.additional_parameters)
441+
if x5_xfm.subtype == "bspline"
442+
else x5_xfm.transform
443+
)
444+
445+
xfm_kwargs = (
446+
{}
447+
if x5_xfm.subtype == "bspline"
448+
else {"is_deltas": x5_xfm.representation == "displacements"}
449+
)
450+
451+
return Transform(xfm_params, reference=reference, **xfm_kwargs)
452+
453+
443454
def _map_xyz(x, reference, knots, coeffs):
444455
"""Apply the transformation to just one coordinate."""
445456
ndim = len(x)

0 commit comments

Comments
 (0)