12
12
from functools import partial
13
13
from collections import namedtuple
14
14
import numpy as np
15
+ import nibabel as nb
15
16
16
17
from nitransforms import io
17
18
from nitransforms .io .base import _ensure_image
19
+ from nitransforms .io .x5 import from_filename as load_x5
18
20
from nitransforms .interp .bspline import grid_bspline_weights , _cubic_bspline
19
21
from nitransforms .base import (
20
22
TransformBase ,
34
36
class DenseFieldTransform (TransformBase ):
35
37
"""Represents dense field (voxel-wise) transforms."""
36
38
37
- __slots__ = ("_field" , "_deltas" )
39
+ __slots__ = ("_field" , "_deltas" , "_is_deltas" )
38
40
39
41
def __init__ (self , field = None , is_deltas = True , reference = None ):
40
42
"""
@@ -68,14 +70,7 @@ def __init__(self, field=None, is_deltas=True, reference=None):
68
70
69
71
super ().__init__ ()
70
72
71
- if field is not None :
72
- field = _ensure_image (field )
73
- self ._field = np .squeeze (
74
- np .asanyarray (field .dataobj ) if hasattr (field , "dataobj" ) else field
75
- )
76
- else :
77
- self ._field = np .zeros ((* reference .shape , reference .ndim ), dtype = "float32" )
78
- is_deltas = True
73
+ self ._is_deltas = is_deltas
79
74
80
75
try :
81
76
self .reference = ImageGrid (reference if reference is not None else field )
@@ -86,22 +81,44 @@ def __init__(self, field=None, is_deltas=True, reference=None):
86
81
else "Reference is not a spatial image"
87
82
)
88
83
84
+ fieldshape = (* self .reference .shape , self .reference .ndim )
85
+ if field is not None :
86
+ field = _ensure_image (field )
87
+ self ._field = np .squeeze (
88
+ np .asanyarray (field .dataobj ) if hasattr (field , "dataobj" ) else field
89
+ )
90
+ if fieldshape != self ._field .shape :
91
+ raise TransformError (
92
+ f"Shape of the field ({ 'x' .join (str (i ) for i in self ._field .shape )} ) "
93
+ f"doesn't match that of the reference({ 'x' .join (str (i ) for i in fieldshape )} )"
94
+ )
95
+ else :
96
+ self ._field = np .zeros (fieldshape , dtype = "float32" )
97
+ self ._is_deltas = True
98
+
89
99
if self ._field .shape [- 1 ] != self .ndim :
90
100
raise TransformError (
91
101
"The number of components of the field (%d) does not match "
92
102
"the number of dimensions (%d)" % (self ._field .shape [- 1 ], self .ndim )
93
103
)
94
104
95
- if is_deltas :
96
- self ._deltas = self ._field
105
+ if self ._is_deltas :
106
+ self ._deltas = (
107
+ self ._field .copy ()
108
+ ) # IMPORTANT: you don't want to update deltas
97
109
# Convert from displacements (deltas) to deformations fields
98
110
# (just add its origin to each delta vector)
99
- self ._field += self .reference .ndcoords .T .reshape (self . _field . shape )
111
+ self ._field += self .reference .ndcoords .T .reshape (fieldshape )
100
112
101
113
def __repr__ (self ):
102
114
"""Beautify the python representation."""
103
115
return f"<{ self .__class__ .__name__ } [{ self ._field .shape [- 1 ]} D] { self ._field .shape [:3 ]} >"
104
116
117
+ @property
118
+ def is_deltas (self ):
119
+ """Check whether this is a displacements (``True``) or a deformation (``False``) field."""
120
+ return self ._is_deltas
121
+
105
122
@property
106
123
def ndim (self ):
107
124
"""Get the dimensions of the transform."""
@@ -230,7 +247,7 @@ def __eq__(self, other):
230
247
True
231
248
232
249
"""
233
- _eq = np .array_equal (self ._field , other ._field )
250
+ _eq = np .allclose (self ._field , other ._field )
234
251
if _eq and self ._reference != other ._reference :
235
252
warnings .warn ("Fields are equal, but references do not match." )
236
253
return _eq
@@ -253,9 +270,9 @@ def to_x5(self, metadata=None):
253
270
return io .x5 .X5Transform (
254
271
type = "nonlinear" ,
255
272
subtype = "densefield" ,
256
- representation = "displacements" ,
273
+ representation = "displacements" if self . is_deltas else "deformations" ,
257
274
metadata = metadata ,
258
- transform = self ._deltas ,
275
+ transform = self ._deltas if self . is_deltas else self . _field ,
259
276
dimension_kinds = kinds ,
260
277
domain = domain ,
261
278
)
@@ -273,12 +290,15 @@ def from_filename(cls, filename, fmt="X5"):
273
290
raise NotImplementedError (f"Unsupported format <{ fmt } >" )
274
291
275
292
if fmt == "X5" :
276
- from .io .x5 import from_filename as load_x5
277
-
278
293
x5_xfm = load_x5 (filename )[0 ]
279
294
Domain = namedtuple ("Domain" , "affine shape" )
280
295
reference = Domain (x5_xfm .domain .mapping , x5_xfm .domain .size )
281
- return cls (x5_xfm .transform , is_deltas = True , reference = reference )
296
+ field = nb .Nifti1Image (x5_xfm .transform , reference .affine )
297
+ return cls (
298
+ field ,
299
+ is_deltas = x5_xfm .representation == "displacements" ,
300
+ reference = reference ,
301
+ )
282
302
283
303
return cls (_factory [fmt .lower ()].from_filename (filename ))
284
304
@@ -315,6 +335,24 @@ def ndim(self):
315
335
"""Get the dimensions of the transform."""
316
336
return self ._coeffs .ndim - 1
317
337
338
+ @classmethod
339
+ def from_filename (cls , filename , fmt = "X5" ):
340
+ _factory = {
341
+ "X5" : None ,
342
+ }
343
+ fmt = fmt .upper ()
344
+ if fmt not in {k .upper () for k in _factory }:
345
+ raise NotImplementedError (f"Unsupported format <{ fmt } >" )
346
+
347
+ x5_xfm = load_x5 (filename )[0 ]
348
+ Domain = namedtuple ("Domain" , "affine shape" )
349
+ reference = Domain (x5_xfm .domain .mapping , x5_xfm .domain .size )
350
+
351
+ coefficients = nb .Nifti1Image (x5_xfm .transform , x5_xfm .additional_parameters )
352
+ return cls (coefficients , reference = reference )
353
+
354
+ # return cls(_factory[fmt.lower()].from_filename(filename))
355
+
318
356
def to_field (self , reference = None , dtype = "float32" ):
319
357
"""Generate a displacements deformation field from this B-Spline field."""
320
358
_ref = (
@@ -349,21 +387,17 @@ def to_x5(self, metadata=None):
349
387
coordinates = "cartesian" ,
350
388
)
351
389
352
- meta = metadata | {
353
- "KnotsAffine" : self ._knots .affine .tolist (),
354
- "KnotsShape" : self ._knots .shape ,
355
- }
356
-
357
390
kinds = tuple ("space" for _ in range (self .ndim )) + ("vector" ,)
358
391
359
392
return io .x5 .X5Transform (
360
393
type = "nonlinear" ,
361
394
subtype = "bspline" ,
362
395
representation = "coefficients" ,
363
- metadata = meta ,
396
+ metadata = metadata ,
364
397
transform = self ._coeffs ,
365
398
dimension_kinds = kinds ,
366
399
domain = domain ,
400
+ additional_parameters = self ._knots .affine ,
367
401
)
368
402
369
403
def map (self , x , inverse = False ):
0 commit comments