Skip to content

Commit c2707bb

Browse files
scheidantpapp
authored andcommitted
add transformation to simplex (#56)
add transform for UnitSimplex
1 parent 12792da commit c2707bb

File tree

2 files changed

+85
-1
lines changed

2 files changed

+85
-1
lines changed

src/special_arrays.jl

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export UnitVector, CorrCholeskyFactor
1+
export UnitVector, UnitSimplex, CorrCholeskyFactor
22

33
####
44
#### building blocks
@@ -77,6 +77,67 @@ function inverse_at!(x::AbstractVector, index, t::UnitVector, y::AbstractVector)
7777
index
7878
end
7979

80+
81+
####
82+
#### UnitSimplex
83+
####
84+
85+
"""
86+
UnitSimplex(n)
87+
88+
Transform `n-1` real numbers to a vector of length `n` whose elements are non-negative and sum to one.
89+
"""
90+
@calltrans struct UnitSimplex <: VectorTransform
91+
n::Int
92+
function UnitSimplex(n::Int)
93+
@argcheck n 1 "Dimension should be positive."
94+
new(n)
95+
end
96+
end
97+
98+
dimension(t::UnitSimplex) = t.n - 1
99+
100+
function transform_with(flag::LogJacFlag, t::UnitSimplex, x::AbstractVector, index)
101+
@unpack n = t
102+
T = extended_eltype(x)
103+
104+
= logjac_zero(flag, T)
105+
stick = one(T)
106+
y = Vector{T}(undef, n)
107+
@inbounds for i in 1:n-1
108+
xi = x[index]
109+
index += 1
110+
z = logistic(xi - log(n-i))
111+
y[i] = z * stick
112+
113+
if !(flag isa NoLogJac)
114+
+= log(stick) - logit_logjac(z)
115+
end
116+
117+
stick *= 1 - z
118+
end
119+
120+
y[end] = stick
121+
122+
y, ℓ, index
123+
end
124+
125+
inverse_eltype(t::UnitSimplex, y::AbstractVector) = extended_eltype(y)
126+
127+
function inverse_at!(x::AbstractVector, index, t::UnitSimplex, y::AbstractVector)
128+
@unpack n = t
129+
@argcheck length(y) == n
130+
131+
stick = one(eltype(y))
132+
@inbounds for i in axes(y, 1)[1:end-1]
133+
z = y[i]/stick
134+
x[index] = logit(z) + log(n-i)
135+
stick -= y[i]
136+
index += 1
137+
end
138+
index
139+
end
140+
80141
####
81142
#### correlation cholesky factor
82143
####

test/runtests.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,29 @@ end
103103
end
104104
end
105105

106+
@testset "to unit simplex" begin
107+
@testset "dimension checks" begin
108+
S = UnitSimplex(3)
109+
x = zeros(3) # incorrect
110+
@test_throws ArgumentError S(x)
111+
@test_throws ArgumentError transform(S, x)
112+
@test_throws ArgumentError transform_and_logjac(S, x)
113+
end
114+
115+
@testset "consistency checks" begin
116+
for K in 1:10
117+
t = UnitSimplex(K)
118+
@test dimension(t) == K - 1
119+
if K > 1
120+
test_transformation(t, y -> (sum(y) 1) & (all(y.>=0)),
121+
vec_y = y -> y[1:(end-1)])
122+
end
123+
x = zeros(dimension(t))
124+
@test transform(t, x) 1 ./ fill(K, K)
125+
end
126+
end
127+
end
128+
106129
@testset "to correlation cholesky factor" begin
107130
@testset "dimension checks" begin
108131
C = CorrCholeskyFactor(3)

0 commit comments

Comments
 (0)