7
7
#
8
8
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
9
9
"""Nonlinear transforms."""
10
+
10
11
import warnings
11
12
from functools import partial
13
+ from collections import namedtuple
12
14
import numpy as np
15
+ import nibabel as nb
13
16
14
17
from nitransforms import io
15
18
from nitransforms .io .base import _ensure_image
19
+ from nitransforms .io .x5 import from_filename as load_x5
16
20
from nitransforms .interp .bspline import grid_bspline_weights , _cubic_bspline
17
21
from nitransforms .base import (
18
22
TransformBase ,
22
26
)
23
27
from scipy .ndimage import map_coordinates
24
28
29
+ # Avoids circular imports
30
+ try :
31
+ from nitransforms ._version import __version__
32
+ except ModuleNotFoundError : # pragma: no cover
33
+ __version__ = "0+unknown"
34
+
25
35
26
36
class DenseFieldTransform (TransformBase ):
27
37
"""Represents dense field (voxel-wise) transforms."""
28
38
29
- __slots__ = ("_field" , "_deltas" )
39
+ __slots__ = ("_field" , "_deltas" , "_is_deltas" )
30
40
31
41
def __init__ (self , field = None , is_deltas = True , reference = None ):
32
42
"""
@@ -60,14 +70,7 @@ def __init__(self, field=None, is_deltas=True, reference=None):
60
70
61
71
super ().__init__ ()
62
72
63
- if field is not None :
64
- field = _ensure_image (field )
65
- self ._field = np .squeeze (
66
- np .asanyarray (field .dataobj ) if hasattr (field , "dataobj" ) else field
67
- )
68
- else :
69
- self ._field = np .zeros ((* reference .shape , reference .ndim ), dtype = "float32" )
70
- is_deltas = True
73
+ self ._is_deltas = is_deltas
71
74
72
75
try :
73
76
self .reference = ImageGrid (reference if reference is not None else field )
@@ -78,22 +81,44 @@ def __init__(self, field=None, is_deltas=True, reference=None):
78
81
else "Reference is not a spatial image"
79
82
)
80
83
84
+ fieldshape = (* self .reference .shape , self .reference .ndim )
85
+ if field is not None :
86
+ field = _ensure_image (field )
87
+ self ._field = np .squeeze (
88
+ np .asanyarray (field .dataobj ) if hasattr (field , "dataobj" ) else field
89
+ )
90
+ if fieldshape != self ._field .shape :
91
+ raise TransformError (
92
+ f"Shape of the field ({ 'x' .join (str (i ) for i in self ._field .shape )} ) "
93
+ f"doesn't match that of the reference({ 'x' .join (str (i ) for i in fieldshape )} )"
94
+ )
95
+ else :
96
+ self ._field = np .zeros (fieldshape , dtype = "float32" )
97
+ self ._is_deltas = True
98
+
81
99
if self ._field .shape [- 1 ] != self .ndim :
82
100
raise TransformError (
83
101
"The number of components of the field (%d) does not match "
84
102
"the number of dimensions (%d)" % (self ._field .shape [- 1 ], self .ndim )
85
103
)
86
104
87
- if is_deltas :
88
- self ._deltas = self ._field
105
+ if self ._is_deltas :
106
+ self ._deltas = (
107
+ self ._field .copy ()
108
+ ) # IMPORTANT: you don't want to update deltas
89
109
# Convert from displacements (deltas) to deformations fields
90
110
# (just add its origin to each delta vector)
91
- self ._field += self .reference .ndcoords .T .reshape (self . _field . shape )
111
+ self ._field += self .reference .ndcoords .T .reshape (fieldshape )
92
112
93
113
def __repr__ (self ):
94
114
"""Beautify the python representation."""
95
115
return f"<{ self .__class__ .__name__ } [{ self ._field .shape [- 1 ]} D] { self ._field .shape [:3 ]} >"
96
116
117
+ @property
118
+ def is_deltas (self ):
119
+ """Check whether this is a displacements (``True``) or a deformation (``False``) field."""
120
+ return self ._is_deltas
121
+
97
122
@property
98
123
def ndim (self ):
99
124
"""Get the dimensions of the transform."""
@@ -222,22 +247,60 @@ def __eq__(self, other):
222
247
True
223
248
224
249
"""
225
- _eq = np .array_equal (self ._field , other ._field )
250
+ _eq = np .allclose (self ._field , other ._field )
226
251
if _eq and self ._reference != other ._reference :
227
252
warnings .warn ("Fields are equal, but references do not match." )
228
253
return _eq
229
254
255
+ def to_x5 (self , metadata = None ):
256
+ """Return an :class:`~nitransforms.io.x5.X5Transform` representation."""
257
+ metadata = {"WrittenBy" : f"NiTransforms { __version__ } " } | (metadata or {})
258
+
259
+ domain = None
260
+ if (reference := self .reference ) is not None :
261
+ domain = io .x5 .X5Domain (
262
+ grid = True ,
263
+ size = getattr (reference , "shape" , (0 , 0 , 0 )),
264
+ mapping = reference .affine ,
265
+ coordinates = "cartesian" ,
266
+ )
267
+
268
+ kinds = tuple ("space" for _ in range (self .ndim )) + ("vector" ,)
269
+
270
+ return io .x5 .X5Transform (
271
+ type = "nonlinear" ,
272
+ subtype = "densefield" ,
273
+ representation = "displacements" if self .is_deltas else "deformations" ,
274
+ metadata = metadata ,
275
+ transform = self ._deltas if self .is_deltas else self ._field ,
276
+ dimension_kinds = kinds ,
277
+ domain = domain ,
278
+ )
279
+
230
280
@classmethod
231
281
def from_filename (cls , filename , fmt = "X5" ):
232
282
_factory = {
233
283
"afni" : io .afni .AFNIDisplacementsField ,
234
284
"itk" : io .itk .ITKDisplacementsField ,
235
285
"fsl" : io .fsl .FSLDisplacementsField ,
286
+ "X5" : None ,
236
287
}
237
- if fmt not in _factory :
288
+ fmt = fmt .upper ()
289
+ if fmt not in {k .upper () for k in _factory }:
238
290
raise NotImplementedError (f"Unsupported format <{ fmt } >" )
239
291
240
- return cls (_factory [fmt ].from_filename (filename ))
292
+ if fmt == "X5" :
293
+ x5_xfm = load_x5 (filename )[0 ]
294
+ Domain = namedtuple ("Domain" , "affine shape" )
295
+ reference = Domain (x5_xfm .domain .mapping , x5_xfm .domain .size )
296
+ field = nb .Nifti1Image (x5_xfm .transform , reference .affine )
297
+ return cls (
298
+ field ,
299
+ is_deltas = x5_xfm .representation == "displacements" ,
300
+ reference = reference ,
301
+ )
302
+
303
+ return cls (_factory [fmt .lower ()].from_filename (filename ))
241
304
242
305
243
306
load = DenseFieldTransform .from_filename
@@ -272,6 +335,24 @@ def ndim(self):
272
335
"""Get the dimensions of the transform."""
273
336
return self ._coeffs .ndim - 1
274
337
338
+ @classmethod
339
+ def from_filename (cls , filename , fmt = "X5" ):
340
+ _factory = {
341
+ "X5" : None ,
342
+ }
343
+ fmt = fmt .upper ()
344
+ if fmt not in {k .upper () for k in _factory }:
345
+ raise NotImplementedError (f"Unsupported format <{ fmt } >" )
346
+
347
+ x5_xfm = load_x5 (filename )[0 ]
348
+ Domain = namedtuple ("Domain" , "affine shape" )
349
+ reference = Domain (x5_xfm .domain .mapping , x5_xfm .domain .size )
350
+
351
+ coefficients = nb .Nifti1Image (x5_xfm .transform , x5_xfm .additional_parameters )
352
+ return cls (coefficients , reference = reference )
353
+
354
+ # return cls(_factory[fmt.lower()].from_filename(filename))
355
+
275
356
def to_field (self , reference = None , dtype = "float32" ):
276
357
"""Generate a displacements deformation field from this B-Spline field."""
277
358
_ref = (
@@ -293,6 +374,32 @@ def to_field(self, reference=None, dtype="float32"):
293
374
field .astype (dtype ).reshape (* _ref .shape , - 1 ), reference = _ref
294
375
)
295
376
377
+ def to_x5 (self , metadata = None ):
378
+ """Return an :class:`~nitransforms.io.x5.X5Transform` representation."""
379
+ metadata = {"WrittenBy" : f"NiTransforms { __version__ } " } | (metadata or {})
380
+
381
+ domain = None
382
+ if (reference := self .reference ) is not None :
383
+ domain = io .x5 .X5Domain (
384
+ grid = True ,
385
+ size = getattr (reference , "shape" , (0 , 0 , 0 )),
386
+ mapping = reference .affine ,
387
+ coordinates = "cartesian" ,
388
+ )
389
+
390
+ kinds = tuple ("space" for _ in range (self .ndim )) + ("vector" ,)
391
+
392
+ return io .x5 .X5Transform (
393
+ type = "nonlinear" ,
394
+ subtype = "bspline" ,
395
+ representation = "coefficients" ,
396
+ metadata = metadata ,
397
+ transform = self ._coeffs ,
398
+ dimension_kinds = kinds ,
399
+ domain = domain ,
400
+ additional_parameters = self ._knots .affine ,
401
+ )
402
+
296
403
def map (self , x , inverse = False ):
297
404
r"""
298
405
Apply the transformation to a list of physical coordinate points.
0 commit comments