From 7662c160ed282541c8eacf465047eefdf507577b Mon Sep 17 00:00:00 2001 From: Esther van Pelt Date: Tue, 21 Oct 2025 17:18:56 +0200 Subject: [PATCH 1/4] start working on issue --- Project.toml | 1 + src/ExponentialFamily.jl | 2 +- src/distributions/{wip => }/multinomial.jl | 103 ++++++++---------- .../distributions/distributions_setuptests.jl | 1 + test/distributions/multinomial_tests.jl | 30 +++++ 5 files changed, 81 insertions(+), 56 deletions(-) rename src/distributions/{wip => }/multinomial.jl (62%) create mode 100644 test/distributions/multinomial_tests.jl diff --git a/Project.toml b/Project.toml index 421a4517..5ce2aa8f 100644 --- a/Project.toml +++ b/Project.toml @@ -46,6 +46,7 @@ Random = "1.9" RecursiveArrayTools = "3.26" SparseArrays = "1.9" SpecialFunctions = "2" +StableRNGs = "1.0.3" StaticArrays = "1" StatsBase = "0.34" StatsFuns = "0.9, 1" diff --git a/src/ExponentialFamily.jl b/src/ExponentialFamily.jl index 86103f61..f0a45b72 100644 --- a/src/ExponentialFamily.jl +++ b/src/ExponentialFamily.jl @@ -55,7 +55,7 @@ include("distributions/dirichlet_collection.jl") include("distributions/beta.jl") include("distributions/lognormal.jl") include("distributions/binomial.jl") -# include("distributions/multinomial.jl") +include("distributions/multinomial.jl") include("distributions/wishart.jl") include("distributions/wishart_inverse.jl") # include("distributions/contingency.jl") diff --git a/src/distributions/wip/multinomial.jl b/src/distributions/multinomial.jl similarity index 62% rename from src/distributions/wip/multinomial.jl rename to src/distributions/multinomial.jl index 4db688cb..81e274a8 100644 --- a/src/distributions/wip/multinomial.jl +++ b/src/distributions/multinomial.jl @@ -1,13 +1,12 @@ export Multinomial import Distributions: Multinomial, probs -import StableRNGs: StableRNG using StaticArrays using LogExpFunctions -vague(::Type{<:Multinomial}, n::Int, dims::Int) = Multinomial(n, ones(dims) ./ dims) +BayesBase.vague(::Type{<:Multinomial}, n::Int, dims::Int) = Multinomial(n, ones(dims) ./ dims) -probvec(dist::Multinomial) = probs(dist) +BayesBase.probvec(dist::Multinomial) = probs(dist) function convert_eltype(::Type{Multinomial}, ::Type{T}, distribution::Multinomial{R}) where {T <: Real, R <: Real} n, p = params(distribution) @@ -55,59 +54,56 @@ function BayesBase.prod(::ClosedProd, left::T, right::T) where {T <: Multinomial return prod(ClosedProd(), ef_left, ef_right) end -function pack_naturalparameters(dist::Multinomial) - @inbounds p = params(dist)[2] - return map(log, p / p[end]) +check_valid_natural(::Type{<:Multinomial}, params) = length(params) >= 1 + +function check_valid_conditioner(::Type{<:Multinomial}, conditioner) + isinteger(conditioner) && conditioner > 0 end -unpack_naturalparameters(ef::ExponentialFamilyDistribution{Multinomial}) = (getnaturalparameters(ef),) +function isproper(::NaturalParametersSpace, ::Type{Multinomial}, natural_parameters::AbstractVector{<:Real}, conditioner::Int) + return (conditioner >= 1) && (length(natural_parameters) >= 1) +end -function Base.convert(::Type{ExponentialFamilyDistribution}, dist::Multinomial) - n, _ = params(dist) - return ExponentialFamilyDistribution(Multinomial, pack_naturalparameters(dist), n) +function unpack_parameters(::Type{Multinomial}, packed::AbstractVector, conditioner) + @show packed + return (packed,) end -function Base.convert(::Type{Distribution}, exponentialfamily::ExponentialFamilyDistribution{Multinomial}) - expη = map(exp, getnaturalparameters(exponentialfamily)) - p = expη / sum(expη) - return Multinomial(getconditioner(exponentialfamily), p) +function pack_parameters(::Type{Multinomial}, unpacked::Tuple{<:AbstractVector}) + return first(unpacked) end -check_valid_natural(::Type{<:Multinomial}, params) = length(params) >= 1 +function separate_conditioner(::Type{Multinomial}, params) + ndims, success_probs = params + return ((success_probs,), ndims) +end -function check_valid_conditioner(::Type{<:Multinomial}, conditioner) - isinteger(conditioner) && conditioner > 0 +function join_conditioner(::Type{Multinomial}, cparams, conditioner) + (succprob,) = cparams + ntrials = conditioner + return (ntrials, succprob) end -function isproper(exponentialfamily::ExponentialFamilyDistribution{Multinomial}) - logp = getnaturalparameters(exponentialfamily) - n = getconditioner(exponentialfamily) - return (n >= 1) && (length(logp) >= 1) +function (::MeanToNatural{Multinomial})(parameters::Tuple{<:AbstractVector}, conditioner) + (succprob,) = parameters + return (log.(succprob),) end +function (::NaturalToMean{Multinomial})(natural_parameters::Tuple{<:AbstractVector}, conditioner) + (log_probs,) = natural_parameters + return (exp.(log_probs),) +end + +getsufficientstatistics(::Type{Multinomial}, _) = (identity,) +getgradlogpartition(::NaturalParametersSpace, ::Type{Multinomial}, conditioner::Int) = (η) -> zeros(length(η)) + function logpartition(exponentialfamily::ExponentialFamilyDistribution{Multinomial}) η = getnaturalparameters(exponentialfamily) n = getconditioner(exponentialfamily) return n * logsumexp(η) end -function computeLogpartition(K, n) - d = Multinomial(n, ones(K) ./ K) - samples = unique(rand(StableRNG(1), d, 4000), dims = 2) - samples = [samples[:, i] for i in 1:size(samples, 2)] - return let samples = samples - (η) -> begin - result = mapreduce(+, samples) do xi - return (factorial(n) / prod(@.factorial(xi)))^2 * exp(η' * xi) - end - return log(result) - end - end -end - -function fisherinformation(expfamily::ExponentialFamilyDistribution{Multinomial}) - η = getnaturalparameters(expfamily) - n = getconditioner(expfamily) +getfisherinformation(::NaturalParametersSpace, ::Type{Multinomial}, conditioner::Int) = (η) -> begin I = Matrix{Float64}(undef, length(η), length(η)) seη = mapreduce(exp, +, η) @inbounds for i in 1:length(η) @@ -117,23 +113,22 @@ function fisherinformation(expfamily::ExponentialFamilyDistribution{Multinomial} I[j, i] = I[i, j] end end - return n * I + return conditioner * I end -function fisherinformation(dist::Multinomial) - n, p = params(dist) - I = Matrix{Float64}(undef, length(p), length(p)) - @inbounds for i in 1:length(p) - I[i, i] = (1 - p[i]) / p[i] +getfisherinformation(::MeanParametersSpace, ::Type{Multinomial}, conditioner::Int) = (θ) -> begin + I = Matrix{Float64}(undef, length(θ), length(θ)) + @inbounds for i in 1:length(θ) + I[i, i] = (1 - θ[i]) / θ[i] @inbounds for j in 1:(i-1) I[i, j] = -1 I[j, i] = I[i, j] end end - return n * I + return conditioner * I end -function BayesBase.insupport(ef::ExponentialFamilyDistribution{Multinomial, P, C, Safe}, x) where {P, C} +function BayesBase.insupport(ef::ExponentialFamilyDistribution{Multinomial, P, C, S}, x) where {P, C, S} n = Int(sum(x)) return n == getconditioner(ef) end @@ -141,16 +136,14 @@ end basemeasureconstant(::ExponentialFamilyDistribution{Multinomial}) = NonConstantBaseMeasure() basemeasureconstant(::Type{<:Multinomial}) = NonConstantBaseMeasure() basemeasure(ef::ExponentialFamilyDistribution{Multinomial}) = (x) -> basemeasure(ef, x) -function basemeasure(::ExponentialFamilyDistribution{Multinomial}, x::Vector) - n = Int(sum(x)) - return factorial(n) / prod(@.factorial(x)) -end - -function basemeasure(::Multinomial, x::Vector) - n = Int(sum(x)) - return factorial(n) / prod(@.factorial(x)) -end sufficientstatistics(::Union{<:ExponentialFamilyDistribution{Multinomial}, <:Multinomial}, x::Vector) = x sufficientstatistics(ef::Union{<:ExponentialFamilyDistribution{Multinomial}, <:Multinomial}) = x -> sufficientstatistics(ef, x) + +getbasemeasure(::Type{Multinomial}, ntrials) = (x) -> begin + n = Int(sum(x)) + return factorial(n) / prod(@.factorial(x)) +end +getlogbasemeasure(::Type{Multinomial}, ntrials) = (x) -> log(getbasemeasure(Multinomial, ntrials)(x)) # TODO change with loggamma +getlogpartition(::NaturalParametersSpace, ::Type{Multinomial}, conditioner::Int) = (η) -> zeros(length(η)) \ No newline at end of file diff --git a/test/distributions/distributions_setuptests.jl b/test/distributions/distributions_setuptests.jl index 3030d4dc..e2b1e656 100644 --- a/test/distributions/distributions_setuptests.jl +++ b/test/distributions/distributions_setuptests.jl @@ -379,6 +379,7 @@ function run_test_fisherinformation_properties(distribution; test_properties_in_ end if test_properties_in_mean_space + @show η θ = map(NaturalParametersSpace() => MeanParametersSpace(), T, η, conditioner) F = getfisherinformation(MeanParametersSpace(), T, conditioner)(θ) diff --git a/test/distributions/multinomial_tests.jl b/test/distributions/multinomial_tests.jl new file mode 100644 index 00000000..6a0ca0e2 --- /dev/null +++ b/test/distributions/multinomial_tests.jl @@ -0,0 +1,30 @@ +@testitem "Multinomial: probvec" begin + include("distributions_setuptests.jl") + + @test probvec(Multinomial(5, [1 / 3, 1 / 3, 1 / 3])) == [1 / 3, 1 / 3, 1 / 3] + @test probvec(Multinomial(3, [0.2, 0.2, 0.4, 0.1, 0.1])) == [0.2, 0.2, 0.4, 0.1, 0.1] + @test probvec(Multinomial(2, [0.5, 0.5])) == [0.5, 0.5] +end + +@testitem "Multinomial: vague" begin + include("distributions_setuptests.jl") + + @test_throws MethodError vague(Multinomial) + @test_throws MethodError vague(Multinomial, 4) + + vague_dist1 = vague(Multinomial, 5, 4) + @test typeof(vague_dist1) <: Multinomial + @test probvec(vague_dist1) == [1 / 4, 1 / 4, 1 / 4, 1 / 4] + + vague_dist2 = vague(Multinomial, 3, 5) + @test typeof(vague_dist2) <: Multinomial + @test probvec(vague_dist2) == [1 / 5, 1 / 5, 1 / 5, 1 / 5, 1 / 5] +end + +@testitem "Multinomial: test_EF_interface" begin + include("distributions_setuptests.jl") + + ef = vague(Multinomial, 5, 4) + + test_exponentialfamily_interface(ef) +end \ No newline at end of file From 624e312f67b18c7bfe120d956179281ea0ddb185 Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Wed, 22 Oct 2025 17:07:59 +0200 Subject: [PATCH 2/4] Implement EF interface for Multinomial --- src/distributions/multinomial.jl | 99 +++++++++---------- .../distributions/distributions_setuptests.jl | 1 - test/distributions/multinomial_tests.jl | 67 ++++++++++++- 3 files changed, 114 insertions(+), 53 deletions(-) diff --git a/src/distributions/multinomial.jl b/src/distributions/multinomial.jl index 81e274a8..f22e471c 100644 --- a/src/distributions/multinomial.jl +++ b/src/distributions/multinomial.jl @@ -6,15 +6,25 @@ using LogExpFunctions BayesBase.vague(::Type{<:Multinomial}, n::Int, dims::Int) = Multinomial(n, ones(dims) ./ dims) -BayesBase.probvec(dist::Multinomial) = probs(dist) - -function convert_eltype(::Type{Multinomial}, ::Type{T}, distribution::Multinomial{R}) where {T <: Real, R <: Real} - n, p = params(distribution) - return Multinomial(n, convert(AbstractVector{T}, p)) -end +BayesBase.convert_paramfloattype(::Type{T}, distribution::Multinomial) where {T <: Real} = + Multinomial(distribution.n, convert(AbstractVector{T}, probvec(distribution))) BayesBase.default_prod_rule(::Type{<:Multinomial}, ::Type{<:Multinomial}) = PreserveTypeProd(ExponentialFamilyDistribution) +function __compute_logpartition_multinomial_product(K::Int, n::Int) + d = vague(Multinomial, n, K) + samples = unique(rand(d, 4000), dims = 2) + samples = [samples[:, i] for i in 1:size(samples, 2)] + return let samples = samples + (η) -> begin + result = mapreduce(+, samples) do xi + return (factorial(n) / prod(@.factorial(xi)))^2 * exp(η' * xi) + end + return log(result) + end + end +end + # NOTE: The product of two Multinomial distributions is NOT a Multinomial distribution. function BayesBase.prod( ::PreserveTypeProd{ExponentialFamilyDistribution}, @@ -30,30 +40,33 @@ function BayesBase.prod( K = length(η_left) naturalparameters = η_left + η_right - sufficientstatistics = (x) -> x - ## If conditioner is larger than 12 factorial will be problematic. Casting to BigInt will resolve the issue. - ##TODO: fix this issue in future PRs - basemeasure = (x) -> factorial(conditioner_left)^2 / (prod(@.factorial(x)))^2 - logpartition = computeLogpartition(K, conditioner_left) + sufficientstatistics = (identity,) + + logbasemeasure = (x) -> 2 * loggamma(conditioner_left + 1) - 2 * sum(loggamma.(x .+ 1)) + basemeasure = (x) -> exp(logbasemeasure(x)) + + # Create log partition function that takes natural parameters as input + logpartition = __compute_logpartition_multinomial_product(K, conditioner_left) + supp = 0:conditioner_left - return ExponentialFamilyDistribution( - Multivariate, - naturalparameters, - nothing, + + attributes = ExponentialFamilyDistributionAttributes( basemeasure, sufficientstatistics, logpartition, supp ) -end -function BayesBase.prod(::ClosedProd, left::T, right::T) where {T <: Multinomial} - @assert left.n == right.n "$(left) and $(right) must have the same number of trials" - ef_left = convert(ExponentialFamilyDistribution, left) - ef_right = convert(ExponentialFamilyDistribution, right) - return prod(ClosedProd(), ef_left, ef_right) + return ExponentialFamilyDistribution( + Multivariate, + naturalparameters, + η_left, + attributes + ) end +BayesBase.probvec(dist::Multinomial) = probs(dist) + check_valid_natural(::Type{<:Multinomial}, params) = length(params) >= 1 function check_valid_conditioner(::Type{<:Multinomial}, conditioner) @@ -61,18 +74,13 @@ function check_valid_conditioner(::Type{<:Multinomial}, conditioner) end function isproper(::NaturalParametersSpace, ::Type{Multinomial}, natural_parameters::AbstractVector{<:Real}, conditioner::Int) - return (conditioner >= 1) && (length(natural_parameters) >= 1) + return (conditioner >= 1) && (length(natural_parameters) >= 1) end -function unpack_parameters(::Type{Multinomial}, packed::AbstractVector, conditioner) - @show packed +function unpack_parameters(::Type{Multinomial}, packed, conditioner) return (packed,) end -function pack_parameters(::Type{Multinomial}, unpacked::Tuple{<:AbstractVector}) - return first(unpacked) -end - function separate_conditioner(::Type{Multinomial}, params) ndims, success_probs = params return ((success_probs,), ndims) @@ -86,21 +94,27 @@ end function (::MeanToNatural{Multinomial})(parameters::Tuple{<:AbstractVector}, conditioner) (succprob,) = parameters - return (log.(succprob),) + pk = last(succprob) + return (map(pi -> log(pi / pk), succprob),) + # return (log.(succprob),) end function (::NaturalToMean{Multinomial})(natural_parameters::Tuple{<:AbstractVector}, conditioner) (log_probs,) = natural_parameters - return (exp.(log_probs),) + return (softmax(log_probs),) end getsufficientstatistics(::Type{Multinomial}, _) = (identity,) -getgradlogpartition(::NaturalParametersSpace, ::Type{Multinomial}, conditioner::Int) = (η) -> zeros(length(η)) -function logpartition(exponentialfamily::ExponentialFamilyDistribution{Multinomial}) - η = getnaturalparameters(exponentialfamily) - n = getconditioner(exponentialfamily) - return n * logsumexp(η) +isbasemeasureconstant(::Type{Multinomial}) = NonConstantBaseMeasure() +getbasemeasure(::Type{Multinomial}, ntrials) = (x) -> factorial(sum(x)) / prod(@.factorial(x)) +getlogbasemeasure(::Type{Multinomial}, ntrials) = (x) -> loggamma(sum(x) + 1) - sum(loggamma.(x .+ 1)) + +getlogpartition(::NaturalParametersSpace, ::Type{Multinomial}, conditioner::Int) = (η) -> conditioner * logsumexp(η) + +getgradlogpartition(::NaturalParametersSpace, ::Type{Multinomial}, conditioner::Int) = (η) -> begin + sumη = mapreduce(exp, +, η) + return map(d -> conditioner * exp(d) / sumη, η) end getfisherinformation(::NaturalParametersSpace, ::Type{Multinomial}, conditioner::Int) = (η) -> begin @@ -132,18 +146,3 @@ function BayesBase.insupport(ef::ExponentialFamilyDistribution{Multinomial, P, C n = Int(sum(x)) return n == getconditioner(ef) end - -basemeasureconstant(::ExponentialFamilyDistribution{Multinomial}) = NonConstantBaseMeasure() -basemeasureconstant(::Type{<:Multinomial}) = NonConstantBaseMeasure() -basemeasure(ef::ExponentialFamilyDistribution{Multinomial}) = (x) -> basemeasure(ef, x) - -sufficientstatistics(::Union{<:ExponentialFamilyDistribution{Multinomial}, <:Multinomial}, x::Vector) = x -sufficientstatistics(ef::Union{<:ExponentialFamilyDistribution{Multinomial}, <:Multinomial}) = - x -> sufficientstatistics(ef, x) - -getbasemeasure(::Type{Multinomial}, ntrials) = (x) -> begin - n = Int(sum(x)) - return factorial(n) / prod(@.factorial(x)) -end -getlogbasemeasure(::Type{Multinomial}, ntrials) = (x) -> log(getbasemeasure(Multinomial, ntrials)(x)) # TODO change with loggamma -getlogpartition(::NaturalParametersSpace, ::Type{Multinomial}, conditioner::Int) = (η) -> zeros(length(η)) \ No newline at end of file diff --git a/test/distributions/distributions_setuptests.jl b/test/distributions/distributions_setuptests.jl index e2b1e656..3030d4dc 100644 --- a/test/distributions/distributions_setuptests.jl +++ b/test/distributions/distributions_setuptests.jl @@ -379,7 +379,6 @@ function run_test_fisherinformation_properties(distribution; test_properties_in_ end if test_properties_in_mean_space - @show η θ = map(NaturalParametersSpace() => MeanParametersSpace(), T, η, conditioner) F = getfisherinformation(MeanParametersSpace(), T, conditioner)(θ) diff --git a/test/distributions/multinomial_tests.jl b/test/distributions/multinomial_tests.jl index 6a0ca0e2..cc5c19ab 100644 --- a/test/distributions/multinomial_tests.jl +++ b/test/distributions/multinomial_tests.jl @@ -24,7 +24,70 @@ end @testitem "Multinomial: test_EF_interface" begin include("distributions_setuptests.jl") - ef = vague(Multinomial, 5, 4) + using StableRNGs - test_exponentialfamily_interface(ef) + rng = StableRNG(42) + + for n in 2:6 + for trials in 2:10 + @testset let d = Multinomial(trials, normalize!(rand(rng, n), 1)) + test_exponentialfamily_interface(d; option_assume_no_allocations = false, + test_fisherinformation_properties = false, + test_fisherinformation_against_jacobian = false) + end + end + end +end + +@testitem "Product of Multinomial distributions" begin + include("distributions_setuptests.jl") + + using StableRNGs + using Distributions: Uniform + + rng = StableRNG(42) + + @testset "prod" begin + for n in 4:6 + pleft = rand(rng, n) + pleft = pleft ./ sum(pleft) + pright = rand(rng, n) + pright = pright ./ sum(pright) + left = Multinomial(n, pleft) + right = Multinomial(n, pright) + efleft = convert(ExponentialFamilyDistribution, left) + efright = convert(ExponentialFamilyDistribution, right) + prod_ef = prod(PreserveTypeProd(ExponentialFamilyDistribution), efleft, efright) + d = vague(Multinomial, n, n) + sample_space = unique(rand(rng, d, 4000), dims = 2) + sample_space = [sample_space[:, i] for i in 1:size(sample_space, 2)] + + # Test normalization for the new interface + hist_sumef(x) = + getbasemeasure(prod_ef)(x) * exp( + getnaturalparameters(prod_ef)' * first(sufficientstatistics(prod_ef, x)) - + logpartition(prod_ef, getnaturalparameters(prod_ef)) + ) + @test sum(hist_sumef(x_sample) for x_sample in sample_space) ≈ 1.0 rtol = 1e-3 + + # Test basemeasure and sufficient statistics + sample_x = rand(d, 5) + for xi in sample_x + @test getbasemeasure(prod_ef)(xi) ≈ (factorial(n) / prod(@.factorial(xi)))^2 rtol = 1e-10 + @test sufficientstatistics(prod_ef, xi) == (xi,) + end + end + + # Test error cases for mismatched conditioners + @test_throws AssertionError prod( + PreserveTypeProd(ExponentialFamilyDistribution), + convert(ExponentialFamilyDistribution, Multinomial(4, [0.2, 0.4, 0.4])), + convert(ExponentialFamilyDistribution, Multinomial(5, [0.1, 0.3, 0.6])) + ) + @test_throws AssertionError prod( + PreserveTypeProd(ExponentialFamilyDistribution), + convert(ExponentialFamilyDistribution, Multinomial(4, [0.2, 0.4, 0.4])), + convert(ExponentialFamilyDistribution, Multinomial(3, [0.1, 0.3, 0.6])) + ) + end end \ No newline at end of file From 49cb6531a7dc41cd06a5c335dc06b220d7703ba2 Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Wed, 22 Oct 2025 17:12:52 +0200 Subject: [PATCH 3/4] Update multinomial tests and remove wip file --- test/distributions/multinomial_tests.jl | 49 ++++++- test/distributions/wip/test_multinomial.jl | 146 --------------------- 2 files changed, 48 insertions(+), 147 deletions(-) delete mode 100644 test/distributions/wip/test_multinomial.jl diff --git a/test/distributions/multinomial_tests.jl b/test/distributions/multinomial_tests.jl index cc5c19ab..a109c734 100644 --- a/test/distributions/multinomial_tests.jl +++ b/test/distributions/multinomial_tests.jl @@ -90,4 +90,51 @@ end convert(ExponentialFamilyDistribution, Multinomial(3, [0.1, 0.3, 0.6])) ) end -end \ No newline at end of file +end + +@testitem "Multinomial: natural parameters" begin + include("distributions_setuptests.jl") + + @testset "natural parameters related " begin + d1 = Multinomial(5, [0.1, 0.4, 0.5]) + d2 = Multinomial(5, [0.2, 0.4, 0.4]) + η1 = ExponentialFamilyDistribution(Multinomial, [log(0.1 / 0.5), log(0.4 / 0.5), 0.0], 5) + η2 = ExponentialFamilyDistribution(Multinomial, [log(0.2 / 0.4), 0.0, 0.0], 5) + + @test convert(ExponentialFamilyDistribution, d1) ≈ η1 + @test convert(ExponentialFamilyDistribution, d2) ≈ η2 + + @test convert(Distribution, η1) ≈ d1 + @test convert(Distribution, η2) ≈ d2 + + @test logpartition(η1) == 3.4657359027997265 + @test logpartition(η2) == 4.5814536593707755 + + @test basemeasure(η1, [1, 2, 2]) == 30.0 + @test basemeasure(η2, [1, 2, 2]) == 30.0 + + @test logpdf(η1, [1, 2, 2]) ≈ logpdf(d1, [1, 2, 2]) atol = 1e-8 + @test logpdf(η2, [1, 2, 2]) ≈ logpdf(d2, [1, 2, 2]) atol = 1e-8 + + @test pdf(η1, [1, 2, 2]) ≈ pdf(d1, [1, 2, 2]) atol = 1e-8 + @test pdf(η2, [1, 2, 2]) ≈ pdf(d2, [1, 2, 2]) atol = 1e-8 + end +end + +@testitem "Multinomial: mean and covariance" begin + include("distributions_setuptests.jl") + + using StableRNGs + using Distributions: Dirichlet + + @testset "ExponentialFamilyDistribution mean,cov" begin + rng = StableRNG(42) + for n in 2:12 + p = rand(rng, Dirichlet(ones(n))) + dist = Multinomial(n, p) + ef = convert(ExponentialFamilyDistribution, dist) + @test mean(dist) ≈ mean(ef) atol = 1e-8 + @test cov(dist) ≈ cov(ef) atol = 1e-8 + end + end +end diff --git a/test/distributions/wip/test_multinomial.jl b/test/distributions/wip/test_multinomial.jl deleted file mode 100644 index aabd9538..00000000 --- a/test/distributions/wip/test_multinomial.jl +++ /dev/null @@ -1,146 +0,0 @@ -module MultinomialTest - -using Test -using ExponentialFamily -using Distributions -using Random -using StableRNGs -using ForwardDiff -import ExponentialFamily: ExponentialFamilyDistribution, getnaturalparameters, basemeasure, fisherinformation - -@testset "Multinomial" begin - @testset "probvec" begin - @test probvec(Multinomial(5, [1 / 3, 1 / 3, 1 / 3])) == [1 / 3, 1 / 3, 1 / 3] - @test probvec(Multinomial(3, [0.2, 0.2, 0.4, 0.1, 0.1])) == [0.2, 0.2, 0.4, 0.1, 0.1] - @test probvec(Multinomial(2, [0.5, 0.5])) == [0.5, 0.5] - end - - @testset "vague" begin - @test_throws MethodError vague(Multinomial) - @test_throws MethodError vague(Multinomial, 4) - - vague_dist1 = vague(Multinomial, 5, 4) - @test typeof(vague_dist1) <: Multinomial - @test probvec(vague_dist1) == [1 / 4, 1 / 4, 1 / 4, 1 / 4] - - vague_dist2 = vague(Multinomial, 3, 5) - @test typeof(vague_dist2) <: Multinomial - @test probvec(vague_dist2) == [1 / 5, 1 / 5, 1 / 5, 1 / 5, 1 / 5] - end - - @testset "prod" begin - for n in 2:3 - plength = Int64(ceil(rand(Uniform(1, n)))) - pleft = rand(plength) - pleft = pleft ./ sum(pleft) - pright = rand(plength) - pright = pright ./ sum(pright) - left = Multinomial(n, pleft) - right = Multinomial(n, pright) - efleft = convert(ExponentialFamilyDistribution, left) - efright = convert(ExponentialFamilyDistribution, right) - prod_dist = prod(ClosedProd(), left, right) - prod_ef = prod(efleft, efright) - d = Multinomial(n, ones(plength) ./ plength) - sample_space = unique(rand(StableRNG(1), d, 4000), dims = 2) - sample_space = [sample_space[:, i] for i in 1:size(sample_space, 2)] - - hist_sum(x) = - prod_dist.basemeasure(x) * exp( - prod_dist.naturalparameters' * prod_dist.sufficientstatistics(x) - - prod_dist.logpartition(prod_dist.naturalparameters) - ) - hist_sumef(x) = - prod_ef.basemeasure(x) * exp( - prod_ef.naturalparameters' * prod_ef.sufficientstatistics(x) - - prod_ef.logpartition(prod_ef.naturalparameters) - ) - @test sum(hist_sum(x_sample) for x_sample in sample_space) ≈ 1.0 atol = 1e-10 - @test sum(hist_sumef(x_sample) for x_sample in sample_space) ≈ 1.0 atol = 1e-10 - sample_x = rand(d, 5) - for xi in sample_x - @test prod_dist.basemeasure(xi) ≈ (factorial(n) / prod(@.factorial(xi)))^2 atol = 1e-10 - @test prod_dist.sufficientstatistics(xi) == xi - - @test prod_ef.basemeasure(xi) ≈ (factorial(n) / prod(@.factorial(xi)))^2 atol = 1e-10 - @test prod_ef.sufficientstatistics(xi) == xi - end - end - - @test_throws AssertionError prod( - ClosedProd(), - Multinomial(4, [0.2, 0.4, 0.4]), - Multinomial(5, [0.1, 0.3, 0.6]) - ) - @test_throws AssertionError prod( - ClosedProd(), - Multinomial(4, [0.2, 0.4, 0.4]), - Multinomial(3, [0.1, 0.3, 0.6]) - ) - end - - @testset "natural parameters related " begin - d1 = Multinomial(5, [0.1, 0.4, 0.5]) - d2 = Multinomial(5, [0.2, 0.4, 0.4]) - η1 = ExponentialFamilyDistribution(Multinomial, [log(0.1 / 0.5), log(0.4 / 0.5), 0.0], 5) - η2 = ExponentialFamilyDistribution(Multinomial, [log(0.2 / 0.4), 0.0, 0.0], 5) - - @test convert(ExponentialFamilyDistribution, d1) ≈ η1 - @test convert(ExponentialFamilyDistribution, d2) ≈ η2 - - @test convert(Distribution, η1) ≈ d1 - @test convert(Distribution, η2) ≈ d2 - - @test logpartition(η1) == 3.4657359027997265 - @test logpartition(η2) == 4.5814536593707755 - - @test basemeasure(η1, [1, 2, 2]) == 30.0 - @test basemeasure(η2, [1, 2, 2]) == 30.0 - - @test logpdf(η1, [1, 2, 2]) == logpdf(d1, [1, 2, 2]) - @test logpdf(η2, [1, 2, 2]) == logpdf(d2, [1, 2, 2]) - - @test pdf(η1, [1, 2, 2]) == pdf(d1, [1, 2, 2]) - @test pdf(η2, [1, 2, 2]) == pdf(d2, [1, 2, 2]) - end - - @testset "fisher information" begin - function transformation(η) - expη = exp.(η) - expη / sum(expη) - end - rng = StableRNG(42) - ## ForwardDiff hessian is slow so we only test one time with hessian - n = 3 - p = rand(rng, Dirichlet(ones(n))) - dist = Multinomial(n, p) - ef = convert(ExponentialFamilyDistribution, dist) - η = getnaturalparameters(ef) - - f_logpartition = (η) -> logpartition(ExponentialFamilyDistribution(Multinomial, η, n)) - autograd_information = (η) -> ForwardDiff.hessian(f_logpartition, η) - @test fisherinformation(ef) ≈ autograd_information(η) atol = 1e-8 - - for n in 2:12 - p = rand(rng, Dirichlet(ones(n))) - dist = Multinomial(n, p) - ef = convert(ExponentialFamilyDistribution, dist) - η = getnaturalparameters(ef) - - J = ForwardDiff.jacobian(transformation, η) - @test J' * fisherinformation(dist) * J ≈ fisherinformation(ef) atol = 1e-8 - end - end - - @testset "ExponentialFamilyDistribution mean,cov" begin - rng = StableRNG(42) - for n in 2:12 - p = rand(rng, Dirichlet(ones(n))) - dist = Multinomial(n, p) - ef = convert(ExponentialFamilyDistribution, dist) - @test mean(dist) ≈ mean(ef) atol = 1e-8 - @test cov(dist) ≈ cov(ef) atol = 1e-8 - end - end -end -end From a42dfc2a52e9e13611f65a5e5071c9d585dd08c0 Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Wed, 22 Oct 2025 17:16:02 +0200 Subject: [PATCH 4/4] Seed rng same --- src/distributions/multinomial.jl | 4 +++- test/distributions/multinomial_tests.jl | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/distributions/multinomial.jl b/src/distributions/multinomial.jl index f22e471c..9828737d 100644 --- a/src/distributions/multinomial.jl +++ b/src/distributions/multinomial.jl @@ -3,6 +3,7 @@ export Multinomial import Distributions: Multinomial, probs using StaticArrays using LogExpFunctions +using Random BayesBase.vague(::Type{<:Multinomial}, n::Int, dims::Int) = Multinomial(n, ones(dims) ./ dims) @@ -13,7 +14,8 @@ BayesBase.default_prod_rule(::Type{<:Multinomial}, ::Type{<:Multinomial}) = Pres function __compute_logpartition_multinomial_product(K::Int, n::Int) d = vague(Multinomial, n, K) - samples = unique(rand(d, 4000), dims = 2) + rng = Random.MersenneTwister(42) + samples = unique(rand(rng, d, 10000), dims = 2) samples = [samples[:, i] for i in 1:size(samples, 2)] return let samples = samples (η) -> begin diff --git a/test/distributions/multinomial_tests.jl b/test/distributions/multinomial_tests.jl index a109c734..b343727b 100644 --- a/test/distributions/multinomial_tests.jl +++ b/test/distributions/multinomial_tests.jl @@ -59,7 +59,7 @@ end efright = convert(ExponentialFamilyDistribution, right) prod_ef = prod(PreserveTypeProd(ExponentialFamilyDistribution), efleft, efright) d = vague(Multinomial, n, n) - sample_space = unique(rand(rng, d, 4000), dims = 2) + sample_space = unique(rand(rng, d, 10000), dims = 2) sample_space = [sample_space[:, i] for i in 1:size(sample_space, 2)] # Test normalization for the new interface