Skip to content

Commit 5f78074

Browse files
committed
Add three of the DPPL models from https://arxiv.org/pdf/2002.02702
1 parent 7c3ef9d commit 5f78074

File tree

6 files changed

+53
-0
lines changed

6 files changed

+53
-0
lines changed

.github/workflows/generate_website.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ jobs:
8787
run: uv run ad.py run --model ${{ matrix.model }}
8888
env:
8989
ADTYPE_KEYS: ${{ needs.setup-keys.outputs.adtype_keys }}
90+
DATADEPS_ALWAYS_ACCEPT: "true"
9091

9192
- name: Output matrix values
9293
id: output-matrix

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de"
44
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
55
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
6+
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
67
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
78
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
89
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
910
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1112
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
13+
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
1214
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
15+
MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411"
1316
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1417
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
1518
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

main.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ end
3535
# These imports tend to get used a lot in models
3636
using DynamicPPL: @model, to_submodel
3737
using Distributions
38+
using DistributionsAD: filldist, arraydist
3839
using LinearAlgebra
3940

4041
include("models/assume_beta.jl")
@@ -74,6 +75,10 @@ include("models/observe_submodel.jl")
7475
include("models/pdb_eight_schools_centered.jl")
7576
include("models/pdb_eight_schools_noncentered.jl")
7677

78+
include("models/dppl_gauss_unknown.jl")
79+
include("models/dppl_high_dim_gauss.jl")
80+
include("models/dppl_naive_bayes.jl")
81+
7782
# The entry point to this script itself begins here
7883
if ARGS == ["--list-model-keys"]
7984
foreach(println, sort(collect(keys(MODELS))))

models/dppl_gauss_unknown.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
n = 10_000
2+
s = abs(rand()) + 0.5
3+
y = randn() .+ s * randn(n)
4+
5+
@model function dppl_gauss_unknown(y)
6+
N = length(y)
7+
m ~ Normal(0, 1)
8+
s ~ truncated(Cauchy(0, 5); lower=0)
9+
y ~ filldist(Normal(m, s), N)
10+
end
11+
12+
@register dppl_gauss_unknown(y)

models/dppl_high_dim_gauss.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
@model function dppl_high_dim_gauss(D)
2+
m ~ filldist(Normal(0, 1), D)
3+
end
4+
5+
@register dppl_high_dim_gauss(10_000)

models/dppl_naive_bayes.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using MLDatasets: MNIST
2+
using MultivariateStats: fit, PCA, transform
3+
4+
# Load MNIST images and labels
5+
features = MNIST(split=:train).features
6+
nrows, ncols, nimages = size(features)
7+
image_raw = Float64.(reshape(features, (nrows * ncols, nimages)))
8+
labels = MNIST(split=:train).targets .+ 1
9+
C = 10 # Number of labels
10+
11+
# Preprocess the images by reducing dimensionality
12+
D = 40
13+
pca = fit(PCA, image_raw; maxoutdim=D)
14+
image = transform(pca, image_raw)
15+
16+
# Take only the first 1000 images and vectorise
17+
N = 1000
18+
image_subset = image[:, 1:N]'
19+
image_vec = vec(image_subset[:, :])
20+
labels = labels[1:N]
21+
22+
@model dppl_naive_bayes(image_vec, labels, C, D) = begin
23+
m ~ filldist(Normal(0, 10), C, D)
24+
image_vec ~ MvNormal(vec(m[labels, :]), I)
25+
end
26+
27+
@register dppl_naive_bayes(image_vec, labels, C, D)

0 commit comments

Comments
 (0)