Skip to content

Commit 8ced1ac

Browse files
committed
fix: improve loading and equalities
1 parent 6bc2bc1 commit 8ced1ac

File tree

6 files changed

+54
-14
lines changed

6 files changed

+54
-14
lines changed

nitransforms/linear.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ def __eq__(self, other):
122122
True
123123
124124
"""
125+
if not hasattr(other, "matrix"):
126+
return False
127+
125128
_eq = np.allclose(self.matrix, other.matrix, rtol=EQUALITY_TOL)
126129
if _eq and self._reference != other._reference:
127130
warnings.warn("Affines are equal, but references do not match.")

nitransforms/manip.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -208,10 +208,16 @@ def from_filename(cls, filename, fmt="X5", reference=None, moving=None, x5_chain
208208

209209
retval = []
210210
if fmt and fmt.upper() == "X5":
211-
xfm_list = load_x5(filename)
211+
# Get list of X5 nodes and generate transforms
212+
xfm_list = [
213+
globals()[f"{node.type}_from_x5"]([node]) for node in load_x5(filename)
214+
]
212215
if not xfm_list:
213216
raise TransformError("Empty transform group")
214217

218+
if x5_chain is None:
219+
return xfm_list
220+
215221
with h5py.File(str(filename), "r") as f:
216222
chain_grp = f.get("TransformChain")
217223
if chain_grp is None:
@@ -221,12 +227,7 @@ def from_filename(cls, filename, fmt="X5", reference=None, moving=None, x5_chain
221227
if isinstance(chain_path, bytes):
222228
chain_path = chain_path.decode()
223229

224-
for idx in chain_path.split("/"):
225-
node = x5io._read_x5_group(xfm_list[int(idx)])
226-
from_x5 = globals()[f"{node.type}_from_x5"]
227-
retval.append(from_x5([node]))
228-
229-
return TransformChain(retval)
230+
return TransformChain([xfm_list[int(idx)] for idx in chain_path.split("/")])
230231

231232
if str(filename).endswith(".h5"):
232233
reference = None
@@ -247,9 +248,14 @@ def to_filename(self, filename, fmt="X5"):
247248
if fmt.upper() != "X5":
248249
raise NotImplementedError("Only X5 format is supported for chains")
249250

250-
existing = load_x5(filename) if os.path.exists(filename) else []
251+
existing = (
252+
self.from_filename(filename, x5_chain=None)
253+
if os.path.exists(filename)
254+
else []
255+
)
256+
251257
xfm_chain = []
252-
new_nodes = []
258+
new_xfms = []
253259
next_xfm_index = len(existing)
254260
for xfm in self.transforms:
255261
for eidx, existing_xfm in enumerate(existing):
@@ -258,7 +264,7 @@ def to_filename(self, filename, fmt="X5"):
258264
break
259265
else:
260266
xfm_chain.append(next_xfm_index)
261-
new_nodes.append((next_xfm_index, xfm))
267+
new_xfms.append((next_xfm_index, xfm))
262268
existing.append(xfm)
263269
next_xfm_index += 1
264270

@@ -269,9 +275,9 @@ def to_filename(self, filename, fmt="X5"):
269275
f.attrs["Version"] = np.uint16(1)
270276

271277
tg = f.require_group("TransformGroup")
272-
for idx, node in new_nodes:
278+
for idx, node in new_xfms:
273279
g = tg.create_group(str(idx))
274-
x5io._write_x5_group(g, node)
280+
x5io._write_x5_group(g, node.to_x5())
275281

276282
cg = f.require_group("TransformChain")
277283
cg.create_dataset(str(len(cg)), data="/".join(str(i) for i in xfm_chain))

nitransforms/nonlinear.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,9 @@ def __eq__(self, other):
247247
True
248248
249249
"""
250+
if not hasattr(other, "_field"):
251+
return False
252+
250253
_eq = np.allclose(self._field, other._field)
251254
if _eq and self._reference != other._reference:
252255
warnings.warn("Fields are equal, but references do not match.")
@@ -322,6 +325,28 @@ def __init__(self, coefficients, reference=None, order=3):
322325
"not match the number of dimensions"
323326
)
324327

328+
def __eq__(self, other):
329+
"""
330+
Overload equals operator.
331+
332+
Examples
333+
--------
334+
>>> xfm1 = BSplineFieldTransform(test_dir / "someones_bspline_coefficients.nii.gz")
335+
>>> xfm2 = BSplineFieldTransform(test_dir / "someones_bspline_coefficients.nii.gz")
336+
>>> xfm1 == xfm2
337+
True
338+
339+
"""
340+
if not hasattr(other, "_coeffs"):
341+
return False
342+
343+
_eq = np.allclose(self._coeffs, other._coeffs)
344+
_eq = _eq and self._order == other._order
345+
346+
if _eq and self._reference != other._reference:
347+
warnings.warn("Coefficients are equal, but references do not match.")
348+
return _eq
349+
325350
@property
326351
def ndim(self):
327352
"""Get the dimensions of the transform."""

nitransforms/tests/test_manip.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,14 @@ def test_transformchain_x5_roundtrip(tmp_path):
5959

6060
fname = tmp_path / "chain.x5"
6161
chain.to_filename(fname)
62+
63+
with h5py.File(fname) as f:
64+
assert len(f["TransformGroup"]) == 2
65+
6266
chain.to_filename(fname) # append again, should not duplicate transforms
6367

6468
with h5py.File(fname) as f:
6569
assert len(f["TransformGroup"]) == 2
66-
assert len(f["TransformChain"]) == 2
6770

6871
loaded0 = TransformChain.from_filename(fname, fmt="X5", x5_chain=0)
6972
loaded1 = TransformChain.from_filename(fname, fmt="X5", x5_chain=1)

nitransforms/tests/test_resampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def test_apply_transformchain(tmp_path, testdata_path):
284284
/ "ds-005_sub-01_from-T1w_to-MNI152NLin2009cAsym_mode-image_xfm.h5"
285285
)
286286

287-
xfm = nitm.load(xfm_fname)
287+
xfm = nitm.load(xfm_fname, fmt="itk")
288288

289289
assert len(xfm) == 2
290290

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,6 @@ ignore = [
107107
"E231",
108108
"W503",
109109
]
110+
per-file-ignores = """
111+
nitransforms/manip.py: F401
112+
"""

0 commit comments

Comments
 (0)