@@ -88,6 +88,7 @@ def __init__(
8888 padding_mode : Union [GridSamplePadMode , str ] = GridSamplePadMode .BORDER ,
8989 align_corners : bool = False ,
9090 dtype : DtypeLike = np .float64 ,
91+ image_only : bool = False ,
9192 ) -> None :
9293 """
9394 Args:
@@ -114,13 +115,16 @@ def __init__(
114115 dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision.
115116 If None, use the data type of input data. To be compatible with other modules,
116117 the output data type is always ``np.float32``.
118+ image_only: if True return only the image volume, otherwise return (image, original affine, new affine).
119+
117120 """
118121 self .pixdim = np .array (ensure_tuple (pixdim ), dtype = np .float64 )
119122 self .diagonal = diagonal
120123 self .mode : GridSampleMode = GridSampleMode (mode )
121124 self .padding_mode : GridSamplePadMode = GridSamplePadMode (padding_mode )
122125 self .align_corners = align_corners
123126 self .dtype = dtype
127+ self .image_only = image_only
124128
125129 def __call__ (
126130 self ,
@@ -131,7 +135,7 @@ def __call__(
131135 align_corners : Optional [bool ] = None ,
132136 dtype : DtypeLike = None ,
133137 output_spatial_shape : Optional [np .ndarray ] = None ,
134- ) -> Tuple [np .ndarray , np .ndarray , np .ndarray ]:
138+ ) -> Union [ np . ndarray , Tuple [np .ndarray , np .ndarray , np .ndarray ] ]:
135139 """
136140 Args:
137141 data_array: in shape (num_channels, H[, W, ...]).
@@ -204,7 +208,8 @@ def __call__(
204208 )
205209 output_data = np .asarray (output_data .squeeze (0 ).detach ().cpu ().numpy (), dtype = np .float32 ) # type: ignore
206210 new_affine = to_affine_nd (affine , new_affine )
207- return output_data , affine , new_affine
211+
212+ return output_data if self .image_only else (output_data , affine , new_affine )
208213
209214
210215class Orientation (Transform ):
@@ -217,6 +222,7 @@ def __init__(
217222 axcodes : Optional [str ] = None ,
218223 as_closest_canonical : bool = False ,
219224 labels : Optional [Sequence [Tuple [str , str ]]] = tuple (zip ("LPI" , "RAS" )),
225+ image_only : bool = False ,
220226 ) -> None :
221227 """
222228 Args:
@@ -229,6 +235,7 @@ def __init__(
229235 labels: optional, None or sequence of (2,) sequences
230236 (2,) sequences are labels for (beginning, end) of output axis.
231237 Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``.
238+ image_only: if True return only the image volume, otherwise return (image, original affine, new affine).
232239
233240 Raises:
234241 ValueError: When ``axcodes=None`` and ``as_closest_canonical=True``. Incompatible values.
@@ -243,10 +250,11 @@ def __init__(
243250 self .axcodes = axcodes
244251 self .as_closest_canonical = as_closest_canonical
245252 self .labels = labels
253+ self .image_only = image_only
246254
247255 def __call__ (
248256 self , data_array : np .ndarray , affine : Optional [np .ndarray ] = None
249- ) -> Tuple [np .ndarray , np .ndarray , np .ndarray ]:
257+ ) -> Union [ np . ndarray , Tuple [np .ndarray , np .ndarray , np .ndarray ] ]:
250258 """
251259 original orientation of `data_array` is defined by `affine`.
252260
@@ -289,7 +297,8 @@ def __call__(
289297 data_array = np .ascontiguousarray (nib .orientations .apply_orientation (data_array , ornt ))
290298 new_affine = affine_ @ nib .orientations .inv_ornt_aff (spatial_ornt , shape )
291299 new_affine = to_affine_nd (affine , new_affine )
292- return data_array , affine , new_affine
300+
301+ return data_array if self .image_only else (data_array , affine , new_affine )
293302
294303
295304class Flip (Transform ):
@@ -1270,6 +1279,7 @@ def __init__(
12701279 padding_mode : Union [GridSamplePadMode , str ] = GridSamplePadMode .REFLECTION ,
12711280 as_tensor_output : bool = False ,
12721281 device : Optional [torch .device ] = None ,
1282+ image_only : bool = False ,
12731283 ) -> None :
12741284 """
12751285 The affine transformations are applied in rotate, shear, translate, scale order.
@@ -1296,6 +1306,7 @@ def __init__(
12961306 as_tensor_output: the computation is implemented using pytorch tensors, this option specifies
12971307 whether to convert it back to numpy arrays.
12981308 device: device on which the tensor will be allocated.
1309+ image_only: if True return only the image volume, otherwise return (image, affine).
12991310 """
13001311 self .affine_grid = AffineGrid (
13011312 rotate_params = rotate_params ,
@@ -1305,6 +1316,7 @@ def __init__(
13051316 as_tensor_output = True ,
13061317 device = device ,
13071318 )
1319+ self .image_only = image_only
13081320 self .resampler = Resample (as_tensor_output = as_tensor_output , device = device )
13091321 self .spatial_size = spatial_size
13101322 self .mode : GridSampleMode = GridSampleMode (mode )
@@ -1316,7 +1328,7 @@ def __call__(
13161328 spatial_size : Optional [Union [Sequence [int ], int ]] = None ,
13171329 mode : Optional [Union [GridSampleMode , str ]] = None ,
13181330 padding_mode : Optional [Union [GridSamplePadMode , str ]] = None ,
1319- ) -> Tuple [ Union [ np . ndarray , torch . Tensor ], Union [ np . ndarray , torch . Tensor ]] :
1331+ ):
13201332 """
13211333 Args:
13221334 img: shape must be (num_channels, H, W[, D]),
@@ -1334,10 +1346,9 @@ def __call__(
13341346 """
13351347 sp_size = fall_back_tuple (spatial_size or self .spatial_size , img .shape [1 :])
13361348 grid , affine = self .affine_grid (spatial_size = sp_size )
1337- return (
1338- self .resampler (img = img , grid = grid , mode = mode or self .mode , padding_mode = padding_mode or self .padding_mode ),
1339- affine ,
1340- )
1349+ ret = self .resampler (img , grid = grid , mode = mode or self .mode , padding_mode = padding_mode or self .padding_mode )
1350+
1351+ return ret if self .image_only else (ret , affine )
13411352
13421353
13431354class RandAffine (RandomizableTransform ):
0 commit comments