7
7
#
8
8
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
9
9
"""Common interface for transforms."""
10
+
11
+ import os
10
12
from collections .abc import Iterable
11
13
import numpy as np
12
14
13
- from .base import (
15
+ import h5py
16
+ from nitransforms .base import (
14
17
TransformBase ,
15
18
TransformError ,
16
19
)
17
- from .linear import Affine , LinearTransformsMapping
18
- from .nonlinear import DenseFieldTransform
20
+ from nitransforms .io import itk , x5 as x5io
21
+ from nitransforms .io .x5 import from_filename as load_x5
22
+ from nitransforms .linear import (
23
+ Affine ,
24
+ from_x5 as linear_from_x5 , # noqa: F401
25
+ )
26
+ from nitransforms .nonlinear import (
27
+ DenseFieldTransform ,
28
+ from_x5 as nonlinear_from_x5 , # noqa: F401
29
+ )
19
30
20
31
21
32
class TransformChain (TransformBase ):
@@ -183,7 +194,9 @@ def asaffine(self, indices=None):
183
194
The indices of the values to extract.
184
195
185
196
"""
186
- affines = self .transforms if indices is None else np .take (self .transforms , indices )
197
+ affines = (
198
+ self .transforms if indices is None else np .take (self .transforms , indices )
199
+ )
187
200
retval = affines [0 ]
188
201
for xfm in affines [1 :]:
189
202
retval = xfm @ retval
@@ -192,51 +205,28 @@ def asaffine(self, indices=None):
192
205
@classmethod
193
206
def from_filename (cls , filename , fmt = "X5" , reference = None , moving = None , x5_chain = 0 ):
194
207
"""Load a transform file."""
195
- from .io import itk , x5 as x5io
196
- import h5py
197
- import nibabel as nb
198
- from collections import namedtuple
199
208
200
209
retval = []
201
210
if fmt and fmt .upper () == "X5" :
211
+ xfm_list = load_x5 (filename )
212
+ if not xfm_list :
213
+ raise TransformError ("Empty transform group" )
214
+
202
215
with h5py .File (str (filename ), "r" ) as f :
203
- if f .attrs .get ("Format" ) == "X5" :
204
- tg = [
205
- x5io ._read_x5_group (node )
206
- for _ , node in sorted (f ["TransformGroup" ].items (), key = lambda kv : int (kv [0 ]))
207
- ]
208
- chain_grp = f .get ("TransformChain" )
209
- if chain_grp is None :
210
- raise TransformError ("X5 file contains no TransformChain" )
211
-
212
- chain_path = chain_grp [str (x5_chain )][()]
213
- if isinstance (chain_path , bytes ):
214
- chain_path = chain_path .decode ()
215
- indices = [int (idx ) for idx in chain_path .split ("/" ) if idx ]
216
-
217
- Domain = namedtuple ("Domain" , "affine shape" )
218
- for idx in indices :
219
- node = tg [idx ]
220
- if node .type == "linear" :
221
- Transform = Affine if node .array_length == 1 else LinearTransformsMapping
222
- reference = None
223
- if node .domain is not None :
224
- reference = Domain (node .domain .mapping , node .domain .size )
225
- retval .append (Transform (node .transform , reference = reference ))
226
- elif node .type == "nonlinear" :
227
- reference = Domain (node .domain .mapping , node .domain .size )
228
- field = nb .Nifti1Image (node .transform , reference .affine )
229
- retval .append (
230
- DenseFieldTransform (
231
- field ,
232
- is_deltas = node .representation == "displacements" ,
233
- reference = reference ,
234
- )
235
- )
236
- else : # pragma: no cover - unsupported type
237
- raise NotImplementedError (f"Unsupported transform type { node .type } " )
238
-
239
- return TransformChain (retval )
216
+ chain_grp = f .get ("TransformChain" )
217
+ if chain_grp is None :
218
+ raise TransformError ("X5 file contains no TransformChain" )
219
+
220
+ chain_path = chain_grp [str (x5_chain )][()]
221
+ if isinstance (chain_path , bytes ):
222
+ chain_path = chain_path .decode ()
223
+
224
+ for idx in chain_path .split ("/" ):
225
+ node = x5io ._read_x5_group (xfm_list [int (idx )])
226
+ from_x5 = globals ()[f"{ node .type } _from_x5" ]
227
+ retval .append (from_x5 ([node ]))
228
+
229
+ return TransformChain (retval )
240
230
241
231
if str (filename ).endswith (".h5" ):
242
232
reference = None
@@ -253,57 +243,24 @@ def from_filename(cls, filename, fmt="X5", reference=None, moving=None, x5_chain
253
243
254
244
def to_filename (self , filename , fmt = "X5" ):
255
245
"""Store the transform chain in X5 format."""
256
- from .io import x5 as x5io
257
- import os
258
- import h5py
259
246
260
247
if fmt .upper () != "X5" :
261
248
raise NotImplementedError ("Only X5 format is supported for chains" )
262
249
263
- if os .path .exists (filename ):
264
- with h5py .File (str (filename ), "r" ) as f :
265
- existing = [
266
- x5io ._read_x5_group (node )
267
- for _ , node in sorted (f ["TransformGroup" ].items (), key = lambda kv : int (kv [0 ]))
268
- ]
269
- else :
270
- existing = []
271
-
272
- # convert to objects for equality check
273
- from collections import namedtuple
274
- import nibabel as nb
275
-
276
- def _as_transform (x5node ):
277
- Domain = namedtuple ("Domain" , "affine shape" )
278
- if x5node .type == "linear" :
279
- Transform = Affine if x5node .array_length == 1 else LinearTransformsMapping
280
- ref = None
281
- if x5node .domain is not None :
282
- ref = Domain (x5node .domain .mapping , x5node .domain .size )
283
- return Transform (x5node .transform , reference = ref )
284
- reference = Domain (x5node .domain .mapping , x5node .domain .size )
285
- field = nb .Nifti1Image (x5node .transform , reference .affine )
286
- return DenseFieldTransform (
287
- field ,
288
- is_deltas = x5node .representation == "displacements" ,
289
- reference = reference ,
290
- )
291
-
292
- existing_objs = [_as_transform (n ) for n in existing ]
293
- path_indices = []
250
+ existing = load_x5 (filename ) if os .path .exists (filename ) else []
251
+ xfm_chain = []
294
252
new_nodes = []
253
+ next_xfm_index = len (existing )
295
254
for xfm in self .transforms :
296
- # find existing
297
- idx = None
298
- for i , obj in enumerate (existing_objs ):
299
- if type (xfm ) is type (obj ) and xfm == obj :
300
- idx = i
255
+ for eidx , existing_xfm in enumerate (existing ):
256
+ if xfm == existing_xfm :
257
+ xfm_chain .append (eidx )
301
258
break
302
- if idx is None :
303
- idx = len ( existing_objs )
304
- new_nodes .append ((idx , xfm . to_x5 () ))
305
- existing_objs .append (xfm )
306
- path_indices . append ( idx )
259
+ else :
260
+ xfm_chain . append ( next_xfm_index )
261
+ new_nodes .append ((next_xfm_index , xfm ))
262
+ existing .append (xfm )
263
+ next_xfm_index += 1
307
264
308
265
mode = "r+" if os .path .exists (filename ) else "w"
309
266
with h5py .File (str (filename ), mode ) as f :
@@ -317,7 +274,7 @@ def _as_transform(x5node):
317
274
x5io ._write_x5_group (g , node )
318
275
319
276
cg = f .require_group ("TransformChain" )
320
- cg .create_dataset (str (len (cg )), data = "/" .join (str (i ) for i in path_indices ))
277
+ cg .create_dataset (str (len (cg )), data = "/" .join (str (i ) for i in xfm_chain ))
321
278
322
279
return filename
323
280
0 commit comments