diff --git a/README.md b/README.md index 4cbd3326d..14278cd79 100644 --- a/README.md +++ b/README.md @@ -194,6 +194,10 @@ python train.py -s Influence of SSIM on total loss from 0 to 1, ```0.2``` by default. #### --percent_dense Percentage of scene extent (0--1) a point must exceed to be forcibly densified, ```0.01``` by default. + #### --data_dtype + The data type (float32, float16) in which images are stored when computing the loss. ```float32``` by default. + #### --store_images_as_uint8 + Flag that describes how to store images in memory. If set, the images will be stored as uint8, and will be converted to the target data type on demand.
diff --git a/arguments/__init__.py b/arguments/__init__.py index 1e13a551e..fdfe3fc65 100644 --- a/arguments/__init__.py +++ b/arguments/__init__.py @@ -53,6 +53,8 @@ def __init__(self, parser, sentinel=False): self._resolution = -1 self._white_background = False self.data_device = "cuda" + self.data_dtype = "float32" + self.store_images_as_uint8 = False self.eval = False super().__init__(parser, "Loading Parameters", sentinel) diff --git a/gaussian_renderer/__init__.py b/gaussian_renderer/__init__.py index f74e336af..3efab6e03 100644 --- a/gaussian_renderer/__init__.py +++ b/gaussian_renderer/__init__.py @@ -92,6 +92,10 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, rotations = rotations, cov3D_precomp = cov3D_precomp) + # after rasterization, we convert the resulting image to the target dtype + # The rasterizer expects parameters as float32, so the result is also float32. + rendered_image = rendered_image.to(viewpoint_camera.data_dtype) + # Those Gaussians that were frustum culled or had a radius of 0 were not visible. # They will be excluded from value updates used in the splitting criteria. return {"render": rendered_image, diff --git a/scene/cameras.py b/scene/cameras.py index abf6e5242..5264a04c6 100644 --- a/scene/cameras.py +++ b/scene/cameras.py @@ -17,7 +17,8 @@ class Camera(nn.Module): def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, image_name, uid, - trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda" + trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", data_dtype=torch.float32, + store_images_as_uint8=True, ): super(Camera, self).__init__() @@ -28,6 +29,7 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, self.FoVx = FoVx self.FoVy = FoVy self.image_name = image_name + self.data_dtype = data_dtype try: self.data_device = torch.device(data_device) @@ -36,14 +38,18 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) self.data_device = torch.device("cuda") - self.original_image = image.clamp(0.0, 1.0).to(self.data_device) - self.image_width = self.original_image.shape[2] - self.image_height = self.original_image.shape[1] + self.store_images_as_uint8 = store_images_as_uint8 - if gt_alpha_mask is not None: - self.original_image *= gt_alpha_mask.to(self.data_device) - else: - self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) + self._original_image = image.to(self.data_device) + self._gt_alpha_mask = gt_alpha_mask + if self._gt_alpha_mask is not None: + self._gt_alpha_mask = self._gt_alpha_mask.to(self.data_device) + + if not store_images_as_uint8: + self._original_image = self.convert_image(self._original_image) + + self.image_width = self._original_image.shape[2] + self.image_height = self._original_image.shape[1] self.zfar = 100.0 self.znear = 0.01 @@ -56,6 +62,23 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) self.camera_center = self.world_view_transform.inverse()[3, :3] + def convert_image(self, image): + image = (image / 255.0).clamp(0.0, 1.0).to(self.data_dtype) + gt_alpha_mask = self._gt_alpha_mask + + if gt_alpha_mask is not None: + gt_alpha_mask = gt_alpha_mask / 255.0 + image *= gt_alpha_mask.to(self.data_dtype) + + return image + + @property + def original_image(self): + if self.store_images_as_uint8: + return self.convert_image(self._original_image) + else: + return self._original_image + class MiniCam: def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): self.image_width = width diff --git a/utils/camera_utils.py b/utils/camera_utils.py index 1a54d0ada..6c886f8a2 100644 --- a/utils/camera_utils.py +++ b/utils/camera_utils.py @@ -11,7 +11,7 @@ from scene.cameras import Camera import numpy as np -from utils.general_utils import PILtoTorch +from utils.general_utils import PILtoTorch, get_data_dtype from utils.graphics_utils import fov2focal WARNED = False @@ -49,7 +49,9 @@ def loadCam(args, id, cam_info, resolution_scale): return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, FoVx=cam_info.FovX, FoVy=cam_info.FovY, image=gt_image, gt_alpha_mask=loaded_mask, - image_name=cam_info.image_name, uid=id, data_device=args.data_device) + image_name=cam_info.image_name, uid=id, data_device=args.data_device, + data_dtype=get_data_dtype(args.data_dtype), + store_images_as_uint8=args.store_images_as_uint8) def cameraList_from_camInfos(cam_infos, resolution_scale, args): camera_list = [] diff --git a/utils/general_utils.py b/utils/general_utils.py index 541c08252..f060e14ed 100644 --- a/utils/general_utils.py +++ b/utils/general_utils.py @@ -20,7 +20,7 @@ def inverse_sigmoid(x): def PILtoTorch(pil_image, resolution): resized_image_PIL = pil_image.resize(resolution) - resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 + resized_image = torch.from_numpy(np.array(resized_image_PIL)) if len(resized_image.shape) == 3: return resized_image.permute(2, 0, 1) else: @@ -131,3 +131,12 @@ def flush(self): np.random.seed(0) torch.manual_seed(0) torch.cuda.set_device(torch.device("cuda:0")) + +def get_data_dtype(dtype): + if dtype == "float32": + return torch.float32 + elif dtype == "float64": + return torch.float64 + elif dtype == "float16": + return torch.float16 + return torch.float32