14
14
TransformBase ,
15
15
TransformError ,
16
16
)
17
- from .linear import Affine
17
+ from .linear import Affine , LinearTransformsMapping
18
18
from .nonlinear import DenseFieldTransform
19
19
20
20
@@ -190,12 +190,15 @@ def asaffine(self, indices=None):
190
190
return retval
191
191
192
192
@classmethod
193
- def from_filename (cls , filename , fmt = "X5" , reference = None , moving = None ):
193
+ def from_filename (cls , filename , fmt = "X5" , reference = None , moving = None , x5_chain = 0 ):
194
194
"""Load a transform file."""
195
- from .io import itk
195
+ from .io import itk , x5 as x5io
196
+ import h5py
197
+ import nibabel as nb
198
+ from collections import namedtuple
196
199
197
200
retval = []
198
- if str (filename ).endswith (".h5" ):
201
+ if str (filename ).endswith (".h5" ) and ( fmt is None or fmt . upper () != "X5" ) :
199
202
reference = None
200
203
xforms = itk .ITKCompositeH5 .from_filename (filename )
201
204
for xfmobj in xforms :
@@ -206,8 +209,120 @@ def from_filename(cls, filename, fmt="X5", reference=None, moving=None):
206
209
207
210
return TransformChain (retval )
208
211
212
+ if fmt and fmt .upper () == "X5" :
213
+ with h5py .File (str (filename ), "r" ) as f :
214
+ if f .attrs .get ("Format" ) != "X5" :
215
+ raise TypeError ("Input file is not in X5 format" )
216
+
217
+ tg = [
218
+ x5io ._read_x5_group (node )
219
+ for _ , node in sorted (f ["TransformGroup" ].items (), key = lambda kv : int (kv [0 ]))
220
+ ]
221
+ chain_grp = f .get ("TransformChain" )
222
+ if chain_grp is None :
223
+ raise TransformError ("X5 file contains no TransformChain" )
224
+
225
+ chain_path = chain_grp [str (x5_chain )][()]
226
+ if isinstance (chain_path , bytes ):
227
+ chain_path = chain_path .decode ()
228
+ indices = [int (idx ) for idx in chain_path .split ("/" ) if idx ]
229
+
230
+ Domain = namedtuple ("Domain" , "affine shape" )
231
+ for idx in indices :
232
+ node = tg [idx ]
233
+ if node .type == "linear" :
234
+ Transform = Affine if node .array_length == 1 else LinearTransformsMapping
235
+ reference = None
236
+ if node .domain is not None :
237
+ reference = Domain (node .domain .mapping , node .domain .size )
238
+ retval .append (Transform (node .transform , reference = reference ))
239
+ elif node .type == "nonlinear" :
240
+ reference = Domain (node .domain .mapping , node .domain .size )
241
+ field = nb .Nifti1Image (node .transform , reference .affine )
242
+ retval .append (
243
+ DenseFieldTransform (
244
+ field ,
245
+ is_deltas = node .representation == "displacements" ,
246
+ reference = reference ,
247
+ )
248
+ )
249
+ else : # pragma: no cover - unsupported type
250
+ raise NotImplementedError (f"Unsupported transform type { node .type } " )
251
+
252
+ return TransformChain (retval )
253
+
209
254
raise NotImplementedError
210
255
256
+ def to_filename (self , filename , fmt = "X5" ):
257
+ """Store the transform chain in X5 format."""
258
+ from .io import x5 as x5io
259
+ import os
260
+ import h5py
261
+
262
+ if fmt .upper () != "X5" :
263
+ raise NotImplementedError ("Only X5 format is supported for chains" )
264
+
265
+ if os .path .exists (filename ):
266
+ with h5py .File (str (filename ), "r" ) as f :
267
+ existing = [
268
+ x5io ._read_x5_group (node )
269
+ for _ , node in sorted (f ["TransformGroup" ].items (), key = lambda kv : int (kv [0 ]))
270
+ ]
271
+ else :
272
+ existing = []
273
+
274
+ # convert to objects for equality check
275
+ from collections import namedtuple
276
+ import nibabel as nb
277
+
278
+ def _as_transform (x5node ):
279
+ Domain = namedtuple ("Domain" , "affine shape" )
280
+ if x5node .type == "linear" :
281
+ Transform = Affine if x5node .array_length == 1 else LinearTransformsMapping
282
+ ref = None
283
+ if x5node .domain is not None :
284
+ ref = Domain (x5node .domain .mapping , x5node .domain .size )
285
+ return Transform (x5node .transform , reference = ref )
286
+ reference = Domain (x5node .domain .mapping , x5node .domain .size )
287
+ field = nb .Nifti1Image (x5node .transform , reference .affine )
288
+ return DenseFieldTransform (
289
+ field ,
290
+ is_deltas = x5node .representation == "displacements" ,
291
+ reference = reference ,
292
+ )
293
+
294
+ existing_objs = [_as_transform (n ) for n in existing ]
295
+ path_indices = []
296
+ new_nodes = []
297
+ for xfm in self .transforms :
298
+ # find existing
299
+ idx = None
300
+ for i , obj in enumerate (existing_objs ):
301
+ if type (xfm ) is type (obj ) and xfm == obj :
302
+ idx = i
303
+ break
304
+ if idx is None :
305
+ idx = len (existing_objs )
306
+ new_nodes .append ((idx , xfm .to_x5 ()))
307
+ existing_objs .append (xfm )
308
+ path_indices .append (idx )
309
+
310
+ mode = "r+" if os .path .exists (filename ) else "w"
311
+ with h5py .File (str (filename ), mode ) as f :
312
+ if "Format" not in f .attrs :
313
+ f .attrs ["Format" ] = "X5"
314
+ f .attrs ["Version" ] = np .uint16 (1 )
315
+
316
+ tg = f .require_group ("TransformGroup" )
317
+ for idx , node in new_nodes :
318
+ g = tg .create_group (str (idx ))
319
+ x5io ._write_x5_group (g , node )
320
+
321
+ cg = f .require_group ("TransformChain" )
322
+ cg .create_dataset (str (len (cg )), data = "/" .join (str (i ) for i in path_indices ))
323
+
324
+ return filename
325
+
211
326
212
327
def _as_chain (x ):
213
328
"""Convert a value into a transform chain."""
0 commit comments