|
1 |
| -# The rest of this file is unchanged, except the very end, |
2 |
| -# but IMO we should move these tests to a new file. |
| 1 | +# This file has tests written specifically for Diffractor v0.1, |
| 2 | +# which were in runtests.jl before PR 73 moved them all. |
| 3 | +# (This commit has all changes to 27 Dec 2022.) |
3 | 4 |
|
4 |
| -# Loading Diffractor: var"'" globally will break many tests above, which use it for adjoint. |
| 5 | +using Test |
| 6 | + |
| 7 | +using Diffractor |
| 8 | +using Diffractor: ∂⃖, DiffractorRuleConfig |
5 | 9 |
|
6 |
| -using Diffractor: var"'", ∂⃖, DiffractorRuleConfig |
7 | 10 | using ChainRules
|
8 | 11 | using ChainRulesCore
|
9 | 12 | using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad
|
10 | 13 | using Symbolics
|
11 |
| -using LinearAlgebra |
12 | 14 |
|
| 15 | +using LinearAlgebra |
13 | 16 |
|
| 17 | +# Loading Diffractor: var"'" globally will break many tests above, which use it for adjoint. |
14 | 18 | const fwd = Diffractor.PrimeDerivativeFwd
|
15 | 19 | const bwd = Diffractor.PrimeDerivativeBack
|
16 | 20 |
|
@@ -48,8 +52,10 @@ ChainRules.rrule(::typeof(my_tuple), args...) = args, Δ->Core.tuple(NoTangent()
|
48 | 52 | @test isequal(simplify(x8), simplify((η + (α*ζ) + (β*ϵ) + (δ*(γ + (α*β))))*exp(ω)))
|
49 | 53 |
|
50 | 54 | # Minimal 2-nd order forward smoke test
|
51 |
| -@test Diffractor.∂☆{2}()(Diffractor.ZeroBundle{2}(sin), |
52 |
| - Diffractor.TangentBundle{2}(1.0, (1.0, 1.0, 0.0)))[Diffractor.CanonicalTangentIndex(1)] == sin'(1.0) |
| 55 | +let var"'" = Diffractor.PrimeDerivativeFwd |
| 56 | + @test Diffractor.∂☆{2}()(Diffractor.ZeroBundle{2}(sin), |
| 57 | + Diffractor.ExplicitTangentBundle{2}(1.0, (1.0, 1.0, 0.0)))[Diffractor.CanonicalTangentIndex(1)] == sin'(1.0) |
| 58 | +end |
53 | 59 |
|
54 | 60 | function simple_control_flow(b, x)
|
55 | 61 | if b
|
@@ -95,8 +101,8 @@ let var"'" = Diffractor.PrimeDerivativeBack
|
95 | 101 | @test @inferred(sin''(1.0)) == -sin(1.0)
|
96 | 102 | @test sin'''(1.0) == -cos(1.0)
|
97 | 103 | @test sin''''(1.0) == sin(1.0)
|
98 |
| - @test sin'''''(1.0) == cos(1.0) # broken = VERSION >= v"1.8" |
99 |
| - @test sin''''''(1.0) == -sin(1.0) # broken = VERSION >= v"1.8" |
| 104 | + @test sin'''''(1.0) == cos(1.0) |
| 105 | + @test sin''''''(1.0) == -sin(1.0) |
100 | 106 |
|
101 | 107 | f_getfield(x) = getfield((x,), 1)
|
102 | 108 | @test f_getfield'(1) == 1
|
@@ -149,7 +155,7 @@ function sin_twice_fwd(x)
|
149 | 155 | end
|
150 | 156 | end
|
151 | 157 | let var"'" = Diffractor.PrimeDerivativeFwd
|
152 |
| - @test sin_twice_fwd'(1.0) == sin'''(1.0) |
| 158 | + @test_broken sin_twice_fwd'(1.0) == sin'''(1.0) |
153 | 159 | end
|
154 | 160 |
|
155 | 161 | # Regression tests
|
@@ -262,24 +268,31 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2)
|
262 | 268 | @test tup_adj[2] ≈ [0.6666666666666666 0.5 0.4]
|
263 | 269 | @test tup_adj[2] isa Transpose
|
264 | 270 | @test gradient(x -> sum(atan.(x, (1,2,3))), Diagonal([4,5,6]))[1] isa Diagonal
|
265 |
| - |
| 271 | + |
266 | 272 | @test gradient(x -> sum((y -> (x*y)).([1,2,3])), 4.0) == (6.0,) # closure
|
267 | 273 | end
|
268 | 274 |
|
269 | 275 | @testset "broadcast, 2nd order" begin
|
270 | 276 | @test gradient(x -> gradient(y -> sum(y .* y), x)[1] |> sum, [1,2,3.0])[1] == [2,2,2] # calls "split broadcasting generic" with f = unthunk
|
271 | 277 | @test gradient(x -> gradient(y -> sum(y .* x), x)[1].^3 |> sum, [1,2,3.0])[1] == [3,12,27]
|
272 |
| - @test_broken gradient(x -> gradient(y -> sum(y .* 2 .* y'), x)[1] |> sum, [1,2,3.0])[1] == [12, 12, 12] # Control flow support not fully implemented yet for higher-order |
| 278 | + @test_broken gradient(x -> gradient(y -> sum(y .* 2 .* y'), x)[1] |> sum, [1,2,3.0])[1] == [12, 12, 12] |
273 | 279 |
|
274 | 280 | @test_broken gradient(x -> sum(gradient(x -> sum(x .^ 2 .+ x'), x)[1]), [1,2,3.0])[1] == [6,6,6] # BoundsError: attempt to access 18-element Vector{Core.Compiler.BasicBlock} at index [0]
|
275 | 281 | @test_broken gradient(x -> sum(gradient(x -> sum((x .+ 1) .* x .- x), x)[1]), [1,2,3.0])[1] == [2,2,2]
|
276 | 282 | @test_broken gradient(x -> sum(gradient(x -> sum(x .* x ./ 2), x)[1]), [1,2,3.0])[1] == [1,1,1]
|
277 |
| - |
| 283 | + |
278 | 284 | @test_broken gradient(x -> sum(gradient(x -> sum(exp.(x)), x)[1]), [1,2,3])[1] ≈ exp.(1:3) # MethodError: no method matching copy(::Nothing)
|
279 | 285 | @test_broken gradient(x -> sum(gradient(x -> sum(atan.(x, x')), x)[1]), [1,2,3.0])[1] ≈ [0,0,0]
|
280 | 286 | @test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) .* x), x)[1]), [1,2,3]) == ([6,6,6],) # accum(a::Transpose{Float64, Vector{Float64}}, b::ChainRulesCore.Tangent{Transpose{Int64, Vector{Int64}}, NamedTuple{(:parent,), Tuple{ChainRulesCore.NoTangent}}})
|
281 | 287 | @test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) ./ x.^2), x)[1]), [1,2,3])[1] ≈ [27.675925925925927, -0.824074074074074, -2.1018518518518516]
|
282 |
| - |
| 288 | + |
283 | 289 | @test_broken gradient(z -> gradient(x -> sum((y -> (x^2*y)).([1,2,3])), z)[1], 5.0) == (12.0,)
|
284 | 290 | end
|
285 | 291 |
|
| 292 | +# Issue 67, due to https://github.com/JuliaDiff/ChainRulesCore.jl/pull/495 |
| 293 | +@test gradient(identity∘sqrt, 4.0) == (0.25,) |
| 294 | + |
| 295 | +# Issue #70 - Complex & getproperty |
| 296 | +@test_broken gradient(x -> x.re, 2+3im)[1] == 1 # Tangent{Complex{Int64}}(re = 1,) |
| 297 | +@test_broken gradient(x -> abs2(x * x.re), 4+5im)[1] == 456 + 160im # accum(a::ComplexF64, b::Tangent) |
| 298 | +@test gradient(x -> abs2(x * real(x)), 4+5im)[1] == 456 + 160im |
0 commit comments