diff --git a/Demos/Python/Demo_CuPy_3D.py b/Demos/Python/Demo_CuPy_3D.py index e85d9893..1c439f0c 100644 --- a/Demos/Python/Demo_CuPy_3D.py +++ b/Demos/Python/Demo_CuPy_3D.py @@ -204,217 +204,3 @@ Qtools = QualityTools(phantom_tm, Fourier_cupy) RMSE = Qtools.rmse() print("Root Mean Square Error is {} for Fourier inversion".format(RMSE)) -# %% -print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") -print("%%%%%%%%Reconstructing using Landweber algorithm %%%%%%%%%%%") -print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") -RecToolsCP_iter = RecToolsIRCuPy( - DetectorsDimH=Horiz_det, # Horizontal detector dimension - DetectorsDimV=Vert_det, # Vertical detector dimension (3D case) - CenterRotOffset=0.0, # Center of Rotation scalar or a vector - AnglesVec=angles_rad, # A vector of projection angles in radians - ObjSize=N_size, # Reconstructed object dimensions (scalar) - datafidelity="LS", # Data fidelity, choose from LS, KL, PWLS, SWLS - device_projector="gpu", -) - -# prepare dictionaries with parameters: -_data_ = { - "projection_norm_data": projData3D_analyt_cupy, - "data_axes_labels_order": input_data_labels, -} # data dictionary - -LWrec_cupy = RecToolsCP_iter.Landweber(_data_) - -lwrec = cp.asnumpy(LWrec_cupy) - -sliceSel = int(0.5 * N_size) -plt.figure() -plt.subplot(131) -plt.imshow(lwrec[sliceSel, :, :]) -plt.title("3D Landweber Reconstruction, axial view") - -plt.subplot(132) -plt.imshow(lwrec[:, sliceSel, :]) -plt.title("3D Landweber Reconstruction, coronal view") - -plt.subplot(133) -plt.imshow(lwrec[:, :, sliceSel]) -plt.title("3D Landweber Reconstruction, sagittal view") -plt.show() -# %% -print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") -print("%%%%%%%%%% Reconstructing using SIRT algorithm %%%%%%%%%%%%%") -print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") -RecToolsCP_iter = RecToolsIRCuPy( - DetectorsDimH=Horiz_det, # Horizontal detector dimension - DetectorsDimV=Vert_det, # Vertical detector dimension (3D case) - CenterRotOffset=0.0, # Center of Rotation scalar or a vector - AnglesVec=angles_rad, # A vector of projection angles in radians - ObjSize=N_size, # Reconstructed object dimensions (scalar) - device_projector="gpu", -) - -# prepare dictionaries with parameters: -_data_ = { - "projection_norm_data": projData3D_analyt_cupy, - "data_axes_labels_order": input_data_labels, -} - -_algorithm_ = {"iterations": 300, "nonnegativity": True} - -SIRTrec_cupy = RecToolsCP_iter.SIRT(_data_, _algorithm_) - -sirt_rec = cp.asnumpy(SIRTrec_cupy) - -sliceSel = int(0.5 * N_size) -plt.figure() -plt.subplot(131) -plt.imshow(sirt_rec[sliceSel, :, :]) -plt.title("3D SIRT Reconstruction, axial view") - -plt.subplot(132) -plt.imshow(sirt_rec[:, sliceSel, :]) -plt.title("3D SIRT Reconstruction, coronal view") - -plt.subplot(133) -plt.imshow(sirt_rec[:, :, sliceSel]) -plt.title("3D SIRT Reconstruction, sagittal view") -plt.show() -# %% -print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") -print("%%%%%%%%%% Reconstructing using CGLS algorithm %%%%%%%%%%%%%") -print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") -RecToolsCP_iter = RecToolsIRCuPy( - DetectorsDimH=Horiz_det, # Horizontal detector dimension - DetectorsDimV=Vert_det, # Vertical detector dimension (3D case) - CenterRotOffset=0.0, # Center of Rotation scalar or a vector - AnglesVec=angles_rad, # A vector of projection angles in radians - ObjSize=N_size, # Reconstructed object dimensions (scalar) - device_projector="gpu", -) - -# prepare dictionaries with parameters: -_data_ = { - "projection_norm_data": projData3D_analyt_cupy, - "data_axes_labels_order": input_data_labels, -} - -_algorithm_ = {"iterations": 20, "nonnegativity": True} -CGLSrec_cupy = RecToolsCP_iter.CGLS(_data_, _algorithm_) - -cgls_rec = cp.asnumpy(CGLSrec_cupy) - -sliceSel = int(0.5 * N_size) -plt.figure() -plt.subplot(131) -plt.imshow(cgls_rec[sliceSel, :, :]) -plt.title("3D CGLS Reconstruction, axial view") - -plt.subplot(132) -plt.imshow(cgls_rec[:, sliceSel, :]) -plt.title("3D CGLS Reconstruction, coronal view") - -plt.subplot(133) -plt.imshow(cgls_rec[:, :, sliceSel]) -plt.title("3D CGLS Reconstruction, sagittal view") -plt.show() -# %% -print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") -print("%%%%%%%%%% Reconstructing using FISTA algorithm %%%%%%%%%%%%") -print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") -RecToolsCP_iter = RecToolsIRCuPy( - DetectorsDimH=Horiz_det, # Horizontal detector dimension - DetectorsDimV=Vert_det, # Vertical detector dimension (3D case) - CenterRotOffset=0.0, # Center of Rotation scalar or a vector - AnglesVec=angles_rad, # A vector of projection angles in radians - ObjSize=N_size, # Reconstructed object dimensions (scalar) - datafidelity="LS", - device_projector=0, -) - -# prepare dictionaries with parameters: -_data_ = { - "projection_norm_data": projData3D_analyt_cupy, - "data_axes_labels_order": input_data_labels, -} # data dictionary - -lc = RecToolsCP_iter.powermethod(_data_) - -_algorithm_ = {"iterations": 300, "lipschitz_const": lc.get()} - -start_time = timeit.default_timer() -RecFISTA = RecToolsCP_iter.FISTA(_data_, _algorithm_, _regularisation_={}) -txtstr = "%s = %.3fs" % ("elapsed time", timeit.default_timer() - start_time) -print(txtstr) - -fista_rec_np = cp.asnumpy(RecFISTA) - -sliceSel = int(0.5 * N_size) -plt.figure() -plt.subplot(131) -plt.imshow(fista_rec_np[sliceSel, :, :]) -plt.title("3D FISTA Reconstruction, axial view") - -plt.subplot(132) -plt.imshow(fista_rec_np[:, sliceSel, :]) -plt.title("3D FISTA Reconstruction, coronal view") - -plt.subplot(133) -plt.imshow(fista_rec_np[:, :, sliceSel]) -plt.title("3D FISTA Reconstruction, sagittal view") -plt.show() -# %% -print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") -print("%%%%% Reconstructing using regularised FISTA-OS algorithm %%") -print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") -# NOTE that you'd need to install CuPy modules for the regularisers from the regularisation toolkit -RecToolsCP_iter = RecToolsIRCuPy( - DetectorsDimH=Horiz_det, # Horizontal detector dimension - DetectorsDimV=Vert_det, # Vertical detector dimension (3D case) - CenterRotOffset=0.0, # Center of Rotation scalar or a vector - AnglesVec=angles_rad, # A vector of projection angles in radians - ObjSize=N_size, # Reconstructed object dimensions (scalar) - datafidelity="LS", # Data fidelity, choose from LS, KL, PWLS - device_projector=0, -) - -start_time = timeit.default_timer() -# prepare dictionaries with parameters: -_data_ = { - "projection_norm_data": projData3D_analyt_cupy, - "OS_number": 8, - "data_axes_labels_order": input_data_labels, -} # data dictionary - -lc = RecToolsCP_iter.powermethod(_data_) -_algorithm_ = {"iterations": 15, "lipschitz_const": lc.get()} - -_regularisation_ = { - "method": "PD_TV", - "regul_param": 0.0005, - "iterations": 35, - "device_regulariser": 0, -} - -RecFISTA = RecToolsCP_iter.FISTA(_data_, _algorithm_, _regularisation_) -txtstr = "%s = %.3fs" % ("elapsed time", timeit.default_timer() - start_time) -print(txtstr) - -fista_rec_np = cp.asnumpy(RecFISTA) - -sliceSel = int(0.5 * N_size) -plt.figure() -plt.subplot(131) -plt.imshow(fista_rec_np[sliceSel, :, :]) -plt.title("3D FISTA-OS Reconstruction, axial view") - -plt.subplot(132) -plt.imshow(fista_rec_np[:, sliceSel, :]) -plt.title("3D FISTA-OS Reconstruction, coronal view") - -plt.subplot(133) -plt.imshow(fista_rec_np[:, :, sliceSel]) -plt.title("3D FISTA-OS Reconstruction, sagittal view") -plt.show() -# %% diff --git a/Demos/Python/Demo_CuPy_3D_search_optimal.py b/Demos/Python/Demo_CuPy_3D_search_optimal.py new file mode 100644 index 00000000..a3a15cb0 --- /dev/null +++ b/Demos/Python/Demo_CuPy_3D_search_optimal.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +GPLv3 license (ASTRA toolbox) + +Script that demonstrates the reconstruction of CuPy arrays while keeping +the data on the GPU (device-to-device) + +Dependencies: + * astra-toolkit, install conda install -c astra-toolbox astra-toolbox + * TomoPhantom, https://github.com/dkazanc/TomoPhantom + * CuPy package + +@author: Daniil Kazantsev +""" +import timeit +import os +import matplotlib.pyplot as plt +import numpy as np +import cupy as cp +import tomophantom +from tomophantom import TomoP3D +from tomophantom.qualitymetrics import QualityTools +from tomobar.methodsDIR_CuPy import RecToolsDIRCuPy +from tomobar.methodsIR_CuPy import RecToolsIRCuPy + +print("center_size, phantom size, angle number, slice number, time") + +for N_size in [512, 1024, 1536, 2048]: + for angles_num in [128, 256, 512, 1024, 1536, 2048]: + + # print("Building 3D phantom using TomoPhantom software") + tic = timeit.default_timer() + model = 13 # select a model number from the library + # N_size = 256 # Define phantom dimensions using a scalar value (cubic phantom) + path = os.path.dirname(tomophantom.__file__) + path_library3D = os.path.join(path, "phantomlib", "Phantom3DLibrary.dat") + + phantom_tm = TomoP3D.Model(model, N_size, path_library3D) + + # Projection geometry related parameters: + Horiz_det = int(np.sqrt(2) * N_size) # detector column count (horizontal) + Vert_det = N_size # detector row count (vertical) (no reason for it to be > N) + # angles_num = int(0.3 * np.pi * N_size) # angles number + angles = np.linspace(0.0, 179.9, angles_num, dtype="float32") # in degrees + angles_rad = angles * (np.pi / 180.0) + + # print("Generate 3D analytical projection data with TomoPhantom") + projData3D_analyt = TomoP3D.ModelSino( + model, N_size, Horiz_det, Vert_det, angles, path_library3D + ) + input_data_labels = ["detY", "angles", "detX"] + + # print(np.shape(projData3D_analyt)) + + # transfering numpy array to CuPy array + projData3D_analyt_cupy = cp.asarray(projData3D_analyt, order="C") + + for slice_number in [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 18, 20]: + + RecToolsCP = RecToolsDIRCuPy( + DetectorsDimH=Horiz_det, # Horizontal detector dimension + DetectorsDimV=slice_number, # Vertical detector dimension (3D case) + CenterRotOffset=0.0, # Center of Rotation scalar or a vector + AnglesVec=angles_rad, # A vector of projection angles in radians + ObjSize=N_size, # Reconstructed object dimensions (scalar) + device_projector="gpu", + ) + + # tic = timeit.default_timer() + # for x in range(80): + # Fourier_cupy = RecToolsCP.FOURIER_INV( + # projData3D_analyt_cupy[1:slice_number, :, :], + # recon_mask_radius=0.95, + # data_axes_labels_order=input_data_labels, + # center_size=2048, + # block_dim=[16, 16], + # block_dim_center=[32, 4], + # ) + # toc = timeit.default_timer() + + # Run_time = (toc - tic)/80 + # print("Phantom size: {}, angle number: {}, and slice number: {}, in time: {} seconds".format(N_size, angles_num, slice_number, Run_time)) + + for center_size in [0, 128, 256, 384, 448, 512, 640, 672, 704, 768, 800, 864, 928, 1024, 1280, 1536, 1792, 2048, 2560, 3072]: + tic = timeit.default_timer() + for x in range(80): + Fourier_cupy = RecToolsCP.FOURIER_INV( + projData3D_analyt_cupy[1:slice_number, :, :], + recon_mask_radius=0.95, + center_size=center_size, + data_axes_labels_order=input_data_labels, + ) + toc = timeit.default_timer() + Run_time = (toc - tic)/80 + print("{}, {}, {}, {}, {}".format(center_size, N_size, angles_num, slice_number, Run_time)) diff --git a/Demos/Python/Demo_CuPy_3D_temp.py b/Demos/Python/Demo_CuPy_3D_temp.py new file mode 100644 index 00000000..1e16b7bd --- /dev/null +++ b/Demos/Python/Demo_CuPy_3D_temp.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +GPLv3 license (ASTRA toolbox) + +Script that demonstrates the reconstruction of CuPy arrays while keeping +the data on the GPU (device-to-device) + +Dependencies: + * astra-toolkit, install conda install -c astra-toolbox astra-toolbox + * TomoPhantom, https://github.com/dkazanc/TomoPhantom + * CuPy package + +@author: Daniil Kazantsev +""" +import timeit +import os +import matplotlib.pyplot as plt +import numpy as np +import cupy as cp +import tomophantom +from tomophantom import TomoP3D +from tomophantom.qualitymetrics import QualityTools +from tomobar.methodsDIR_CuPy import RecToolsDIRCuPy +from tomobar.methodsIR_CuPy import RecToolsIRCuPy + +print("Building 3D phantom using TomoPhantom software") +tic = timeit.default_timer() +model = 13 # select a model number from the library +N_size = 256 # Define phantom dimensions using a scalar value (cubic phantom) +path = os.path.dirname(tomophantom.__file__) +path_library3D = os.path.join(path, "phantomlib", "Phantom3DLibrary.dat") + +phantom_tm = TomoP3D.Model(model, N_size, path_library3D) + +# Projection geometry related parameters: +Horiz_det = int(np.sqrt(2) * N_size) # detector column count (horizontal) +Vert_det = N_size # detector row count (vertical) (no reason for it to be > N) +angles_num = int(0.3 * np.pi * N_size) # angles number +angles = np.linspace(0.0, 179.9, angles_num, dtype="float32") # in degrees +angles_rad = angles * (np.pi / 180.0) + +print("Generate 3D analytical projection data with TomoPhantom") +projData3D_analyt = TomoP3D.ModelSino( + model, N_size, Horiz_det, Vert_det, angles, path_library3D +) +input_data_labels = ["detY", "angles", "detX"] + +# transfering numpy array to CuPy array +projData3D_analyt_cupy = cp.asarray(projData3D_analyt, order="C") + +# %% +print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") +print("%%%%%%%%%Reconstructing with 3D Fourier-CuPy method %%%%%%%%") +print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") +RecToolsCP = RecToolsDIRCuPy( + DetectorsDimH=Horiz_det, # Horizontal detector dimension + DetectorsDimV=Vert_det, # Vertical detector dimension (3D case) + CenterRotOffset=0.0, # Center of Rotation scalar or a vector + AnglesVec=angles_rad, # A vector of projection angles in radians + ObjSize=N_size, # Reconstructed object dimensions (scalar) + device_projector="gpu", +) + +Fourier_cupy = RecToolsCP.FOURIER_INV( + projData3D_analyt_cupy, + recon_mask_radius=0.95, + data_axes_labels_order=input_data_labels, +) + +tic = timeit.default_timer() +for x in range(80): + Fourier_cupy = RecToolsCP.FOURIER_INV( + projData3D_analyt_cupy, + recon_mask_radius=0.95, + data_axes_labels_order=input_data_labels, + ) +toc = timeit.default_timer() + +Run_time = (toc - tic)/80 +print("Log-polar 3D reconstruction in {} seconds".format(Run_time)) + +# for block_dim in [[32, 8], [64, 4], [32, 16], [16, 16], [32, 32]]: +# for block_dim_center in [[32, 8], [64, 4], [32, 16], [32, 4]]: +# for center_size in [448, 512, 640, 672, 704, 768]: +# tic = timeit.default_timer() +# for x in range(80): +# Fourier_cupy = RecToolsCP.FOURIER_INV( +# projData3D_analyt_cupy, +# recon_mask_radius=0.95, +# center_size=center_size, +# block_dim=block_dim, +# block_dim_center=block_dim_center, +# data_axes_labels_order=input_data_labels, +# ) +# toc = timeit.default_timer() + +# Run_time = (toc - tic)/80 +# print("Log-polar 3D reconstruction center_size; {}; block dim; {}; block_dim_center; {}; in ; {}; seconds".format(center_size, block_dim, block_dim_center, Run_time)) diff --git a/Demos/Python/Demo_RealData3_bench.py b/Demos/Python/Demo_RealData3_bench.py new file mode 100644 index 00000000..8ba95a6d --- /dev/null +++ b/Demos/Python/Demo_RealData3_bench.py @@ -0,0 +1,192 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +A script to reconstruct tomographic X-ray data (macromollecular crystallography) +obtained at Diamond Light Source (UK synchrotron), beamline i23 + +""" +import timeit +import numpy as np +import cupy as cp +from cupy import mean +import matplotlib.pyplot as plt +from tomobar.supp.suppTools import normaliser + +from numpy import float32 +from typing import Tuple + +def normalize_origin( + data: cp.ndarray, + flats: cp.ndarray, + darks: cp.ndarray, + cutoff: float = 10.0, + minus_log: bool = True, + nonnegativity: bool = False, + remove_nans: bool = False, +) -> cp.ndarray: + """ + Normalize raw projection data using the flat and dark field projections. + This is a raw CUDA kernel implementation with CuPy wrappers. + + Parameters + ---------- + data : cp.ndarray + Projection data as a CuPy array. + flats : cp.ndarray + 3D flat field data as a CuPy array. + darks : cp.ndarray + 3D dark field data as a CuPy array. + cutoff : float, optional + Permitted maximum value for the normalised data. + minus_log : bool, optional + Apply negative log to the normalised data. + nonnegativity : bool, optional + Remove negative values in the normalised data. + remove_nans : bool, optional + Remove NaN and Inf values in the normalised data. + + Returns + ------- + cp.ndarray + Normalised 3D tomographic data as a CuPy array. + """ + _check_valid_input(data, flats, darks) + + dark0 = cp.empty(darks.shape[1:], dtype=float32) + flat0 = cp.empty(flats.shape[1:], dtype=float32) + out = cp.empty(data.shape, dtype=float32) + mean(darks, axis=0, dtype=float32, out=dark0) + mean(flats, axis=0, dtype=float32, out=flat0) + + kernel_name = "normalisation" + kernel = r""" + float denom = float(flats) - float(darks); + if (denom < eps) { + denom = eps; + } + float v = (float(data) - float(darks))/denom; + """ + if minus_log: + kernel += "v = -log(v);\n" + kernel_name += "_mlog" + if nonnegativity: + kernel += "if (v < 0.0f) v = 0.0f;\n" + kernel_name += "_nneg" + if remove_nans: + kernel += "if (isnan(v)) v = 0.0f;\n" + kernel += "if (isinf(v)) v = 0.0f;\n" + kernel_name += "_remnan" + kernel += "if (v > cutoff) v = cutoff;\n" + kernel += "if (v < -cutoff) v = cutoff;\n" + kernel += "out = v;\n" + + normalisation_kernel = cp.ElementwiseKernel( + "T data, U flats, U darks, raw float32 cutoff", + "float32 out", + kernel, + kernel_name, + options=("-std=c++11",), + loop_prep="constexpr float eps = 1.0e-07;", + no_return=True, + ) + + normalisation_kernel(data, flat0, dark0, float32(cutoff), out) + + return out + +def _check_valid_input(data, flats, darks) -> None: + """Helper function to check the validity of inputs to normalisation functions""" + if data.ndim != 3: + raise ValueError("Input data must be a 3D stack of projections") + + if flats.ndim not in (2, 3): + raise ValueError("Input flats must be 2D or 3D data only") + + if darks.ndim not in (2, 3): + raise ValueError("Input darks must be 2D or 3D data only") + + if flats.ndim == 2: + flats = flats[cp.newaxis, :, :] + if darks.ndim == 2: + darks = darks[cp.newaxis, :, :] + + +# data = np.load("data/i12_dataset2.npz") +# data = np.load("data/i13_dataset2.npz") +data = np.load("data/geant4_dataset1.npz") +projdata = cp.asarray(data['projdata']) +angles = data['angles'] +flats = cp.asarray(data['flats']) +darks = cp.asarray(data['darks']) +del data +#%% normalising data +data_normalised = normalize_origin(projdata, flats, darks, minus_log=True) + +del projdata, flats, darks +cp._default_memory_pool.free_all_blocks() + +data_labels3D = ["angles", "detY", "detX"] # set the input data labels +angles_number, detectorVec, detectorHoriz = np.shape(data_normalised) +print(np.shape(data_normalised)) +angles_rad = angles[:] * (np.pi / 180.0) + +N_size = detectorHoriz + +# %% +print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") +print("%%%%%%%%%%%%Reconstructing with FBP method %%%%%%%%%%%%%%%%%") +print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") +from tomobar.methodsDIR import RecToolsDIR +from tomobar.methodsDIR_CuPy import RecToolsDIRCuPy +from tomobar.methodsIR_CuPy import RecToolsIRCuPy + +RecToolsCP = RecToolsDIRCuPy( + DetectorsDimH=detectorHoriz, # Horizontal detector dimension + DetectorsDimV=detectorVec, # Vertical detector dimension (3D case) + CenterRotOffset=None, # Centre of Rotation scalar + AnglesVec=angles_rad, # A vector of projection angles in radians + ObjSize=N_size, # Reconstructed object dimensions (scalar) + device_projector="gpu", +) + +tic = timeit.default_timer() +Fourier_cupy = RecToolsCP.FOURIER_INV( + data_normalised, + filter_freq_cutoff=0.35, + recon_mask_radius=0.95, + data_axes_labels_order=data_labels3D, +) +toc = timeit.default_timer() + +tic = timeit.default_timer() +for x in range(10): + Fourier_cupy = RecToolsCP.FOURIER_INV( + data_normalised, + filter_freq_cutoff=0.35, + recon_mask_radius=0.95, + data_axes_labels_order=data_labels3D, + center_size=1024, + ) +toc = timeit.default_timer() + +Run_time = (toc - tic)/10 +print("Log-polar 3D reconstruction in {} seconds".format(Run_time)) + +# for block_dim in [[32, 8], [64, 4], [32, 16], [16, 16], [32, 32]]: +# for block_dim_center in [[32, 8], [64, 4], [32, 16], [32, 4]]: +# for center_size in [448, 512, 640, 672, 704, 768]: +# tic = timeit.default_timer() +# for x in range(10): +# Fourier_cupy = RecToolsCP.FOURIER_INV( +# data_normalised, +# filter_freq_cutoff=0.35, +# recon_mask_radius=0.95, +# data_axes_labels_order=data_labels3D, +# block_dim=block_dim, +# block_dim_center=block_dim_center, +# center_size=center_size, +# ) +# toc = timeit.default_timer() + +# Run_time = (toc - tic)/10 +# print("Log-polar 3D reconstruction center_size; {}; block dim; {}; block_dim_center; {}; in ; {}; seconds".format(center_size, block_dim, block_dim_center, Run_time)) diff --git a/Demos/Python/Demo_RealData3_temp.py b/Demos/Python/Demo_RealData3_temp.py new file mode 100644 index 00000000..09cfe8c1 --- /dev/null +++ b/Demos/Python/Demo_RealData3_temp.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +A script to reconstruct tomographic X-ray data (macromollecular crystallography) +obtained at Diamond Light Source (UK synchrotron), beamline i23 + +""" +import timeit +import numpy as np +import cupy as cp +from cupy import mean +import matplotlib.pyplot as plt +from tomobar.supp.suppTools import normaliser + +from numpy import float32 +from typing import Tuple + +def normalize_origin( + data: cp.ndarray, + flats: cp.ndarray, + darks: cp.ndarray, + cutoff: float = 10.0, + minus_log: bool = True, + nonnegativity: bool = False, + remove_nans: bool = False, +) -> cp.ndarray: + """ + Normalize raw projection data using the flat and dark field projections. + This is a raw CUDA kernel implementation with CuPy wrappers. + + Parameters + ---------- + data : cp.ndarray + Projection data as a CuPy array. + flats : cp.ndarray + 3D flat field data as a CuPy array. + darks : cp.ndarray + 3D dark field data as a CuPy array. + cutoff : float, optional + Permitted maximum value for the normalised data. + minus_log : bool, optional + Apply negative log to the normalised data. + nonnegativity : bool, optional + Remove negative values in the normalised data. + remove_nans : bool, optional + Remove NaN and Inf values in the normalised data. + + Returns + ------- + cp.ndarray + Normalised 3D tomographic data as a CuPy array. + """ + _check_valid_input(data, flats, darks) + + dark0 = cp.empty(darks.shape[1:], dtype=float32) + flat0 = cp.empty(flats.shape[1:], dtype=float32) + out = cp.empty(data.shape, dtype=float32) + mean(darks, axis=0, dtype=float32, out=dark0) + mean(flats, axis=0, dtype=float32, out=flat0) + + kernel_name = "normalisation" + kernel = r""" + float denom = float(flats) - float(darks); + if (denom < eps) { + denom = eps; + } + float v = (float(data) - float(darks))/denom; + """ + if minus_log: + kernel += "v = -log(v);\n" + kernel_name += "_mlog" + if nonnegativity: + kernel += "if (v < 0.0f) v = 0.0f;\n" + kernel_name += "_nneg" + if remove_nans: + kernel += "if (isnan(v)) v = 0.0f;\n" + kernel += "if (isinf(v)) v = 0.0f;\n" + kernel_name += "_remnan" + kernel += "if (v > cutoff) v = cutoff;\n" + kernel += "if (v < -cutoff) v = cutoff;\n" + kernel += "out = v;\n" + + normalisation_kernel = cp.ElementwiseKernel( + "T data, U flats, U darks, raw float32 cutoff", + "float32 out", + kernel, + kernel_name, + options=("-std=c++11",), + loop_prep="constexpr float eps = 1.0e-07;", + no_return=True, + ) + + normalisation_kernel(data, flat0, dark0, float32(cutoff), out) + + return out + +def _check_valid_input(data, flats, darks) -> None: + """Helper function to check the validity of inputs to normalisation functions""" + if data.ndim != 3: + raise ValueError("Input data must be a 3D stack of projections") + + if flats.ndim not in (2, 3): + raise ValueError("Input flats must be 2D or 3D data only") + + if darks.ndim not in (2, 3): + raise ValueError("Input darks must be 2D or 3D data only") + + if flats.ndim == 2: + flats = flats[cp.newaxis, :, :] + if darks.ndim == 2: + darks = darks[cp.newaxis, :, :] + + + +# data = np.load("data/i13_dataset2.npz") +data = np.load("data/geant4_dataset1.npz") +projdata = cp.asarray(data['projdata']) +angles = data['angles'] +flats = cp.asarray(data['flats']) +darks = cp.asarray(data['darks']) +del data +#%% normalising data +data_normalised = normalize_origin(projdata, flats, darks, minus_log=True) + +del projdata, flats, darks +cp._default_memory_pool.free_all_blocks() + +data_labels3D = ["angles", "detY", "detX"] # set the input data labels + +print(angles) +print(np.shape(data_normalised)) + +angles_number, detectorVec, detectorHoriz = np.shape(data_normalised) +plt.figure(1) +plt.imshow(data_normalised[:, detectorVec/2, :].get(), cmap="gray") +plt.title("Sinogram of i23 data") +plt.show() + +angles_rad = angles[:] * (np.pi / 180.0) + +N_size = detectorHoriz + +# %% +print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") +print("%%%%%%%%%%%%Reconstructing with FBP method %%%%%%%%%%%%%%%%%") +print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") +from tomobar.methodsDIR import RecToolsDIR +from tomobar.methodsDIR_CuPy import RecToolsDIRCuPy +from tomobar.methodsIR_CuPy import RecToolsIRCuPy + +RecToolsCP = RecToolsDIRCuPy( + DetectorsDimH=detectorHoriz, # Horizontal detector dimension + DetectorsDimV=detectorVec, # Vertical detector dimension (3D case) + CenterRotOffset=None, # Centre of Rotation scalar + AnglesVec=angles_rad, # A vector of projection angles in radians + ObjSize=N_size, # Reconstructed object dimensions (scalar) + device_projector="gpu", +) + +tic = timeit.default_timer() +Fourier_cupy = RecToolsCP.FOURIER_INV( + data_normalised, + filter_freq_cutoff=0.35, + recon_mask_radius=0.95, + data_axes_labels_order=data_labels3D, +) +toc = timeit.default_timer() + +# bring data from the device to the host +Fourier_cupy = cp.asnumpy(Fourier_cupy) + +recon_x, recon_y, recon_z = cp.shape(Fourier_cupy) + +plt.figure() +plt.subplot(131) +plt.imshow(Fourier_cupy[recon_x//2, :, :], cmap='gray') +plt.title("3D Fourier Reconstruction, axial view") + +plt.subplot(132) +plt.imshow(Fourier_cupy[:, recon_y//2, :], cmap='gray') +plt.title("3D Fourier Reconstruction, coronal view") + +plt.subplot(133) +plt.imshow(Fourier_cupy[:, :, recon_z//2], cmap='gray') +plt.title("3D Fourier Reconstruction, sagittal view") +plt.show() diff --git a/tomobar/cuda_kernels/fft_us_kernels.cu b/tomobar/cuda_kernels/fft_us_kernels.cu index fed606bb..46119055 100644 --- a/tomobar/cuda_kernels/fft_us_kernels.cu +++ b/tomobar/cuda_kernels/fft_us_kernels.cu @@ -2,56 +2,386 @@ #define M_PI 3.1415926535897932384626433832795f #endif -extern "C" __global__ void gather_kernel(float2 *g, float2 *f, float *theta, int m, - float mu, int n, int nproj, int nz) +template +__device__ void update_f_value(float2 *f, float2 g0t, float x0, float y0, + float coeff0, float coeff1, + int center_half_size, int ell0, int ell1, + int stride, int n); + +template<> +__device__ void update_f_value(float2 *f, float2 g0, float x0, float y0, + float coeff0, float coeff1, + int center_half_size, int ell0, int ell1, + int stride, int n) { + float w0 = ell0 / (float)(2 * n) - x0; + float w1 = ell1 / (float)(2 * n) - y0; + float w = coeff0 * __expf(coeff1 * (w0 * w0 + w1 * w1)); + float2 g0t = make_float2(w*g0.x, w*g0.y); + int f_ind = ell0 + stride * ell1; + atomicAdd(&(f[f_ind].x), g0t.x); + atomicAdd(&(f[f_ind].y), g0t.y); +} + +template<> +__device__ void update_f_value(float2 *f, float2 g0, float x0, float y0, + float coeff0, float coeff1, + int center_half_size, int ell0, int ell1, + int stride, int n) +{ + if( ell0 < -center_half_size || ell0 >= center_half_size || + ell1 < -center_half_size || ell1 >= center_half_size ) { + float w0 = ell0 / (float)(2 * n) - x0; + float w1 = ell1 / (float)(2 * n) - y0; + float w = coeff0 * __expf(coeff1 * (w0 * w0 + w1 * w1)); + float2 g0t = make_float2(w*g0.x, w*g0.y); + int f_ind = ell0 + stride * ell1; + atomicAdd(&(f[f_ind].x), g0t.x); + atomicAdd(&(f[f_ind].y), g0t.y); + } +} +template +__device__ void gather_kernel_common(float2 *g, float2 *f, float *theta, + int m, float mu, + int center_size, int n, int nproj, int nz) +{ int tx = blockDim.x * blockIdx.x + threadIdx.x; int ty = blockDim.y * blockIdx.y + threadIdx.y; int tz = blockDim.z * blockIdx.z + threadIdx.z; + const int center_half_size = center_size/2; + if (tx >= n || ty >= nproj || tz >= nz) return; float2 g0, g0t; - float w, coeff0; - float w0, w1, x0, y0, coeff1; + float coeff0, coeff1; + float x0, y0; int ell0, ell1, g_ind, f_ind; g_ind = tx + ty * n + tz * n * nproj; coeff0 = M_PI / mu; coeff1 = -M_PI * M_PI / mu; - x0 = (tx - n / 2) / (float)n * __cosf(theta[ty]); - y0 = -(tx - n / 2) / (float)n * __sinf(theta[ty]); + float sintheta, costheta; + __sincosf(theta[ty], &sintheta, &costheta); + x0 = (tx - n / 2) / (float)n * costheta; + y0 = -(tx - n / 2) / (float)n * sintheta; if (x0 >= 0.5f) x0 = 0.5f - 1e-5; if (y0 >= 0.5f) y0 = 0.5f - 1e-5; + + int stride1 = 2*n + 2*m; + int stride2 = stride1 * stride1; + g0.x = g[g_ind].x; g0.y = g[g_ind].y; + // offset f by [tz, n+m, n+m] - int stride1 = 2*n + 2*m; - int stride2 = stride1 * stride1; f += n+m + (n+m) * stride1 + tz * stride2; + + #pragma unroll for (int i1 = 0; i1 < 2 * m + 1; i1++) { ell1 = floorf(2 * n * y0) - m + i1; + #pragma unroll for (int i0 = 0; i0 < 2 * m + 1; i0++) { ell0 = floorf(2 * n * x0) - m + i0; - w0 = ell0 / (float)(2 * n) - x0; - w1 = ell1 / (float)(2 * n) - y0; - w = coeff0 * __expf(coeff1 * (w0 * w0 + w1 * w1)); - g0t.x = w*g0.x; - g0t.y = w*g0.y; - f_ind = ell0 + stride1 * ell1 ; - atomicAdd(&(f[f_ind].x), g0t.x); - atomicAdd(&(f[f_ind].y), g0t.y); + update_f_value(f, g0, x0, y0, coeff0, coeff1, + center_half_size, + ell0, ell1, stride1, n); } } } +extern "C" __global__ void gather_kernel_partial(float2 *g, float2 *f, float *theta, + int m, float mu, + int center_size, int n, int nproj, int nz) +{ + gather_kernel_common(g, f, theta, m, mu, center_size, n, nproj, nz); +} + +extern "C" __global__ void gather_kernel(float2 *g, float2 *f, float *theta, + int m, float mu, int n, int nproj, int nz) +{ + gather_kernel_common(g, f, theta, m, mu, 0, n, nproj, nz); +} + +/*m = 4 +mu = 2.6356625556996645e-05 +n = 362 +nproj = 241 +nz = 128 +g (128, 241, 362) +f (128, 732, 732) +theta (241,)*/ + +#define FULL_MASK 0xffffffff + +extern "C" __global__ void gather_kernel_center_prune(int* angle_range, float *theta, + int m, int center_size, + int n, int nproj) +{ + + const int center_half_size = center_size/2; + + int thread_x = threadIdx.x; + int thread_y = blockDim.y * blockIdx.y + threadIdx.y; + int thread_z = blockDim.z * blockIdx.z + threadIdx.z; + + int tx = max(0, n + m - center_half_size) + thread_y; + int ty = max(0, n + m - center_half_size) + thread_z; + + if (thread_y >= center_size || thread_z >= center_size) + return; + + int f_stride = 2*n + 2*m; + int f_stride_2 = f_stride * f_stride; + + const float radius_2 = 2.f * (float(m) + 0.5f) * (float(m) + 0.5f) / f_stride_2; + + // offset angle_index_out by thread_x and thread_y + angle_range += (unsigned long long)3 * (thread_y + thread_z * center_size); + // Point coordinates + float2 point = make_float2(float(tx - (n+m)) / float(2 * n), float((n+m) - ty) / float(2 * n)); + + unsigned thread_mask = FULL_MASK >> (32 - thread_x); + + // Result value + int valid_index = 0; + int proj_valid_index_min = nproj; + int proj_valid_index_max = 0; + int proj_invalid_index_min = nproj; + int proj_invalid_index_max = 0; + int nproj_ceil = (nproj / 32 + 1) * 32; + for (int proj_index = thread_x; proj_index < nproj_ceil; proj_index +=32) { + float sintheta, costheta; + __sincosf(theta[proj_index%nproj], &sintheta, &costheta); + + float polar_radius = 0.5; + float polar_radius_2 = polar_radius * polar_radius; + + float2 vector_polar = make_float2(polar_radius * costheta, polar_radius * sintheta); + float2 vector_point = make_float2(point.x, point.y); + + float dot = vector_polar.x * vector_point.x + vector_polar.y * vector_point.y; + float2 mid_point = make_float2(dot * vector_polar.x / polar_radius_2, + dot * vector_polar.y / polar_radius_2); + + float distance_2 = (mid_point.x - vector_point.x) * (mid_point.x - vector_point.x) + + (mid_point.y - vector_point.y) * (mid_point.y - vector_point.y); + + unsigned mask = __ballot_sync(FULL_MASK, radius_2 >= distance_2 && proj_index < nproj); + + if( proj_index < nproj ) { + if(radius_2 >= distance_2) { + int valid_count = __popc(mask&thread_mask); + proj_valid_index_min = min(proj_valid_index_min, proj_index); + proj_valid_index_max = max(proj_valid_index_max, proj_index); + } else { + proj_invalid_index_min = min(proj_invalid_index_min, proj_index); + proj_invalid_index_max = max(proj_invalid_index_max, proj_index); + } + } + + valid_index += __popc(mask); + } + + // Find the minimum and maximum indices + #pragma unroll + for (int offset = 16; offset > 0; offset /= 2) { + int proj_valid_index_min_temp = __shfl_down_sync(FULL_MASK, proj_valid_index_min, offset); + proj_valid_index_min = min(proj_valid_index_min, proj_valid_index_min_temp); + int proj_valid_index_max_temp = __shfl_down_sync(FULL_MASK, proj_valid_index_max, offset); + proj_valid_index_max = max(proj_valid_index_max, proj_valid_index_max_temp); + + int proj_invalid_index_min_temp = __shfl_down_sync(FULL_MASK, proj_invalid_index_min, offset); + proj_invalid_index_min = min(proj_invalid_index_min, proj_invalid_index_min_temp); + int proj_invalid_index_max_temp = __shfl_down_sync(FULL_MASK, proj_invalid_index_max, offset); + proj_invalid_index_max = max(proj_invalid_index_max, proj_invalid_index_max_temp); + } + + if( thread_x == 0 ) { + if((valid_index - 1) == (proj_valid_index_max - proj_valid_index_min)) { + angle_range[0] = proj_valid_index_min; + angle_range[1] = proj_valid_index_max; + angle_range[2] = 1; + } else { + angle_range[0] = proj_invalid_index_min; + angle_range[1] = proj_invalid_index_max; + angle_range[2] = 0; + } + } +} + +__device__ void inline +gather_kernel_center_common(float2 *g, float *theta, + float2& f_value, const float2& point, + const float& radius_2, + int proj_index, int tz, + const float coeff0, + const float coeff1, + int n, int nproj) +{ + float sintheta, costheta; + __sincosf(theta[proj_index], &sintheta, &costheta); + + float polar_radius = 0.5; + float polar_radius_2 = polar_radius * polar_radius; + + float2 vector_polar = make_float2(polar_radius * costheta, polar_radius * sintheta); + float2 vector_point = make_float2(point.x, point.y); + + float dot = vector_polar.x * vector_point.x + vector_polar.y * vector_point.y; + float2 mid_point = make_float2(dot * vector_polar.x / polar_radius_2, + dot * vector_polar.y / polar_radius_2); + + float distance_2 = (mid_point.x - vector_point.x) * (mid_point.x - vector_point.x) + + (mid_point.y - vector_point.y) * (mid_point.y - vector_point.y); + + if( radius_2 >= distance_2 ) { + + // Distance to intersect + float distance_to_intersect = sqrtf(radius_2 - distance_2); + + int radius_min, radius_max; + if( fabsf(vector_polar.x) > fabsf(vector_polar.y) ) { + radius_min = n/2 - 1 + floorf((mid_point.x - distance_to_intersect * vector_polar.x / polar_radius) / (2.f * vector_polar.x / n)); + radius_max = n/2 + 1 + floorf((mid_point.x + distance_to_intersect * vector_polar.x / polar_radius) / (2.f * vector_polar.x / n)); + } else { + radius_min = n/2 - 1 + floorf((mid_point.y - distance_to_intersect * vector_polar.y / polar_radius) / (2.f * vector_polar.y / n)); + radius_max = n/2 + 1 + floorf((mid_point.y + distance_to_intersect * vector_polar.y / polar_radius) / (2.f * vector_polar.y / n)); + } + + if( radius_min > radius_max ) { + int temp(radius_max); radius_max = radius_min; radius_min = temp; + } + + radius_min = min( max(radius_min, 0), (n-1)); + radius_max = min( max(radius_max, 0), (n-1)); + + constexpr int length = 4; + float2 f_values[length]; + for (int radius_index = radius_min; radius_index < radius_max; radius_index+=length) { + + #pragma unroll + for (int i = 0; i < length; i++) { + int g_ind = radius_index + i + proj_index * n + tz * n * nproj; + if( radius_index + i < radius_max ) { + f_values[i].x = g[g_ind].x; + f_values[i].y = g[g_ind].y; + } else { + f_values[i].x = 0.f; + f_values[i].y = 0.f; + } + } + + #pragma unroll + for (int i = 0; i < length; i++) { + float x0 = (radius_index + i - n / 2) / (float)n * costheta; + float y0 = (radius_index + i - n / 2) / (float)n * sintheta; + + if (x0 >= 0.5f) + x0 = 0.5f - 1e-5; + if (y0 >= 0.5f) + y0 = 0.5f - 1e-5; + + float w0 = point.x - x0; + float w1 = point.y - y0; + float w = coeff0 * __expf(coeff1 * (w0 * w0 + w1 * w1)); + + f_values[i].x *= w; + f_values[i].y *= w; + } + + #pragma unroll + for (int i = 0; i < length; i++) { + f_value.x += f_values[i].x; + f_value.y += f_values[i].y; + } + } + } +} + +extern "C" __global__ void gather_kernel_center(float2 *g, float2 *f, + int* angle_range, float *theta, + int m, float mu, + int center_size, + int n, int nproj, int nz) +{ + + const int center_half_size = center_size/2; + + int thread_x = blockDim.x * blockIdx.x + threadIdx.x; + int thread_y = blockDim.y * blockIdx.y + threadIdx.y; + int thread_z = blockDim.z * blockIdx.z + threadIdx.z; + + int tx = max(0, n + m - center_half_size) + thread_x; + int ty = max(0, n + m - center_half_size) + thread_y; + int tz = thread_z; + + if (thread_x >= center_size || thread_y >= center_size || tz >= nz) + return; + + const float coeff0 = M_PI / mu; + const float coeff1 = -M_PI * M_PI / mu; + + int f_stride = 2*n + 2*m; + int f_stride_2 = f_stride * f_stride; + + // offset f by tz + f += (unsigned long long)tz * f_stride_2; + // offset angle_index_out by thread_x and thread_y + angle_range += (unsigned long long)3 * (thread_x + thread_y * center_size); + + const float radius_2 = 2.f * (float(m) + 0.5f) * (float(m) + 0.5f) / f_stride_2; + + // Result value + float2 f_value = make_float2(0.f, 0.f); + // Point coordinates + float2 point = make_float2(float(tx - (n+m)) / float(2 * n), float((n+m) - ty) / float(2 * n)); + + if( angle_range[2] ) { + for (int proj_index = angle_range[0]; proj_index <= angle_range[1]; proj_index++) { + gather_kernel_center_common(g, theta, + f_value, point, + radius_2, + proj_index, tz, + coeff0, + coeff1, + n, nproj); + } + } else { + for (int proj_index = 0; proj_index < angle_range[0]; proj_index++) { + gather_kernel_center_common(g, theta, + f_value, point, + radius_2, + proj_index, tz, + coeff0, + coeff1, + n, nproj); + } + for (int proj_index = angle_range[1] + 1; proj_index < nproj; proj_index++) { + gather_kernel_center_common(g, theta, + f_value, point, + radius_2, + proj_index, tz, + coeff0, + coeff1, + n, nproj); + } + } + + // index of the force + int f_ind = tx + ty * f_stride; + + f[f_ind].x = f_value.x; + f[f_ind].y = f_value.y; +} -extern "C" __global__ void wrap_kernel(float2 *f, int n, int nz, int m) +extern "C" __global__ void wrap_kernel(float2 *f, + int n, int nz, int m) { int tx = blockDim.x * blockIdx.x + threadIdx.x; int ty = blockDim.y * blockIdx.y + threadIdx.y; diff --git a/tomobar/methodsDIR_CuPy.py b/tomobar/methodsDIR_CuPy.py index 284441c7..c42829a9 100644 --- a/tomobar/methodsDIR_CuPy.py +++ b/tomobar/methodsDIR_CuPy.py @@ -6,7 +6,8 @@ """ import numpy as np - +import timeit +import matplotlib.pyplot as plt try: import cupy as xp @@ -159,9 +160,22 @@ def FOURIER_INV(self, data: xp.ndarray, **kwargs) -> xp.ndarray: cutoff_freq = 1.0 # default value filter_type = "shepp" # default filter + center_size = 2048 + block_dim = [16, 16] + block_dim_prune = 4 + block_dim_center = [32, 4] + for key, value in kwargs.items(): if key == "data_axes_labels_order" and value is not None: data = _data_dims_swapper(data, value, ["detY", "angles", "detX"]) + elif key == "center_size" and value is not None: + center_size = value + elif key == "block_dim" and value is not None: + block_dim = value + elif key == "block_dim_prune" and value is not None: + block_dim_prune = value + elif key == "block_dim_center" and value is not None: + block_dim_center = value if key == "cutoff_freq" and value is not None: cutoff_freq = value if key == "filter_type" and value is not None: @@ -183,6 +197,9 @@ def FOURIER_INV(self, data: xp.ndarray, **kwargs) -> xp.ndarray: # extract kernels from CUDA modules module = load_cuda_module("fft_us_kernels") gather_kernel = module.get_function("gather_kernel") + gather_kernel_partial = module.get_function("gather_kernel_partial") + gather_kernel_center_prune = module.get_function("gather_kernel_center_prune") + gather_kernel_center = module.get_function("gather_kernel_center") wrap_kernel = module.get_function("wrap_kernel") # initialisation @@ -226,6 +243,9 @@ def FOURIER_INV(self, data: xp.ndarray, **kwargs) -> xp.ndarray: ) oversampling_level = 2 # at least 2 or larger required + # Limit the center size parameter + center_size = min(center_size, n * 2 + m * 2) + # memory for recon if odd_horiz: recon_up = xp.empty([nz, n + 1, n + 1], dtype=xp.float32) @@ -237,8 +257,10 @@ def FOURIER_INV(self, data: xp.ndarray, **kwargs) -> xp.ndarray: t = xp.linspace(-1 / 2, 1 / 2, n, endpoint=False, dtype=xp.float32) [dx, dy] = xp.meshgrid(t, t) phi = xp.exp(mu * (n * n) * (dx * dx + dy * dy)) * ((1 - n % 4) / nproj) - # padded fft, reusable by chunks - fde = xp.zeros([nz // 2, 2 * m + 2 * n, 2 * m + 2 * n], dtype=xp.complex64) + + if center_size > 0: + angle_range = xp.empty([center_size, center_size, 3], dtype=xp.int32) + # (+1,-1) arrays for fftshift c1dfftshift = xp.empty(n, dtype=xp.int8) c1dfftshift[::2] = -1 @@ -271,25 +293,83 @@ def FOURIER_INV(self, data: xp.ndarray, **kwargs) -> xp.ndarray: # can be done without introducing array datac, saves memory, see tomocupy (TODO) del tmp_p + # padded fft, reusable by chunks + fde = xp.zeros([nz // 2, 2 * m + 2 * n, 2 * m + 2 * n], dtype=xp.complex64) + # STEP1: fft 1d datac = fft(c1dfftshift * datac) * c1dfftshift * (4 / n) # STEP2: interpolation (gathering) in the frequency domain - # When profiling gather_kernel takes up to 50% of the time! - gather_kernel( - (int(xp.ceil(n / 32)), int(xp.ceil(nproj / 32)), nz // 2), - (32, 32, 1), - ( - datac, - fde, - theta, - np.int32(m), - np.float32(mu), - np.int32(n), - np.int32(nproj), - np.int32(nz // 2), - ), - ) + if center_size > 0: + + if center_size != (n * 2 + m * 2): + + gather_kernel_partial( + (int(xp.ceil(n / block_dim[0])), int(xp.ceil(nproj / block_dim[1])), nz // 2), + (block_dim[0], block_dim[1], 1), + ( + datac, + fde, + theta, + np.int32(m), + np.float32(mu), + np.int32(center_size), + np.int32(n), + np.int32(nproj), + np.int32(nz // 2), + ), + ) + + gather_kernel_center_prune( + (1, int(xp.ceil(center_size / block_dim_prune)), center_size), + (32, block_dim_prune, 1), + ( + angle_range, + theta, + np.int32(m), + np.int32(center_size), + np.int32(n), + np.int32(nproj), + ), + ) + + gather_kernel_center( + ( + int(xp.ceil(center_size / block_dim_center[0])), + int(xp.ceil(center_size / block_dim_center[1])), + nz // 2, + ), + (block_dim_center[0], block_dim_center[1], 1), + ( + datac, + fde, + angle_range, + theta, + np.int32(m), + np.float32(mu), + np.int32(center_size), + np.int32(n), + np.int32(nproj), + np.int32(nz // 2), + ), + ) + + else: + gather_kernel( + (int(xp.ceil(n / block_dim[0])), int(xp.ceil(nproj / block_dim[1])), nz // 2), + (block_dim[0], block_dim[1], 1), + ( + datac, + fde, + theta, + np.int32(m), + np.float32(mu), + np.int32(n), + np.int32(nproj), + np.int32(nz // 2), + ), + ) + wrap_kernel( ( int(np.ceil((2 * n + 2 * m) / 32)), @@ -300,6 +380,12 @@ def FOURIER_INV(self, data: xp.ndarray, **kwargs) -> xp.ndarray: (fde, n, nz // 2, m), ) + if center_size > 0: + del angle_range, datac + else: + del datac + xp._default_memory_pool.free_all_blocks() + # STEP3: ifft 2d fde2 = fde[ :, m:-m, m:-m