|
| 1 | +# The rest of this file is unchanged, except the very end, |
| 2 | +# but IMO we should move these tests to a new file. |
| 3 | + |
| 4 | +# Loading Diffractor: var"'" globally will break many tests above, which use it for adjoint. |
| 5 | + |
| 6 | +using Diffractor: var"'", ∂⃖, DiffractorRuleConfig |
| 7 | +using ChainRules |
| 8 | +using ChainRulesCore |
| 9 | +using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad |
| 10 | +using Symbolics |
| 11 | +using LinearAlgebra |
| 12 | + |
| 13 | + |
| 14 | +const fwd = Diffractor.PrimeDerivativeFwd |
| 15 | +const bwd = Diffractor.PrimeDerivativeBack |
| 16 | + |
| 17 | +# Unit tests |
| 18 | +function tup2(f) |
| 19 | + a, b = ∂⃖{2}()(f, 1) |
| 20 | + c, d = b((2,)) |
| 21 | + e, f = d(ZeroTangent(), 3) |
| 22 | + f((4,)) |
| 23 | +end |
| 24 | + |
| 25 | +@test tup2(tuple) == (NoTangent(), 4) |
| 26 | + |
| 27 | +my_tuple(args...) = args |
| 28 | +ChainRules.rrule(::typeof(my_tuple), args...) = args, Δ->Core.tuple(NoTangent(), Δ...) |
| 29 | +@test tup2(my_tuple) == (ZeroTangent(), 4) |
| 30 | + |
| 31 | +# Check characteristic of exp rule |
| 32 | +@variables ω α β γ δ ϵ ζ η |
| 33 | +(x1, c1) = ∂⃖{3}()(exp, ω) |
| 34 | +@test isequal(simplify(x1), simplify(exp(ω))) |
| 35 | +((_, x2), c2) = c1(α) |
| 36 | +@test isequal(simplify(x2), simplify(α*exp(ω))) |
| 37 | +(x3, c3) = c2(ZeroTangent(), β) |
| 38 | +@test isequal(simplify(x3), simplify(β*exp(ω))) |
| 39 | +((_, x4), c4) = c3(γ) |
| 40 | +@test isequal(simplify(x4), simplify(exp(ω)*(γ + (α*β)))) |
| 41 | +(x5, c5) = c4(ZeroTangent(), δ) |
| 42 | +@test isequal(simplify(x5), simplify(δ*exp(ω))) |
| 43 | +((_, x6), c6) = c5(ϵ) |
| 44 | +@test isequal(simplify(x6), simplify(ϵ*exp(ω) + α*δ*exp(ω))) |
| 45 | +(x7, c7) = c6(ZeroTangent(), ζ) |
| 46 | +@test isequal(simplify(x7), simplify(ζ*exp(ω) + β*δ*exp(ω))) |
| 47 | +(_, x8) = c7(η) |
| 48 | +@test isequal(simplify(x8), simplify((η + (α*ζ) + (β*ϵ) + (δ*(γ + (α*β))))*exp(ω))) |
| 49 | + |
| 50 | +# Minimal 2-nd order forward smoke test |
| 51 | +@test Diffractor.∂☆{2}()(Diffractor.ZeroBundle{2}(sin), |
| 52 | + Diffractor.TangentBundle{2}(1.0, (1.0, 1.0, 0.0)))[Diffractor.CanonicalTangentIndex(1)] == sin'(1.0) |
| 53 | + |
| 54 | +function simple_control_flow(b, x) |
| 55 | + if b |
| 56 | + return sin(x) |
| 57 | + else |
| 58 | + return cos(x) |
| 59 | + end |
| 60 | +end |
| 61 | + |
| 62 | +function myprod(xs) |
| 63 | + s = 1 |
| 64 | + for x in xs |
| 65 | + s *= x |
| 66 | + end |
| 67 | + return s |
| 68 | +end |
| 69 | + |
| 70 | +function mypow(x, n) |
| 71 | + r = one(x) |
| 72 | + while n > 0 |
| 73 | + n -= 1 |
| 74 | + r *= x |
| 75 | + end |
| 76 | + return r |
| 77 | +end |
| 78 | + |
| 79 | +function times_three_while(x) |
| 80 | + z = x |
| 81 | + i = 3 |
| 82 | + while i > 1 |
| 83 | + z += x |
| 84 | + i -= 1 |
| 85 | + end |
| 86 | + z |
| 87 | +end |
| 88 | + |
| 89 | +isa_control_flow(::Type{T}, x) where {T} = isa(x, T) ? x : T(x) |
| 90 | + |
| 91 | +# Simple Reverse Mode tests |
| 92 | +let var"'" = Diffractor.PrimeDerivativeBack |
| 93 | + # Integration tests |
| 94 | + @test @inferred(sin'(1.0)) == cos(1.0) |
| 95 | + @test @inferred(sin''(1.0)) == -sin(1.0) |
| 96 | + @test sin'''(1.0) == -cos(1.0) |
| 97 | + @test sin''''(1.0) == sin(1.0) |
| 98 | + @test sin'''''(1.0) == cos(1.0) # broken = VERSION >= v"1.8" |
| 99 | + @test sin''''''(1.0) == -sin(1.0) # broken = VERSION >= v"1.8" |
| 100 | + |
| 101 | + f_getfield(x) = getfield((x,), 1) |
| 102 | + @test f_getfield'(1) == 1 |
| 103 | + @test f_getfield''(1) == 0 |
| 104 | + @test f_getfield'''(1) == 0 |
| 105 | + |
| 106 | + # Higher order mixed mode tests |
| 107 | + |
| 108 | + complicated_2sin(x) = (x = map(sin, Diffractor.xfill(x, 2)); x[1] + x[2]) |
| 109 | + @test @inferred(complicated_2sin'(1.0)) == 2sin'(1.0) |
| 110 | + @test @inferred(complicated_2sin''(1.0)) == 2sin''(1.0) broken=true |
| 111 | + @test @inferred(complicated_2sin'''(1.0)) == 2sin'''(1.0) broken=true |
| 112 | + @test @inferred(complicated_2sin''''(1.0)) == 2sin''''(1.0) broken=true |
| 113 | + |
| 114 | + # Control flow cases |
| 115 | + @test @inferred((x->simple_control_flow(true, x))'(1.0)) == sin'(1.0) |
| 116 | + @test @inferred((x->simple_control_flow(false, x))'(1.0)) == cos'(1.0) |
| 117 | + @test (x->sum(isa_control_flow(Matrix{Float64}, x)))'(Float32[1 2;]) == [1.0 1.0;] |
| 118 | + @test times_three_while'(1.0) == 3.0 |
| 119 | + |
| 120 | + pow5p(x) = (x->mypow(x, 5))'(x) |
| 121 | + @test pow5p(1.0) == 5.0 |
| 122 | +end |
| 123 | + |
| 124 | +# Simple Forward Mode tests |
| 125 | +let var"'" = Diffractor.PrimeDerivativeFwd |
| 126 | + recursive_sin(x) = sin(x) |
| 127 | + ChainRulesCore.frule(∂, ::typeof(recursive_sin), x) = frule(∂, sin, x) |
| 128 | + |
| 129 | + # Integration tests |
| 130 | + @test recursive_sin'(1.0) == cos(1.0) |
| 131 | + @test recursive_sin''(1.0) == -sin(1.0) |
| 132 | + # Error: ArgumentError: Tangent for the primal Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}} |
| 133 | + # should be backed by a NamedTuple type, not by Tuple{Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}}. |
| 134 | + @test_broken recursive_sin'''(1.0) == -cos(1.0) |
| 135 | + @test_broken recursive_sin''''(1.0) == sin(1.0) |
| 136 | + @test_broken recursive_sin'''''(1.0) == cos(1.0) |
| 137 | + @test_broken recursive_sin''''''(1.0) == -sin(1.0) |
| 138 | + |
| 139 | + # Test the special rules for sin/cos/exp |
| 140 | + @test sin''''''(1.0) == -sin(1.0) |
| 141 | + @test cos''''''(1.0) == -cos(1.0) |
| 142 | + @test exp''''''(1.0) == exp(1.0) |
| 143 | +end |
| 144 | + |
| 145 | +# Some Basic Mixed Mode tests |
| 146 | +function sin_twice_fwd(x) |
| 147 | + let var"'" = Diffractor.PrimeDerivativeFwd |
| 148 | + sin''(x) |
| 149 | + end |
| 150 | +end |
| 151 | +let var"'" = Diffractor.PrimeDerivativeFwd |
| 152 | + @test sin_twice_fwd'(1.0) == sin'''(1.0) |
| 153 | +end |
| 154 | + |
| 155 | +# Regression tests |
| 156 | +@test gradient(x -> sum(abs2, x .+ 1.0), zeros(3))[1] == [2.0, 2.0, 2.0] |
| 157 | + |
| 158 | +function f_broadcast(a) |
| 159 | + l = a / 2.0 * [[0. 1. 1.]; [1. 0. 1.]; [1. 1. 0.]] |
| 160 | + return sum(l) |
| 161 | +end |
| 162 | +@test fwd(f_broadcast)(1.0) == bwd(f_broadcast)(1.0) |
| 163 | + |
| 164 | +# Make sure that there's no infinite recursion in kwarg calls |
| 165 | +g_kw(;x=1.0) = sin(x) |
| 166 | +f_kw(x) = g_kw(;x) |
| 167 | +@test bwd(f_kw)(1.0) == bwd(sin)(1.0) |
| 168 | + |
| 169 | +function f_crit_edge(a, b, c, x) |
| 170 | + # A function with two critical edges. This used to trigger an issue where |
| 171 | + # Diffractor would fail to insert edges for the second split critical edge. |
| 172 | + y = 1x |
| 173 | + if a && b |
| 174 | + y = 2x |
| 175 | + end |
| 176 | + if b && c |
| 177 | + y = 3x |
| 178 | + end |
| 179 | + |
| 180 | + if c |
| 181 | + y = 4y |
| 182 | + end |
| 183 | + |
| 184 | + return y |
| 185 | +end |
| 186 | +@test bwd(x->f_crit_edge(false, false, false, x))(1.0) == 1.0 |
| 187 | +@test bwd(x->f_crit_edge(true, true, false, x))(1.0) == 2.0 |
| 188 | +@test bwd(x->f_crit_edge(false, true, true, x))(1.0) == 12.0 |
| 189 | +@test bwd(x->f_crit_edge(false, false, true, x))(1.0) == 4.0 |
| 190 | +@test bwd(bwd(x->5))(1.0) == ZeroTangent() |
| 191 | +@test fwd(fwd(x->5))(1.0) == ZeroTangent() |
| 192 | + |
| 193 | +# Issue #27 - Mixup in lifting of getfield |
| 194 | +let var"'" = bwd |
| 195 | + @test (x->x^5)''(1.0) == 20. |
| 196 | + @test (x->(x*x)*(x*x)*x)'''(1.0) == 60. |
| 197 | + # Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24) |
| 198 | + @test_broken (x->x^5)'''(1.0) == 60. |
| 199 | +end |
| 200 | + |
| 201 | +# Issue #38 - Splatting arrays |
| 202 | +@test gradient(x -> max(x...), (1,2,3))[1] == (0.0, 0.0, 1.0) |
| 203 | +@test gradient(x -> max(x...), [1,2,3])[1] == [0.0, 0.0, 1.0] |
| 204 | + |
| 205 | +# Issue #40 - Symbol type parameters not properly quoted |
| 206 | +@test Diffractor.∂⃖recurse{1}()(Val{:transformations})[1] === Val{:transformations}() |
| 207 | + |
| 208 | +# PR #43 |
| 209 | +loss(res, z, w) = sum(res.U * Diagonal(res.S) * res.V) + sum(res.S .* w) |
| 210 | +x43 = rand(10, 10) |
| 211 | +@test Diffractor.gradient(x->loss(svd(x), x[:,1], x[:,2]), x43) isa Tuple{Matrix{Float64}} |
| 212 | + |
| 213 | +# PR # 45 - Calling back into AD from ChainRules |
| 214 | +r45 = rrule_via_ad(DiffractorRuleConfig(), x -> log(exp(x)), 2) |
| 215 | +@test r45 isa Tuple |
| 216 | +y45, back45 = r45 |
| 217 | +@test y45 ≈ 2.0 |
| 218 | +@test back45(1) == (ZeroTangent(), 1.0) |
| 219 | + |
| 220 | +z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2) |
| 221 | +@test z45 ≈ 2.0 |
| 222 | +@test delta45 ≈ 1.0 |
| 223 | + |
| 224 | +# PR #82 - getindex on non-numeric arrays |
| 225 | +@test gradient(ls -> ls[1](1.), [Base.Fix1(*, 1.)])[1][1] isa Tangent{<:Base.Fix1} |
| 226 | + |
| 227 | +@testset "broadcast" begin |
| 228 | + @test gradient(x -> sum(x ./ x), [1,2,3]) == ([0,0,0],) # derivatives_given_output |
| 229 | + @test gradient(x -> sum(sqrt.(atan.(x, transpose(x)))), [1,2,3])[1] ≈ [0.2338, -0.0177, -0.0661] atol=1e-3 |
| 230 | + @test gradient(x -> sum(exp.(log.(x))), [1,2,3]) == ([1,1,1],) |
| 231 | + |
| 232 | + @test gradient(x -> sum((exp∘log).(x)), [1,2,3]) == ([1,1,1],) # frule_via_ad |
| 233 | + exp_log(x) = exp(log(x)) |
| 234 | + @test gradient(x -> sum(exp_log.(x)), [1,2,3]) == ([1,1,1],) |
| 235 | + @test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], [1,2]) == ([1 1; 0.5 0.5], [-3, -1.75]) |
| 236 | + @test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], 5) == ([0.2 0.2; 0.2 0.2], -0.4) |
| 237 | + @test gradient(x -> sum((y -> y/x).([1,2,3])), 4) == (-0.375,) # closure |
| 238 | + |
| 239 | + @test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 # array of arrays |
| 240 | + @test gradient(x -> sum(sum, Ref(x) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 |
| 241 | + @test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 |
| 242 | + @test gradient(x -> sum(sum, (x,) .* transpose(x)), [1,2,3])[1] ≈ [12, 12, 12] # must not take the * fast path |
| 243 | + |
| 244 | + @test gradient(x -> sum(x ./ 4), [1,2,3]) == ([0.25, 0.25, 0.25],) |
| 245 | + @test gradient(x -> sum([1,2,3] ./ x), 4) == (-0.375,) # x/y rule |
| 246 | + @test gradient(x -> sum(x.^2), [1,2,3]) == ([2.0, 4.0, 6.0],) # x.^2 rule |
| 247 | + @test gradient(x -> sum([1,2,3] ./ x.^2), 4) == (-0.1875,) # scalar^2 rule |
| 248 | + |
| 249 | + @test gradient(x -> sum((1,2,3) .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-1.0, -1.0, -1.0),) |
| 250 | + @test gradient(x -> sum(transpose([1,2,3]) .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-3.0, -3.0, -3.0),) |
| 251 | + @test gradient(x -> sum([1 2 3] .+ x .^ 2), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(6.0, 12.0, 18.0),) |
| 252 | + |
| 253 | + @test gradient(x -> sum(x .> 2), [1,2,3]) |> only |> iszero # Bool output |
| 254 | + @test gradient(x -> sum(1 .+ iseven.(x)), [1,2,3]) |> only |> iszero |
| 255 | + @test gradient((x,y) -> sum(x .== y), [1,2,3], [1 2 3]) == (NoTangent(), NoTangent()) |
| 256 | + @test gradient(x -> sum(x .+ [1,2,3]), true) |> only |> iszero # Bool input |
| 257 | + @test gradient(x -> sum(x ./ [1,2,3]), [true false]) |> only |> iszero |
| 258 | + @test gradient(x -> sum(x .* transpose([1,2,3])), (true, false)) |> only |> iszero |
| 259 | + |
| 260 | + tup_adj = gradient((x,y) -> sum(2 .* x .+ log.(y)), (1,2), transpose([3,4,5])) |
| 261 | + @test tup_adj[1] == Tangent{Tuple{Int64, Int64}}(6.0, 6.0) |
| 262 | + @test tup_adj[2] ≈ [0.6666666666666666 0.5 0.4] |
| 263 | + @test tup_adj[2] isa Transpose |
| 264 | + @test gradient(x -> sum(atan.(x, (1,2,3))), Diagonal([4,5,6]))[1] isa Diagonal |
| 265 | + |
| 266 | + @test gradient(x -> sum((y -> (x*y)).([1,2,3])), 4.0) == (6.0,) # closure |
| 267 | +end |
| 268 | + |
| 269 | +@testset "broadcast, 2nd order" begin |
| 270 | + @test gradient(x -> gradient(y -> sum(y .* y), x)[1] |> sum, [1,2,3.0])[1] == [2,2,2] # calls "split broadcasting generic" with f = unthunk |
| 271 | + @test gradient(x -> gradient(y -> sum(y .* x), x)[1].^3 |> sum, [1,2,3.0])[1] == [3,12,27] |
| 272 | + @test_broken gradient(x -> gradient(y -> sum(y .* 2 .* y'), x)[1] |> sum, [1,2,3.0])[1] == [12, 12, 12] # Control flow support not fully implemented yet for higher-order |
| 273 | + |
| 274 | + @test_broken gradient(x -> sum(gradient(x -> sum(x .^ 2 .+ x'), x)[1]), [1,2,3.0])[1] == [6,6,6] # BoundsError: attempt to access 18-element Vector{Core.Compiler.BasicBlock} at index [0] |
| 275 | + @test_broken gradient(x -> sum(gradient(x -> sum((x .+ 1) .* x .- x), x)[1]), [1,2,3.0])[1] == [2,2,2] |
| 276 | + @test_broken gradient(x -> sum(gradient(x -> sum(x .* x ./ 2), x)[1]), [1,2,3.0])[1] == [1,1,1] |
| 277 | + |
| 278 | + @test_broken gradient(x -> sum(gradient(x -> sum(exp.(x)), x)[1]), [1,2,3])[1] ≈ exp.(1:3) # MethodError: no method matching copy(::Nothing) |
| 279 | + @test_broken gradient(x -> sum(gradient(x -> sum(atan.(x, x')), x)[1]), [1,2,3.0])[1] ≈ [0,0,0] |
| 280 | + @test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) .* x), x)[1]), [1,2,3]) == ([6,6,6],) # accum(a::Transpose{Float64, Vector{Float64}}, b::ChainRulesCore.Tangent{Transpose{Int64, Vector{Int64}}, NamedTuple{(:parent,), Tuple{ChainRulesCore.NoTangent}}}) |
| 281 | + @test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) ./ x.^2), x)[1]), [1,2,3])[1] ≈ [27.675925925925927, -0.824074074074074, -2.1018518518518516] |
| 282 | + |
| 283 | + @test_broken gradient(z -> gradient(x -> sum((y -> (x^2*y)).([1,2,3])), z)[1], 5.0) == (12.0,) |
| 284 | +end |
| 285 | + |
0 commit comments