Skip to content

Commit 6dc63a8

Browse files
committed
Fix tests
1 parent e74607f commit 6dc63a8

File tree

1 file changed

+10
-13
lines changed

1 file changed

+10
-13
lines changed

test/modeling_library/product.jl

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,8 @@ discrete_product = ProductDistribution(bernoulli, binom)
2222
f = (x, p1, n, p2) -> logpdf(discrete_product, x, p1, n, p2)
2323
args = (x, p1, n, p2)
2424
actual = logpdf_grad(discrete_product, args...)
25-
for (i, b) in enumerate(grad_bools)
26-
if b
27-
@test isapprox(actual[i], finite_diff(f, args, i, dx))
28-
end
25+
for i in [2, 4]
26+
@test isapprox(actual[i], finite_diff(f, args, i, dx))
2927
end
3028
end
3129

@@ -51,12 +49,13 @@ continuous_product = ProductDistribution(uniform, normal)
5149

5250
# test logpdf_grad against finite differencing
5351
f = (x, low, high, mu, std) -> logpdf(continuous_product, x, low, high, mu, std)
54-
args = (x, low, high, mu, std)
52+
# A mutable indexable is required by `finite_diff_vec`, hence the `collect` here:
53+
args = (collect(x), low, high, mu, std)
5554
actual = logpdf_grad(continuous_product, args...)
56-
for (i, b) in enumerate(grad_bools)
57-
if b
58-
@test isapprox(actual[i], finite_diff(f, args, i, dx))
59-
end
55+
@test isapprox(actual[1][1], finite_diff_vec(f, args, 1, 1, dx))
56+
@test isapprox(actual[1][2], finite_diff_vec(f, args, 1, 2, dx))
57+
for i in 2:5
58+
@test isapprox(actual[i], finite_diff(f, args, i, dx))
6059
end
6160
end
6261

@@ -84,9 +83,7 @@ dissimilar_product = ProductDistribution(bernoulli, normal)
8483
f = (x, p, mu, std) -> logpdf(dissimilar_product, x, p, mu, std)
8584
args = (x, p, mu, std)
8685
actual = logpdf_grad(dissimilar_product, args...)
87-
for (i, b) in enumerate(grad_bools)
88-
if b
89-
@test isapprox(actual[i], finite_diff(f, args, i, dx))
90-
end
86+
for i in 2:4
87+
@test isapprox(actual[i], finite_diff(f, args, i, dx))
9188
end
9289
end

0 commit comments

Comments
 (0)