@@ -406,9 +406,11 @@ renderCUDA(
406406 const float2 * __restrict__ points_xy_image,
407407 const float4 * __restrict__ conic_opacity,
408408 const float * __restrict__ colors,
409+ const float * __restrict__ depths,
409410 const float * __restrict__ final_Ts,
410411 const uint32_t * __restrict__ n_contrib,
411412 const float * __restrict__ dL_dpixels,
413+ const float * __restrict__ dL_depths,
412414 float3 * __restrict__ dL_dmean2D,
413415 float4 * __restrict__ dL_dconic2D,
414416 float * __restrict__ dL_dopacity,
@@ -435,6 +437,7 @@ renderCUDA(
435437 __shared__ float2 collected_xy[BLOCK_SIZE];
436438 __shared__ float4 collected_conic_opacity[BLOCK_SIZE];
437439 __shared__ float collected_colors[C * BLOCK_SIZE];
440+ __shared__ float collected_depths[BLOCK_SIZE];
438441
439442 // In the forward, we stored the final value for T, the
440443 // product of all (1 - alpha) factors.
@@ -448,12 +451,16 @@ renderCUDA(
448451
449452 float accum_rec[C] = { 0 };
450453 float dL_dpixel[C];
454+ float dL_depth;
455+ float accum_depth_rec = 0 ;
451456 if (inside)
452457 for (int i = 0 ; i < C; i++)
453458 dL_dpixel[i] = dL_dpixels[i * H * W + pix_id];
459+ dL_depth = dL_depths[pix_id];
454460
455461 float last_alpha = 0 ;
456462 float last_color[C] = { 0 };
463+ float last_depth = 0 ;
457464
458465 // Gradient of pixel coordinate w.r.t. normalized
459466 // screen-space viewport corrdinates (-1 to 1)
@@ -475,6 +482,7 @@ renderCUDA(
475482 collected_conic_opacity[block.thread_rank ()] = conic_opacity[coll_id];
476483 for (int i = 0 ; i < C; i++)
477484 collected_colors[i * BLOCK_SIZE + block.thread_rank ()] = colors[coll_id * C + i];
485+ collected_depths[block.thread_rank ()] = depths[coll_id];
478486 }
479487 block.sync ();
480488
@@ -522,6 +530,10 @@ renderCUDA(
522530 // many that were affected by this Gaussian.
523531 atomicAdd (&(dL_dcolors[global_id * C + ch]), dchannel_dcolor * dL_dchannel);
524532 }
533+ const float c_d = collected_depths[j];
534+ accum_depth_rec = last_alpha * last_depth + (1 .f - last_alpha) * accum_depth_rec;
535+ last_depth = c_d;
536+ dL_dalpha += (c_d - accum_depth_rec) * dL_depth;
525537 dL_dalpha *= T;
526538 // Update last alpha (to be used in the next iteration)
527539 last_alpha = alpha;
@@ -630,9 +642,11 @@ void BACKWARD::render(
630642 const float2 * means2D,
631643 const float4 * conic_opacity,
632644 const float * colors,
645+ const float * depths,
633646 const float * final_Ts,
634647 const uint32_t * n_contrib,
635648 const float * dL_dpixels,
649+ const float * dL_depths,
636650 float3 * dL_dmean2D,
637651 float4 * dL_dconic2D,
638652 float * dL_dopacity,
@@ -646,12 +660,14 @@ void BACKWARD::render(
646660 means2D,
647661 conic_opacity,
648662 colors,
663+ depths,
649664 final_Ts,
650665 n_contrib,
651666 dL_dpixels,
667+ dL_depths,
652668 dL_dmean2D,
653669 dL_dconic2D,
654670 dL_dopacity,
655671 dL_dcolors
656672 );
657- }
673+ }
0 commit comments