Skip to content

Commit 23a425d

Browse files
committed
fix chain loading fallback
1 parent 5885ed8 commit 23a425d

File tree

1 file changed

+41
-43
lines changed

1 file changed

+41
-43
lines changed

nitransforms/manip.py

Lines changed: 41 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,47 @@ def from_filename(cls, filename, fmt="X5", reference=None, moving=None, x5_chain
198198
from collections import namedtuple
199199

200200
retval = []
201-
if str(filename).endswith(".h5") and (fmt is None or fmt.upper() != "X5"):
201+
if fmt and fmt.upper() == "X5":
202+
with h5py.File(str(filename), "r") as f:
203+
if f.attrs.get("Format") == "X5":
204+
tg = [
205+
x5io._read_x5_group(node)
206+
for _, node in sorted(f["TransformGroup"].items(), key=lambda kv: int(kv[0]))
207+
]
208+
chain_grp = f.get("TransformChain")
209+
if chain_grp is None:
210+
raise TransformError("X5 file contains no TransformChain")
211+
212+
chain_path = chain_grp[str(x5_chain)][()]
213+
if isinstance(chain_path, bytes):
214+
chain_path = chain_path.decode()
215+
indices = [int(idx) for idx in chain_path.split("/") if idx]
216+
217+
Domain = namedtuple("Domain", "affine shape")
218+
for idx in indices:
219+
node = tg[idx]
220+
if node.type == "linear":
221+
Transform = Affine if node.array_length == 1 else LinearTransformsMapping
222+
reference = None
223+
if node.domain is not None:
224+
reference = Domain(node.domain.mapping, node.domain.size)
225+
retval.append(Transform(node.transform, reference=reference))
226+
elif node.type == "nonlinear":
227+
reference = Domain(node.domain.mapping, node.domain.size)
228+
field = nb.Nifti1Image(node.transform, reference.affine)
229+
retval.append(
230+
DenseFieldTransform(
231+
field,
232+
is_deltas=node.representation == "displacements",
233+
reference=reference,
234+
)
235+
)
236+
else: # pragma: no cover - unsupported type
237+
raise NotImplementedError(f"Unsupported transform type {node.type}")
238+
239+
return TransformChain(retval)
240+
241+
if str(filename).endswith(".h5"):
202242
reference = None
203243
xforms = itk.ITKCompositeH5.from_filename(filename)
204244
for xfmobj in xforms:
@@ -209,48 +249,6 @@ def from_filename(cls, filename, fmt="X5", reference=None, moving=None, x5_chain
209249

210250
return TransformChain(retval)
211251

212-
if fmt and fmt.upper() == "X5":
213-
with h5py.File(str(filename), "r") as f:
214-
if f.attrs.get("Format") != "X5":
215-
raise TypeError("Input file is not in X5 format")
216-
217-
tg = [
218-
x5io._read_x5_group(node)
219-
for _, node in sorted(f["TransformGroup"].items(), key=lambda kv: int(kv[0]))
220-
]
221-
chain_grp = f.get("TransformChain")
222-
if chain_grp is None:
223-
raise TransformError("X5 file contains no TransformChain")
224-
225-
chain_path = chain_grp[str(x5_chain)][()]
226-
if isinstance(chain_path, bytes):
227-
chain_path = chain_path.decode()
228-
indices = [int(idx) for idx in chain_path.split("/") if idx]
229-
230-
Domain = namedtuple("Domain", "affine shape")
231-
for idx in indices:
232-
node = tg[idx]
233-
if node.type == "linear":
234-
Transform = Affine if node.array_length == 1 else LinearTransformsMapping
235-
reference = None
236-
if node.domain is not None:
237-
reference = Domain(node.domain.mapping, node.domain.size)
238-
retval.append(Transform(node.transform, reference=reference))
239-
elif node.type == "nonlinear":
240-
reference = Domain(node.domain.mapping, node.domain.size)
241-
field = nb.Nifti1Image(node.transform, reference.affine)
242-
retval.append(
243-
DenseFieldTransform(
244-
field,
245-
is_deltas=node.representation == "displacements",
246-
reference=reference,
247-
)
248-
)
249-
else: # pragma: no cover - unsupported type
250-
raise NotImplementedError(f"Unsupported transform type {node.type}")
251-
252-
return TransformChain(retval)
253-
254252
raise NotImplementedError
255253

256254
def to_filename(self, filename, fmt="X5"):

0 commit comments

Comments
 (0)