11using DocStringExtensions, LinearAlgebra, LogDensityProblems, OffsetArrays, Parameters,
22 Random, Test, TransformVariables, StaticArrays
33import Flux, ForwardDiff, ReverseDiff
4- using LogDensityProblems: Value, ValueGradient
4+ using LogDensityProblems: logdensity, logdensity_and_gradient
55using TransformVariables:
66 AbstractTransform, ScalarTransform, VectorTransform, ArrayTransform,
77 unit_triangular_dimension, logistic, logistic_logjac, logit, inverse_and_logjac
88
9- include (" test_utilities .jl" )
9+ include (" utilities .jl" )
1010
1111Random. seed! (1 )
1212
1313const CIENV = get (ENV , " TRAVIS" , " " ) == " true" || get (ENV , " CI" , " " ) == " true"
1414
15+ # ###
16+ # ### utilities
17+ # ###
18+
1519@testset " misc utilities" begin
1620 @test unit_triangular_dimension (1 ) == 0
1721 @test unit_triangular_dimension (2 ) == 1
3438 end
3539end
3640
41+ # ###
42+ # ### scalar transformations correctness checks
43+ # ###
44+
3745@testset " scalar transformations consistency" begin
3846 for _ in 1 : 100
3947 a = randn () * 100
7078 @test as_unit_interval ≡ as𝕀
7179end
7280
81+ # ###
82+ # ### special array transformation correctness checks
83+ # ###
84+
7385@testset " to unit vector" begin
7486 @testset " dimension checks" begin
7587 U = UnitVector (3 )
114126 end
115127end
116128
129+ # ###
130+ # ### aggregation
131+ # ###
132+
133+ # ##
134+ # ## array correctness checks
135+ # ##
136+
117137@testset " to array scalar" begin
118138 dims = (3 , 4 , 5 )
119139 t = as𝕀
156176 @test lj == 0
157177end
158178
179+ # ##
180+ # ## tuple correctness checks
181+ # ##
182+
159183@testset " to tuple" begin
160184 t1 = asℝ
161185 t2 = as𝕀
182206 @test lj2 ≈ ljacc
183207end
184208
209+ # ##
210+ # ## named tuple correctness checks
211+ # ##
212+
185213@testset " to named tuple" begin
186214 t1 = asℝ
187215 t2 = CorrCholeskyFactor (7 )
224252 @test_skip inverse (za, []) == []
225253end
226254
255+ # ###
256+ # ### log density correctness checks
257+ # ###
258+
227259@testset " transform logdensity: correctness" begin
228260 # the density is p(σ) = σ⁻³
229261 # let z = log(σ), so σ = exp(z)
246278 @test (@inferred transform_logdensity (t, f, z)) isa Float64
247279end
248280
281+ # ###
282+ # ### custom transformations
283+ # ###
284+
249285@testset " custom transformation: triangle below diagonal in [0,1]²" begin
250286 tfun (y) = y[1 ], y[1 ]* y[2 ] # triangle below diagonal in unit square
251287 t = CustomTransform (as (Array, as𝕀, 2 ), tfun, collect;)
304340 end
305341end
306342
343+ # ###
344+ # ### AD compatibility tests
345+ # ###
346+
307347@testset " AD tests" begin
308348 t = as ((μ = asℝ, σ = asℝ₊, β = asℝ₋, α = as (Real, 0.0 , 1.0 ),
309349 u = UnitVector (3 ), L = CorrCholeskyFactor (4 )))
@@ -313,33 +353,28 @@ end
313353 end
314354 P = TransformedLogDensity (t, f)
315355 x = zeros (dimension (t))
316- v = logdensity (Value, P, x)
356+ v = logdensity (P, x)
357+ g = ForwardDiff. gradient (x -> logdensity (P, x), x)
317358
318359 # ForwardDiff
319360 P1 = ADgradient (:ForwardDiff , P)
320- @test v == logdensity (Value, P1, x)
321- g1 = @inferred logdensity (ValueGradient, P1, x)
322- @test g1. value == v. value
361+ @test v == logdensity (P1, x)
362+ v1, g1 = @inferred logdensity_and_gradient (P1, x)
363+ @test v1 == v
364+ @test g1 ≈ g
323365
324366 # Flux # NOTE @inferred removed as it currently fails, cf
325367 # https://github.com/FluxML/Flux.jl/issues/497
326368 P2 = ADgradient (:Flux , P)
327- g2 = logdensity (ValueGradient, P2, x) #
328- @test g2. value == v. value
329- @test g2. gradient ≈ g1. gradient
330-
331- # test element type calculations for Flux
332- t2 = CorrCholeskyFactor (4 )
333- @test t2 (Flux. param (ones (dimension (t2)))) isa UpperTriangular
334-
335- t3 = UnitVector (3 )
336- @test sum (abs2, t3 (Flux. param (ones (dimension (t3))))) ≈ Flux. param (1.0 )
369+ v2, g2 = logdensity_and_gradient (P2, x)
370+ @test v2 == v
371+ @test g2 ≈ g
337372
338373 # ReverseDiff
339374 P3 = ADgradient (:ReverseDiff , P)
340- g3 = @inferred logdensity (ValueGradient, P3, x)
341- @test g3 . value == v. value
342- @test g3. gradient ≈ g1 . gradient
375+ v3, g3 = @inferred logdensity_and_gradient ( P3, x)
376+ @test v3 == v
377+ @test g3 ≈ g
343378end
344379
345380@testset " inverse_and_logjac" begin
0 commit comments