Skip to content

Commit 179d768

Browse files
committed
Add depth Loss for 3D gaussian
1 parent 13c6602 commit 179d768

File tree

6 files changed

+30
-10
lines changed

6 files changed

+30
-10
lines changed

SIBR_viewers

Submodule SIBR_viewers updated from 440bd4c to 29dd2f3

scene/cameras.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
class Camera(nn.Module):
1818
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
1919
image_name, uid,
20-
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
20+
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", gt_depth=None
2121
):
2222
super(Camera, self).__init__()
2323

@@ -37,6 +37,7 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
3737
self.data_device = torch.device("cuda")
3838

3939
self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
40+
self.depth = gt_depth.to(self.data_device) if gt_depth is not None else None
4041
self.image_width = self.original_image.shape[2]
4142
self.image_height = self.original_image.shape[1]
4243

scene/dataset_readers.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class CameraInfo(NamedTuple):
3434
image_name: str
3535
width: int
3636
height: int
37+
depth: np.array
3738

3839
class SceneInfo(NamedTuple):
3940
point_cloud: BasicPointCloud
@@ -183,10 +184,15 @@ def readCamerasFromTransforms(path, transformsfile, white_background, extension=
183184
contents = json.load(json_file)
184185
fovx = contents["camera_angle_x"]
185186

187+
is_test = 'test' in transformsfile
188+
186189
frames = contents["frames"]
187190
for idx, frame in enumerate(frames):
188191
cam_name = os.path.join(path, frame["file_path"] + extension)
189192

193+
if is_test:
194+
depth_name = os.path.join(path, frame["file_path"] + "_depth_0001" + extension)
195+
190196
matrix = np.linalg.inv(np.array(frame["transform_matrix"]))
191197
R = -np.transpose(matrix[:3,:3])
192198
R[:,0] = -R[:,0]
@@ -195,6 +201,7 @@ def readCamerasFromTransforms(path, transformsfile, white_background, extension=
195201
image_path = os.path.join(path, cam_name)
196202
image_name = Path(cam_name).stem
197203
image = Image.open(image_path)
204+
depth = Image.open(depth_name).convert('RGBA') if is_test else None
198205

199206
im_data = np.array(image.convert("RGBA"))
200207

@@ -209,15 +216,15 @@ def readCamerasFromTransforms(path, transformsfile, white_background, extension=
209216
FovX = fovy
210217

211218
cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
212-
image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1]))
219+
image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1], depth=depth))
213220

214221
return cam_infos
215222

216223
def readNerfSyntheticInfo(path, white_background, eval, extension=".png"):
217224
print("Reading Training Transforms")
218-
train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension)
225+
train_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension)
219226
print("Reading Test Transforms")
220-
test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension)
227+
test_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension)
221228

222229
if not eval:
223230
train_cam_infos.extend(test_cam_infos)

train.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#
1+
7#
22
# Copyright (C) 2023, Inria
33
# GRAPHDECO research group, https://team.inria.fr/graphdeco
44
# All rights reserved.
@@ -69,16 +69,20 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations):
6969
# Pick a random Camera
7070
if not viewpoint_stack:
7171
viewpoint_stack = scene.getTrainCameras().copy()
72-
viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
7372

73+
viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
74+
gt_depth = viewpoint_cam.depth.unsqueeze(0)
7475
# Render
7576
render_pkg = render(viewpoint_cam, gaussians, pipe, background)
7677
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
78+
depth = render_pkg["depth"]
7779

7880
# Loss
7981
gt_image = viewpoint_cam.original_image.cuda()
8082
Ll1 = l1_loss(image, gt_image)
8183
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
84+
depth_loss = l1_loss(depth, gt_depth) * 0.1
85+
loss = loss + depth_loss
8286
loss.backward()
8387

8488
iter_end.record()
@@ -199,7 +203,7 @@ def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_i
199203
safe_state(args.quiet)
200204

201205
# Start GUI server, configure and run training
202-
network_gui.init(args.ip, args.port)
206+
# network_gui.init(args.ip, args.port)
203207
torch.autograd.set_detect_anomaly(args.detect_anomaly)
204208
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations)
205209

utils/camera_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,16 @@ def loadCam(args, id, cam_info, resolution_scale):
3939
resolution = (int(orig_w / scale), int(orig_h / scale))
4040

4141
resized_image_rgb = PILtoTorch(cam_info.image, resolution)
42+
resized_depth_rgb = PILtoTorch(cam_info.depth, resolution) if cam_info.depth is not None else None
4243

4344
gt_image = resized_image_rgb[:3, ...]
45+
if resized_depth_rgb is not None:
46+
depth_mask = resized_depth_rgb[3, ...] > 0
47+
gt_depth = resized_depth_rgb[0, ...]
48+
gt_depth[depth_mask] = 2. + 6. * (1 - gt_depth[depth_mask])
49+
else:
50+
gt_depth = None
51+
4452
loaded_mask = None
4553

4654
if resized_image_rgb.shape[1] == 4:
@@ -49,7 +57,7 @@ def loadCam(args, id, cam_info, resolution_scale):
4957
return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
5058
FoVx=cam_info.FovX, FoVy=cam_info.FovY,
5159
image=gt_image, gt_alpha_mask=loaded_mask,
52-
image_name=cam_info.image_name, uid=id, data_device=args.data_device)
60+
image_name=cam_info.image_name, uid=id, data_device=args.data_device, gt_depth=gt_depth)
5361

5462
def cameraList_from_camInfos(cam_infos, resolution_scale, args):
5563
camera_list = []

0 commit comments

Comments
 (0)