This package is meant to provide basic infrastructure for probabilistic programming, with CuArrays support and Tracker/Zygote compatibility.
Generate a normally distributed random variable and compute its likelihood with respect to several different mean parameters:
julia> x = randn()
julia> logpdf(Normal).(x, CuArray(collect(-10.:1.:10.)), 1.)
21-element CuArray{Float64,1,Nothing}:
Make a simple maximum likelihood estimation using Zygote:
Zygote.@adjoint logpdf(d) = logpdf(d), _ -> 0
n = 10000
data = CuArray([sample(Gamma)(.5,.9) for _ in 1:n])
μ, σ = rand(), rand()
ϵ = .01
for i in 1:n
dμ, dσ = gradient(μ, σ) do μ, σ
sum(logpdf(Gamma).(data, μ, σ)) / n
μ += ϵ * dμ; σ += ϵ * dσ
i % 1000 == 0 && @show μ, σ