Skip to content

Commit 18c06fd

Browse files
authored
Merge pull request #521 from sharlaon/product
Add product distribution combinator
2 parents 73d3790 + 6dc63a8 commit 18c06fd

File tree

5 files changed

+206
-2
lines changed

5 files changed

+206
-2
lines changed

docs/src/ref/distributions.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Probability Distributions
22

3-
Gen provides a library of built-in probability distributions, and three ways of
3+
Gen provides a library of built-in probability distributions, and four ways of
44
defining custom distributions, each of which are explained below:
55

66
1. The [`@dist` constructor](@ref dist_dsl), for a distribution that can be expressed as a
@@ -11,7 +11,10 @@ defining custom distributions, each of which are explained below:
1111
2. The [`HeterogeneousMixture`](@ref) and [`HomogeneousMixture`](@ref) constructors
1212
for distributions that are mixtures of other distributions.
1313

14-
3. An API for defining arbitrary [custom distributions](@ref
14+
3. The [`ProductDistribution`](@ref) constructor for distributions that are products of
15+
other distributions.
16+
17+
4. An API for defining arbitrary [custom distributions](@ref
1518
custom_distributions) in plain Julia code.
1619

1720
## Built-In Distributions
@@ -220,6 +223,13 @@ HomogeneousMixture
220223
HeterogeneousMixture
221224
```
222225

226+
## Product Distribution Constructors
227+
228+
There is a built-in constructor for defining product distributions:
229+
```@docs
230+
ProductDistribution
231+
```
232+
223233
## Defining New Distributions From Scratch
224234

225235
For distributions that cannot be expressed in the `@dist` DSL, users can define

src/modeling_library/modeling_library.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ include("dist_dsl/dist_dsl.jl")
6262
# mixtures of distributions
6363
include("mixture.jl")
6464

65+
# products of distributions
66+
include("product.jl")
67+
6568
###############
6669
# combinators #
6770
###############

src/modeling_library/product.jl

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
########################################################################
2+
# ProductDistribution: product of fixed distributions of similar types #
3+
########################################################################
4+
5+
"""
6+
ProductDistribution(distributions::Vararg{<:Distribution})
7+
8+
Define new distribution that is the product of the given nonempty list of distributions having a common type.
9+
10+
The arguments comprise the list of base distributions.
11+
12+
Example:
13+
```julia
14+
normal_strip = ProductDistribution(uniform, normal)
15+
```
16+
17+
The resulting product distribution takes `n` arguments, where `n` is the sum of the numbers of arguments taken by each distribution in the list.
18+
These arguments are the arguments to each component distribution, in the order in which the distributions are passed to the constructor.
19+
20+
Example:
21+
```julia
22+
@gen function unit_strip_and_near_seven()
23+
x ~ flip_and_number(0.0, 0.1, 7.0, 0.01)
24+
end
25+
```
26+
"""
27+
struct ProductDistribution{T, Ds} <: Distribution{T}
28+
K::Int
29+
distributions::Ds
30+
has_output_grad::Bool
31+
has_argument_grads::Tuple
32+
is_discrete::Bool
33+
num_args::Vector{Int}
34+
starting_args::Vector{Int}
35+
end
36+
37+
(dist::ProductDistribution)(args...) = random(dist, args...)
38+
39+
Gen.has_output_grad(dist::ProductDistribution) = dist.has_output_grad
40+
Gen.has_argument_grads(dist::ProductDistribution) = dist.has_argument_grads
41+
Gen.is_discrete(dist::ProductDistribution) = dist.is_discrete
42+
43+
function ProductDistribution(distributions::Vararg{<:Distribution})
44+
_has_output_grads = true
45+
_is_discrete = true
46+
47+
types = Type[]
48+
49+
_has_argument_grads = Bool[]
50+
_num_args = Int[]
51+
_starting_args = Int[]
52+
start_pos = 1
53+
54+
for dist in distributions
55+
push!(types, Gen.get_return_type(dist))
56+
57+
_has_output_grads = _has_output_grads && has_output_grad(dist)
58+
_is_discrete = _is_discrete && is_discrete(dist)
59+
60+
grads_data = has_argument_grads(dist)
61+
append!(_has_argument_grads, grads_data)
62+
push!(_num_args, length(grads_data))
63+
push!(_starting_args, start_pos)
64+
start_pos += length(grads_data)
65+
end
66+
67+
return ProductDistribution{Tuple{types...}, typeof(distributions)}(
68+
length(distributions),
69+
distributions,
70+
_has_output_grads,
71+
Tuple(_has_argument_grads),
72+
_is_discrete,
73+
_num_args,
74+
_starting_args)
75+
end
76+
77+
function extract_args_for_component(dist::ProductDistribution, component_args_flat, k::Int)
78+
start_arg = dist.starting_args[k]
79+
n = dist.num_args[k]
80+
return component_args_flat[start_arg:start_arg+n-1]
81+
end
82+
83+
Gen.random(dist::ProductDistribution, args...) =
84+
Tuple(random(d, extract_args_for_component(dist, args, k)...) for (k, d) in enumerate(dist.distributions))
85+
86+
Gen.logpdf(dist::ProductDistribution, x, args...) =
87+
sum(Gen.logpdf(d, x[k], extract_args_for_component(dist, args, k)...) for (k, d) in enumerate(dist.distributions))
88+
89+
function Gen.logpdf_grad(dist::ProductDistribution, x, args...)
90+
x_grad = ()
91+
arg_grads = ()
92+
for (k, d) in enumerate(dist.distributions)
93+
grads = Gen.logpdf_grad(d, x[k], extract_args_for_component(dist, args, k)...)
94+
x_grad = (x_grad..., grads[1])
95+
arg_grads = (arg_grads..., grads[2:end]...)
96+
end
97+
x_grad = dist.has_output_grad ? x_grad : nothing
98+
return (x_grad, arg_grads...)
99+
end
100+
101+
export ProductDistribution

test/modeling_library/modeling_library.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ include("recurse.jl")
88
include("switch.jl")
99
include("dist_dsl.jl")
1010
include("mixture.jl")
11+
include("product.jl")

test/modeling_library/product.jl

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
discrete_product = ProductDistribution(bernoulli, binom)
2+
3+
@testset "product of discrete distributions" begin
4+
@test is_discrete(discrete_product)
5+
grad_bools = (has_output_grad(discrete_product), has_argument_grads(discrete_product)...)
6+
@test grad_bools == (false, true, false, true)
7+
8+
p1 = 0.5
9+
(n, p2) = (3, 0.9)
10+
11+
# random
12+
x = discrete_product(p1, n, p2)
13+
@assert typeof(x) == Gen.get_return_type(discrete_product) == Tuple{Bool, Int}
14+
15+
# logpdf
16+
x = (true, 2)
17+
actual = logpdf(discrete_product, x, p1, n, p2)
18+
expected = logpdf(bernoulli, x[1], p1) + logpdf(binom, x[2], n, p2)
19+
@test isapprox(actual, expected)
20+
21+
# test logpdf_grad against finite differencing
22+
f = (x, p1, n, p2) -> logpdf(discrete_product, x, p1, n, p2)
23+
args = (x, p1, n, p2)
24+
actual = logpdf_grad(discrete_product, args...)
25+
for i in [2, 4]
26+
@test isapprox(actual[i], finite_diff(f, args, i, dx))
27+
end
28+
end
29+
30+
continuous_product = ProductDistribution(uniform, normal)
31+
32+
@testset "product of continuous distributions" begin
33+
@test !is_discrete(continuous_product)
34+
grad_bools = (has_output_grad(continuous_product), has_argument_grads(continuous_product)...)
35+
@test grad_bools == (true, true, true, true, true)
36+
37+
(low, high) = (-0.5, 0.5)
38+
(mu, std) = (0.0, 1.0)
39+
40+
# random
41+
x = continuous_product(low, high, mu, std)
42+
@assert typeof(x) == Gen.get_return_type(continuous_product) == Tuple{Float64, Float64}
43+
44+
# logpdf
45+
x = (0.1, 0.7)
46+
actual = logpdf(continuous_product, x, low, high, mu, std)
47+
expected = logpdf(uniform, x[1], low, high) + logpdf(normal, x[2], mu, std)
48+
@test isapprox(actual, expected)
49+
50+
# test logpdf_grad against finite differencing
51+
f = (x, low, high, mu, std) -> logpdf(continuous_product, x, low, high, mu, std)
52+
# A mutable indexable is required by `finite_diff_vec`, hence the `collect` here:
53+
args = (collect(x), low, high, mu, std)
54+
actual = logpdf_grad(continuous_product, args...)
55+
@test isapprox(actual[1][1], finite_diff_vec(f, args, 1, 1, dx))
56+
@test isapprox(actual[1][2], finite_diff_vec(f, args, 1, 2, dx))
57+
for i in 2:5
58+
@test isapprox(actual[i], finite_diff(f, args, i, dx))
59+
end
60+
end
61+
62+
dissimilar_product = ProductDistribution(bernoulli, normal)
63+
64+
@testset "product of dissimilarly-typed distributions" begin
65+
@test !is_discrete(dissimilar_product)
66+
grad_bools = (has_output_grad(dissimilar_product), has_argument_grads(dissimilar_product)...)
67+
@test grad_bools == (false, true, true, true)
68+
69+
p = 0.5
70+
(mu, std) = (0.0, 1.0)
71+
72+
# random
73+
x = dissimilar_product(p, mu, std)
74+
@assert typeof(x) == Gen.get_return_type(dissimilar_product) == Tuple{Bool, Float64}
75+
76+
# logpdf
77+
x = (false, 0.3)
78+
actual = logpdf(dissimilar_product, x, p, mu, std)
79+
expected = logpdf(bernoulli, x[1], p) + logpdf(normal, x[2], mu, std)
80+
@test isapprox(actual, expected)
81+
82+
# test logpdf_grad against finite differencing
83+
f = (x, p, mu, std) -> logpdf(dissimilar_product, x, p, mu, std)
84+
args = (x, p, mu, std)
85+
actual = logpdf_grad(dissimilar_product, args...)
86+
for i in 2:4
87+
@test isapprox(actual[i], finite_diff(f, args, i, dx))
88+
end
89+
end

0 commit comments

Comments
 (0)