Skip to content

Commit 9da0c16

Browse files
committed
Add Reading of Depth Image
1 parent 179d768 commit 9da0c16

File tree

4 files changed

+23
-11
lines changed

4 files changed

+23
-11
lines changed

scene/dataset_readers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
1919
import numpy as np
2020
import json
21+
import imageio
2122
from pathlib import Path
2223
from plyfile import PlyData, PlyElement
2324
from utils.sh_utils import SH2RGB
@@ -190,8 +191,7 @@ def readCamerasFromTransforms(path, transformsfile, white_background, extension=
190191
for idx, frame in enumerate(frames):
191192
cam_name = os.path.join(path, frame["file_path"] + extension)
192193

193-
if is_test:
194-
depth_name = os.path.join(path, frame["file_path"] + "_depth_0001" + extension)
194+
depth_name = os.path.join(path, frame["file_path"] + "_depth0000" + '.exr')
195195

196196
matrix = np.linalg.inv(np.array(frame["transform_matrix"]))
197197
R = -np.transpose(matrix[:3,:3])
@@ -201,7 +201,7 @@ def readCamerasFromTransforms(path, transformsfile, white_background, extension=
201201
image_path = os.path.join(path, cam_name)
202202
image_name = Path(cam_name).stem
203203
image = Image.open(image_path)
204-
depth = Image.open(depth_name).convert('RGBA') if is_test else None
204+
depth = imageio.imread(depth_name)
205205

206206
im_data = np.array(image.convert("RGBA"))
207207

@@ -222,9 +222,9 @@ def readCamerasFromTransforms(path, transformsfile, white_background, extension=
222222

223223
def readNerfSyntheticInfo(path, white_background, eval, extension=".png"):
224224
print("Reading Training Transforms")
225-
train_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension)
225+
train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension)
226226
print("Reading Test Transforms")
227-
test_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension)
227+
test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension)
228228

229229
if not eval:
230230
train_cam_infos.extend(test_cam_infos)

train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations):
7171
viewpoint_stack = scene.getTrainCameras().copy()
7272

7373
viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
74-
gt_depth = viewpoint_cam.depth.unsqueeze(0)
74+
gt_depth = viewpoint_cam.depth
7575
# Render
7676
render_pkg = render(viewpoint_cam, gaussians, pipe, background)
7777
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
@@ -82,7 +82,7 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations):
8282
Ll1 = l1_loss(image, gt_image)
8383
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
8484
depth_loss = l1_loss(depth, gt_depth) * 0.1
85-
loss = loss + depth_loss
85+
# loss = loss + depth_loss
8686
loss.backward()
8787

8888
iter_end.record()

utils/camera_utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from scene.cameras import Camera
1313
import numpy as np
14-
from utils.general_utils import PILtoTorch
14+
from utils.general_utils import PILtoTorch, ArrayToTorch
1515
from utils.graphics_utils import fov2focal
1616

1717
WARNED = False
@@ -39,13 +39,16 @@ def loadCam(args, id, cam_info, resolution_scale):
3939
resolution = (int(orig_w / scale), int(orig_h / scale))
4040

4141
resized_image_rgb = PILtoTorch(cam_info.image, resolution)
42-
resized_depth_rgb = PILtoTorch(cam_info.depth, resolution) if cam_info.depth is not None else None
42+
if cam_info.depth is not None:
43+
resized_depth_rgb = ArrayToTorch(cam_info.depth, resolution)
44+
else:
45+
resized_depth_rgb = None
4346

4447
gt_image = resized_image_rgb[:3, ...]
4548
if resized_depth_rgb is not None:
46-
depth_mask = resized_depth_rgb[3, ...] > 0
49+
depth_mask = resized_depth_rgb[0, ...] > 60000
4750
gt_depth = resized_depth_rgb[0, ...]
48-
gt_depth[depth_mask] = 2. + 6. * (1 - gt_depth[depth_mask])
51+
gt_depth[depth_mask] = 0
4952
else:
5053
gt_depth = None
5154

utils/general_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@ def PILtoTorch(pil_image, resolution):
2626
else:
2727
return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
2828

29+
def ArrayToTorch(array, resolution):
30+
# resized_image = np.resize(array, resolution)
31+
resized_image_torch = torch.from_numpy(array)
32+
33+
if len(resized_image_torch.shape) == 3:
34+
return resized_image_torch.permute(2, 0, 1)
35+
else:
36+
return resized_image_torch.unsqueeze(dim=-1).permute(2, 0, 1)
37+
2938
def get_expon_lr_func(
3039
lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
3140
):

0 commit comments

Comments
 (0)