-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathsplatam.py
1141 lines (1018 loc) · 60.9 KB
/
splatam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import argparse
import os
import shutil
import sys
import time
from importlib.machinery import SourceFileLoader
_BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, _BASE_DIR)
print("System Paths:")
for p in sys.path:
print(p)
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
import wandb
from datasets.gradslam_datasets import (
load_dataset_config,
ICLDataset,
ReplicaDataset,
ReplicaV2Dataset,
AzureKinectDataset,
ScannetDataset,
Ai2thorDataset,
Record3DDataset,
RealsenseDataset,
TUMDataset,
ScannetPPDataset,
NeRFCaptureDataset
)
from utils.common_utils import seed_everything, save_params_ckpt, save_params
from utils.eval_helpers import report_loss, report_progress, eval
from utils.keyframe_selection import keyframe_selection_overlap
from utils.recon_helpers import setup_camera
from utils.slam_helpers import (
transformed_params2rendervar, transformed_params2depthplussilhouette,
transform_to_frame, l1_loss_v1, matrix_to_quaternion
)
from utils.slam_external import calc_ssim, build_rotation, prune_gaussians, densify
from diff_gaussian_rasterization import GaussianRasterizer as Renderer
def get_dataset(config_dict, basedir, sequence, **kwargs):
if config_dict["dataset_name"].lower() in ["icl"]:
return ICLDataset(config_dict, basedir, sequence, **kwargs)
elif config_dict["dataset_name"].lower() in ["replica"]:
return ReplicaDataset(config_dict, basedir, sequence, **kwargs)
elif config_dict["dataset_name"].lower() in ["replicav2"]:
return ReplicaV2Dataset(config_dict, basedir, sequence, **kwargs)
elif config_dict["dataset_name"].lower() in ["azure", "azurekinect"]:
return AzureKinectDataset(config_dict, basedir, sequence, **kwargs)
elif config_dict["dataset_name"].lower() in ["scannet"]:
return ScannetDataset(config_dict, basedir, sequence, **kwargs)
elif config_dict["dataset_name"].lower() in ["ai2thor"]:
return Ai2thorDataset(config_dict, basedir, sequence, **kwargs)
elif config_dict["dataset_name"].lower() in ["record3d"]:
return Record3DDataset(config_dict, basedir, sequence, **kwargs)
elif config_dict["dataset_name"].lower() in ["realsense"]:
return RealsenseDataset(config_dict, basedir, sequence, **kwargs)
elif config_dict["dataset_name"].lower() in ["tum"]:
return TUMDataset(config_dict, basedir, sequence, **kwargs)
elif config_dict["dataset_name"].lower() in ["scannetpp"]:
return ScannetPPDataset(basedir, sequence, **kwargs)
elif config_dict["dataset_name"].lower() in ["nerfcapture"]:
return NeRFCaptureDataset(basedir, sequence, **kwargs)
else:
raise ValueError(f"Unknown dataset name {config_dict['dataset_name']}")
# 函数目的:从给定的颜色color、深度depth、相机内参instrinsics、世界到相机的转换矩阵w2c等必要传参中,获取点云
def get_pointcloud(color, depth, intrinsics, w2c, transform_pts=True,
mask=None, compute_mean_sq_dist=False, mean_sq_dist_method="projective"):
# A.内参解析
width, height = color.shape[2], color.shape[1]
CX = intrinsics[0][2]
CY = intrinsics[1][2]
FX = intrinsics[0][0]
FY = intrinsics[1][1]
# B.像素网格生成
# Compute indices of pixels
x_grid, y_grid = torch.meshgrid(torch.arange(width).cuda().float(),
torch.arange(height).cuda().float(),
indexing='xy')
xx = (x_grid - CX)/FX
yy = (y_grid - CY)/FY
xx = xx.reshape(-1)
yy = yy.reshape(-1)
depth_z = depth[0].reshape(-1)
# C.点云初始化
# Initialize point cloud
pts_cam = torch.stack((xx * depth_z, yy * depth_z, depth_z), dim=-1)
if transform_pts:
pix_ones = torch.ones(height * width, 1).cuda().float()
pts4 = torch.cat((pts_cam, pix_ones), dim=1)
c2w = torch.inverse(w2c)
pts = (c2w @ pts4.T).T[:, :3]
else:
pts = pts_cam
# Optional: Compute mean squared distance for initializing the scale of the Gaussians
if compute_mean_sq_dist:
if mean_sq_dist_method == "projective":
# Projective Geometry (this is fast, farther -> larger radius)
scale_gaussian = depth_z / ((FX + FY)/2)
mean3_sq_dist = scale_gaussian**2
else:
raise ValueError(f"Unknown mean_sq_dist_method {mean_sq_dist_method}")
# D.点云着色
# Colorize point cloud
cols = torch.permute(color, (1, 2, 0)).reshape(-1, 3) # (C, H, W) -> (H, W, C) -> (H * W, C)
point_cld = torch.cat((pts, cols), -1)
# Optional: Select points based on mask
if mask is not None:
point_cld = point_cld[mask]
if compute_mean_sq_dist:
mean3_sq_dist = mean3_sq_dist[mask]
if compute_mean_sq_dist:
return point_cld, mean3_sq_dist
else:
return point_cld
# 函数目的:初始化高斯分布参数
def initialize_params(init_pt_cld, num_frames, mean3_sq_dist):
# A.基本参数设置
num_pts = init_pt_cld.shape[0]
means3D = init_pt_cld[:, :3] # [num_gaussians, 3]
unnorm_rots = np.tile([1, 0, 0, 0], (num_pts, 1)) # [num_gaussians, 3]
logit_opacities = torch.zeros((num_pts, 1), dtype=torch.float, device="cuda")
# 3D Gaussian待优化的参数
params = {
'means3D': means3D,
'rgb_colors': init_pt_cld[:, 3:6],
'unnorm_rotations': unnorm_rots,
'logit_opacities': logit_opacities,
'log_scales': torch.tile(torch.log(torch.sqrt(mean3_sq_dist))[..., None], (1, 1)),
}
# B.相机参数设置
# Initialize a single gaussian trajectory to model the camera poses relative to the first frame
cam_rots = np.tile([1, 0, 0, 0], (1, 1))
cam_rots = np.tile(cam_rots[:, :, None], (1, 1, num_frames))
params['cam_unnorm_rots'] = cam_rots
params['cam_trans'] = np.zeros((1, 3, num_frames))
# C.参数转换为PyTorch张量
for k, v in params.items():
# Check if value is already a torch tensor
if not isinstance(v, torch.Tensor):
params[k] = torch.nn.Parameter(torch.tensor(v).cuda().float().contiguous().requires_grad_(True))
else:
params[k] = torch.nn.Parameter(v.cuda().float().contiguous().requires_grad_(True))
variables = {'max_2D_radius': torch.zeros(params['means3D'].shape[0]).cuda().float(),
'means2D_gradient_accum': torch.zeros(params['means3D'].shape[0]).cuda().float(),
'denom': torch.zeros(params['means3D'].shape[0]).cuda().float(),
'timestep': torch.zeros(params['means3D'].shape[0]).cuda().float()}
return params, variables
def initialize_optimizer(params, lrs_dict, tracking):
lrs = lrs_dict
param_groups = [{'params': [v], 'name': k, 'lr': lrs[k]} for k, v in params.items()]
if tracking:
return torch.optim.Adam(param_groups)
else:
return torch.optim.Adam(param_groups, lr=0.0, eps=1e-15)
# 函数目的:在全局第一次第一帧做初始化,初始化高斯点云
def initialize_first_timestep(dataset, num_frames, scene_radius_depth_ratio, mean_sq_dist_method, densify_dataset=None):
# Get RGB-D Data & Camera Parameters
# A.数据获取:从数据集中获取第一帧RGB-D数据(颜色、深度)、相机内参和相机位姿
color, depth, intrinsics, pose = dataset[0]
# B.数据处理、格式处理、各项设置
# Process RGB-D Data
color = color.permute(2, 0, 1) / 255 # (H, W, C) -> (C, H, W)
depth = depth.permute(2, 0, 1) # (H, W, C) -> (C, H, W)
# Process Camera Parameters
intrinsics = intrinsics[:3, :3]
w2c = torch.linalg.inv(pose)
# Setup Camera
cam = setup_camera(color.shape[2], color.shape[1], intrinsics.cpu().numpy(), w2c.detach().cpu().numpy())
# C.密集化处理:如果传参提供了密集化数据集,则做相应处理
if densify_dataset is not None:
# Get Densification RGB-D Data & Camera Parameters
color, depth, densify_intrinsics, _ = densify_dataset[0]
color = color.permute(2, 0, 1) / 255 # (H, W, C) -> (C, H, W)
depth = depth.permute(2, 0, 1) # (H, W, C) -> (C, H, W)
densify_intrinsics = densify_intrinsics[:3, :3]
densify_cam = setup_camera(color.shape[2], color.shape[1], densify_intrinsics.cpu().numpy(), w2c.detach().cpu().numpy())
else:
densify_intrinsics = intrinsics
# D.初始化点云和初始化参数,重点函数get_pointcloud()和initialize_params()
# Get Initial Point Cloud (PyTorch CUDA Tensor)
mask = (depth > 0) # Mask out invalid depth values
mask = mask.reshape(-1)
init_pt_cld, mean3_sq_dist = get_pointcloud(color, depth, densify_intrinsics, w2c,
mask=mask, compute_mean_sq_dist=True,
mean_sq_dist_method=mean_sq_dist_method)
# Initialize Parameters
params, variables = initialize_params(init_pt_cld, num_frames, mean3_sq_dist)
# Initialize an estimate of scene radius for Gaussian-Splatting Densification
variables['scene_radius'] = torch.max(depth)/scene_radius_depth_ratio
if densify_dataset is not None:
return params, variables, intrinsics, w2c, cam, densify_intrinsics, densify_cam
else:
return params, variables, intrinsics, w2c, cam
# 函数目的:在Tracking、Mapping的过程中计算当前帧的loss
# 函数的输入:相机参数 params、当前数据 curr_data、一些中间变量 variables、迭代的时间索引 iter_time_idx、损失权重 loss_weights、是否使用深度图用于损失计算 use_sil_for_loss、阈值 sil_thres 等等。
def get_loss(params, curr_data, variables, iter_time_idx, loss_weights, use_sil_for_loss,
sil_thres, use_l1,ignore_outlier_depth_loss, tracking=False,
mapping=False, do_ba=False, plot_dir=None, visualize_tracking_loss=False, tracking_iteration=None):
# Initialize Loss Dictionary
losses = {}
# tracking的时候camera pose需要计算梯度
if tracking:
# Get current frame Gaussians, where only the camera pose gets gradient
# transform_to_frame()函数执行了从世界坐标系到相机坐标系的高斯分布中心点的转换操作,同时考虑了是否需要计算梯度
transformed_pts = transform_to_frame(params, iter_time_idx,
gaussians_grad=False,
camera_grad=True)
elif mapping:
# mapping的时候BA优化,则高斯和pose的梯度都要优化
# 但do_ba一直是False,不执行
if do_ba:
# Get current frame Gaussians, where both camera pose and Gaussians get gradient
transformed_pts = transform_to_frame(params, iter_time_idx,
gaussians_grad=True,
camera_grad=True)
# 单纯的mapping则只需要优化高斯的梯度
else:
# Get current frame Gaussians, where only the Gaussians get gradient
transformed_pts = transform_to_frame(params, iter_time_idx,
gaussians_grad=True,
camera_grad=False)
else:
# Get current frame Gaussians, where only the Gaussians get gradient
transformed_pts = transform_to_frame(params, iter_time_idx,
gaussians_grad=True,
camera_grad=False)
# Initialize Render Variables
rendervar = transformed_params2rendervar(params, transformed_pts)
depth_sil_rendervar = transformed_params2depthplussilhouette(params, curr_data['w2c'],
transformed_pts)
# RGB Rendering
rendervar['means2D'].retain_grad()
# 使用渲染器 Renderer 对当前帧进行RGB渲染,得到RGB图像 im、半径信息 radius
im, radius, _, = Renderer(raster_settings=curr_data['cam'])(**rendervar)
variables['means2D'] = rendervar['means2D'] # Gradient only accum from colour render for densification
# Depth & Silhouette Rendering
# 使用渲染器 Renderer 对当前帧进行深度和轮廓渲染,得到深度轮廓图 depth_sil
depth_sil, _, _, = Renderer(raster_settings=curr_data['cam'])(**depth_sil_rendervar)
# 从深度轮廓图中提取深度信息 depth,轮廓信息 silhouette,以及深度的平方 depth_sq
depth = depth_sil[0, :, :].unsqueeze(0)
silhouette = depth_sil[1, :, :]
presence_sil_mask = (silhouette > sil_thres)
depth_sq = depth_sil[2, :, :].unsqueeze(0)
# 计算深度的不确定性,即深度平方的差值,然后将其分离出来并进行 detach 操作(不计算梯度)
uncertainty = depth_sq - depth**2
uncertainty = uncertainty.detach()
# Mask with valid depth values (accounts for outlier depth values)
# 建一个 nan_mask,用于标记深度和不确定性的有效值,避免处理异常值
nan_mask = (~torch.isnan(depth)) & (~torch.isnan(uncertainty))
# 如果开启了 ignore_outlier_depth_loss,则基于深度误差生成一个新的掩码 mask,并且该掩码会剔除深度值异常的区域
if ignore_outlier_depth_loss:
depth_error = torch.abs(curr_data['depth'] - depth) * (curr_data['depth'] > 0)
mask = (depth_error < 10*depth_error.median())
mask = mask & (curr_data['depth'] > 0)
# 以configs/replica/splatam.py为例,tracking和mapping的ignore_outlier_depth_loss都是False
# 实际是没有开启 ignore_outlier_depth_loss,则直接使用深度大于零的区域作为 mask
else:
mask = (curr_data['depth'] > 0)
mask = mask & nan_mask
# Mask with presence silhouette mask (accounts for empty space)
# 如果在跟踪模式下且开启了使用轮廓图进行损失计算 (use_sil_for_loss),则将 mask 与轮廓图的存在性掩码 presence_sil_mask 做 &(与)操作
# 以configs/replica/splatam.py为例,tracking的use_sil_for_loss=True, mapping的use_sil_for_loss=False
if tracking and use_sil_for_loss:
mask = mask & presence_sil_mask
# Depth loss 计算深度的loss
# 如果使用L1损失 (use_l1),则将 mask 进行 detach 操作,即不计算其梯度
if use_l1:
mask = mask.detach()
if tracking:
# 如果在Tracking环节,计算深度损失 (losses['depth']) 为当前深度图与渲染深度图之间差值的绝对值之和(只考虑掩码内的区域)
losses['depth'] = torch.abs(curr_data['depth'] - depth)[mask].sum()
else:
# 如果不在Tracking环节,计算深度损失为当前深度图与渲染深度图之间差值的绝对值的平均值(只考虑掩码内的区域)
# 对于Mapping环节,其mask不需要和presence_sil_mask进行与(&)操作,其mask是直接使用深度大于零的区域
# mask没有使用到silhouette,对应原论文的“we want to optimize over all pixels”
losses['depth'] = torch.abs(curr_data['depth'] - depth)[mask].mean()
# RGB Loss
# 如果在跟踪模式下 (tracking) 并且使用轮廓图进行损失计算 (use_sil_for_loss) 或者忽略异常深度值 (ignore_outlier_depth_loss)
# 计算RGB损失 (losses['im']) 为当前图像与渲染图像之间差值的绝对值之和
if tracking and (use_sil_for_loss or ignore_outlier_depth_loss):
color_mask = torch.tile(mask, (3, 1, 1))
color_mask = color_mask.detach()
losses['im'] = torch.abs(curr_data['im'] - im)[color_mask].sum()
elif tracking:
# 如果在Tracking环节,但没有使用轮廓图进行损失计算,计算RGB损失为当前图像与渲染图像之间差值的绝对值之和
losses['im'] = torch.abs(curr_data['im'] - im).sum()
else:
# 如果在Mapping环节,计算RGB损失为L1损失和SSIM(结构相似性损失)的加权和
# 其中 l1_loss_v1 是L1损失的计算函数,calc_ssim 是结构相似性损失的计算函数,这部分参照3DGS原论文的设计
losses['im'] = 0.8 * l1_loss_v1(im, curr_data['im']) + 0.2 * (1.0 - calc_ssim(im, curr_data['im']))
# 可视化
# Visualize the Diff Images
if tracking and visualize_tracking_loss:
fig, ax = plt.subplots(2, 4, figsize=(12, 6))
weighted_render_im = im * color_mask
weighted_im = curr_data['im'] * color_mask
weighted_render_depth = depth * mask
weighted_depth = curr_data['depth'] * mask
diff_rgb = torch.abs(weighted_render_im - weighted_im).mean(dim=0).detach().cpu()
diff_depth = torch.abs(weighted_render_depth - weighted_depth).mean(dim=0).detach().cpu()
viz_img = torch.clip(weighted_im.permute(1, 2, 0).detach().cpu(), 0, 1)
ax[0, 0].imshow(viz_img)
ax[0, 0].set_title("Weighted GT RGB")
viz_render_img = torch.clip(weighted_render_im.permute(1, 2, 0).detach().cpu(), 0, 1)
ax[1, 0].imshow(viz_render_img)
ax[1, 0].set_title("Weighted Rendered RGB")
ax[0, 1].imshow(weighted_depth[0].detach().cpu(), cmap="jet", vmin=0, vmax=6)
ax[0, 1].set_title("Weighted GT Depth")
ax[1, 1].imshow(weighted_render_depth[0].detach().cpu(), cmap="jet", vmin=0, vmax=6)
ax[1, 1].set_title("Weighted Rendered Depth")
ax[0, 2].imshow(diff_rgb, cmap="jet", vmin=0, vmax=0.8)
ax[0, 2].set_title(f"Diff RGB, Loss: {torch.round(losses['im'])}")
ax[1, 2].imshow(diff_depth, cmap="jet", vmin=0, vmax=0.8)
ax[1, 2].set_title(f"Diff Depth, Loss: {torch.round(losses['depth'])}")
ax[0, 3].imshow(presence_sil_mask.detach().cpu(), cmap="gray")
ax[0, 3].set_title("Silhouette Mask")
ax[1, 3].imshow(mask[0].detach().cpu(), cmap="gray")
ax[1, 3].set_title("Loss Mask")
# Turn off axis
for i in range(2):
for j in range(4):
ax[i, j].axis('off')
# Set Title
fig.suptitle(f"Tracking Iteration: {tracking_iteration}", fontsize=16)
# Figure Tight Layout
fig.tight_layout()
os.makedirs(plot_dir, exist_ok=True)
plt.savefig(os.path.join(plot_dir, f"tmp.png"), bbox_inches='tight')
plt.close()
plot_img = cv2.imread(os.path.join(plot_dir, f"tmp.png"))
cv2.imshow('Diff Images', plot_img)
cv2.waitKey(1)
## Save Tracking Loss Viz
# save_plot_dir = os.path.join(plot_dir, f"tracking_%04d" % iter_time_idx)
# os.makedirs(save_plot_dir, exist_ok=True)
# plt.savefig(os.path.join(save_plot_dir, f"%04d.png" % tracking_iteration), bbox_inches='tight')
# plt.close()
# 对每个损失项按照其权重进行加权,得到 weighted_losses 字典,其中 k 是损失项的名称,v 是对应的损失值,loss_weights 是各个损失项的权重
weighted_losses = {k: v * loss_weights[k] for k, v in losses.items()}
# 最终损失值 loss 是加权损失项的和
loss = sum(weighted_losses.values())
seen = radius > 0 # 创建一个布尔掩码 seen,其中对应的位置为 True 表示在当前迭代中观察到了某个点
variables['max_2D_radius'][seen] = torch.max(radius[seen], variables['max_2D_radius'][seen])
variables['seen'] = seen
weighted_losses['loss'] = loss
# 输出结果是:最终损失项loss,变量variables,加权损失项字典weighted_losses
return loss, variables, weighted_losses
# 函数目的:初始化新的高斯分布参数
def initialize_new_params(new_pt_cld, mean3_sq_dist):
num_pts = new_pt_cld.shape[0]
means3D = new_pt_cld[:, :3] # [num_gaussians, 3]
unnorm_rots = np.tile([1, 0, 0, 0], (num_pts, 1)) # [num_gaussians, 3]
logit_opacities = torch.zeros((num_pts, 1), dtype=torch.float, device="cuda")
# 3D Gaussian待优化的参数
params = {
'means3D': means3D,
'rgb_colors': new_pt_cld[:, 3:6],
'unnorm_rotations': unnorm_rots,
'logit_opacities': logit_opacities,
'log_scales': torch.tile(torch.log(torch.sqrt(mean3_sq_dist))[..., None], (1, 1)),
}
for k, v in params.items():
# Check if value is already a torch tensor
if not isinstance(v, torch.Tensor):
params[k] = torch.nn.Parameter(torch.tensor(v).cuda().float().contiguous().requires_grad_(True))
else:
params[k] = torch.nn.Parameter(v.cuda().float().contiguous().requires_grad_(True))
return params
# 函数目的:根据输入的深度图,通过阈值 config['mapping']['sil_thres'] 等一系列操作生成掩码,然后在场景中添加新的高斯分布。这些高斯分布代表了场景中的新结构。
# 传参:当前模型参数 params、变量 variables、密集化数据 densify_curr_data,以及一些配置参数,如阈值、时间索引等。
def add_new_gaussians(params, variables, curr_data, sil_thres, time_idx, mean_sq_dist_method):
# Silhouette Rendering
transformed_pts = transform_to_frame(params, time_idx, gaussians_grad=False, camera_grad=False)
depth_sil_rendervar = transformed_params2depthplussilhouette(params, curr_data['w2c'],
transformed_pts)
# 通过渲染器 Renderer 得到深度图和轮廓图,其中 depth_sil 包含了深度信息和轮廓信息,silhouette取出轮廓信息
depth_sil, _, _, = Renderer(raster_settings=curr_data['cam'])(**depth_sil_rendervar)
silhouette = depth_sil[1, :, :]
non_presence_sil_mask = (silhouette < sil_thres) # S(p) < 0.5
# Check for new foreground objects by using GT depth
gt_depth = curr_data['depth'][0, :, :]
render_depth = depth_sil[0, :, :]
depth_error = torch.abs(gt_depth - render_depth) * (gt_depth > 0)
non_presence_depth_mask = (render_depth > gt_depth) * (depth_error > 50*depth_error.median())
# Determine non-presence mask
# A.致密化Mask M(p)的创建:对应论文的公式(9)
non_presence_mask = non_presence_sil_mask | non_presence_depth_mask
# Flatten mask
non_presence_mask = non_presence_mask.reshape(-1)
# Get the new frame Gaussians based on the Silhouette
# B.对于每个像素,基于此掩模,生成新的高斯分布参数,并将这些参数添加到原有的高斯分布参数中
if torch.sum(non_presence_mask) > 0:
# Get the new pointcloud in the world frame
curr_cam_rot = torch.nn.functional.normalize(params['cam_unnorm_rots'][..., time_idx].detach())
curr_cam_tran = params['cam_trans'][..., time_idx].detach()
curr_w2c = torch.eye(4).cuda().float()
curr_w2c[:3, :3] = build_rotation(curr_cam_rot)
curr_w2c[:3, 3] = curr_cam_tran
valid_depth_mask = (curr_data['depth'][0, :, :] > 0)
non_presence_mask = non_presence_mask & valid_depth_mask.reshape(-1)
# 重点函数:get_pointcloud()与initialize_new_params()
new_pt_cld, mean3_sq_dist = get_pointcloud(curr_data['im'], curr_data['depth'], curr_data['intrinsics'],
curr_w2c, mask=non_presence_mask, compute_mean_sq_dist=True,
mean_sq_dist_method=mean_sq_dist_method)
new_params = initialize_new_params(new_pt_cld, mean3_sq_dist)
for k, v in new_params.items():
params[k] = torch.nn.Parameter(torch.cat((params[k], v), dim=0).requires_grad_(True))
num_pts = params['means3D'].shape[0]
variables['means2D_gradient_accum'] = torch.zeros(num_pts, device="cuda").float()
variables['denom'] = torch.zeros(num_pts, device="cuda").float()
variables['max_2D_radius'] = torch.zeros(num_pts, device="cuda").float()
new_timestep = time_idx*torch.ones(new_pt_cld.shape[0],device="cuda").float()
variables['timestep'] = torch.cat((variables['timestep'],new_timestep),dim=0)
return params, variables
# 函数作用:用于初始化相机姿态
# 根据当前时间(使用的是curr_time_idx索引)初始化相机的旋转(cam_unnorm_rots)和平移参数(cam_trans)
def initialize_camera_pose(params, curr_time_idx, forward_prop):
with torch.no_grad(): # 用来确保在这个上下文中,没有梯度计算
if curr_time_idx > 1 and forward_prop: # 检查当前时间索引 curr_time_idx 是否大于 1,是否使用了向前传播
# Initialize the camera pose for the current frame based on a constant velocity model
# 使用恒速运动模型初始化相机姿态
# Rotation
# 通过前两帧的旋转计算出当前帧的新旋转
prev_rot1 = F.normalize(params['cam_unnorm_rots'][..., curr_time_idx-1].detach())
prev_rot2 = F.normalize(params['cam_unnorm_rots'][..., curr_time_idx-2].detach())
new_rot = F.normalize(prev_rot1 + (prev_rot1 - prev_rot2))
params['cam_unnorm_rots'][..., curr_time_idx] = new_rot.detach()
# Translation
# 通过前两帧的平移计算出当前帧的新平移
prev_tran1 = params['cam_trans'][..., curr_time_idx-1].detach()
prev_tran2 = params['cam_trans'][..., curr_time_idx-2].detach()
new_tran = prev_tran1 + (prev_tran1 - prev_tran2)
params['cam_trans'][..., curr_time_idx] = new_tran.detach()
else:
# Initialize the camera pose for the current frame
# 否则,直接复制前一帧的相机姿态到当前帧
params['cam_unnorm_rots'][..., curr_time_idx] = params['cam_unnorm_rots'][..., curr_time_idx-1].detach()
params['cam_trans'][..., curr_time_idx] = params['cam_trans'][..., curr_time_idx-1].detach()
return params
def convert_params_to_store(params):
params_to_store = {}
for k, v in params.items():
if isinstance(v, torch.Tensor):
params_to_store[k] = v.detach().clone()
else:
params_to_store[k] = v
return params_to_store
# SplaTAM的核心处理模块入口,内有500多行
def rgbd_slam(config: dict):
# Print Config 打印配置信息
print("Loaded Config:")
if "use_depth_loss_thres" not in config['tracking']:
config['tracking']['use_depth_loss_thres'] = False
config['tracking']['depth_loss_thres'] = 100000
if "visualize_tracking_loss" not in config['tracking']:
config['tracking']['visualize_tracking_loss'] = False
print(f"{config}")
# Create Output Directories 创建输出目录
output_dir = os.path.join(config["workdir"], config["run_name"])
eval_dir = os.path.join(output_dir, "eval")
os.makedirs(eval_dir, exist_ok=True)
# Init WandB
# 满足config条件的时候,会初始化WandB
if config['use_wandb']:
wandb_time_step = 0
wandb_tracking_step = 0
wandb_mapping_step = 0
wandb_run = wandb.init(project=config['wandb']['project'],
entity=config['wandb']['entity'],
group=config['wandb']['group'],
name=config['wandb']['name'],
config=config)
# 加载设备和数据集(相关代码较长,中间涉及到几个环节的Init seperate dataloader)
# Get Device
device = torch.device(config["primary_device"])
# Load Dataset
print("Loading Dataset ...")
dataset_config = config["data"]
if "gradslam_data_cfg" not in dataset_config:
gradslam_data_cfg = {}
gradslam_data_cfg["dataset_name"] = dataset_config["dataset_name"]
else:
gradslam_data_cfg = load_dataset_config(dataset_config["gradslam_data_cfg"])
if "ignore_bad" not in dataset_config:
dataset_config["ignore_bad"] = False
if "use_train_split" not in dataset_config:
dataset_config["use_train_split"] = True
if "densification_image_height" not in dataset_config:
dataset_config["densification_image_height"] = dataset_config["desired_image_height"]
dataset_config["densification_image_width"] = dataset_config["desired_image_width"]
seperate_densification_res = False
else:
if dataset_config["densification_image_height"] != dataset_config["desired_image_height"] or \
dataset_config["densification_image_width"] != dataset_config["desired_image_width"]:
seperate_densification_res = True
else:
seperate_densification_res = False
if "tracking_image_height" not in dataset_config:
dataset_config["tracking_image_height"] = dataset_config["desired_image_height"]
dataset_config["tracking_image_width"] = dataset_config["desired_image_width"]
seperate_tracking_res = False
else:
if dataset_config["tracking_image_height"] != dataset_config["desired_image_height"] or \
dataset_config["tracking_image_width"] != dataset_config["desired_image_width"]:
seperate_tracking_res = True
else:
seperate_tracking_res = False
# Poses are relative to the first frame
dataset = get_dataset(
config_dict=gradslam_data_cfg,
basedir=dataset_config["basedir"],
sequence=os.path.basename(dataset_config["sequence"]),
start=dataset_config["start"],
end=dataset_config["end"],
stride=dataset_config["stride"],
desired_height=dataset_config["desired_image_height"],
desired_width=dataset_config["desired_image_width"],
device=device,
relative_pose=True,
ignore_bad=dataset_config["ignore_bad"],
use_train_split=dataset_config["use_train_split"],
)
num_frames = dataset_config["num_frames"]
if num_frames == -1:
num_frames = len(dataset)
# Init seperate dataloader for densification if required
if seperate_densification_res:
densify_dataset = get_dataset(
config_dict=gradslam_data_cfg,
basedir=dataset_config["basedir"],
sequence=os.path.basename(dataset_config["sequence"]),
start=dataset_config["start"],
end=dataset_config["end"],
stride=dataset_config["stride"],
desired_height=dataset_config["densification_image_height"],
desired_width=dataset_config["densification_image_width"],
device=device,
relative_pose=True,
ignore_bad=dataset_config["ignore_bad"],
use_train_split=dataset_config["use_train_split"],
)
# Initialize Parameters, Canonical & Densification Camera parameters
params, variables, intrinsics, first_frame_w2c, cam, \
densify_intrinsics, densify_cam = initialize_first_timestep(dataset, num_frames,
config['scene_radius_depth_ratio'],
config['mean_sq_dist_method'],
densify_dataset=densify_dataset)
# initialize_first_timestep()函数对第一帧做Map Initialization的地方;
# 上方和下方都使用到了initialize_first_timestep(),在传参处有区别;
# 满足if判断,则densify_dataset作为要密集化的数据集传入;
# 否则else,densify_dataset默认置为None
else:
# Initialize Parameters & Canoncial Camera parameters
params, variables, intrinsics, first_frame_w2c, cam = initialize_first_timestep(dataset, num_frames,
config['scene_radius_depth_ratio'],
config['mean_sq_dist_method'])
# Init seperate dataloader for tracking if required
if seperate_tracking_res:
tracking_dataset = get_dataset(
config_dict=gradslam_data_cfg,
basedir=dataset_config["basedir"],
sequence=os.path.basename(dataset_config["sequence"]),
start=dataset_config["start"],
end=dataset_config["end"],
stride=dataset_config["stride"],
desired_height=dataset_config["tracking_image_height"],
desired_width=dataset_config["tracking_image_width"],
device=device,
relative_pose=True,
ignore_bad=dataset_config["ignore_bad"],
use_train_split=dataset_config["use_train_split"],
)
tracking_color, _, tracking_intrinsics, _ = tracking_dataset[0]
tracking_color = tracking_color.permute(2, 0, 1) / 255 # (H, W, C) -> (C, H, W)
tracking_intrinsics = tracking_intrinsics[:3, :3]
tracking_cam = setup_camera(tracking_color.shape[2], tracking_color.shape[1],
tracking_intrinsics.cpu().numpy(), first_frame_w2c.detach().cpu().numpy())
# Initialize list to keep track of Keyframes
keyframe_list = []
keyframe_time_indices = []
# Init Variables to keep track of ground truth poses and runtimes
gt_w2c_all_frames = []
tracking_iter_time_sum = 0
tracking_iter_time_count = 0
mapping_iter_time_sum = 0
mapping_iter_time_count = 0
tracking_frame_time_sum = 0
tracking_frame_time_count = 0
mapping_frame_time_sum = 0
mapping_frame_time_count = 0
# Load Checkpoint
if config['load_checkpoint']:
checkpoint_time_idx = config['checkpoint_time_idx']
print(f"Loading Checkpoint for Frame {checkpoint_time_idx}")
ckpt_path = os.path.join(config['workdir'], config['run_name'], f"params{checkpoint_time_idx}.npz")
params = dict(np.load(ckpt_path, allow_pickle=True))
params = {k: torch.tensor(params[k]).cuda().float().requires_grad_(True) for k in params.keys()}
variables['max_2D_radius'] = torch.zeros(params['means3D'].shape[0]).cuda().float()
variables['means2D_gradient_accum'] = torch.zeros(params['means3D'].shape[0]).cuda().float()
variables['denom'] = torch.zeros(params['means3D'].shape[0]).cuda().float()
variables['timestep'] = torch.zeros(params['means3D'].shape[0]).cuda().float()
# Load the keyframe time idx list
keyframe_time_indices = np.load(os.path.join(config['workdir'], config['run_name'], f"keyframe_time_indices{checkpoint_time_idx}.npy"))
keyframe_time_indices = keyframe_time_indices.tolist()
# Update the ground truth poses list
for time_idx in range(checkpoint_time_idx):
# Load RGBD frames incrementally instead of all frames
color, depth, _, gt_pose = dataset[time_idx]
# Process poses
gt_w2c = torch.linalg.inv(gt_pose)
gt_w2c_all_frames.append(gt_w2c)
# Initialize Keyframe List
if time_idx in keyframe_time_indices:
# Get the estimated rotation & translation
curr_cam_rot = F.normalize(params['cam_unnorm_rots'][..., time_idx].detach())
curr_cam_tran = params['cam_trans'][..., time_idx].detach()
curr_w2c = torch.eye(4).cuda().float()
curr_w2c[:3, :3] = build_rotation(curr_cam_rot)
curr_w2c[:3, 3] = curr_cam_tran
# Initialize Keyframe Info
color = color.permute(2, 0, 1) / 255
depth = depth.permute(2, 0, 1)
curr_keyframe = {'id': time_idx, 'est_w2c': curr_w2c, 'color': color, 'depth': depth}
# Add to keyframe list
keyframe_list.append(curr_keyframe)
else:
checkpoint_time_idx = 0
# ******************* 重点:迭代处理RGB-D帧,进行跟踪(Tracking)和建图(Mapping)*******************
# Iterate over Scan
for time_idx in tqdm(range(checkpoint_time_idx, num_frames)): # 循环迭代处理 RGB-D 帧,循环的起始索引是 checkpoint_time_idx(也就是是否从某帧开始,一般都是0开始),终止索引是 num_frames
# Load RGBD frames incrementally instead of all frames
color, depth, _, gt_pose = dataset[time_idx] # 从数据集 dataset 中加载 RGB-D 帧的颜色、深度、姿态等信息
# Process poses
gt_w2c = torch.linalg.inv(gt_pose) # 对姿态信息进行处理,计算gt_pose的逆,也就是世界到相机的变换矩阵 gt_w2c
# Process RGB-D Data
color = color.permute(2, 0, 1) / 255 # 颜色归一化
depth = depth.permute(2, 0, 1)
gt_w2c_all_frames.append(gt_w2c)
curr_gt_w2c = gt_w2c_all_frames
# Optimize only current time step for tracking
iter_time_idx = time_idx
# Initialize Mapping Data for selected frame
# 初始化当前帧的数据 curr_data 包括相机参数、颜色数据、深度数据等
curr_data = {'cam': cam, 'im': color, 'depth': depth, 'id': iter_time_idx, 'intrinsics': intrinsics,
'w2c': first_frame_w2c, 'iter_gt_w2c_list': curr_gt_w2c}
# Initialize Data for Tracking
# 根据配置,初始化跟踪数据 tracking_curr_data
if seperate_tracking_res:
tracking_color, tracking_depth, _, _ = tracking_dataset[time_idx]
tracking_color = tracking_color.permute(2, 0, 1) / 255
tracking_depth = tracking_depth.permute(2, 0, 1)
tracking_curr_data = {'cam': tracking_cam, 'im': tracking_color, 'depth': tracking_depth, 'id': iter_time_idx,
'intrinsics': tracking_intrinsics, 'w2c': first_frame_w2c, 'iter_gt_w2c_list': curr_gt_w2c}
else:
tracking_curr_data = curr_data
# Optimization Iterations
# 设置建图迭代次数
num_iters_mapping = config['mapping']['num_iters']
# ******************* Sec. 1 进入Camera Tracking阶段 *******************
# ** Sec 1.1 Camera Pose Initialization **
# Initialize the camera pose for the current frame
# 如果当前帧索引大于 0,则初始化相机姿态参数
if time_idx > 0:
params = initialize_camera_pose(params, time_idx, forward_prop=config['tracking']['forward_prop']) # 在configs/replica/splatam.py中,forward_prop是True
# Tracking
tracking_start_time = time.time()
# 如果当前时间索引 time_idx 大于 0 且不使用真实姿态
if time_idx > 0 and not config['tracking']['use_gt_poses']:
# ** Sec 1.2 多个变量的重置、初始化和各项设置 **
# Reset Optimizer & Learning Rates for tracking
# 重置优化器和学习率
optimizer = initialize_optimizer(params, config['tracking']['lrs'], tracking=True)
# Keep Track of Best Candidate Rotation & Translation
# 初始化变量 candidate_cam_unnorm_rot 和 candidate_cam_tran 以跟踪最佳的相机旋转和平移
candidate_cam_unnorm_rot = params['cam_unnorm_rots'][..., time_idx].detach().clone()
candidate_cam_tran = params['cam_trans'][..., time_idx].detach().clone()
# 初始化变量 current_min_loss 用于跟踪当前迭代中的最小损失
current_min_loss = float(1e20)
# Tracking Optimization
iter = 0
do_continue_slam = False
num_iters_tracking = config['tracking']['num_iters']
# 使用 tqdm 创建一个进度条,显示当前跟踪迭代的进度
progress_bar = tqdm(range(num_iters_tracking), desc=f"Tracking Time Step: {time_idx}")
# ** Sec 1.3 在循环中进行迭代优化 **
while True:
iter_start_time = time.time() # 计算迭代开始的时间
# Loss for current frame
# 重点函数:计算当前帧的损失
loss, variables, losses = get_loss(params, tracking_curr_data, variables, iter_time_idx, config['tracking']['loss_weights'],
config['tracking']['use_sil_for_loss'], config['tracking']['sil_thres'],
config['tracking']['use_l1'], config['tracking']['ignore_outlier_depth_loss'], tracking=True,
plot_dir=eval_dir, visualize_tracking_loss=config['tracking']['visualize_tracking_loss'],
tracking_iteration=iter)
if config['use_wandb']:
# Report Loss
wandb_tracking_step = report_loss(losses, wandb_run, wandb_tracking_step, tracking=True)
# Backprop 将loss进行反向传播,计算梯度
loss.backward()
# Optimizer Update
# 更新模型参数
optimizer.step()
# 清除已计算的梯度
optimizer.zero_grad(set_to_none=True)
with torch.no_grad():
# Save the best candidate rotation & translation
# 如果当前损失小于 current_min_loss,更新最小损失对应的相机旋转和平移
if loss < current_min_loss:
current_min_loss = loss
candidate_cam_unnorm_rot = params['cam_unnorm_rots'][..., time_idx].detach().clone()
candidate_cam_tran = params['cam_trans'][..., time_idx].detach().clone()
# Report Progress
if config['report_iter_progress']:
if config['use_wandb']:
report_progress(params, tracking_curr_data, iter+1, progress_bar, iter_time_idx, sil_thres=config['tracking']['sil_thres'], tracking=True,
wandb_run=wandb_run, wandb_step=wandb_tracking_step, wandb_save_qual=config['wandb']['save_qual'])
else:
report_progress(params, tracking_curr_data, iter+1, progress_bar, iter_time_idx, sil_thres=config['tracking']['sil_thres'], tracking=True)
else:
progress_bar.update(1)
# Update the runtime numbers 更新迭代次数和计算迭代的运行时间
iter_end_time = time.time()
tracking_iter_time_sum += iter_end_time - iter_start_time
tracking_iter_time_count += 1
# Check if we should stop tracking 检查是否最大迭代次数,满足终止计算
iter += 1
if iter == num_iters_tracking:
if losses['depth'] < config['tracking']['depth_loss_thres'] and config['tracking']['use_depth_loss_thres']:
break
elif config['tracking']['use_depth_loss_thres'] and not do_continue_slam:
do_continue_slam = True
progress_bar = tqdm(range(num_iters_tracking), desc=f"Tracking Time Step: {time_idx}")
num_iters_tracking = 2*num_iters_tracking
if config['use_wandb']:
wandb_run.log({"Tracking/Extra Tracking Iters Frames": time_idx,
"Tracking/step": wandb_time_step})
else:
break
# ** Sec 1.4 数据更新与进度跟踪 **
# 这里从while循环出来了,更新最佳候选
progress_bar.close()
# Copy over the best candidate rotation & translation
with torch.no_grad():
params['cam_unnorm_rots'][..., time_idx] = candidate_cam_unnorm_rot
params['cam_trans'][..., time_idx] = candidate_cam_tran
# 另一个分支,即如果当前时间索引 time_idx 大于 0 且使用真实姿态
elif time_idx > 0 and config['tracking']['use_gt_poses']:
with torch.no_grad():
# Get the ground truth pose relative to frame 0
rel_w2c = curr_gt_w2c[-1]
rel_w2c_rot = rel_w2c[:3, :3].unsqueeze(0).detach()
rel_w2c_rot_quat = matrix_to_quaternion(rel_w2c_rot)
rel_w2c_tran = rel_w2c[:3, 3].detach()
# Update the camera parameters
params['cam_unnorm_rots'][..., time_idx] = rel_w2c_rot_quat
params['cam_trans'][..., time_idx] = rel_w2c_tran
# 更新运行时间
# Update the runtime numbers
tracking_end_time = time.time()
tracking_frame_time_sum += tracking_end_time - tracking_start_time
tracking_frame_time_count += 1
# 报告跟踪进度,可视化进度条并自动保存参数
if time_idx == 0 or (time_idx+1) % config['report_global_progress_every'] == 0:
try:
# Report Final Tracking Progress
progress_bar = tqdm(range(1), desc=f"Tracking Result Time Step: {time_idx}")
with torch.no_grad():
if config['use_wandb']:
report_progress(params, tracking_curr_data, 1, progress_bar, iter_time_idx, sil_thres=config['tracking']['sil_thres'], tracking=True,
wandb_run=wandb_run, wandb_step=wandb_time_step, wandb_save_qual=config['wandb']['save_qual'], global_logging=True)
else:
report_progress(params, tracking_curr_data, 1, progress_bar, iter_time_idx, sil_thres=config['tracking']['sil_thres'], tracking=True)
progress_bar.close()
except:
ckpt_output_dir = os.path.join(config["workdir"], config["run_name"])
save_params_ckpt(params, ckpt_output_dir, time_idx)
print('Failed to evaluate trajectory.')
# ******************* Sec. 2 和 Sec. 3 进入 Densification 和 KeyFrame-based Mapping 阶段 *******************
# Densification & KeyFrame-based Mapping
if time_idx == 0 or (time_idx+1) % config['map_every'] == 0:
# Densification
# ******************* Sec. 2 Densification *******************
if config['mapping']['add_new_gaussians'] and time_idx > 0:
# Setup Data for Densification
if seperate_densification_res:
# 如果if判断成立,逐个加载RGB-D帧,而不是一次性加载所有帧
# Load RGBD frames incrementally instead of all frames
densify_color, densify_depth, _, _ = densify_dataset[time_idx]
densify_color = densify_color.permute(2, 0, 1) / 255
densify_depth = densify_depth.permute(2, 0, 1)
densify_curr_data = {'cam': densify_cam, 'im': densify_color, 'depth': densify_depth, 'id': time_idx,
'intrinsics': densify_intrinsics, 'w2c': first_frame_w2c, 'iter_gt_w2c_list': curr_gt_w2c}
else:
# 否则使用当前数据 curr_data
densify_curr_data = curr_data
# Add new Gaussians to the scene based on the Silhouette
# 重点函数:添加新的gaussians
params, variables = add_new_gaussians(params, variables, densify_curr_data,
config['mapping']['sil_thres'], time_idx,
config['mean_sq_dist_method'])
# 记录添加新的高斯后,post_num_pts是高斯分布数量
post_num_pts = params['means3D'].shape[0]
if config['use_wandb']:
wandb_run.log({"Mapping/Number of Gaussians": post_num_pts,
"Mapping/step": wandb_time_step})
# ******************* Sec. 3 Keyframe-based Map Update *******************
# KeyFrame Selection
with torch.no_grad(): # 在此代码块内部进行的计算不会涉及梯度计算
# Get the current estimated rotation & translation
# 从时间索引提取当前帧相机位姿并做坐标系转换
curr_cam_rot = F.normalize(params['cam_unnorm_rots'][..., time_idx].detach())
curr_cam_tran = params['cam_trans'][..., time_idx].detach()
curr_w2c = torch.eye(4).cuda().float()
curr_w2c[:3, :3] = build_rotation(curr_cam_rot)
curr_w2c[:3, 3] = curr_cam_tran
# Select Keyframes for Mapping
# 根据配置中的 mapping_window_size,计算需要选择的关键帧数量 num_keyframes
# 这里的减去2对应论文的原文,对应着"k-2个先前关键帧"的由来,在参数传入的时候就做好了k-2的限制
num_keyframes = config['mapping_window_size']-2
# 重点函数:keyframe_selection_overlap,根据重叠程度进行关键帧选择
selected_keyframes = keyframe_selection_overlap(depth, curr_w2c, intrinsics, keyframe_list[:-1], num_keyframes)
selected_time_idx = [keyframe_list[frame_idx]['id'] for frame_idx in selected_keyframes]
# 添加最后一帧和当前帧到关键帧列表
if len(keyframe_list) > 0:
# Add last keyframe to the selected keyframes
selected_time_idx.append(keyframe_list[-1]['id'])
selected_keyframes.append(len(keyframe_list)-1)
# Add current frame to the selected keyframes
selected_time_idx.append(time_idx)
selected_keyframes.append(-1)
# Print the selected keyframes
print(f"\nSelected Keyframes at Frame {time_idx}: {selected_time_idx}")
# Reset Optimizer & Learning Rates for Full Map Optimization
# 执行Mapping的优化前,初始化优化器
optimizer = initialize_optimizer(params, config['mapping']['lrs'], tracking=False)
# Mapping
mapping_start_time = time.time()
if num_iters_mapping > 0:
progress_bar = tqdm(range(num_iters_mapping), desc=f"Mapping Time Step: {time_idx}")
for iter in range(num_iters_mapping):
iter_start_time = time.time()
# Randomly select a frame until current time step amongst keyframes
rand_idx = np.random.randint(0, len(selected_keyframes))
selected_rand_keyframe_idx = selected_keyframes[rand_idx]
if selected_rand_keyframe_idx == -1:
# Use Current Frame Data
iter_time_idx = time_idx
iter_color = color
iter_depth = depth
else:
# Use Keyframe Data
iter_time_idx = keyframe_list[selected_rand_keyframe_idx]['id']
iter_color = keyframe_list[selected_rand_keyframe_idx]['color']
iter_depth = keyframe_list[selected_rand_keyframe_idx]['depth']
iter_gt_w2c = gt_w2c_all_frames[:iter_time_idx+1]
iter_data = {'cam': cam, 'im': iter_color, 'depth': iter_depth, 'id': iter_time_idx,
'intrinsics': intrinsics, 'w2c': first_frame_w2c, 'iter_gt_w2c_list': iter_gt_w2c}
# Loss for current frame
# 重点函数:计算当前帧的损失
loss, variables, losses = get_loss(params, iter_data, variables, iter_time_idx, config['mapping']['loss_weights'],
config['mapping']['use_sil_for_loss'], config['mapping']['sil_thres'],
config['mapping']['use_l1'], config['mapping']['ignore_outlier_depth_loss'], mapping=True)
if config['use_wandb']:
# Report Loss
wandb_mapping_step = report_loss(losses, wandb_run, wandb_mapping_step, mapping=True)
# Backprop
loss.backward()
with torch.no_grad():
# Prune Gaussians
# 以configs/replica/splatam.py为例,config['mapping']['prune_gaussians']=True,执行
if config['mapping']['prune_gaussians']:
params, variables = prune_gaussians(params, variables, optimizer, iter, config['mapping']['pruning_dict'])
if config['use_wandb']:
wandb_run.log({"Mapping/Number of Gaussians - Pruning": params['means3D'].shape[0],
"Mapping/step": wandb_mapping_step})
# Gaussian-Splatting's Gradient-based Densification
# 以configs/replica/splatam.py为例,config['mapping']['use_gaussian_splatting_densification']=False,不执行
if config['mapping']['use_gaussian_splatting_densification']:
params, variables = densify(params, variables, optimizer, iter, config['mapping']['densify_dict'])
if config['use_wandb']:
wandb_run.log({"Mapping/Number of Gaussians - Densification": params['means3D'].shape[0],
"Mapping/step": wandb_mapping_step})
# Optimizer Update
optimizer.step()
optimizer.zero_grad(set_to_none=True)
# Report Progress
if config['report_iter_progress']:
if config['use_wandb']:
report_progress(params, iter_data, iter+1, progress_bar, iter_time_idx, sil_thres=config['mapping']['sil_thres'],
wandb_run=wandb_run, wandb_step=wandb_mapping_step, wandb_save_qual=config['wandb']['save_qual'],
mapping=True, online_time_idx=time_idx)