Skip to content

Commit 331b87d

Browse files
committed
Fixed issue with prbvolpath gradient propagation
1 parent 7df908b commit 331b87d

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

src/emitters/volumelight.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ class VolumeLight final : public Emitter<Float, Spectrum> {
7474
m_flags = +EmitterFlags::Medium;
7575

7676
dr::set_attr(this, "flags", m_flags);
77-
dr::set_attr(this, "radiance", m_radiance);
7877
}
7978

8079
void traverse(TraversalCallback *callback) override {

src/python/python/ad/integrators/prbvolpath.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,11 +268,14 @@ def sample(self,
268268
null_scatter = (sampler.next_1d(active_medium) < (emission_prob + null_prob))
269269
act_null_scatter = null_scatter & active_medium
270270
act_medium_scatter = ~null_scatter & active_medium
271-
L[act_null_scatter] += throughput * radiance * dr.detach(weight * emission_weight)
271+
contrib = throughput * radiance * dr.detach(emission_weight * weight)
272+
L[act_null_scatter] += dr.detach(contrib if is_primal else -contrib)
272273
weight[act_null_scatter] *= mei.sigma_n * dr.detach(null_weight)
273274
else:
274275
scatter_weight = mi.Float(1.0)
275276
act_medium_scatter = active_medium
277+
contrib = dr.zeros(mi.UnpolarizedSpectrum)
278+
276279

277280
depth[act_medium_scatter] += 1
278281
last_scatter_event[act_medium_scatter] = dr.detach(mei)
@@ -291,6 +294,8 @@ def sample(self,
291294
if not is_primal and dr.grad_enabled(weight):
292295
Lo = dr.detach(dr.select(active_medium | escaped_medium, L / dr.maximum(1e-8, weight), 0.0))
293296
dr.backward(δL * weight * Lo)
297+
if not is_primal and dr.grad_enabled(contrib):
298+
dr.backward(δL * contrib)
294299

295300
phase_ctx = mi.PhaseFunctionContext(sampler)
296301
phase = mei.medium.phase_function()

0 commit comments

Comments
 (0)