diff --git a/.gitignore b/.gitignore index aa1ae78..77db517 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ diff_gaussian_rasterization.egg-info/ dist/ diff_gaussian_rasterization/__pycache__/ *so +*.pyc \ No newline at end of file diff --git a/cuda_rasterizer/forward.cu b/cuda_rasterizer/forward.cu index 9362275..5f83847 100644 --- a/cuda_rasterizer/forward.cu +++ b/cuda_rasterizer/forward.cu @@ -18,16 +18,16 @@ namespace cg = cooperative_groups; // Forward method for converting the input spherical harmonics // coefficients of each Gaussian to a simple RGB color. -__device__ glm::vec3 computeColorFromSH(int idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, bool* clamped) +__device__ glm::vec3 computeColorFromSH(int point_idx, int result_idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, bool* clamped) { // The implementation is loosely based on code for // "Differentiable Point-Based Radiance Fields for // Efficient View Synthesis" by Zhang et al. (2022) - glm::vec3 pos = means[idx]; + glm::vec3 pos = means[point_idx]; glm::vec3 dir = pos - campos; dir = dir / glm::length(dir); - glm::vec3* sh = ((glm::vec3*)shs) + idx * max_coeffs; + glm::vec3* sh = ((glm::vec3*)shs) + point_idx * max_coeffs; glm::vec3 result = SH_C0 * sh[0]; if (deg > 0) @@ -65,9 +65,9 @@ __device__ glm::vec3 computeColorFromSH(int idx, int deg, int max_coeffs, const // RGB colors are clamped to positive values. If values are // clamped, we need to keep track of this for the backward pass. - clamped[3 * idx + 0] = (result.x < 0); - clamped[3 * idx + 1] = (result.y < 0); - clamped[3 * idx + 2] = (result.z < 0); + clamped[3 * result_idx + 0] = (result.x < 0); + clamped[3 * result_idx + 1] = (result.y < 0); + clamped[3 * result_idx + 2] = (result.z < 0); return glm::max(result, 0.0f); } @@ -213,7 +213,6 @@ __global__ void preprocessCUDA(int P, int D, int M, computeCov3D(scales[idx], scale_modifier, rotations[idx], cov3Ds + idx * 6); cov3D = cov3Ds + idx * 6; } - // Compute 2D screen-space covariance matrix float3 cov = computeCov2D(p_orig, focal_x, focal_y, tan_fovx, tan_fovy, cov3D, viewmatrix); @@ -242,7 +241,7 @@ __global__ void preprocessCUDA(int P, int D, int M, // spherical harmonics coefficients to RGB color. if (colors_precomp == nullptr) { - glm::vec3 result = computeColorFromSH(idx, D, M, (glm::vec3*)orig_points, *cam_pos, shs, clamped); + glm::vec3 result = computeColorFromSH(idx, idx, D, M, (glm::vec3*)orig_points, *cam_pos, shs, clamped); rgb[idx * C + 0] = result.x; rgb[idx * C + 1] = result.y; rgb[idx * C + 2] = result.z; @@ -471,7 +470,7 @@ void FORWARD::preprocess(int P, int D, int M, uint32_t* tiles_touched, bool prefiltered) { - preprocessCUDA << <(P + ONE_DIM_BLOCK_SIZE - 1) / ONE_DIM_BLOCK_SIZE, ONE_DIM_BLOCK_SIZE >> > ( + preprocessCUDA << > > ( P, D, M, means3D, scales, @@ -498,4 +497,155 @@ void FORWARD::preprocess(int P, int D, int M, tiles_touched, prefiltered ); -} \ No newline at end of file +} + + +template +__global__ void preprocessCUDABatched( + int P, int D, int M, + const float* orig_points, const glm::vec3* scales, const float scale_modifier, + const glm::vec4* rotations, const float* opacities, const float* shs, + bool* clamped, const float* cov3D_precomp, const float* colors_precomp, + const float* viewmatrix_arr, const float* projmatrix_arr, const glm::vec3* cam_pos, + const int W, int H, const float* tan_fovx, const float* tan_fovy, + int* radii, float2* points_xy_image, float* depths, float* cov3Ds, + float* rgb, float4* conic_opacity, const dim3 grid, uint32_t* tiles_touched, + bool prefiltered, const int num_viewpoints) +{ + auto point_idx = blockIdx.x * blockDim.x + threadIdx.x; + auto viewpoint_idx = blockIdx.y; + + if (viewpoint_idx >= num_viewpoints || point_idx >= P) return; + + auto idx = viewpoint_idx * P + point_idx; + const float* viewmatrix = viewmatrix_arr + viewpoint_idx * 16; + const float* projmatrix = projmatrix_arr + viewpoint_idx * 16; + + // Initialize radius and touched tiles to 0. If this isn't changed, + // this Gaussian will not be processed further. + radii[idx] = 0; + tiles_touched[idx] = 0; + + // Perform near culling, quit if outside. + float3 p_view; + if (!in_frustum(point_idx, orig_points, viewmatrix, projmatrix, prefiltered, p_view)) return; + + // Transform point by projecting + float3 p_orig = { orig_points[3 * point_idx], orig_points[3 * point_idx + 1], orig_points[3 * point_idx + 2] }; + + float4 p_hom = transformPoint4x4(p_orig, projmatrix); + float p_w = 1.0f / (p_hom.w + 0.0000001f); + float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w }; + + // If 3D covariance matrix is precomputed, use it, otherwise compute + // from scaling and rotation parameters. + const float* cov3D; + if (cov3D_precomp != nullptr) { + cov3D = cov3D_precomp + idx * 6; + } else { + computeCov3D(scales[point_idx], scale_modifier, rotations[point_idx], cov3Ds + idx * 6); + cov3D = cov3Ds + idx * 6; + } + + + // Compute 2D screen-space covariance matrix + const float focal_x = W / (2.0f * tan_fovx[viewpoint_idx]); + const float focal_y = H / (2.0f * tan_fovy[viewpoint_idx]); + float3 cov = computeCov2D(p_orig, focal_x, focal_y, tan_fovx[viewpoint_idx], tan_fovy[viewpoint_idx], cov3D, viewmatrix); + + + // Invert covariance (EWA algorithm) + float det = (cov.x * cov.z - cov.y * cov.y); + if (det == 0.0f) return; + float det_inv = 1.f / det; + float3 conic = { cov.z * det_inv, -cov.y * det_inv, cov.x * det_inv }; + + // Compute extent in screen space (by finding eigenvalues of + // 2D covariance matrix). Use extent to compute a bounding rectangle + // of screen-space tiles that this Gaussian overlaps with. Quit if + // rectangle covers 0 tiles. + float mid = 0.5f * (cov.x + cov.z); + float lambda1 = mid + sqrt(max(0.1f, mid * mid - det)); + float lambda2 = mid - sqrt(max(0.1f, mid * mid - det)); + float my_radius = ceil(3.f * sqrt(max(lambda1, lambda2))); + float2 point_image = { ndc2Pix(p_proj.x, W), ndc2Pix(p_proj.y, H) }; + uint2 rect_min, rect_max; + getRect(point_image, my_radius, rect_min, rect_max, grid); + if ((rect_max.x - rect_min.x) * (rect_max.y - rect_min.y) == 0) return; + + // If colors have been precomputed, use them, otherwise convert + // spherical harmonics coefficients to RGB color. + + if (colors_precomp == nullptr) { + + glm::vec3 result = computeColorFromSH(point_idx, idx, D, M, (glm::vec3*)orig_points, cam_pos[viewpoint_idx], shs, clamped); + rgb[idx * C + 0] = result.x; + rgb[idx * C + 1] = result.y; + rgb[idx * C + 2] = result.z; + } + + // Store some useful helper data for the next steps. + depths[idx] = p_view.z; + radii[idx] = my_radius; + points_xy_image[idx] = point_image; + + // Inverse 2D covariance and opacity neatly pack into one float4 + conic_opacity[idx] = { conic.x, conic.y, conic.z, opacities[point_idx] }; + tiles_touched[idx] = (rect_max.y - rect_min.y) * (rect_max.x - rect_min.x); +} + +void FORWARD::preprocess_batch(int P, int D, int M, + const float* means3D, + const glm::vec3* scales, + const float scale_modifier, + const glm::vec4* rotations, + const float* opacities, + const float* shs, + bool* clamped, + const float* cov3D_precomp, + const float* colors_precomp, + const float* viewmatrix, + const float* projmatrix, + const glm::vec3* cam_pos, + const int W, int H, + const float* tan_fovx, const float* tan_fovy, + int* radii, + float2* means2D, + float* depths, + float* cov3Ds, + float* rgb, + float4* conic_opacity, + const dim3 grid, + uint32_t* tiles_touched, + bool prefiltered, + const int num_viewpoints) +{ + dim3 tile_grid(cdiv(P, ONE_DIM_BLOCK_SIZE), num_viewpoints); + preprocessCUDABatched<<>>( + P, D, M, + means3D, + scales, + scale_modifier, + rotations, + opacities, + shs, + clamped, + cov3D_precomp, + colors_precomp, + viewmatrix, + projmatrix, + cam_pos, + W, H, + tan_fovx, tan_fovy, + radii, + means2D, + depths, + cov3Ds, + rgb, + conic_opacity, + grid, + tiles_touched, + prefiltered, + num_viewpoints + ); +} diff --git a/cuda_rasterizer/forward.h b/cuda_rasterizer/forward.h index 86e5cb9..902fa00 100644 --- a/cuda_rasterizer/forward.h +++ b/cuda_rasterizer/forward.h @@ -45,7 +45,35 @@ namespace FORWARD float4* conic_opacity, const dim3 grid, uint32_t* tiles_touched, - bool prefiltered); + bool prefiltered + ); + + void preprocess_batch(int P, int D, int M, + const float* means3D, + const glm::vec3* scales, + const float scale_modifier, + const glm::vec4* rotations, + const float* opacities, + const float* shs, + bool* clamped, + const float* cov3D_precomp, + const float* colors_precomp, + const float* viewmatrix, + const float* projmatrix, + const glm::vec3* cam_pos, + const int W, int H, + const float* tan_fovx, const float* tan_fovy, + int* radii, + float2* means2D, + float* depths, + float* cov3Ds, + float* rgb, + float4* conic_opacity, + const dim3 grid, + uint32_t* tiles_touched, + bool prefiltered, + const int num_viewpoints + ); // Main rasterization method. void render( @@ -61,7 +89,8 @@ namespace FORWARD uint32_t* n_contrib2loss, const int* compute_locally_1D_2D_map, const float* bg_color, - float* out_color); + float* out_color + ); } diff --git a/cuda_rasterizer/rasterizer.h b/cuda_rasterizer/rasterizer.h index ddc989c..b7f93fd 100644 --- a/cuda_rasterizer/rasterizer.h +++ b/cuda_rasterizer/rasterizer.h @@ -65,6 +65,31 @@ namespace CudaRasterizer bool debug,//raster_settings const pybind11::dict &args); + static int preprocessForwardBatches( + float2* means2D, + float* depths, + int* radii, + float* cov3D, + float4* conic_opacity, + float* rgb, + bool* clamped,//the above are all per-Gaussian intemediate results. + const int P, int D, int M, + const int width, int height, + const float* means3D, + const float* scales, + const float* rotations, + const float* shs, + const float* opacities,//3dgs parameters + const float scale_modifier, + const float* viewmatrix, + const float* projmatrix, + const float* cam_pos, + const float* tan_fovx, const float* tan_fovy, + const bool prefiltered, + const int num_viewpoints, + bool debug,//raster_settings + const pybind11::dict &args); + static void preprocessBackward( const int* radii, const float* cov3D, diff --git a/cuda_rasterizer/rasterizer_impl.cu b/cuda_rasterizer/rasterizer_impl.cu index dec20fa..44d09f4 100644 --- a/cuda_rasterizer/rasterizer_impl.cu +++ b/cuda_rasterizer/rasterizer_impl.cu @@ -424,6 +424,100 @@ int CudaRasterizer::Rasterizer::preprocessForward( return num_rendered; } + +int CudaRasterizer::Rasterizer::preprocessForwardBatches( + float2* means2D, + float* depths, + int* radii, + float* cov3D, + float4* conic_opacity, + float* rgb, + bool* clamped,//the above are all per-Gaussian intemediate results. + const int P, int D, int M, + const int width, int height, + const float* means3D, + const float* scales, + const float* rotations, + const float* shs, + const float* opacities,//3dgs parameters + const float scale_modifier, + const float* viewmatrix, + const float* projmatrix, + const float* cam_pos, + const float* tan_fovx, const float* tan_fovy, + const bool prefiltered, + const int num_viewpoints, + bool debug,//raster_settings + const pybind11::dict &args) +{ + auto [global_rank, world_size, iteration, log_interval, device, zhx_debug, zhx_time, mode, dist_division_mode, log_folder] = prepareArgs(args); + char* log_tmp = new char[500]; + + // print out the environment variables + if (mode == "train" && zhx_debug && iteration % log_interval == 1) { + sprintf(log_tmp, "world_size: %d, global_rank: %d, iteration: %d, log_folder: %s, zhx_debug: %d, zhx_time: %d, device: %d, log_interval: %d, dist_division_mode: %s", + world_size, global_rank, iteration, log_folder.c_str(), zhx_debug, zhx_time, device, log_interval, dist_division_mode.c_str()); + save_log_in_file(iteration, global_rank, world_size, log_folder, "cuda", log_tmp); + } + + MyTimerOnGPU timer; + + //CONVERT ALL VECTORS TO FLOATSSSSSS PRAPTIIIIIII + + dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1); + dim3 block(BLOCK_X, BLOCK_Y, 1); + int tile_num = tile_grid.x * tile_grid.y; + + // allocate temporary buffer for tiles_touched. + // In sep_rendering==True case, we will compute tiles_touched in the renderForward. + // TODO: remove it later by modifying FORWARD::preprocess when we deprecate sep_rendering==False case + uint32_t* tiles_touched_temp_buffer; + CHECK_CUDA(cudaMalloc(&tiles_touched_temp_buffer, num_viewpoints * P * sizeof(uint32_t)), debug); + CHECK_CUDA(cudaMemset(tiles_touched_temp_buffer, 0, num_viewpoints * P * sizeof(uint32_t)), debug); + + timer.start("10 preprocess"); + // Run preprocessing per-Gaussian (transformation, bounding, conversion of SHs to RGB) + CHECK_CUDA(FORWARD::preprocess_batch( + P, D, M, + means3D, + (glm::vec3*)scales, + scale_modifier, + (glm::vec4*)rotations, + opacities, + shs, + clamped, + nullptr,//cov3D_precomp, + nullptr,//colors_precomp,TODO: this is correct? + viewmatrix, + projmatrix, + (glm::vec3*)cam_pos, + width, height, + tan_fovx, tan_fovy, + radii, + means2D, + depths, + cov3D, + rgb, + conic_opacity, + tile_grid, + tiles_touched_temp_buffer, + prefiltered, + num_viewpoints + ), debug) + timer.stop("10 preprocess"); + + int num_rendered = 0;//TODO: should I calculate this here? + + // Print out timing information + if (zhx_time && iteration % log_interval == 1) { + timer.printAllTimes(iteration, world_size, global_rank, log_folder, true); + } + delete log_tmp; + // free temporary buffer for tiles_touched. TODO: remove it. + CHECK_CUDA(cudaFree(tiles_touched_temp_buffer), debug); + return num_rendered; +} + void CudaRasterizer::Rasterizer::preprocessBackward( const int* radii, const float* cov3D, diff --git a/diff_gaussian_rasterization/__init__.py b/diff_gaussian_rasterization/__init__.py index 7e3ad04..7667063 100644 --- a/diff_gaussian_rasterization/__init__.py +++ b/diff_gaussian_rasterization/__init__.py @@ -10,10 +10,12 @@ # from typing import NamedTuple -import torch.nn as nn + import torch +import torch.nn as nn + from . import _C -import time + def cpu_deep_copy_tuple(input_tuple): copied_tensors = [item.cpu().clone() if isinstance(item, torch.Tensor) else item for item in input_tuple] @@ -31,7 +33,7 @@ def preprocess_gaussians( sh, opacities, raster_settings, - cuda_args, + cuda_args ): return _PreprocessGaussians.apply( means3D, @@ -40,7 +42,7 @@ def preprocess_gaussians( sh, opacities, raster_settings, - cuda_args, + cuda_args ) class _PreprocessGaussians(torch.autograd.Function): @@ -53,10 +55,28 @@ def forward( sh, opacities, raster_settings, - cuda_args, + cuda_args ): # Restructure arguments the way that the C++ lib expects them + if isinstance(raster_settings, list): + viewmatrix, projmatrix, campos = [ + torch.stack(tensors) for tensors in zip( + *[(rs.viewmatrix, rs.projmatrix, rs.campos) for rs in raster_settings] + ) + ] + tanfovx, tanfovy = [ + torch.tensor(vals, device=means3D.device) + for vals in zip(*[(rs.tanfovx, rs.tanfovy) for rs in raster_settings]) + ] + raster_settings = raster_settings[0]._replace( + tanfovx=tanfovx, + tanfovy=tanfovy, + viewmatrix=viewmatrix, + projmatrix=projmatrix, + campos=campos + ) + args = ( means3D, scales, @@ -78,7 +98,10 @@ def forward( ) # TODO: update this. - num_rendered, means2D, depths, radii, cov3D, conic_opacity, rgb, clamped = _C.preprocess_gaussians(*args) + if not torch.is_tensor(raster_settings.tanfovx): + num_rendered, means2D, depths, radii, cov3D, conic_opacity, rgb, clamped = _C.preprocess_gaussians(*args) + else: + num_rendered, means2D, depths, radii, cov3D, conic_opacity, rgb, clamped = _C.preprocess_gaussians_batched(*args) # Keep relevant tensors for backward ctx.raster_settings = raster_settings @@ -304,6 +327,33 @@ class GaussianRasterizationSettings(NamedTuple): prefiltered : bool debug : bool +class GaussianRasterizerBatches(nn.Module): + def __init__(self, raster_settings_batch): + super().__init__() + self.raster_settings_batch = raster_settings_batch + + def markVisible(self, positions): + # Mark visible points (based on frustum culling for camera) with a boolean + with torch.no_grad(): + visible = [] + for raster_settings in self.raster_settings_batch: + viewmatrix = raster_settings.viewmatrix + projmatrix = raster_settings.projmatrix + visible.append(_C.mark_visible(positions, viewmatrix, projmatrix)) + return visible + + def preprocess_gaussians(self, means3D, scales, rotations, shs, opacities, batched_cuda_args=None): + # Invoke C++/CUDA rasterization routine + + return preprocess_gaussians( + means3D, + scales, + rotations, + shs, + opacities, + self.raster_settings_batch, + batched_cuda_args) + class GaussianRasterizer(nn.Module): def __init__(self, raster_settings): super().__init__() diff --git a/ext.cpp b/ext.cpp index a957cd2..e4249bb 100644 --- a/ext.cpp +++ b/ext.cpp @@ -16,6 +16,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("mark_visible", &markVisible); m.def("preprocess_gaussians", &PreprocessGaussiansCUDA); + m.def("preprocess_gaussians_batched", &PreprocessGaussiansCUDABatches); m.def("preprocess_gaussians_backward", &PreprocessGaussiansBackwardCUDA); m.def("get_distribution_strategy", &GetDistributionStrategyCUDA); m.def("render_gaussians", &RenderGaussiansCUDA); diff --git a/rasterization_tests.py b/rasterization_tests.py new file mode 100644 index 0000000..90afdcd --- /dev/null +++ b/rasterization_tests.py @@ -0,0 +1,276 @@ +import math + +import torch + +from diff_gaussian_rasterization import ( + GaussianRasterizationSettings, + GaussianRasterizer, + GaussianRasterizerBatches, +) + +num_gaussians = 1000000 +num_batches=64 +means3D = torch.randn(num_gaussians, 3).cuda() +scales = torch.randn(num_gaussians, 3).cuda() +rotations = torch.randn(num_gaussians, 4).cuda() +shs = torch.randn(num_gaussians, 16, 3).cuda() +opacity = torch.randn(num_gaussians, 1).cuda() +SH_ACTIVE_DEGREE = 3 + +def get_cuda_args(strategy, mode="train"): + cuda_args = { + "mode": mode, + "world_size": "1", + "global_rank": "0", + "local_rank": "0", + "mp_world_size": "1", + "mp_rank": "0", + "log_folder": "./logs", + "log_interval": "10", + "iteration": "0", + "zhx_debug": "False", + "zhx_time": "False", + "dist_global_strategy": "default", + "avoid_pixel_all2all": False, + "stats_collector": {}, + } + return cuda_args + + +def test_batched_gaussian_rasterizer(): + # Set up the viewpoint cameras + batched_viewpoint_cameras = [] + for _ in range(num_batches): + viewpoint_camera = type('ViewpointCamera', (), {}) + viewpoint_camera.FoVx = math.radians(60) + viewpoint_camera.FoVy = math.radians(60) + viewpoint_camera.image_height = 512 + viewpoint_camera.image_width = 512 + viewpoint_camera.world_view_transform = torch.eye(4).cuda() + viewpoint_camera.full_proj_transform = torch.eye(4).cuda() + viewpoint_camera.camera_center = torch.zeros(3).cuda() + batched_viewpoint_cameras.append(viewpoint_camera) + + # Set up the strategies + batched_strategies = [None] * num_batches + + # Set up other parameters + bg_color = torch.ones(3).cuda() + scaling_modifier = 1.0 + pc = type('PC', (), {}) + pc.active_sh_degree = SH_ACTIVE_DEGREE + pipe = type('Pipe', (), {}) + pipe.debug = False + mode = "train" + + batched_rasterizers = [] + batched_cuda_args = [] + batched_screenspace_params = [] + batched_means2D = [] + batched_radii = [] + batched_conic_opacity=[] + batched_depths=[] + batched_rgb=[] + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start_event.record() + for i, (viewpoint_camera, strategy) in enumerate(zip(batched_viewpoint_cameras, batched_strategies)): + ########## [START] Prepare CUDA Rasterization Settings ########## + cuda_args = get_cuda_args(strategy, mode) + batched_cuda_args.append(cuda_args) + + # Set up rasterization configuration + tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) + tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) + raster_settings = GaussianRasterizationSettings( + image_height=int(viewpoint_camera.image_height), + image_width=int(viewpoint_camera.image_width), + tanfovx=tanfovx, + tanfovy=tanfovy, + bg=bg_color, + scale_modifier=scaling_modifier, + viewmatrix=viewpoint_camera.world_view_transform, + projmatrix=viewpoint_camera.full_proj_transform, + sh_degree=pc.active_sh_degree, + campos=viewpoint_camera.camera_center, + prefiltered=False, + debug=pipe.debug + ) + rasterizer = GaussianRasterizer(raster_settings=raster_settings) + ########## [END] Prepare CUDA Rasterization Settings ########## + + #[3DGS-wise preprocess] + means2D, rgb, conic_opacity, radii, depths = rasterizer.preprocess_gaussians( + means3D=means3D, + scales=scales, + rotations=rotations, + shs=shs, + opacities=opacity, + cuda_args=cuda_args + ) + + # TODO: make the below work + # if mode == "train": + # means2D.retain_grad() + + batched_means2D.append(means2D) + screenspace_params = [means2D, rgb, conic_opacity, radii, depths] + batched_rasterizers.append(rasterizer) + batched_screenspace_params.append(screenspace_params) + batched_radii.append(radii) + batched_rgb.append(rgb) + batched_conic_opacity.append(conic_opacity) + batched_depths.append(depths) + + + end_event.record() + torch.cuda.synchronize() + elapsed_time_ms = start_event.elapsed_time(end_event) + print(f"Time taken by test_batched_gaussian_rasterizer: {elapsed_time_ms:.4f} ms") + # Perform further operations with the batched results + # Test results and performance + + batched_means2D = torch.stack(batched_means2D, dim=0) + batched_radii = torch.stack(batched_radii, dim=0) + batched_conic_opacity=torch.stack(batched_conic_opacity,dim=0) + batched_rgb=torch.stack(batched_rgb,dim=0) + batched_depths=torch.stack(batched_depths,dim=0) + + return batched_means2D, batched_radii, batched_screenspace_params,batched_conic_opacity,batched_rgb,batched_depths + + +def test_batched_gaussian_rasterizer_batch_processing(): + # Set up the input data + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() # Wait for the events to be recorded! + start_event.record() + # Set up the viewpoint cameras + batched_viewpoint_cameras = [] + for _ in range(num_batches): + viewpoint_camera = type('ViewpointCamera', (), {}) + viewpoint_camera.FoVx = math.radians(60) + viewpoint_camera.FoVy = math.radians(60) + viewpoint_camera.image_height = 512 + viewpoint_camera.image_width = 512 + viewpoint_camera.world_view_transform = torch.eye(4).cuda() + viewpoint_camera.full_proj_transform = torch.eye(4).cuda() + viewpoint_camera.camera_center = torch.zeros(3).cuda() + batched_viewpoint_cameras.append(viewpoint_camera) + + # Set up the strategies + batched_strategies = [None] * num_batches + + # Set up other parameters + bg_color = torch.ones(3).cuda() + scaling_modifier = 1.0 + pc = type('PC', (), {}) + pc.active_sh_degree = SH_ACTIVE_DEGREE + pipe = type('Pipe', (), {}) + pipe.debug = False + mode = "train" + + # Set up rasterization configuration for the batch + raster_settings_batch = [] + batched_cuda_args = [] + for i, (viewpoint_camera, strategy) in enumerate(zip(batched_viewpoint_cameras, batched_strategies)): + ########## [START] Prepare CUDA Rasterization Settings ########## + cuda_args = get_cuda_args(strategy, mode) + batched_cuda_args.append(cuda_args) + tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) + tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) + + raster_settings = GaussianRasterizationSettings( + image_height=int(batched_viewpoint_cameras[0].image_height), + image_width=int(batched_viewpoint_cameras[0].image_width), + tanfovx=tanfovx, + tanfovy=tanfovy, + bg=bg_color, + scale_modifier=scaling_modifier, + viewmatrix=viewpoint_camera.world_view_transform, + projmatrix=viewpoint_camera.full_proj_transform, + sh_degree=pc.active_sh_degree, + campos=viewpoint_camera.camera_center, + prefiltered=False, + debug=pipe.debug + ) + raster_settings_batch.append(raster_settings) + + # Create the GaussianRasterizer for the batch + rasterizer = GaussianRasterizerBatches(raster_settings_batch=raster_settings_batch) + + # Preprocess the Gaussians for the entire batch + batched_means2D, batched_rgb, batched_conic_opacity, batched_radii, batched_depths = rasterizer.preprocess_gaussians( + means3D=means3D, + scales=scales, + rotations=rotations, + shs=shs, + opacities=opacity, + batched_cuda_args=batched_cuda_args[0] #TODO: look into sending list of cuda_args/strategies + ) + end_event.record() + torch.cuda.synchronize() # Wait for the events to be recorded! + elapsed_time_ms = start_event.elapsed_time(end_event) + print(f"Time taken by test_batched_gaussian_rasterizer_batch_processing: {elapsed_time_ms:.4f} ms") + + # TODO: make the below work + # if mode == "train": + # batched_means2D.retain_grad() + + + # Perform assertions on the preprocessed data + + assert batched_means2D.shape == (num_batches, num_gaussians, 2) + assert batched_rgb.shape == (num_batches, num_gaussians, 3) + assert batched_conic_opacity.shape == (num_batches, num_gaussians,4) + assert batched_radii.shape == (num_batches, num_gaussians) + assert batched_depths.shape == (num_batches, num_gaussians) + + batched_screenspace_params = [] + for i in range(num_batches): + means2D = batched_means2D[i] + rgb = batched_rgb[i] + conic_opacity = batched_conic_opacity[i] + radii = batched_radii[i] + depths = batched_depths[i] + + screenspace_params = [means2D, rgb, conic_opacity, radii, depths] + batched_screenspace_params.append(screenspace_params) + + return batched_means2D, batched_radii, batched_screenspace_params, batched_conic_opacity,batched_rgb,batched_depths + + +def compare_tensors(tensor1, tensor2): + if tensor1.shape != tensor2.shape: + print("Tensors have different shapes:") + print("Tensor 1 shape:", tensor1.shape) + print("Tensor 2 shape:", tensor2.shape) + return False + + equality_matrix = torch.eq(tensor1, tensor2) + if torch.all(equality_matrix): + return True + else: + print("Tensors have non-matching values.") + non_matching_indices = torch.where(equality_matrix == False) + for idx in zip(*non_matching_indices[:5]): + value1 = tensor1[idx].item() + value2 = tensor2[idx].item() + print(f"Non-matching values at index {idx}: {value1} != {value2}") + return False + +if __name__ == "__main__": + batched_means2D, batched_radii, batched_screenspace_params,batched_conic_opacity,batched_rgb,batched_depths = test_batched_gaussian_rasterizer() + batched_means2D_batch_processed, batched_radii_batch_processed, batched_screenspace_params_batch_processed,batched_conic_opacity_batch_processed,batched_rgb_batch_processed,batched_depths_batch_processed = test_batched_gaussian_rasterizer_batch_processing() + + assert compare_tensors(batched_means2D, batched_means2D_batch_processed) + assert compare_tensors(batched_radii, batched_radii_batch_processed) + assert compare_tensors(batched_conic_opacity, batched_conic_opacity_batch_processed) + + assert compare_tensors(batched_rgb, batched_rgb_batch_processed) + assert compare_tensors(batched_depths, batched_depths_batch_processed) + assert len(batched_screenspace_params) == len(batched_screenspace_params_batch_processed) + + diff --git a/rasterize_points.cu b/rasterize_points.cu index e4400a6..e8eb8a7 100644 --- a/rasterize_points.cu +++ b/rasterize_points.cu @@ -142,6 +142,84 @@ PreprocessGaussiansCUDA( return std::make_tuple(rendered, means2D, depths, radii, cov3D, conic_opacity, rgb, clamped); } +std::tuple +PreprocessGaussiansCUDABatches( + const torch::Tensor& means3D, + const torch::Tensor& scales, + const torch::Tensor& rotations, + const torch::Tensor& sh, + const torch::Tensor& opacity,//3dgs' parametes. + const float scale_modifier, + const torch::Tensor& viewmatrix, + const torch::Tensor& projmatrix, + const torch::Tensor& tan_fovx, + const torch::Tensor& tan_fovy, + const int image_height, + const int image_width, + const int degree, + const torch::Tensor& campos, + const bool prefiltered,//raster_settings + const bool debug, + const pybind11::dict &args) { + + if (means3D.ndimension() != 2 || means3D.size(1) != 3) { + AT_ERROR("means3D must have dimensions (num_points, 3)"); + } + + const int P = means3D.size(0); + const int num_viewpoints = viewmatrix.size(0); + + // of shape (P, 2). means2D is (P, 2) in cuda. It will be converted to (P, 3) when is sent back to python to meet torch graph's requirement. + torch::Tensor means2D = torch::full({num_viewpoints, P, 2}, 0.0, means3D.options());//TODO: what about require_grads? + // of shape (P) + torch::Tensor depths = torch::full({num_viewpoints, P}, 0.0, means3D.options()); + // of shape (P) + torch::Tensor radii = torch::full({num_viewpoints, P}, 0, means3D.options().dtype(torch::kInt32)); + // of shape (P, 6) + torch::Tensor cov3D = torch::full({num_viewpoints, P, 6}, 0.0, means3D.options()); + // of shape (P, 4) + torch::Tensor conic_opacity = torch::full({num_viewpoints, P, 4}, 0.0, means3D.options()); + // of shape (P, 3) + torch::Tensor rgb = torch::full({num_viewpoints, P, 3}, 0.0, means3D.options()); + // of shape (P) + torch::Tensor clamped = torch::full({num_viewpoints, P, 3}, false, means3D.options().dtype(at::kBool)); + //TODO: compare to original GeometryState implementation, this one does not explicitly do gpu memory alignment. + //That may lead to problems. However, pytorch does implicit memory alignment. + + int rendered = 0;//TODO: I could compute rendered here by summing up geomState.tiles_touched. + if(P != 0) + { + int M = sh.size(0) != 0 ? sh.size(1) : 0; + + rendered = CudaRasterizer::Rasterizer::preprocessForwardBatches( + reinterpret_cast(means2D.contiguous().data()),//TODO: check whether it supports float2? + depths.contiguous().data(), + radii.contiguous().data(), + cov3D.contiguous().data(), + reinterpret_cast(conic_opacity.contiguous().data()), + rgb.contiguous().data(), + clamped.contiguous().data(), + P, degree, M, + image_width, image_height, + means3D.contiguous().data(), + scales.contiguous().data_ptr(), + rotations.contiguous().data_ptr(), + sh.contiguous().data_ptr(), + opacity.contiguous().data(), + scale_modifier, + viewmatrix.contiguous().data(), + projmatrix.contiguous().data(), + campos.contiguous().data(), + tan_fovx.contiguous().data(), + tan_fovy.contiguous().data(), + prefiltered, + num_viewpoints, + debug, + args); + } + return std::make_tuple(rendered, means2D, depths, radii, cov3D, conic_opacity, rgb, clamped); +} + std::tuple PreprocessGaussiansBackwardCUDA( diff --git a/rasterize_points.h b/rasterize_points.h index 86798ec..3700126 100644 --- a/rasterize_points.h +++ b/rasterize_points.h @@ -49,6 +49,26 @@ PreprocessGaussiansCUDA( const bool debug, const pybind11::dict &args); +std::tuple +PreprocessGaussiansCUDABatches( + const torch::Tensor& means3D, + const torch::Tensor& scales, + const torch::Tensor& rotations, + const torch::Tensor& sh, + const torch::Tensor& opacity,//3dgs' parametes. + const float scale_modifier, + const torch::Tensor& viewmatrix, + const torch::Tensor& projmatrix, + const torch::Tensor& tan_fovx, + const torch::Tensor& tan_fovy, + const int image_height, + const int image_width, + const int degree, + const torch::Tensor& campos, + const bool prefiltered,//raster_settings + const bool debug, + const pybind11::dict &args); + std::tuple PreprocessGaussiansBackwardCUDA( const torch::Tensor& radii, diff --git a/setup.py b/setup.py index 8c4d011..03b2df8 100644 --- a/setup.py +++ b/setup.py @@ -9,9 +9,11 @@ # For inquiries contact george.drettakis@inria.fr # -from setuptools import setup -from torch.utils.cpp_extension import CUDAExtension, BuildExtension import os + +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + os.path.dirname(os.path.abspath(__file__)) setup(