diff --git a/src/extra_rules.jl b/src/extra_rules.jl index 8182d70c..66990cbd 100644 --- a/src/extra_rules.jl +++ b/src/extra_rules.jl @@ -246,6 +246,7 @@ end @ChainRules.non_differentiable Base.:(|)(a::Integer, b::Integer) @ChainRules.non_differentiable Base.throw(err) @ChainRules.non_differentiable Core.Compiler.return_type(args...) + ChainRulesCore.canonicalize(::NoTangent) = NoTangent() # Disable thunking at higher order (TODO: These should go into ChainRulesCore) @@ -262,3 +263,12 @@ Base.real(z::NoTangent) = z # TODO should be in CRC, https://github.com/JuliaDi # Avoid https://github.com/JuliaDiff/ChainRulesCore.jl/pull/495 ChainRulesCore._backing_error(P::Type{<:Base.Pairs}, G::Type{<:NamedTuple}, E::Type{<:AbstractDict}) = nothing + +# For gradient(pow_simd, 2, 3)[1] in zygote_features.jl +ChainRulesCore.@non_differentiable Base.SimdLoop.simd_inner_length(::Any, ::Any) + +# This allows fill!(similar([1,2,3], ZeroTangent), false) +function Base.convert(::Type{ZeroTangent}, x::Number) + iszero(x) || throw(InexactError(:convert, ZeroTangent, x)) + ZeroTangent() +end diff --git a/src/runtime.jl b/src/runtime.jl index 48776239..2c447f50 100644 --- a/src/runtime.jl +++ b/src/runtime.jl @@ -28,3 +28,9 @@ accum(x::Tangent{T}, y::Tangent) where T = _tangent(T, accum(backing(x), backing _tangent(::Type{T}, z) where T = Tangent{T,typeof(z)}(z) _tangent(::Type, ::NamedTuple{()}) = NoTangent() _tangent(::Type, ::NamedTuple{<:Any, <:Tuple{Vararg{AbstractZero}}}) = NoTangent() + +function accum(x::Tangent{T}, y::Tuple) where {T<:Tuple} + # @warn "gradient is both a Tangent and a Tuple" x y + _tangent(T, accum(backing(x), y)) +end +accum(x::Tuple, y::Tangent{<:Tuple}) = accum(y, x) diff --git a/test/Project.toml b/test/Project.toml index 6135bbbb..1ab6f008 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,11 +2,15 @@ ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d" +Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" @@ -16,6 +20,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ChainRules = "1.44.5" ChainRulesCore = "1.15.3" Combinatorics = "1" +DiffTests = "0.1.1" StaticArrays = "1" StatsBase = "0.33" StructArrays = "0.6.12" diff --git a/test/chainrules.jl b/test/chainrules.jl new file mode 100644 index 00000000..311ec242 --- /dev/null +++ b/test/chainrules.jl @@ -0,0 +1,97 @@ + +# This file has integration tests for some rules defined in ChainRules.jl, +# especially those which aim to support higher derivatives, as properly +# testing those is difficult. Organised according to the files in CR.jl. + +using Diffractor, ForwardDiff, ChainRulesCore +using Test, LinearAlgebra + +using Test: Threw, eval_test + + +##### +##### Base/array.jl +##### + + + + + +##### +##### Base/arraymath.jl +##### + + + + +##### +##### Base/base.jl +##### + + + + + +##### +##### Base/indexing.jl +##### + +@testset "getindex, first" begin + @test_broken gradient(x -> sum(abs2, gradient(first, x)[1]), [1,2,3])[1] == [0, 0, 0] # MethodError: no method matching +(::Tuple{ZeroTangent, ZeroTangent}, ::Tuple{ZeroTangent, ZeroTangent}) + @test_broken gradient(x -> sum(abs2, gradient(sqrt∘first, x)[1]), [1,2,3])[1] ≈ [-0.25, 0, 0] # error() in perform_optic_transform(ff::Type{Diffractor.∂⃖recurse{2}}, args::Any) + @test gradient(x -> sum(abs2, gradient(x -> x[1]^2, x)[1]), [1,2,3])[1] == [8, 0, 0] + @test_broken gradient(x -> sum(abs2, gradient(x -> sum(x[1:2])^2, x)[1]), [1,2,3])[1] == [48, 0, 0] # MethodError: no method matching +(::Tuple{ZeroTangent, ZeroTangent}, ::Tuple{ZeroTangent, ZeroTangent}) +end + +@testset "eachcol etc" begin + @test gradient(m -> sum(prod, eachcol(m)), [1 2 3; 4 5 6])[1] == [4 5 6; 1 2 3] + @test gradient(m -> sum(first, eachcol(m)), [1 2 3; 4 5 6])[1] == [1 1 1; 0 0 0] + @test gradient(m -> sum(first(eachcol(m))), [1 2 3; 4 5 6])[1] == [1 0 0; 1 0 0] + @test_skip gradient(x -> sum(sin, gradient(m -> sum(first(eachcol(m))), x)[1]), [1 2 3; 4 5 6])[1] # MethodError: no method matching one(::Base.OneTo{Int64}), unzip_broadcast, split_bc_forwards +end + +##### +##### Base/mapreduce.jl +##### + +@testset "sum" begin + @test gradient(x -> sum(abs2, gradient(sum, x)[1]), [1,2,3])[1] == [0,0,0] + @test gradient(x -> sum(abs2, gradient(x -> sum(abs2, x), x)[1]), [1,2,3])[1] == [8,16,24] + + @test gradient(x -> sum(abs2, gradient(sum, x .^ 2)[1]), [1,2,3])[1] == [0,0,0] + @test gradient(x -> sum(abs2, gradient(sum, x .^ 3)[1]), [1,2,3])[1] == [0,0,0] +end + +@testset "foldl" begin + + @test gradient(x -> foldl(*, x), [1,2,3,4])[1] == [24.0, 12.0, 8.0, 6.0] + @test gradient(x -> foldl(*, x; init=5), [1,2,3,4])[1] == [120.0, 60.0, 40.0, 30.0] + @test gradient(x -> foldr(*, x), [1,2,3,4])[1] == [24, 12, 8, 6] + + @test gradient(x -> foldl(*, x), (1,2,3,4))[1] == Tangent{NTuple{4,Int}}(24.0, 12.0, 8.0, 6.0) + @test_broken gradient(x -> foldl(*, x; init=5), (1,2,3,4))[1] == Tangent{NTuple{4,Int}}(120.0, 60.0, 40.0, 30.0) # does not return a Tangent + @test gradient(x -> foldl(*, x; init=5), (1,2,3,4)) |> only |> Tuple == (120.0, 60.0, 40.0, 30.0) + @test_broken gradient(x -> foldr(*, x), (1,2,3,4))[1] == Tangent{NTuple{4,Int}}(24, 12, 8, 6) + @test gradient(x -> foldr(*, x), (1,2,3,4)) |> only |> Tuple == (24, 12, 8, 6) + +end + + +##### +##### LinearAlgebra/dense.jl +##### + + +@testset "dot" begin + + @test gradient(x -> dot(x, [1,2,3])^2, [4,5,6])[1] == [64,128,192] + @test_broken gradient(x -> sum(gradient(x -> dot(x, [1,2,3])^2, x)[1]), [4,5,6])[1] == [12,24,36] # MethodError: no method matching +(::Tuple{Tangent{ChainRules.var + +end + + +##### +##### LinearAlgebra/norm.jl +##### + + diff --git a/test/diffractor_01.jl b/test/diffractor_01.jl new file mode 100644 index 00000000..19860474 --- /dev/null +++ b/test/diffractor_01.jl @@ -0,0 +1,298 @@ +# This file has tests written specifically for Diffractor v0.1, +# which were in runtests.jl before PR 73 moved them all. +# (This commit has all changes to 27 Dec 2022.) + +using Test + +using Diffractor +using Diffractor: ∂⃖, DiffractorRuleConfig + +using ChainRules +using ChainRulesCore +using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad +using Symbolics + +using LinearAlgebra + +# Loading Diffractor: var"'" globally will break many tests above, which use it for adjoint. +const fwd = Diffractor.PrimeDerivativeFwd +const bwd = Diffractor.PrimeDerivativeBack + +# Unit tests +function tup2(f) + a, b = ∂⃖{2}()(f, 1) + c, d = b((2,)) + e, f = d(ZeroTangent(), 3) + f((4,)) +end + +@test tup2(tuple) == (NoTangent(), 4) + +my_tuple(args...) = args +ChainRules.rrule(::typeof(my_tuple), args...) = args, Δ->Core.tuple(NoTangent(), Δ...) +@test tup2(my_tuple) == (ZeroTangent(), 4) + +# Check characteristic of exp rule +@variables ω α β γ δ ϵ ζ η +(x1, c1) = ∂⃖{3}()(exp, ω) +@test isequal(simplify(x1), simplify(exp(ω))) +((_, x2), c2) = c1(α) +@test isequal(simplify(x2), simplify(α*exp(ω))) +(x3, c3) = c2(ZeroTangent(), β) +@test isequal(simplify(x3), simplify(β*exp(ω))) +((_, x4), c4) = c3(γ) +@test isequal(simplify(x4), simplify(exp(ω)*(γ + (α*β)))) +(x5, c5) = c4(ZeroTangent(), δ) +@test isequal(simplify(x5), simplify(δ*exp(ω))) +((_, x6), c6) = c5(ϵ) +@test isequal(simplify(x6), simplify(ϵ*exp(ω) + α*δ*exp(ω))) +(x7, c7) = c6(ZeroTangent(), ζ) +@test isequal(simplify(x7), simplify(ζ*exp(ω) + β*δ*exp(ω))) +(_, x8) = c7(η) +@test isequal(simplify(x8), simplify((η + (α*ζ) + (β*ϵ) + (δ*(γ + (α*β))))*exp(ω))) + +# Minimal 2-nd order forward smoke test +let var"'" = Diffractor.PrimeDerivativeFwd + @test Diffractor.∂☆{2}()(Diffractor.ZeroBundle{2}(sin), + Diffractor.ExplicitTangentBundle{2}(1.0, (1.0, 1.0, 0.0)))[Diffractor.CanonicalTangentIndex(1)] == sin'(1.0) +end + +function simple_control_flow(b, x) + if b + return sin(x) + else + return cos(x) + end +end + +function myprod(xs) + s = 1 + for x in xs + s *= x + end + return s +end + +function mypow(x, n) + r = one(x) + while n > 0 + n -= 1 + r *= x + end + return r +end + +function times_three_while(x) + z = x + i = 3 + while i > 1 + z += x + i -= 1 + end + z +end + +isa_control_flow(::Type{T}, x) where {T} = isa(x, T) ? x : T(x) + +# Simple Reverse Mode tests +let var"'" = Diffractor.PrimeDerivativeBack + # Integration tests + @test @inferred(sin'(1.0)) == cos(1.0) + @test @inferred(sin''(1.0)) == -sin(1.0) + @test sin'''(1.0) == -cos(1.0) + @test sin''''(1.0) == sin(1.0) + @test sin'''''(1.0) == cos(1.0) + @test sin''''''(1.0) == -sin(1.0) + + f_getfield(x) = getfield((x,), 1) + @test f_getfield'(1) == 1 + @test f_getfield''(1) == 0 + @test f_getfield'''(1) == 0 + + # Higher order mixed mode tests + + complicated_2sin(x) = (x = map(sin, Diffractor.xfill(x, 2)); x[1] + x[2]) + @test @inferred(complicated_2sin'(1.0)) == 2sin'(1.0) + @test @inferred(complicated_2sin''(1.0)) == 2sin''(1.0) broken=true + @test @inferred(complicated_2sin'''(1.0)) == 2sin'''(1.0) broken=true + @test @inferred(complicated_2sin''''(1.0)) == 2sin''''(1.0) broken=true + + # Control flow cases + @test @inferred((x->simple_control_flow(true, x))'(1.0)) == sin'(1.0) + @test @inferred((x->simple_control_flow(false, x))'(1.0)) == cos'(1.0) + @test (x->sum(isa_control_flow(Matrix{Float64}, x)))'(Float32[1 2;]) == [1.0 1.0;] + @test times_three_while'(1.0) == 3.0 + + pow5p(x) = (x->mypow(x, 5))'(x) + @test pow5p(1.0) == 5.0 +end + +# Simple Forward Mode tests +let var"'" = Diffractor.PrimeDerivativeFwd + recursive_sin(x) = sin(x) + ChainRulesCore.frule(∂, ::typeof(recursive_sin), x) = frule(∂, sin, x) + + # Integration tests + @test recursive_sin'(1.0) == cos(1.0) + @test recursive_sin''(1.0) == -sin(1.0) + # Error: ArgumentError: Tangent for the primal Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}} + # should be backed by a NamedTuple type, not by Tuple{Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}}. + @test_broken recursive_sin'''(1.0) == -cos(1.0) + @test_broken recursive_sin''''(1.0) == sin(1.0) + @test_broken recursive_sin'''''(1.0) == cos(1.0) + @test_broken recursive_sin''''''(1.0) == -sin(1.0) + + # Test the special rules for sin/cos/exp + @test sin''''''(1.0) == -sin(1.0) + @test cos''''''(1.0) == -cos(1.0) + @test exp''''''(1.0) == exp(1.0) +end + +# Some Basic Mixed Mode tests +function sin_twice_fwd(x) + let var"'" = Diffractor.PrimeDerivativeFwd + sin''(x) + end +end +let var"'" = Diffractor.PrimeDerivativeFwd + @test_broken sin_twice_fwd'(1.0) == sin'''(1.0) +end + +# Regression tests +@test gradient(x -> sum(abs2, x .+ 1.0), zeros(3))[1] == [2.0, 2.0, 2.0] + +function f_broadcast(a) + l = a / 2.0 * [[0. 1. 1.]; [1. 0. 1.]; [1. 1. 0.]] + return sum(l) +end +@test fwd(f_broadcast)(1.0) == bwd(f_broadcast)(1.0) + +# Make sure that there's no infinite recursion in kwarg calls +g_kw(;x=1.0) = sin(x) +f_kw(x) = g_kw(;x) +@test bwd(f_kw)(1.0) == bwd(sin)(1.0) + +function f_crit_edge(a, b, c, x) + # A function with two critical edges. This used to trigger an issue where + # Diffractor would fail to insert edges for the second split critical edge. + y = 1x + if a && b + y = 2x + end + if b && c + y = 3x + end + + if c + y = 4y + end + + return y +end +@test bwd(x->f_crit_edge(false, false, false, x))(1.0) == 1.0 +@test bwd(x->f_crit_edge(true, true, false, x))(1.0) == 2.0 +@test bwd(x->f_crit_edge(false, true, true, x))(1.0) == 12.0 +@test bwd(x->f_crit_edge(false, false, true, x))(1.0) == 4.0 +@test bwd(bwd(x->5))(1.0) == ZeroTangent() +@test fwd(fwd(x->5))(1.0) == ZeroTangent() + +# Issue #27 - Mixup in lifting of getfield +let var"'" = bwd + @test (x->x^5)''(1.0) == 20. + @test (x->(x*x)*(x*x)*x)'''(1.0) == 60. + # Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24) + @test_broken (x->x^5)'''(1.0) == 60. +end + +# Issue #38 - Splatting arrays +@test gradient(x -> max(x...), (1,2,3))[1] == (0.0, 0.0, 1.0) +@test gradient(x -> max(x...), [1,2,3])[1] == [0.0, 0.0, 1.0] + +# Issue #40 - Symbol type parameters not properly quoted +@test Diffractor.∂⃖recurse{1}()(Val{:transformations})[1] === Val{:transformations}() + +# PR #43 +loss(res, z, w) = sum(res.U * Diagonal(res.S) * res.V) + sum(res.S .* w) +x43 = rand(10, 10) +@test Diffractor.gradient(x->loss(svd(x), x[:,1], x[:,2]), x43) isa Tuple{Matrix{Float64}} + +# PR # 45 - Calling back into AD from ChainRules +r45 = rrule_via_ad(DiffractorRuleConfig(), x -> log(exp(x)), 2) +@test r45 isa Tuple +y45, back45 = r45 +@test y45 ≈ 2.0 +@test back45(1) == (ZeroTangent(), 1.0) + +z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2) +@test z45 ≈ 2.0 +@test delta45 ≈ 1.0 + +# PR #82 - getindex on non-numeric arrays +@test gradient(ls -> ls[1](1.), [Base.Fix1(*, 1.)])[1][1] isa Tangent{<:Base.Fix1} + +@testset "broadcast" begin + @test gradient(x -> sum(x ./ x), [1,2,3]) == ([0,0,0],) # derivatives_given_output + @test gradient(x -> sum(sqrt.(atan.(x, transpose(x)))), [1,2,3])[1] ≈ [0.2338, -0.0177, -0.0661] atol=1e-3 + @test gradient(x -> sum(exp.(log.(x))), [1,2,3]) == ([1,1,1],) + + @test gradient(x -> sum((exp∘log).(x)), [1,2,3]) == ([1,1,1],) # frule_via_ad + exp_log(x) = exp(log(x)) + @test gradient(x -> sum(exp_log.(x)), [1,2,3]) == ([1,1,1],) + @test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], [1,2]) == ([1 1; 0.5 0.5], [-3, -1.75]) + @test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], 5) == ([0.2 0.2; 0.2 0.2], -0.4) + @test gradient(x -> sum((y -> y/x).([1,2,3])), 4) == (-0.375,) # closure + + @test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 # array of arrays + @test gradient(x -> sum(sum, Ref(x) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 + @test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 + @test gradient(x -> sum(sum, (x,) .* transpose(x)), [1,2,3])[1] ≈ [12, 12, 12] # must not take the * fast path + + @test gradient(x -> sum(x ./ 4), [1,2,3]) == ([0.25, 0.25, 0.25],) + @test gradient(x -> sum([1,2,3] ./ x), 4) == (-0.375,) # x/y rule + @test gradient(x -> sum(x.^2), [1,2,3]) == ([2.0, 4.0, 6.0],) # x.^2 rule + @test gradient(x -> sum([1,2,3] ./ x.^2), 4) == (-0.1875,) # scalar^2 rule + + @test gradient(x -> sum((1,2,3) .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-1.0, -1.0, -1.0),) + @test gradient(x -> sum(transpose([1,2,3]) .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-3.0, -3.0, -3.0),) + @test gradient(x -> sum([1 2 3] .+ x .^ 2), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(6.0, 12.0, 18.0),) + + @test gradient(x -> sum(x .> 2), [1,2,3]) |> only |> iszero # Bool output + @test gradient(x -> sum(1 .+ iseven.(x)), [1,2,3]) |> only |> iszero + @test gradient((x,y) -> sum(x .== y), [1,2,3], [1 2 3]) == (NoTangent(), NoTangent()) + @test gradient(x -> sum(x .+ [1,2,3]), true) |> only |> iszero # Bool input + @test gradient(x -> sum(x ./ [1,2,3]), [true false]) |> only |> iszero + @test gradient(x -> sum(x .* transpose([1,2,3])), (true, false)) |> only |> iszero + + tup_adj = gradient((x,y) -> sum(2 .* x .+ log.(y)), (1,2), transpose([3,4,5])) + @test tup_adj[1] == Tangent{Tuple{Int64, Int64}}(6.0, 6.0) + @test tup_adj[2] ≈ [0.6666666666666666 0.5 0.4] + @test tup_adj[2] isa Transpose + @test gradient(x -> sum(atan.(x, (1,2,3))), Diagonal([4,5,6]))[1] isa Diagonal + + @test gradient(x -> sum((y -> (x*y)).([1,2,3])), 4.0) == (6.0,) # closure +end + +@testset "broadcast, 2nd order" begin + @test gradient(x -> gradient(y -> sum(y .* y), x)[1] |> sum, [1,2,3.0])[1] == [2,2,2] # calls "split broadcasting generic" with f = unthunk + @test gradient(x -> gradient(y -> sum(y .* x), x)[1].^3 |> sum, [1,2,3.0])[1] == [3,12,27] + @test_broken gradient(x -> gradient(y -> sum(y .* 2 .* y'), x)[1] |> sum, [1,2,3.0])[1] == [12, 12, 12] + + @test_broken gradient(x -> sum(gradient(x -> sum(x .^ 2 .+ x'), x)[1]), [1,2,3.0])[1] == [6,6,6] # BoundsError: attempt to access 18-element Vector{Core.Compiler.BasicBlock} at index [0] + @test_broken gradient(x -> sum(gradient(x -> sum((x .+ 1) .* x .- x), x)[1]), [1,2,3.0])[1] == [2,2,2] + @test_broken gradient(x -> sum(gradient(x -> sum(x .* x ./ 2), x)[1]), [1,2,3.0])[1] == [1,1,1] + + @test_broken gradient(x -> sum(gradient(x -> sum(exp.(x)), x)[1]), [1,2,3])[1] ≈ exp.(1:3) # MethodError: no method matching copy(::Nothing) + @test_broken gradient(x -> sum(gradient(x -> sum(atan.(x, x')), x)[1]), [1,2,3.0])[1] ≈ [0,0,0] + @test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) .* x), x)[1]), [1,2,3]) == ([6,6,6],) # accum(a::Transpose{Float64, Vector{Float64}}, b::ChainRulesCore.Tangent{Transpose{Int64, Vector{Int64}}, NamedTuple{(:parent,), Tuple{ChainRulesCore.NoTangent}}}) + @test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) ./ x.^2), x)[1]), [1,2,3])[1] ≈ [27.675925925925927, -0.824074074074074, -2.1018518518518516] + + @test_broken gradient(z -> gradient(x -> sum((y -> (x^2*y)).([1,2,3])), z)[1], 5.0) == (12.0,) +end + +# Issue 67, due to https://github.com/JuliaDiff/ChainRulesCore.jl/pull/495 +@test gradient(identity∘sqrt, 4.0) == (0.25,) + +# Issue #70 - Complex & getproperty +@test_broken gradient(x -> x.re, 2+3im)[1] == 1 # Tangent{Complex{Int64}}(re = 1,) +@test_broken gradient(x -> abs2(x * x.re), 4+5im)[1] == 456 + 160im # accum(a::ComplexF64, b::Tangent) +@test gradient(x -> abs2(x * real(x)), 4+5im)[1] == 456 + 160im diff --git a/test/forwarddiff.jl b/test/forwarddiff.jl new file mode 100644 index 00000000..f7532edc --- /dev/null +++ b/test/forwarddiff.jl @@ -0,0 +1,767 @@ + +# This file contains tests adapted from FowardDiff.jl, many of them using DiffTests.jl +# Organised here file-by-file in alphabetical order, after some definitions, fully self-contained. + +##### +##### setup +##### + +using Test, LinearAlgebra +using ForwardDiff, DiffTests +using Diffractor, ChainRulesCore + +# Define functions which behave like the ones from ForwardDiff. +# These have plenty of sharp edges! +begin + + fwd_derivative(f, x::Number) = Diffractor.PrimeDerivativeFwd(f)(float(x)) |> unthunk + function rev_derivative(f, x::Number) + y = f(x) + if y isa Number || y isa AbstractZero + Diffractor.PrimeDerivativeBack(f)(float(x)) |> unthunk + elseif y isa AbstractArray + map(CartesianIndices(y)) do I + Diffractor.PrimeDerivativeBack(x -> f(x)[I])(float(x)) |> unthunk + end + else + throw("rev_derivative can't handle f(x)::$(typeof(y))") + end + end + + @test ForwardDiff.derivative(abs2, 3) == 6 + @test fwd_derivative(abs2, 3) == 6 + @test rev_derivative(abs2, 3) == 6 + + @test ForwardDiff.derivative(x -> fill(x,2,3), 7) == [1 1 1; 1 1 1] + @test fwd_derivative(x -> fill(x,2,3), 7) == [1 1 1; 1 1 1] + @test rev_derivative(x -> fill(x,2,3), 7) == [1 1 1; 1 1 1] + + DERIVATIVES = (ForwardDiff.derivative, fwd_derivative, rev_derivative) + + function fwd_gradient(f, x::AbstractVector) + map(eachindex(x)) do i + fwd_derivative(ξ -> f(vcat(x[begin:i-1], ξ, x[i+1:end])), x[i]) + end + end + fwd_gradient(f, x::AbstractArray) = reshape(fwd_gradient(v -> f(reshape(v, size(x))), vec(x)), size(x)) + rev_gradient(f, x::AbstractArray) = ChainRulesCore.unthunk(Diffractor.PrimeDerivativeBack(f)(float(x))) + + @test ForwardDiff.gradient(prod, [1,2,3]) == [6,3,2] + @test fwd_gradient(prod, [1,2,3]) == [6,3,2] + @test rev_gradient(prod, [1,2,3]) == [6,3,2] + + @test fwd_gradient(sum, [1,2]) == [1,1] + @test fwd_gradient(first, [1,1]) == [1,0] + + GRADIENTS = (ForwardDiff.gradient, rev_gradient) + + fwd_jacobian(f, x::AbstractArray) = hcat(vec.(fwd_gradient(f, x))...) + function rev_jacobian(f, x::AbstractArray) + y = f(x) + slices = map(LinearIndices(y)) do i # fails if y isa Number, just like ForwardDiff.jacobian + vec(rev_gradient(x -> f(x)[i], x)) + end + vcat(transpose(slices)...) + # permutedims(hcat(slices...)) + end + + @test ForwardDiff.jacobian(x -> x[1:2], [1,2,3]) == [1 0 0; 0 1 0] + @test fwd_jacobian(x -> x[1:2], [1,2,3]) == [1 0 0; 0 1 0] + @test rev_jacobian(x -> x[1:2], [1,2,3]) == [1 0 0; 0 1 0] + + JACOBIANS = (ForwardDiff.jacobian, fwd_jacobian, rev_jacobian) + + fwd_hessian(f, x::AbstractArray) = fwd_jacobian(y -> fwd_gradient(f, y), float(x)) + rev_hessian(f, x::AbstractArray) = rev_jacobian(y -> rev_gradient(f, y), float(x)) + fwd_rev_hessian(f, x::AbstractArray) = fwd_jacobian(y -> rev_gradient(f, y), float(x)) + rev_fwd_hessian(f, x::AbstractArray) = rev_jacobian(y -> fwd_gradient(f, y), float(x)) + + @test ForwardDiff.hessian(x -> -log(x[1]), [2,3]) == [0.25 0; 0 0] + @test_broken fwd_hessian(x -> -log(x[1]), [2,3]) == [0.25 0; 0 0] # TypeError: in new, expected DataType, got Type{Diffractor.TangentBundle{1, Type{Diffractor.UniformBundle{1, var"#s305", ZeroTangent}}, Tuple{NoTangent}}} + @test rev_hessian(x -> -log(x[1]), [2,3]) == [0.25 0; 0 0] + @test_broken rev_fwd_hessian(x -> -log(x[1]), [2,3]) == [0.25 0; 0 0] # MethodError: no method matching (::Diffractor.∂⃖recurse{1})(::typeof(Core.arrayset), ::Bool, ::Vector{Float64}, ::Float64, ::Int64) + @test_skip fwd_rev_hessian(x -> -log(x[1]), [2,3]) # Internal error: encountered unexpected error in runtime: AssertionError(msg="argextype only works on argument-position values") MethodError: no method matching (::Core.OpaqueClosure{Tuple{Any}, Any})(::Float64) + + HESSIANS = (ForwardDiff.hessian, rev_hessian) # omitting fwd_hessian, rev_fwd_hessian, fwd_rev_hessian + +end + +##### +##### ConfusionTest +##### + +@testset verbose=true "ConfusionTest" begin + + # Perturbation Confusion (Issue #83) # + #------------------------------------# + + @testset "issue 83: perturbation confusion 1" begin # for D in DERIVATIVES + + g = [2.0] + @test g == ForwardDiff.gradient(v -> sum(v) * norm(v), [1]) + + @testset "ForwardDiff.derivative" begin + D = ForwardDiff.derivative + + @test D(x -> x * D(y -> x + y, 1), 1) == 1 + + @test ForwardDiff.gradient(v -> sum(v) * D(y -> y * norm(v), 1), [1]) == g + @test_broken fwd_gradient(v -> sum(v) * D(y -> y * norm(v), 1), [1]) == g + @test_broken rev_gradient(v -> sum(v) * D(y -> y * norm(v), 1), [1]) == g + + end + @testset "fwd_derivative" begin + D = fwd_derivative + + @test_broken D(x -> x * D(y -> x + y, 1), 1) == 1 # UndefVarError: B not defined + + @test ForwardDiff.gradient(v -> sum(v) * D(y -> y * norm(v), 1), [1]) == g + @test_broken fwd_gradient(v -> sum(v) * D(y -> y * norm(v), 1), [1]) == g + @test_broken rev_gradient(v -> sum(v) * D(y -> y * norm(v), 1), [1]) == g + + end + @testset "rev_derivative" begin + D = rev_derivative + + @test_broken D(x -> x * D(y -> x + y, 1), 1) == 1 # MethodError: no method matching +(::Float64, ::Tangent{var"#269#271"{Float64}, NamedTuple{(:x,), Tuple{ZeroTangent}}}) + + @test ForwardDiff.gradient(v -> sum(v) * D(y -> y * norm(v), 1), [1]) == g + @test_skip fwd_gradient(v -> sum(v) * D(y -> y * norm(v), 1), [1]) == g # Internal error: encountered unexpected error in runtime: AssertionError(msg="argextype only works on argument-position values") + @test_broken rev_gradient(v -> sum(v) * D(y -> y * norm(v), 1), [1]) == g # accum(a::Vector{Float64}, b::Tangent{var"#705#707"{Vector{Float64}}, NamedTuple{(:v,), Tuple{InplaceableThunk{Thunk{ChainRules.var"#1976#1979"{NoTangent, Vector{Float64}, Float64}}, ChainRules.var"#1975#1978"{NoTangent, Vector{Float64}, Float64}}}}}) + end + + end + + @testset "issue 83: perturbation confusion 2, $jacobian + $gradient" for jacobian in JACOBIANS, gradient in GRADIENTS + + A = rand(10,8) + y = rand(10) + x = rand(8) + + (jacobian, gradient) == (fwd_jacobian, rev_gradient) && continue # avoids Internal error: encountered unexpected error in runtime: AssertionError(msg="argextype only works on argument-position values") + + @test A == jacobian(x) do x + gradient(y) do y + dot(y, A*x) + end + end broken = jacobian != ForwardDiff.jacobian + + end + + # Issue #238 # + #------------------------------------# + + @testset "issue 238: legendre transformation 1, $jacobian + $gradient" for jacobian in JACOBIANS, gradient in GRADIENTS + + m,g = 1, 9.8 + t = 1 + q = [1,2] + q̇ = [3,4] + L(t,q,q̇) = m/2 * dot(q̇,q̇) - m*g*q[2] + + ∂L∂q̇(L, t, q, q̇) = ForwardDiff.gradient(a->L(t,q,a), q̇) + Dqq̇(L, t, q, q̇) = ForwardDiff.jacobian(a->∂L∂q̇(L,t,a,q̇), q) + @test Dqq̇(L, t, q, q̇) == fill(0.0, 2, 2) + + end + + @testset "issue 238: legendre transformation 2, $hessian + $gradient" for hessian in HESSIANS, gradient in GRADIENTS + + m,g = 1, 9.8 + t = 1 + q = [1,2] + q̇ = [3,4] .+ 0.0 + L(t,q,q̇) = m/2 * dot(q̇,q̇) - m*g*q[2] + + q = [1,2] .+ 0.0 + p = [5,6] .+ 0.0 + function Legendre_transformation(F, w) + z = fill(0.0, size(w)) + M = hessian(F, z) + b = gradient(F, z) + v = cholesky(M)\(w-b) + dot(w,v) - F(v) + end + function Lagrangian2Hamiltonian(Lagrangian, t, q, p) + L = q̇ -> Lagrangian(t, q, q̇) + Legendre_transformation(L, p) + end + + @test Lagrangian2Hamiltonian(L, t, q, p) isa Number + @test_broken gradient(a->Lagrangian2Hamiltonian(L, t, a, p), q) == [0.0,g] + + end + + @testset "issue 267: let scoping $hessian" for hessian in HESSIANS + + @noinline f83a(z, x) = x[1] + z83a = ([(1, (2), [(3, (4, 5, [1, 2, (3, (4, 5), [5])]), (5))])]) + let z = z83a + g = x -> f83a(z, x) + h = x -> g(x) + @test hessian(h, [1.]) == zeros(1, 1) broken = hessian != ForwardDiff.hessian + end + + end + + @testset "simple 2nd order $derivative" for derivative in DERIVATIVES + + @test derivative(1.0) do x + derivative(x) do y + x + end + end == 0.0 broken = derivative != ForwardDiff.derivative + # Cotangent space not defined for `ZeroTangent`. Try a real-valued function. + + end + + @testset "assignment within closure: $D1, $D2" for D1 in DERIVATIVES, D2 in DERIVATIVES + # https://github.com/JuliaDiff/ForwardDiff.jl/issues/443 + # https://github.com/JuliaDiff/Diffractor.jl/issues/23 + + function evil443(x) + f(y) = begin x = x*y end + D1(f, 1) + D1(f, 1) + end + + # Answers with D = ForwardDiff.derivative + # 2 ≈ evil443(1) + # 2 ≈ (evil443(1.0001) - evil443(1.0)) / 0.0001 + # 2 ≈ D2(evil443, 1) + + # The claim is that the correct answer here is 1. + # Certainly if that disagrees with the next, then + # something fishy has occurred: + function easy443(x) + f(y) = begin x = x*y end + ff(y) = begin x = x*y end # different tag + D1(f, 1) + D1(ff, 1) + end + # ForwardDiff: + # easy443(1) isa ForwardDiff.Dual + # @test_throws Exception D(easy443, 1) # throws DualMismatchError + + @test_broken easy443(1) ≈ evil443(1) + @test_broken D2(easy443, 1) ≈ D2(evil443, 1) + + # Diffractor: + # "UndefVarError: B not defined" Diffractor.TangentBundle src/tangent.jl:88 + # MethodError: no method matching setfield!(::Core.Box, ::Symbol, ::Float64) + + end + +end + +##### +##### DerivativeTest +##### + +@testset verbose=true "DerivativeTest" begin + + x = 1 + + @testset "scalar derivative of DiffTests.$f" for f in DiffTests.NUMBER_TO_NUMBER_FUNCS + v = f(x) + d = ForwardDiff.derivative(f, x) + # @test isapprox(d, Calculus.derivative(f, x), atol=FINITEDIFF_ERROR) + + @test d ≈ fwd_derivative(f, x) broken=(f==DiffTests.num2num_4) + @test d ≈ rev_derivative(f, x) broken=(f==DiffTests.num2num_4) + end + + @testset "array derivative of DiffTests.$f" for f in DiffTests.NUMBER_TO_ARRAY_FUNCS + v = f(x) + d = ForwardDiff.derivative(f, x) + # @test isapprox(d, Calculus.derivative(f, x), atol=FINITEDIFF_ERROR) + + @test d ≈ fwd_derivative(f, x) + @test d ≈ rev_derivative(f, x) + end + + @testset "exponential function at base zero: $derivative" for derivative in DERIVATIVES + @test (x -> derivative(y -> x^y, -0.5))(0.0) === -Inf broken = (derivative != ForwardDiff.derivative) + @test (x -> derivative(y -> x^y, 0.0))(0.0) === -Inf broken = (derivative != ForwardDiff.derivative) + @test (x -> derivative(y -> x^y, 0.5))(0.0) === 0.0 + @test (x -> derivative(y -> x^y, 1.5))(0.0) === 0.0 + end + +end + +##### +##### GradientTest +##### + +@testset verbose=true "GradientTest" begin + + @testset "hardcoded rosenbrock gradient" begin + f = DiffTests.rosenbrock_1 + x = [0.1, 0.2, 0.3] + v = f(x) + g = [-9.4, 15.6, 52.0] + + @test g ≈ ForwardDiff.gradient(f, x) + @test g ≈ fwd_gradient(f, x) + @test g ≈ rev_gradient(f, x) + end + + @testset "gradient of DiffTests.$f" for f in DiffTests.VECTOR_TO_NUMBER_FUNCS + X, Y = rand(13), rand(7) + FINITEDIFF_ERROR = 3e-5 + + v = f(X) + g = ForwardDiff.gradient(f, X) + # @test isapprox(g, Calculus.gradient(f, X), atol=FINITEDIFF_ERROR) + + @test_skip g ≈ fwd_gradient(f, X) + @test_skip g ≈ rev_gradient(f, X) + # Many of these fail. They don't involve mutation: + # https://github.com/JuliaDiff/DiffTests.jl/blob/master/src/DiffTests.jl#L64-L121 + end + + @testset "exponential function at base zero: $gradient" for gradient in GRADIENTS + @test isequal(gradient(t -> t[1]^t[2], [0.0, -0.5]), [NaN, NaN]) broken = (gradient != ForwardDiff.gradient) + @test isequal(gradient(t -> t[1]^t[2], [0.0, 0.0]), [NaN, NaN]) broken = (gradient != ForwardDiff.gradient) + @test isequal(gradient(t -> t[1]^t[2], [0.0, 0.5]), [Inf, NaN]) broken = (gradient != ForwardDiff.gradient) + @test isequal(gradient(t -> t[1]^t[2], [0.0, 1.5]), [0.0, 0.0]) + end + + @testset "chunk size zero - issue 399: $gradient" for gradient in GRADIENTS + f_const(x) = 1.0 + g_grad_const = x -> gradient(f_const, x) + @test g_grad_const([1.0]) |> iszero + @test g_grad_const(zeros(Float64, 0)) |> (g -> isempty(g) || g isa AbstractZero) + end + + # Issue 548 + @testset "ArithmeticStyle: $gradient" for gradient in GRADIENTS + function f(p) + sum(collect(0.0:p[1]:p[2])) + end + @test gradient(f, [0.2,25.0]) == [7875.0, 0.0] broken = gradient==rev_gradient # Rewrite reached intrinsic function fptosi. Missing rule? + end + + # Issue 197 + @testset "det with branches" begin + det2(A) = return ( + A[1,1]*(A[2,2]*A[3,3]-A[2,3]*A[3,2]) - + A[1,2]*(A[2,1]*A[3,3]-A[2,3]*A[3,1]) + + A[1,3]*(A[2,1]*A[3,2]-A[2,2]*A[3,1]) + ) + + A = [1 0 0; 0 2 0; 0 pi 3] + @test det2(A) == det(A) == 6 + @test istril(A) + + ∇A = [6 0 0; 0 3 -pi; 0 0 2] + @test ForwardDiff.gradient(det2, A) ≈ ∇A + @test_broken ForwardDiff.gradient(det, A) ≈ ∇A + + @test fwd_gradient(det2, A) ≈ ∇A + @test fwd_gradient(det, A) ≈ ∇A + + @test rev_gradient(det2, A) ≈ ∇A + @test rev_gradient(det, A) ≈ ∇A + + # And issue 407 + @test_broken ForwardDiff.hessian(det, A) ≈ ForwardDiff.hessian(det2, A) + + H = ForwardDiff.hessian(det2, A) + + @test_broken fwd_hessian(det, A) ≈ H + @test_broken rev_hessian(det, A) ≈ H + @test_broken fwd_rev_hessian(det, A) ≈ H + @test_broken rev_fwd_hessian(det, A) ≈ H + + @test_broken fwd_hessian(det2, A) ≈ H # UndefVarError: B not defined + @test_broken rev_hessian(det2, A) ≈ H + @test_skip fwd_rev_hessian(det2, A) ≈ H # Internal error: encountered unexpected error in runtime: AssertionError(msg="argextype only works on argument-position values") + @test_broken rev_fwd_hessian(det2, A) ≈ H # MethodError: no method matching (::Diffractor.∂⃖recurse{1})(::typeof(Core.arrayset) + end + + @testset "branches in mul!" begin + a, b = rand(3,3), rand(3,3) + + # Issue 536, version with 3-arg *, Julia 1.7: + @test_broken ForwardDiff.derivative(x -> sum(x*a*b), 0.0) ≈ sum(a * b) + @test fwd_derivative(x -> sum(x*a*b), 0.0) ≈ sum(a * b) + @test rev_derivative(x -> sum(x*a*b), 0.0) ≈ sum(a * b) + + # version with just mul! + function f536(x) + c = similar(a, typeof(x)) + mul!(c, a, b, x, false) + sum(c) + end + @test_broken ForwardDiff.derivative(f536, 0.0) ≈ sum(a * b) + @test_broken fwd_derivative(f536, 0.0) ≈ sum(a * b) + @test_broken rev_derivative(f536, 0.0) ≈ sum(a * b) # maybe no hope... + end + +end + +##### +##### HessianTest +##### + +@testset verbose=true "HessianTest" begin + + @testset "hardcoded rosenbrock hessian" begin + + f = DiffTests.rosenbrock_1 + x = [0.1, 0.2, 0.3] + v = f(x) + g = [-9.4, 15.6, 52.0] + h = [-66.0 -40.0 0.0; + -40.0 130.0 -80.0; + 0.0 -80.0 200.0] + + @test isapprox(h, ForwardDiff.hessian(f, x)) + + @test_skip h ≈ fwd_hessian(f, x) + @test_broken h ≈ rev_hessian(f, x) # Control flow support not fully implemented yet for higher-order reverse mode + @test_skip h ≈ rev_fwd_hessian(f, x) + @test_skip h ≈ fwd_rev_hessian(f, x) + end + + @testset "hessians for DiffTests.$f" for f in DiffTests.VECTOR_TO_NUMBER_FUNCS + X, Y = rand(13), rand(7) + + v = f(X) + g = ForwardDiff.gradient(f, X) + h = ForwardDiff.hessian(f, X) + + @test_broken g ≈ rev_gradient(f, x) + @test_broken h ≈ rev_hessian(f, x) + end + + @testset "branches in dot" begin # $hessian" for hessian in HESSIANS + # https://github.com/JuliaDiff/ForwardDiff.jl/issues/551 + M = [1 2 3; 4 5 6; 7 8 9]; + H = [2 6 10; 6 10 14; 10 14 18] + @test ForwardDiff.hessian(x->dot(x,M,x), fill(0.00001, 3)) ≈ H + @test_broken ForwardDiff.hessian(x->dot(x,M,x), zeros(3)) ≈ H + + @test_broken rev_hessian(x->dot(x,M,x), fill(0.00001, 3)) ≈ H # DimensionMismatch("variable with size(x) == (1, 3) cannot have a gradient with size(dx) == (3,)") + @test rev_hessian(x->(x'*M*x), fill(0.00001, 3)) ≈ H + @test_broken rev_hessian(x->dot(x,M,x), zeros(3)) ≈ H + end + +end + +##### +##### JacobianTest +##### + +@testset verbose=true "JacobianTest" begin + + @testset "hardcoded jacobian" begin + + f(x) = begin + y1 = x[1] * x[2] * sin(x[3]^2) + y2 = y1 + x[3] + y3 = y1 / y2 + y4 = x[3] + [y1, y2, y3, y4] + end + x = [1, 2, 3] + v = f(x) + j = [0.8242369704835132 0.4121184852417566 -10.933563142616123 + 0.8242369704835132 0.4121184852417566 -9.933563142616123 + 0.169076696546684 0.084538348273342 -2.299173530851733 + 0.0 0.0 1.0] + + @test isapprox(j, ForwardDiff.jacobian(f, x)) + @test isapprox(j, fwd_jacobian(f, x)) + @test isapprox(j, rev_jacobian(f, x)) + + end + + @testset "jacobians of DiffTests.$f" for f in DiffTests.ARRAY_TO_ARRAY_FUNCS + X, Y = rand(13), rand(7) + FINITEDIFF_ERROR = 3e-5 + + v = f(X) + j = ForwardDiff.jacobian(f, X) + # @test isapprox(j, Calculus.jacobian(x -> vec(f(x)), X, :forward), atol=1.3FINITEDIFF_ERROR) + + @test j ≈ fwd_jacobian(f, X) broken = f ∉ [-, identity, DiffTests.arr2arr_2] + @test j ≈ rev_jacobian(f, X) broken = f ∉ [-, identity, DiffTests.arr2arr_2] + # Most of these involve mutation: + # https://github.com/JuliaDiff/DiffTests.jl/blob/master/src/DiffTests.jl#L252-L272 + + end + + # for f! in DiffTests.INPLACE_ARRAY_TO_ARRAY_FUNCS + # v = fill!(similar(Y), 0.0) + # f!(v, X) + # j = ForwardDiff.jacobian(f!, fill!(similar(Y), 0.0), X) + # @test isapprox(j, Calculus.jacobian(x -> (y = fill!(similar(Y), 0.0); f!(y, x); vec(y)), X, :forward), atol=FINITEDIFF_ERROR) + # end + + # @testset "dimension errors for jacobian" begin + # @test_throws DimensionMismatch ForwardDiff.jacobian(identity, 2pi) # input + # @test_throws DimensionMismatch ForwardDiff.jacobian(sum, fill(2pi, 2)) # vector_mode_jacobian + # @test_throws DimensionMismatch ForwardDiff.jacobian(sum, fill(2pi, 10^6)) # chunk_mode_jacobian + # end + + @testset "eigen" begin + @test ForwardDiff.jacobian(x -> eigvals(SymTridiagonal(x, x[1:end-1])), [1.,2.]) ≈ [(1 - 3/sqrt(5))/2 (1 - 1/sqrt(5))/2 ; (1 + 3/sqrt(5))/2 (1 + 1/sqrt(5))/2] + @test ForwardDiff.jacobian(x -> eigvals(Symmetric(x*x')), [1.,2.]) ≈ [0 0; 2 4] + + @test_broken fwd_jacobian(x -> eigvals(SymTridiagonal(x, x[1:end-1])), [1.,2.]) ≈ [(1 - 3/sqrt(5))/2 (1 - 1/sqrt(5))/2 ; (1 + 3/sqrt(5))/2 (1 + 1/sqrt(5))/2] + @test_broken fwd_jacobian(x -> eigvals(Symmetric(x*x')), [1.,2.]) ≈ [0 0; 2 4] + + @test_broken rev_jacobian(x -> eigvals(SymTridiagonal(x, x[1:end-1])), [1.,2.]) ≈ [(1 - 3/sqrt(5))/2 (1 - 1/sqrt(5))/2 ; (1 + 3/sqrt(5))/2 (1 + 1/sqrt(5))/2] + @test rev_jacobian(x -> eigvals(Symmetric(x*x')), [1.,2.]) ≈ [0 0; 2 4] + end + +end + +##### +##### MiscTest +##### + +@testset verbose=true "MiscTest" begin + + ########################## + # Nested Differentiation # + ########################## + + @testset "nested README example" begin # , $jacobian + $gradient" for jacobian in JACOBIANS, gradient in GRADIENTS + + # README example # + #----------------# + + x = rand(5) + f = x -> sum(sin, x) + prod(tan, x) * sum(sqrt, x) + + g = x -> ForwardDiff.gradient(f, x) + j = x -> ForwardDiff.jacobian(g, x) + @test isapprox(ForwardDiff.hessian(f, x), j(x)) + + # Trying to run that in a loop of cases, and mark some broken, is confusing. + + H = ForwardDiff.hessian(f, x) + @test H ≈ ForwardDiff.jacobian(x -> fwd_gradient(f, x), x) + @test H ≈ ForwardDiff.jacobian(x -> rev_gradient(f, x), x) + + @test_skip H ≈ fwd_jacobian(x -> fwd_gradient(f, x), x) # TypeError: in new, expected DataType, got Type{Diffractor.TangentBundle{1, Type{Diffractor.UniformBundle{1, var"#s305", ZeroTangent}}, Tuple{NoTangent}}} + @test_skip H ≈ rev_jacobian(x -> rev_gradient(f, x), x) # error() in perform_optic_transform(ff::Type{Diffractor.∂⃖recurse{2}}, args::Any) + end + + # higher-order derivatives # + #--------------------------# + + @test_skip @testset "tensor 3rd order, $jacobian + $hessian" for jacobian in JACOBIANS, hessian in HESSIANS + + function tensor(f, x) + n = length(x) + out = jacobian(y -> hessian(f, y), x) + return reshape(out, n, n, n) + end + + test_tensor_output = reshape([240.0 -400.0 0.0; + -400.0 0.0 0.0; + 0.0 0.0 0.0; + -400.0 0.0 0.0; + 0.0 480.0 -400.0; + 0.0 -400.0 0.0; + 0.0 0.0 0.0; + 0.0 -400.0 0.0; + 0.0 0.0 0.0], 3, 3, 3) + + @test isapprox(tensor(DiffTests.rosenbrock_1, [0.1, 0.2, 0.3]), test_tensor_output) + + end + + @test_skip @testset "broadcast 4rd order, $jacobian + $jacobian2" for jacobian in JACOBIANS, jacobian2 in JACOBIANS + + test_nested_jacobian_output = [-sin(1) 0.0 0.0; + -0.0 -0.0 -0.0; + -0.0 -0.0 -0.0; + 0.0 0.0 0.0; + -0.0 -sin(2) -0.0; + -0.0 -0.0 -0.0; + 0.0 0.0 0.0; + -0.0 -0.0 -0.0; + -0.0 -0.0 -sin(3)] + + sin_jacobian = x -> jacobian2(y -> broadcast(sin, y), x) + + @test isapprox(jacobian(sin_jacobian, [1., 2., 3.]), test_nested_jacobian_output) broken = jacobian != ForwardDiff.jacobian + # segmentation fault julia + + end + + @testset "trig 2rd order, some gradient + $derivative" for derivative in DERIVATIVES + # Issue #59 example # + #-------------------# + + x = rand(2) + + f = x -> sin(x)/3 * cos(x)/2 + df = x -> derivative(f, x) + testdf = x -> (((cos(x)^2)/3) - (sin(x)^2)/3) / 2 + + @test df(x[1]) ≈ testdf(x[1]) + + f2 = x -> df(x[1]) * f(x[2]) + testf2 = x -> testdf(x[1]) * f(x[2]) + + @test isapprox(ForwardDiff.gradient(f2, x), ForwardDiff.gradient(testf2, x)) + g = ForwardDiff.gradient(testf2, x) + + @test g ≈ fwd_gradient(f2, x) broken = derivative != fwd_derivative + @test g ≈ rev_gradient(f2, x) broken = derivative != rev_derivative + + # MethodError: no method matching *(::ForwardDiff.Dual{ForwardDiff.Tag{var"#139#140", Float64}, Float64, 1}, ::Tuple{Float64, Tuple{Tuple{Float64}}}) + end + + ###################################### + # Higher-Dimensional Differentiation # + ###################################### + + @testset "inv & kron, $jacobian" for jacobian in JACOBIANS + + x = rand(5, 5) + + @test isapprox(ForwardDiff.jacobian(inv, x), -kron(inv(x'), inv(x))) + + end + + ######################################### + # Differentiation with non-Array inputs # + ######################################### + + # x = rand(5,5) + + # # Sparse + # f = x -> sum(sin, x) + prod(tan, x) * sum(sqrt, x) + # gfx = ForwardDiff.gradient(f, x) + # @test isapprox(gfx, ForwardDiff.gradient(f, sparse(x))) + + # # Views + # jinvx = ForwardDiff.jacobian(inv, x) + # @test isapprox(jinvx, ForwardDiff.jacobian(inv, view(x, 1:5, 1:5))) + + ######################## + # Conversion/Promotion # + ######################## + + + @testset "issue 71: target function returns a literal" begin + + # target function returns a literal (Issue #71) # + #-----------------------------------------------# + + # @test ForwardDiff.derivative(x->zero(x), rand()) == ForwardDiff.derivative(x->1.0, rand()) + # @test ForwardDiff.gradient(x->zero(x[1]), [rand()]) == ForwardDiff.gradient(x->1.0, [rand()]) + # @test ForwardDiff.hessian(x->zero(x[1]), [rand()]) == ForwardDiff.hessian(x->1.0, [rand()]) + # @test ForwardDiff.jacobian(x->[zero(x[1])], [rand()]) == ForwardDiff.jacobian(x->[1.0], [rand()]) + + for derivative in DERIVATIVES + @test derivative(x->zero(x), rand()) |> iszero + end + for gradient in GRADIENTS + @test gradient(x->zero(x[1]), [rand()]) |> iszero + end + for hessian in HESSIANS + @test hessian(x->zero(x[1]), [rand()]) |> iszero + end + for jacobian in JACOBIANS + @test jacobian(x->[zero(x[1])], [rand()]) |> iszero + end + + end + + @testset "arithmetic element-wise functions, $jacobian" for jacobian in JACOBIANS + + if jacobian != ForwardDiff.jacobian + @test_broken false + # Got exception outside of a @test + # DimensionMismatch("arrays could not be broadcast to a common size; got a dimension with lengths 2 and 4") + + # Later: + # ArgumentError: Tangent for the primal Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(-), Tuple{Float64, Vector{Float64}}} should be backed by a NamedTuple type, not by Tuple{ZeroTangent, Tangent{Tuple{Float64, Vector{Float64}}, Tuple{Float64, ZeroTangent}}, ZeroTangent}. + # Stacktrace: + # [1] _backing_error(P::Type, G::Type, E::Type) + # @ ChainRulesCore ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/tangent.jl:62 + # [2] Tangent{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(-), Tuple{Float64, Vector{Float64}}}, Tuple{ZeroTangent, Tangent{Tuple{Float64, Vector{Float64}}, Tuple{Float64, ZeroTangent}}, ZeroTangent}}(backing::Tuple{ZeroTangent, Tangent{Tuple{Float64, Vector{Float64}}, Tuple{Float64, ZeroTangent}}, ZeroTangent}) + # @ ChainRulesCore ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/tangent.jl:36 + # [3] (Tangent{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(-), Tuple{Float64, Vector{Float64}}}})(::ZeroTangent, ::Vararg{Any}) + # @ ChainRulesCore ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/tangent.jl:48 + # [4] partial(x::Diffractor.CompositeBundle{1, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(-), Tuple{Float64, Vector{Float64}}}, Tuple{Diffractor.TangentBundle{1, typeof(-), Diffractor.UniformTangent{ZeroTangent}}, Diffractor.CompositeBundle{1, Tuple{Float64, Vector{Float64}}, Tuple{Diffractor.TangentBundle{1, Float64, Diffractor.ExplicitTangent{Tuple{Float64}}}, Diffractor.TangentBundle{1, Vector{Float64}, Diffractor.UniformTangent{ZeroTangent}}}}, Diffractor.TangentBundle{1, Nothing, Diffractor.UniformTangent{ZeroTangent}}}}, i::Int64) + # @ Diffractor ~/.julia/dev/Diffractor/src/stage1/forward.jl:7 + continue + end + + # arithmetic element-wise functions # + #-----------------------------------# + + for op in (:.-, :.+, :./, :.*) + @eval begin + N = 4 + a = fill(1.0, N) + jac0 = reshape(vcat([[fill(0.0, N*(i-1)); a; fill(0.0, N^2-N*i)] for i = 1:N]...), N^2, N) + + f = x -> [$op(x[1], a); $op(x[2], a); $op(x[3], a); $op(x[4], a)] + jac = $jacobian(f, a) + @test isapprox(jac0, jac) + end + end + + end + + # NaNs # + #------# + + # @test ForwardDiff.partials(NaNMath.pow(ForwardDiff.Dual(-2.0,1.0),ForwardDiff.Dual(2.0,0.0)),1) == -4.0 + + # Partials{0} # + #-------------# + + @testset "Partials? $hessian" for hessian in HESSIANS + + if hessian == rev_hessian + x, y = rand(3), rand(3) + @test_broken hessian(y -> sum(hypot.(x, y)), y) isa AbstractMatrix + # Control flow support not fully implemented yet for higher-order reverse mode (TODO) + continue + end + + x, y = rand(3), rand(3) + h = hessian(y -> sum(hypot.(x, y)), y) + + @test h[1, 1] ≈ (x[1]^2) / (x[1]^2 + y[1]^2)^(3/2) + @test h[2, 2] ≈ (x[2]^2) / (x[2]^2 + y[2]^2)^(3/2) + @test h[3, 3] ≈ (x[3]^2) / (x[3]^2 + y[3]^2)^(3/2) + let i, j + for i in 1:3, j in 1:3 + i != j && @test h[i, j] ≈ 0.0 + end + end + + end + + @testset "issue 267: $hessian" for hessian in HESSIANS + + @noinline f267(z, x) = x[1] + z267 = ([(1, (2), [(3, (4, 5, [1, 2, (3, (4, 5), [5])]), (5))])]) + let z = z267, + g = x -> f267(z, x), + h = x -> g(x) + @test hessian(h, [1.]) == fill(0.0, 1, 1) broken = hessian == rev_hessian + end + + end + + @testset "issue 290: rem2pi & rounding modes, $derivative" for derivative in DERIVATIVES + + @test derivative(x -> rem2pi(x, RoundUp), rand()) == 1 + @test derivative(x -> rem2pi(x, RoundDown), rand()) == 1 + + end + +end diff --git a/test/pinn.jl b/test/pinn.jl index 84035877..90c259cf 100644 --- a/test/pinn.jl +++ b/test/pinn.jl @@ -1,5 +1,5 @@ using Diffractor -using Diffractor: var"'", ∂⃖ +using Diffractor: ∂⃖ using ForwardDiff using StaticArrays using Random @@ -39,14 +39,20 @@ function (a::Dense)(x::AbstractArray) z = map(+, W*x, b) map(σ, z) end -#g(NNODE,t,x,y) = ((((t*(1-x))*x)*(1-y))*y)*NNODE(@SVector [t,x,y]) + sin(2π*y)*sin(2π*x) -g(NNODE, t, x, y) = NNODE(@SVector [t,x,y]) -loss(NNODE, at=0.5) = (x->g(NNODE, -0.1, 0.1, x))''(at) + +#g_pinn(NNODE,t,x,y) = ((((t*(1-x))*x)*(1-y))*y)*NNODE(@SVector [t,x,y]) + sin(2π*y)*sin(2π*x) +g_pinn(NNODE, t, x, y) = NNODE(@SVector [t,x,y]) + +let var"'" = Diffractor.var"'" + global loss + loss(NNODE, at=0.5) = (x->g_pinn(NNODE, -0.1, 0.1, x))''(at) +end let var"'" = Diffractor.PrimeDerivativeFwd global loss_fwd_diff - loss_fwd_diff(NNODE, at=0.5) = (x->g(NNODE, -0.1, 0.1, x))''(at) + loss_fwd_diff(NNODE, at=0.5) = (x->g_pinn(NNODE, -0.1, 0.1, x))''(at) end -loss_fwd(NNODE, at=0.5) = ForwardDiff.derivative(x->ForwardDiff.derivative(x->g(NNODE, -0.1, 0.1, x), x), at) +loss_fwd(NNODE, at=0.5) = ForwardDiff.derivative(x->ForwardDiff.derivative(x->g_pinn(NNODE, -0.1, 0.1, x), x), at) + NNODE = Chain(Dense(3,256,tanh), Dense(256,256,tanh), Dense(256,256,tanh), @@ -54,11 +60,11 @@ NNODE = Chain(Dense(3,256,tanh), # Don't fall over on this semi-complicated nested AD case training_step(NNODE) = gradient(NNODE->loss(NNODE), NNODE) -@test loss(NNODE, 0.1) ≈ loss_fwd(NNODE, 0.1) -@test loss(NNODE, 0.5) ≈ loss_fwd(NNODE, 0.5) -@test loss(NNODE, 0.1) ≈ loss_fwd_diff(NNODE, 0.1) -@test loss(NNODE, 0.5) ≈ loss_fwd_diff(NNODE, 0.5) +@test_broken loss(NNODE, 0.1) ≈ loss_fwd(NNODE, 0.1) +@test_broken loss(NNODE, 0.5) ≈ loss_fwd(NNODE, 0.5) +@test_broken loss(NNODE, 0.1) ≈ loss_fwd_diff(NNODE, 0.1) +@test_broken loss(NNODE, 0.5) ≈ loss_fwd_diff(NNODE, 0.5) # How to test that this is actually the right answer? -training_step(NNODE) +@test_skip training_step(NNODE) #gradient(NNODE->loss_fwd_diff(NNODE), NNODE) diff --git a/test/runtests.jl b/test/runtests.jl index eca14ae5..6a8643b4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,282 +1,31 @@ using Diffractor -using Diffractor: var"'", ∂⃖, DiffractorRuleConfig -using ChainRules -using ChainRulesCore -using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad -using Symbolics -using LinearAlgebra - using Test -const fwd = Diffractor.PrimeDerivativeFwd -const bwd = Diffractor.PrimeDerivativeBack - -@testset verbose=true "Diffractor.jl" begin # overall testset, ensures all tests run - -# Unit tests -function tup2(f) - a, b = ∂⃖{2}()(f, 1) - c, d = b((2,)) - e, f = d(ZeroTangent(), 3) - f((4,)) -end - -@test tup2(tuple) == (NoTangent(), 4) - -my_tuple(args...) = args -ChainRules.rrule(::typeof(my_tuple), args...) = args, Δ->Core.tuple(NoTangent(), Δ...) -@test tup2(my_tuple) == (ZeroTangent(), 4) - -# Check characteristic of exp rule -@variables ω α β γ δ ϵ ζ η -(x1, c1) = ∂⃖{3}()(exp, ω) -@test isequal(simplify(x1), simplify(exp(ω))) -((_, x2), c2) = c1(α) -@test isequal(simplify(x2), simplify(α*exp(ω))) -(x3, c3) = c2(ZeroTangent(), β) -@test isequal(simplify(x3), simplify(β*exp(ω))) -((_, x4), c4) = c3(γ) -@test isequal(simplify(x4), simplify(exp(ω)*(γ + (α*β)))) -(x5, c5) = c4(ZeroTangent(), δ) -@test isequal(simplify(x5), simplify(δ*exp(ω))) -((_, x6), c6) = c5(ϵ) -@test isequal(simplify(x6), simplify(ϵ*exp(ω) + α*δ*exp(ω))) -(x7, c7) = c6(ZeroTangent(), ζ) -@test isequal(simplify(x7), simplify(ζ*exp(ω) + β*δ*exp(ω))) -(_, x8) = c7(η) -@test isequal(simplify(x8), simplify((η + (α*ζ) + (β*ϵ) + (δ*(γ + (α*β))))*exp(ω))) - -# Minimal 2-nd order forward smoke test -@test Diffractor.∂☆{2}()(Diffractor.ZeroBundle{2}(sin), - Diffractor.ExplicitTangentBundle{2}(1.0, (1.0, 1.0, 0.0)))[Diffractor.CanonicalTangentIndex(1)] == sin'(1.0) +@testset verbose=true "Diffractor.jl" begin -function simple_control_flow(b, x) - if b - return sin(x) - else - return cos(x) + @testset verbose=true "Diffractor 0.1's own unit tests" begin + include("diffractor_01.jl") end -end -function myprod(xs) - s = 1 - for x in xs - s *= x + @testset verbose=true "pseudo-Flux higher-order" begin + # Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24) + # include("pinn.jl") end - return s -end -function mypow(x, n) - r = one(x) - while n > 0 - n -= 1 - r *= x + @testset verbose=true "ChainRules integration" begin + include("chainrules.jl") end - return r -end -function times_three_while(x) - z = x - i = 3 - while i > 1 - z += x - i -= 1 + @testset verbose=true "from ForwardDiff" begin + include("forwarddiff.jl") end - z -end - -isa_control_flow(::Type{T}, x) where {T} = isa(x, T) ? x : T(x) - -# Simple Reverse Mode tests -let var"'" = Diffractor.PrimeDerivativeBack - # Integration tests - @test @inferred(sin'(1.0)) == cos(1.0) - @test @inferred(sin''(1.0)) == -sin(1.0) - @test sin'''(1.0) == -cos(1.0) - @test sin''''(1.0) == sin(1.0) - @test sin'''''(1.0) == cos(1.0) - @test sin''''''(1.0) == -sin(1.0) - - f_getfield(x) = getfield((x,), 1) - @test f_getfield'(1) == 1 - @test f_getfield''(1) == 0 - @test f_getfield'''(1) == 0 - # Higher order mixed mode tests - - complicated_2sin(x) = (x = map(sin, Diffractor.xfill(x, 2)); x[1] + x[2]) - @test @inferred(complicated_2sin'(1.0)) == 2sin'(1.0) - @test @inferred(complicated_2sin''(1.0)) == 2sin''(1.0) broken=true - @test @inferred(complicated_2sin'''(1.0)) == 2sin'''(1.0) broken=true - @test @inferred(complicated_2sin''''(1.0)) == 2sin''''(1.0) broken=true - - # Control flow cases - @test @inferred((x->simple_control_flow(true, x))'(1.0)) == sin'(1.0) - @test @inferred((x->simple_control_flow(false, x))'(1.0)) == cos'(1.0) - @test (x->sum(isa_control_flow(Matrix{Float64}, x)))'(Float32[1 2;]) == [1.0 1.0;] - @test times_three_while'(1.0) == 3.0 - - pow5p(x) = (x->mypow(x, 5))'(x) - @test pow5p(1.0) == 5.0 -end - -# Simple Forward Mode tests -let var"'" = Diffractor.PrimeDerivativeFwd - recursive_sin(x) = sin(x) - ChainRulesCore.frule(∂, ::typeof(recursive_sin), x) = frule(∂, sin, x) - - # Integration tests - @test recursive_sin'(1.0) == cos(1.0) - @test recursive_sin''(1.0) == -sin(1.0) - # Error: ArgumentError: Tangent for the primal Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}} - # should be backed by a NamedTuple type, not by Tuple{Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}}. - @test_broken recursive_sin'''(1.0) == -cos(1.0) - @test_broken recursive_sin''''(1.0) == sin(1.0) - @test_broken recursive_sin'''''(1.0) == cos(1.0) - @test_broken recursive_sin''''''(1.0) == -sin(1.0) - - # Test the special rules for sin/cos/exp - @test sin''''''(1.0) == -sin(1.0) - @test cos''''''(1.0) == -cos(1.0) - @test exp''''''(1.0) == exp(1.0) -end - -# Some Basic Mixed Mode tests -function sin_twice_fwd(x) - let var"'" = Diffractor.PrimeDerivativeFwd - sin''(x) + @testset verbose=true "from Zygote's features.jl" begin + include("zygote_features.jl") end -end -let var"'" = Diffractor.PrimeDerivativeFwd - @test_broken sin_twice_fwd'(1.0) == sin'''(1.0) -end - -# Regression tests -@test gradient(x -> sum(abs2, x .+ 1.0), zeros(3))[1] == [2.0, 2.0, 2.0] - -function f_broadcast(a) - l = a / 2.0 * [[0. 1. 1.]; [1. 0. 1.]; [1. 1. 0.]] - return sum(l) -end -@test fwd(f_broadcast)(1.0) == bwd(f_broadcast)(1.0) -# Make sure that there's no infinite recursion in kwarg calls -g_kw(;x=1.0) = sin(x) -f_kw(x) = g_kw(;x) -@test bwd(f_kw)(1.0) == bwd(sin)(1.0) - -function f_crit_edge(a, b, c, x) - # A function with two critical edges. This used to trigger an issue where - # Diffractor would fail to insert edges for the second split critical edge. - y = 1x - if a && b - y = 2x - end - if b && c - y = 3x - end - - if c - y = 4y + @testset verbose=true "from Zygote's gradcheck.jl" begin + include("zygote_gradcheck.jl") end - return y -end -@test bwd(x->f_crit_edge(false, false, false, x))(1.0) == 1.0 -@test bwd(x->f_crit_edge(true, true, false, x))(1.0) == 2.0 -@test bwd(x->f_crit_edge(false, true, true, x))(1.0) == 12.0 -@test bwd(x->f_crit_edge(false, false, true, x))(1.0) == 4.0 - -# Issue #27 - Mixup in lifting of getfield -let var"'" = bwd - @test (x->x^5)''(1.0) == 20. - @test_broken (x->x^5)'''(1.0) == 60. -end - -# Issue #38 - Splatting arrays -@test gradient(x -> max(x...), (1,2,3))[1] == (0.0, 0.0, 1.0) -@test gradient(x -> max(x...), [1,2,3])[1] == [0.0, 0.0, 1.0] - -# Issue #40 - Symbol type parameters not properly quoted -@test Diffractor.∂⃖recurse{1}()(Val{:transformations})[1] === Val{:transformations}() - -# PR #43 -loss(res, z, w) = sum(res.U * Diagonal(res.S) * res.V) + sum(res.S .* w) -x43 = rand(10, 10) -@test Diffractor.gradient(x->loss(svd(x), x[:,1], x[:,2]), x43) isa Tuple{Matrix{Float64}} - -# PR # 45 - Calling back into AD from ChainRules -y45, back45 = rrule_via_ad(DiffractorRuleConfig(), x -> log(exp(x)), 2) -@test y45 ≈ 2.0 -@test back45(1) == (ZeroTangent(), 1.0) - -z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2) -@test z45 ≈ 2.0 -@test delta45 ≈ 1.0 - -# PR #82 - getindex on non-numeric arrays -@test gradient(ls -> ls[1](1.), [Base.Fix1(*, 1.)])[1][1] isa Tangent{<:Base.Fix1} - -@testset "broadcast" begin - @test gradient(x -> sum(x ./ x), [1,2,3]) == ([0,0,0],) # derivatives_given_output - @test gradient(x -> sum(sqrt.(atan.(x, transpose(x)))), [1,2,3])[1] ≈ [0.2338, -0.0177, -0.0661] atol=1e-3 - @test gradient(x -> sum(exp.(log.(x))), [1,2,3]) == ([1,1,1],) - - @test gradient(x -> sum((exp∘log).(x)), [1,2,3]) == ([1,1,1],) # frule_via_ad - exp_log(x) = exp(log(x)) - @test gradient(x -> sum(exp_log.(x)), [1,2,3]) == ([1,1,1],) - @test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], [1,2]) == ([1 1; 0.5 0.5], [-3, -1.75]) - @test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], 5) == ([0.2 0.2; 0.2 0.2], -0.4) - @test gradient(x -> sum((y -> y/x).([1,2,3])), 4) == (-0.375,) # closure - - @test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 # array of arrays - @test gradient(x -> sum(sum, Ref(x) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 - @test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 - @test gradient(x -> sum(sum, (x,) .* transpose(x)), [1,2,3])[1] ≈ [12, 12, 12] # must not take the * fast path - - @test gradient(x -> sum(x ./ 4), [1,2,3]) == ([0.25, 0.25, 0.25],) - @test gradient(x -> sum([1,2,3] ./ x), 4) == (-0.375,) # x/y rule - @test gradient(x -> sum(x.^2), [1,2,3]) == ([2.0, 4.0, 6.0],) # x.^2 rule - @test gradient(x -> sum([1,2,3] ./ x.^2), 4) == (-0.1875,) # scalar^2 rule - - @test gradient(x -> sum((1,2,3) .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-1.0, -1.0, -1.0),) - @test gradient(x -> sum(transpose([1,2,3]) .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-3.0, -3.0, -3.0),) - @test gradient(x -> sum([1 2 3] .+ x .^ 2), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(6.0, 12.0, 18.0),) - - @test gradient(x -> sum(x .> 2), [1,2,3]) |> only |> iszero # Bool output - @test gradient(x -> sum(1 .+ iseven.(x)), [1,2,3]) |> only |> iszero - @test gradient((x,y) -> sum(x .== y), [1,2,3], [1 2 3]) == (NoTangent(), NoTangent()) - @test gradient(x -> sum(x .+ [1,2,3]), true) |> only |> iszero # Bool input - @test gradient(x -> sum(x ./ [1,2,3]), [true false]) |> only |> iszero - @test gradient(x -> sum(x .* transpose([1,2,3])), (true, false)) |> only |> iszero - - tup_adj = gradient((x,y) -> sum(2 .* x .+ log.(y)), (1,2), transpose([3,4,5])) - @test tup_adj[1] == Tangent{Tuple{Int64, Int64}}(6.0, 6.0) - @test tup_adj[2] ≈ [0.6666666666666666 0.5 0.4] - @test tup_adj[2] isa Transpose - @test gradient(x -> sum(atan.(x, (1,2,3))), Diagonal([4,5,6]))[1] isa Diagonal - - @test gradient(x -> sum((y -> (x*y)).([1,2,3])), 4.0) == (6.0,) # closure end - -@testset "broadcast, 2nd order" begin - @test gradient(x -> gradient(y -> sum(y .* y), x)[1] |> sum, [1,2,3.0])[1] == [2,2,2] # calls "split broadcasting generic" with f = unthunk - @test gradient(x -> gradient(y -> sum(y .* x), x)[1].^3 |> sum, [1,2,3.0])[1] == [3,12,27] - @test_broken gradient(x -> gradient(y -> sum(y .* 2 .* y'), x)[1] |> sum, [1,2,3.0])[1] == [12, 12, 12] # Control flow support not fully implemented yet for higher-order - - @test_broken gradient(x -> sum(gradient(x -> sum(x .^ 2 .+ x'), x)[1]), [1,2,3.0])[1] == [6,6,6] # BoundsError: attempt to access 18-element Vector{Core.Compiler.BasicBlock} at index [0] - @test_broken gradient(x -> sum(gradient(x -> sum((x .+ 1) .* x .- x), x)[1]), [1,2,3.0])[1] == [2,2,2] - @test_broken gradient(x -> sum(gradient(x -> sum(x .* x ./ 2), x)[1]), [1,2,3.0])[1] == [1,1,1] - - @test_broken gradient(x -> sum(gradient(x -> sum(exp.(x)), x)[1]), [1,2,3])[1] ≈ exp.(1:3) # MethodError: no method matching copy(::Nothing) - @test_broken gradient(x -> sum(gradient(x -> sum(atan.(x, x')), x)[1]), [1,2,3.0])[1] ≈ [0,0,0] - @test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) .* x), x)[1]), [1,2,3]) == ([6,6,6],) # accum(a::Transpose{Float64, Vector{Float64}}, b::ChainRulesCore.Tangent{Transpose{Int64, Vector{Int64}}, NamedTuple{(:parent,), Tuple{ChainRulesCore.NoTangent}}}) - @test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) ./ x.^2), x)[1]), [1,2,3])[1] ≈ [27.675925925925927, -0.824074074074074, -2.1018518518518516] - - @test_broken gradient(z -> gradient(x -> sum((y -> (x^2*y)).([1,2,3])), z)[1], 5.0) == (12.0,) -end - -# Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24) -#include("pinn.jl") - -end # overall testset diff --git a/test/zygote_features.jl b/test/zygote_features.jl new file mode 100644 index 00000000..23c97a0a --- /dev/null +++ b/test/zygote_features.jl @@ -0,0 +1,1173 @@ + +# This file contains many examples borrowed from Zygote's tests, +# all files except its "gradcheck.jl" which is elsewhere. + +# Ideally this will catch the awkward cases, and should be extended +# to test forward mode and higher derivatives. + +using Diffractor, ChainRulesCore +using Test, LinearAlgebra + +isZero(x) = x isa AbstractZero + +# Zygote's misnamed hobbit function: +function pullback(f, x...) + y, b = Diffractor.∂⃖{1}()(f, x...) + back(dy) = map(unthunk, Base.tail(b(dy))) + y, back +end + +function withgradient(f, x...) + y, b = Diffractor.∂⃖{1}()(f, x...) + (val=y, grad=map(unthunk, Base.tail(b(one(y))))) +end + +using ChainRulesCore: backing, canonicalize, AbstractTangent + +# These pirate methods greatly simplify RHS of tests: +Base.isapprox(x::Tangent, y::NamedTuple) = all(map(isapprox, backing(canonicalize(x)), y)) +Base.isapprox(x::Tangent{<:Tuple}, y::Tuple) = all(map(isapprox, backing(x), y)) +Base.isapprox(x::Array{<:AbstractTangent}, y::Array) = all(map(isapprox, x, y)) +Base.isapprox(::AbstractZero, ::Nothing) = true +Base.isapprox(::AbstractZero, y) = iszero(y) + +##### +##### Zygote/test/complex.jl +##### + +gradfwd(f, x::Number) = (Diffractor.PrimeDerivativeFwd(f)(x),) +gradback(f, x::Number) = (Diffractor.PrimeDerivativeBack(f)(x),) + +@testset "complex numbers" begin # : $gradient" for gradient in (gradfwd, gradback) + + @test gradient(x -> real(abs(x)*exp(im*angle(x))), 10+20im)[1] ≈ 1 + @test gradient(x -> imag(real(x)+0.3im), 0.3)[1] ≈ 0 + @test gradient(x -> imag(conj(x)+0.3im), 0.3 + 0im)[1] ≈ -1im + @test gradient(x -> abs((imag(x)+0.3)), 0.3 + 0im)[1] ≈ 1im + + @test gradient(x -> norm((im*x) / (im)), 2)[1] == 1 + @test gradient(x -> norm((im) / (im*x)), 2)[1] == -1/4 + + fs_C_to_R = ( + real, + imag, + abs, + abs2, + z -> abs(z)*cos(im*angle(z)), + z->abs(cos(exp(z))), + z->3*real(z)^3-2*imag(z)^5 + ) + @testset "C->R: $i" for (i,f) in enumerate(fs_C_to_R) + for z in (1.0+2.0im, -2.0+pi*im) + ε = 1e-8 + grad_fd = (f(z+ε)-f(z))/ε + im*(f(z+ε*im)-f(z))/ε + @test abs(gradient(x -> real(f(x)), z)[1] - grad_fd) < sqrt(ε) + end + end + + fs_C_to_C_holomorphic = ( + cos, + exp, + log, + z->z^2, + z->(real(z)+im*imag(z))^2, + z->real(z)^2 - imag(z)^2 +2im*(real(z)*imag(z)), + z->exp(cos(log(z))), + z->abs(z)*exp(im*angle(z)), + ) + @testset "C->C holomorphic: $i" for (i,f) in enumerate(fs_C_to_C_holomorphic) + for z in (1.0+2.0im, -2.0+pi*im) + ε = 1e-8 + grad_fd_r = (f(z+ε)-f(z))/ε + grad_fd_i = (f(z+ε*im)-f(z))/(ε*im) + @assert abs(grad_fd_r - grad_fd_i) < sqrt(ε) # check the function is indeed holomorphic + @test abs(gradient(x -> real(f(x)), z)[1] - conj(grad_fd_r)) < sqrt(ε) + end + end + + fs_C_to_C_non_holomorphic = ( + conj, + z->abs(z)+0im, + z->im*abs(z), + z->abs2(z)+0im, + z->im*abs2(z), + z->z'z, + z->conj(z)*z^2, + ) + @testset "C->C non-holomorphic: $i" for (i,f) in enumerate((fs_C_to_C_holomorphic...,fs_C_to_C_holomorphic...)) + for z in (1.0+2.0im, -2.0+pi*im) + ε = 1e-8 + grad_fd = real(f(z+ε)-f(z))/ε + im*real(f(z+ε*im)-f(z))/ε + @test abs(gradient(x -> real(f(x)), z)[1] - grad_fd) < sqrt(ε) + end + end + + # Zygote issue 342 + @test gradient(x->real(x + 2.0*im), 3.0) == (1.0,) + @test gradient(x->imag(x + 2.0*im), 3.0) == (0.0,) + +end + +@testset "complex arrays" begin + + # Zygote issue 705 + @test gradient(x -> imag(sum(exp, x)), [1,2,3]) |> only |> isZero + @test gradient(x -> imag(sum(exp, x)), [1+0im,2,3])[1] ≈ im .* exp.(1:3) + +end + +##### +##### Zygote/test/features.jl +##### + +# This file isn't really organised; here it's broken arbitrarily into testsets each about a page long. +@testset "features I" begin + + # power functions + + function pow(x, n) + r = 1 + while n > 0 + n -= 1 + r *= x + end + return r + end + @test gradient(pow, 2, 3)[1] == 12 + @test gradient(pow, 2, 3)[2] |> isZero + + function pow_mut(x, n) + r = Ref(one(x)) + while n > 0 + n -= 1 + r[] = r[] * x # not sure Diffractor supports this, if not it could give a helpful error + end + return r[] + end + @test_broken gradient(pow_mut, 2, 3)[1] == 12 # no method matching (::Diffractor.∂⃖recurse{1})(::typeof(setfield!), ::Base.RefValue{Int64}, ::Symbol, ::Int64) + @test_broken gradient(pow_mut, 2, 3)[2] |> isZero + + global r163 = 1 + function pow_global(x, n) + global r163 + while n > 0 + r163 *= x + n -= 1 + end + return r163 + end + @test_broken gradient(pow_global, 2, 3)[1] == 12 # transform!(ci::Core.CodeInfo, meth::Method, nargs::Int64, sparams::Core.SimpleVector, N::Int64) + @test_broken gradient(pow_global, 2, 3)[2] |> isZero + + # misc. + + @test gradient(x -> 1, 2) |> only |> isZero + + @test gradient(t -> t[1]*t[2], (2, 3)) |> only |> Tuple == (3, 2) + @test_broken gradient(t -> t[1]*t[2], (2, 3)) isa Tangent # should be! + + # complex & getproperty -- https://github.com/JuliaDiff/Diffractor.jl/issues/71 + + @test_broken gradient(x -> x.re, 2+3im) === (1.0 + 0.0im,) # one NamedTuple + @test_broken gradient(x -> x.re*x.im, 2+3im) == (3.0 + 2.0im,) # two, different fields + @test_broken gradient(x -> x.re*x.im + x.re, 2+3im) == (4.0 + 2.0im,) # three, with accumulation + + @test_skip gradient(x -> abs2(x * x.re), 4+5im) == (456.0 + 160.0im,) # gradient participates + @test gradient(x -> abs2(x * real(x)), 4+5im) == (456.0 + 160.0im,) # function not getproperty + @test_skip gradient(x -> abs2(x * getfield(x, :re)), 4+5im) == (456.0 + 160.0im,) + +end +@testset "features II" begin + + # structs + + struct Bar{T} + a::T + b::T + end + function mul_struct(a, b) + c = Bar(a, b) + c.a * c.b + end + @test gradient(mul_struct, 2, 3) == (3, 2) + + @test_broken gradient(x -> [x][1].a, Bar(1, 1)) == ((a=1, b=NoTangent()),) # MethodError: no method matching zero(::Type{Bar{Int64}}) + + function mul_tuple(a, b) + c = (a, b) + c[1] * c[2] + end + @test gradient(mul_tuple, 2, 3) == (3, 2) + + function mul_lambda(x, y) + g = z -> x * z + g(y) + end + @test gradient(mul_lambda, 2, 3) == (3, 2) + + # splats + + @test gradient((a, b...) -> *(a, b...), 2, 3) == (3, 2) + + @test_broken gradient((x, a...) -> x, 1) == (1,) + @test gradient((x, a...) -> x, 1, 1) == (1, ZeroTangent()) + @test_broken gradient((x, a...) -> x == a, 1) == (NoTangent(),) + @test gradient((x, a...) -> x == a, 1, 2) == (NoTangent(), NoTangent()) + + # keywords + + kwmul(; a = 1, b) = a*b + mul_kw(a, b) = kwmul(a = a, b = b) + @test gradient(mul_kw, 2, 3) == (3, 2) # passes at REPL, not in testset? + +end +@testset "features III" begin + + function myprod(xs) + s = 1 + for x in xs + s *= x + end + return s + end + @test gradient(myprod, [1,2,3])[1] == [6,3,2] + + function mul_vec(a, b) + xs = [a, b] + xs[1] * xs[2] + end + @test gradient(mul_vec, 2, 3) == (3, 2) + + # dictionary + + @test_skip gradient(2) do x + d = Dict() + d[:x] = x + x * d[:x] + end == (4,) + + # keywords + + f249(args...; a=nothing, kwargs...) = g250(a,args...; kwargs...) + g250(args...; x=1, idx=Colon(), kwargs...) = x[idx] + @test gradient(x -> sum(f249(; x=x, idx=1:1)), ones(2))[1] == [1, 0] + + # recursion + + pow_rec(x, n) = n == 0 ? 1 : x*pow_rec(x, n-1) + @test gradient(pow_rec, 2, 3)[1] == 12 + @test gradient(pow_rec, 2, 3)[2] |> isZero + + # second derivatives + + function grad258(f, args...) + y, back = pullback(f, args...) + return back(1) + end + D263(f, x) = grad258(f, x)[1] + + @test D263(x -> D263(sin, x), 0.5) == -sin(0.5) + @test D263(x -> x*D263(y -> x+y, 1), 1) == 1 + @test D263(x -> x*D263(y -> x*y, 1), 4) == 8 + + # throw + + f272(x) = throw(DimensionMismatch("fubar")) + @test_throws DimensionMismatch gradient(f272, 1) + + # hvcat + + @test gradient(2) do x + H = [1 x; 3 4] + sum(H) + end[1] == 1 + + @test gradient(x -> one(eltype(x)), rand(10)) |> only |> isZero + + # three-way control flow merge + + @test gradient(1) do x + if x > 0 + x *= 2 + elseif x < 0 + x *= 3 + end + x + end[1] == 2 + +end +@testset "features IV" begin + + @test gradient(1) do x + if true + elseif true + nothing + end + x + x + end == (2,) + + # try + + function pow_try(x) + try + 2x + catch e + println("error") + end + end + + @test_broken gradient(pow_try, 1) == (2,) # BoundsError: attempt to access 6-element Vector{Core.Compiler.BasicBlock} at index [0] + + # @simd + + function pow_simd(x, n) + r = 1 + @simd for i = 1:n + r *= x + end + return r + end + @test gradient(pow_simd, 2, 3)[1] == 12 + @test gradient(pow_simd, 2, 3)[2] |> isZero + + # @timed + + @test_broken gradient(x -> first(@timed x), 0) == (1,) # transform!(ci::Core.CodeInfo, meth::Method, nargs::Int64, sparams::Core.SimpleVector, N::Int64) + @test_broken gradient(x -> (@time x^2), 3) == (6,) # BoundsError: attempt to access 12-element Vector{Core.Compiler.BasicBlock} at index [0] + + # kwarg splat + + g516(; kwargs...) = kwargs[:x] * kwargs[:z] + h517(somedata) = g516(; somedata...) + @test gradient(h517, (; x=3.0, y=4.0, z=2.3)) |> only ≈ (; x=2.3, y=0.0, z=3.0) + @test_broken gradient(h517, Dict(:x=>3.0, :y=>4.0, :z=>2.3)) == (Tangent{NamedTuple{(:x, :y, :z), Tuple{Float64,Float64,Float64}}}(; x=2.3, y=0.0, z=3.0),) # ERROR: (1, get(d::IdDict{K, V}, key, default) where {K, V} @ Base iddict.jl:101, :($(Expr(:foreigncall, :(:jl_eqtable_get), Any, svec(Any, Any, Any), 0, :(:ccall), :(%1), Core.Argument(3), Core.Argument(4))))) + +end + +@testset "mutable structs" begin + + mutable struct MyMutable + value::Float64 + end + function foo!(m::MyMutable, x) + m.value = x + end + function baz(args) + m = MyMutable(0.0) + foo!(m, args...) + m.value + end + @test_broken gradient(baz, (1.0,)) |> only ≈ (1.0,) + + # ChainRules represents these as the same Tangent as immutable structs, but is that ideal? + + @test gradient(x -> x.value^2 + x.value, MyMutable(3)) === (Tangent{MyMutable}(value = 7.0,),) + @test gradient(x -> x.value^2 + x.value, MyMutable(3)) |> only ≈ (value = 7.0,) # with new isapprox methods + + @test gradient(x -> x.x^2 + x.x, Ref(3)) == (Tangent{Base.RefValue{Int}}(x = 7.0,),) + @test gradient(x -> real(x.x^2 + im * x.x), Ref(4)) == (Tangent{Base.RefValue{Int}}(x = 8.0,),) + + # Field access of contents: + @test_broken gradient(x -> abs2(x.x) + 7 * x.x.re, Ref(1+im)) |> only ≈ (; x = 9.0 + 2.0im) + @test gradient(x -> abs2(x[1].x) + 7 * x[1].x.re, [Ref(1+im)]) |> only ≈ [(x = 9.0 + 2.0im,)] + @test gradient(x -> abs2(x[1].x) + 7 * real(x[1].x), [Ref(1+im)]) |> only ≈ [(x = 9.0 + 2.0im,)] # worked on Zygote 0.6.0, 0.6.20 + @test gradient(x -> abs2(x[].x) + 7 * real(x[].x), Ref(Ref(1+im))) |> only ≈ (x = (x = 9.0 + 2.0im,),) # Zygote gives nothing, same in 0.6.0 + + # Array of mutables: + @test_broken gradient(x -> sum(getindex.(x).^2), Ref.(1:3)) |> only ≈ [(;x=2i) for i in 1:3] # MethodError: no method matching one(::Base.RefValue{Int64}) + @test gradient(x -> sum(abs2∘getindex, x), Ref.(1:3)) |> only ≈ [(;x=2i) for i in 1:3] # Tangent for the primal Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}} should be backed by a AbstractDict type, not by NamedTuple + + @test_broken gradient(x -> (getindex.(x).^2)[1], Ref.(1:3))[1][1] ≈ (x=2.0,) # rest are (x = 0.0,), but nothing would be OK too + @test_broken gradient(x -> (prod.(getindex.(x)))[1], Ref.(eachcol([1 2; 3 4])))[1][1] ≈ (x = [3.0, 1.0],) # MethodError: no method matching one(::SubArray{Int64, 1, Matrix{Int64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}) + + # Broadcasting over Ref: + @test gradient(x -> sum(sum, x .* [1,2,3]), Ref([4,5])) |> only ≈ (x = [6.0, 6.0],) + @test gradient(x -> sum(sum, Ref(x) .* [1,2,3]), [4,5]) == ([6.0, 6.0],) + +end + +@testset "NamedTuples" begin + + @test gradient(x -> x.a, (a=1, b=2)) == (Tangent{NamedTuple{(:a, :b), Tuple{Int,Int}}}(a = 1,),) + @test gradient(x -> x[1].a, [(a=1, b=2)]) |> only ≈ [(a = 1, b = nothing)] + + @test gradient(x -> x[1].a, [(a=1, b=2), (a=3, b=4)]) |> only == [Tangent{NamedTuple{(:a, :b), Tuple{Int64, Int64}}}(a = 1.0,), ZeroTangent()] + + # Mix with Ref + @test gradient(x -> x[].a, Ref((a=1, b=2))) |> only ≈ (x = (a = 1, b = nothing),) + @test gradient(x -> x[1][].a, [Ref((a=1, b=2)), Ref((a=3, b=4))]) |> only |> first ≈ (x = (a = 1, b = nothing),) + @test gradient(x -> x[1].a, [(a=1, b=2), "three"]) |> only |> first ≈ (a = 1, b = nothing) + + @testset "indexing kwargs: PR 1286" begin + # https://github.com/FluxML/Zygote.jl/pull/1286 + inner_lit_index(; kwargs...) = kwargs[:x] + outer_lit_index(; kwargs...) = inner_lit_index(; x=kwargs[:x]) + + inner_dyn_index(k; kwargs...) = kwargs[k] + outer_dyn_index(k; kwargs...) = inner_dyn_index(k; x=kwargs[k]) + + @test gradient(x -> outer_lit_index(; x), 0.0) == (1.0,) + @test gradient((x, k) -> outer_dyn_index(k; x), 0.0, :x) == (1.0, NoTangent()) + end +end + +@testset "Pairs" begin + + @test gradient(x->10*pairs((a=x, b=2))[1], 100)[1] === 10.0 + @test gradient(x->10*pairs((a=x, b=2))[2], 100) |> only |> isZero + + foo387(; kw...) = 1 + @test_skip gradient(() -> foo387(a=1,b=2.0)) === () # Here Diffractor returns a function, by design + + @test gradient(x->10*(x => 2)[1], 100) === (10.0,) + @test gradient(x->10*(x => 2)[2], 100) |> only |> isZero + + @test gradient(x-> (:x => x)[2], 17) == (1,) + + d = Dict(:x=>1.0, :y=>3.0); + @test_broken gradient(d -> Dict(:x => d[:x])[:x], d) == (Dict(:x => 1),) # BoundsError: attempt to access 3-element Vector{Core.Compiler.BasicBlock} at index [4] + + # https://github.com/FluxML/Zygote.jl/pull/1295 + no_kwarg_grad(x; kwargs...) = x[kwargs[:i]] + @test gradient(x -> no_kwarg_grad(x; i=1), [1]) == ([1],) +end + +@testset "Iterators" begin + + # enumerate + + @test_broken gradient(1:5) do xs + sum([x^i for (i,x) in enumerate(xs)]) + end == ([1, 4, 27, 256, 3125],) + + @test_broken gradient([1,10,100]) do xs + sum([xs[i]^i for (i,x) in enumerate(xs)]) + end == ([1, 2 * 10^1, 3 * 100^2],) + + @test gradient([1,10,100]) do xs + sum((xs[i]^i for (i,x) in enumerate(xs))) # same without collect + end == ([1, 2 * 10^1, 3 * 100^2],) + + # zip + # On Julia 1.4 and earlier, [x/y for (x,y) in zip(10:14, 1:10)] is a DimensionMismatch, + # while on 1.5 - 1.7 it stops early. + + @test_broken gradient(10:14, 1:10) do xs, ys + sum([x/y for (x,y) in zip(xs, ys)]) + end[2] ≈ vcat(.-(10:14) ./ (1:5).^2, zeros(5)) + + @test_broken gradient(10:14, 1:10) do xs, ys + sum(x/y for (x,y) in zip(xs, ys)) # same without collect + end[2] ≈ vcat(.-(10:14) ./ (1:5).^2, zeros(5)) + + @test_skip begin + bk_z = pullback((xs,ys) -> sum([abs2(x*y) for (x,y) in zip(xs,ys)]), [1,2], [3im,4im])[2] + @test bk_z(1.0)[1] isa AbstractVector{<:Real} # projection + end + + # Iterators.Filter + + @test_broken gradient(2:9) do xs + sum([x^2 for x in xs if iseven(x)]) + end == ([4, 0, 8, 0, 12, 0, 16, 0],) + + @test_broken gradient(2:9) do xs + sum(x^2 for x in xs if iseven(x)) # same without collect + end == ([4, 0, 8, 0, 12, 0, 16, 0],) + + # Iterators.Product + + @test_broken gradient(1:10, 3:7) do xs, ys + sum([x^2+y for x in xs, y in ys]) + end == (10:10:100, fill(10, 5)) + + @test_broken gradient(1:10, 3:7) do xs, ys + sum(x^2+y for x in xs, y in ys) # same without collect + end == (10:10:100, fill(10, 5)) + + # Repeat that test without sum(iterator) + function prod_acc(xs, ys) + out = 0 + for xy in Iterators.product(xs, ys) + out += xy[1]^2 + xy[2] + end + out + end + @test prod_acc(1:10, 3:7) == sum(x^2+y for x in 1:10, y in 3:7) + @test_broken gradient(prod_acc, 1:10, 3:7) == (10:10:100, fill(10, 5)) + + @test_broken gradient(rand(2,3)) do A + sum([A[i,j] for i in 1:1, j in 1:2]) + end == ([1 1 0; 0 0 0],) + + @test_broken gradient(ones(3,5), 1:7) do xs, ys + sum([x+y for x in xs, y in ys]) + end == (fill(7, 3,5), fill(15, 7)) + + @test_skip begin + bk_p = pullback((xs,ys) -> sum([x/y for x in xs, y in ys]), Diagonal([3,4,5]), [6,7]')[2] + @test bk_p(1.0)[1] isa Diagonal # projection + @test bk_p(1.0)[2] isa Adjoint + end + + # Iterators.Product with enumerate + + @test_broken gradient([2 3; 4 5]) do xs + sum([x^i+y for (i,x) in enumerate(xs), y in xs]) + end == ([8 112; 36 2004],) + + # Issue 1150 + + @test_broken gradient(x -> sum([x[i] for i in 1:3 if i != 100]), [1,2,3])[1] == [1,1,1] + @test_broken gradient(x -> sum(map(i -> x[i], filter(i -> i != 100, 1:3))), [1,2,3])[1] == [1,1,1] + +end + +@testset "adjoints of Iterators.product, PR 1170" begin + # Adapted from Zygote's file test/lib/array.jl + + y, back = pullback(Iterators.product, 1:5, 1:3, 1:2) + @test_broken back(collect(y)) == (NoTangent(), [6.0, 12.0, 18.0, 24.0, 30.0], [10.0, 20.0, 30.0], [15.0, 30.0]) + @test_broken back([(NoTangent(), j, k) for i in 1:5, j in 1:3, k in 1:2]) == (NoTangent(), [10.0, 20.0, 30.0], [15.0, 30.0]) + @test_broken back([(i, NoTangent(), k) for i in 1:5, j in 1:3, k in 1:2]) == ([6.0, 12.0, 18.0, 24.0, 30.0], NoTangent(), [15.0, 30.0]) + @test_broken back([(i, j, NoTangent()) for i in 1:5, j in 1:3, k in 1:2]) == ([6.0, 12.0, 18.0, 24.0, 30.0], [10.0, 20.0, 30.0], NoTangent()) + + # This was wrong before https://github.com/FluxML/Zygote.jl/pull/1170 + @test_broken gradient(x -> sum([y[2] * y[3] for y in Iterators.product(x, x, x, x)]), [1,2,3,4])[1] ≈ [320, 320, 320, 320] # MethodError: no method matching copy(::Nothing) + @test_broken gradient(x -> sum(y[2] * y[3] for y in Iterators.product(x, x, x, x)), [1,2,3,4])[1] ≈ [320, 320, 320, 320] # accum(a::Tuple{NoTangent, NoTangent}, b::Tuple{Tuple{ZeroTangent, NoTangent}}), tail_pullback +end + +@testset "keyword passing" begin + # https://github.com/JuliaDiff/ChainRules.jl/issues/257 + + struct Type1{VJP} + x::VJP + end + + struct Type2{compile} + Type2(compile=false) = new{compile}() + end + + function loss_adjoint(θ) + sum(f623(sensealg=Type1(Type2(true)))) + end + + i = 1 + + global x620 = Any[nothing, nothing] + + g622(x, i, sensealg) = Main.x620[i] = sensealg + ChainRulesCore.@non_differentiable g622(x, i, sensealg) + + function f623(; sensealg=nothing) + g622(x620, i, sensealg) + return rand(100) + end + + loss_adjoint([1.0]) + i = 2 + + @test_skip gradient(loss_adjoint, [1.0]) + + @test_broken x620[1] == x620[2] + +end + +@testset "splats" begin + + @test gradient(x -> max(x...), [1,2,3])[1] == [0,0,1] + @test gradient(x -> min(x...), (1,2,3))[1] === (1.0, 0.0, 0.0) + + @test_broken gradient(x -> max(x...), [1 2; 3 4])[1] == [0 0; 0 1] + @test_broken gradient(x -> max(x...), [1,2,3]')[1] == [0 0 1] + @test_broken gradient(x -> max(x...), [1,2,3]')[1] isa Adjoint + + # https://github.com/FluxML/Zygote.jl/issues/599 + @test gradient(w -> sum([w...]), [1,1])[1] isa AbstractVector + + # https://github.com/FluxML/Zygote.jl/issues/866 + f866(x) = reshape(x, fill(2, 2)...) + @test gradient(x->sum(f866(x)), rand(4))[1] == [1,1,1,1] + + # https://github.com/FluxML/Zygote.jl/issues/731 + f731(x) = sum([x' * x, x...]) + @test gradient(f731, ones(3)) == ([3,3,3],) + +end + +@testset "accumulation" begin + + # from https://github.com/FluxML/Zygote.jl/issues/905 + function net905(x1) + x2 = x1 + x3 = x1 + x2 + x4 = x1 + x2 + x3 + x5 = x1 + x2 + x3 + x4 + x6 = x1 + x2 + x3 + x4 + x5 + x7 = x1 + x2 + x3 + x4 + x5 + x6 + x8 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x9 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x10 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9 + end + loss(x) = sum(abs2, net905(x)) + + @test gradient(loss, ones(10,10))[1] == fill(131072, 10, 10) + @test_broken 150_000_000 > @allocated gradient(loss, ones(1000,1000)) + +end + +@testset "tricky broadcasting" begin + + @test gradient(x -> sum(x .+ ones(2,2)), (1,2)) == (Tangent{Tuple{Int, Int}}(2,2),) + @test gradient(x -> sum(x .+ ones(2,2)), (1,)) == (Tangent{Tuple{Int}}(4),) + @test gradient(x -> sum(x .+ ones(2,1)), (1,2)) == (Tangent{Tuple{Int, Int}}(1,1),) + + # https://github.com/FluxML/Zygote.jl/issues/975 + gt = gradient((x,p) -> prod(x .^ p), [3,4], (1,2)) + gv = gradient((x,p) -> prod(x .^ p), [3,4], [1,2]) + @test gt[1] == gv[1] + @test collect(gt[2]) ≈ gv[2] + + # closure captures y -- can't use ForwardDiff + @test gradient((x,y) -> sum((z->z^2+y[1]).(x)), [1,2,3], [4,5]) == ([2, 4, 6], [3, 0]) + @test gradient((x,y) -> sum((z->z^2+y[1]), x), [1,2,3], [4,5]) == ([2, 4, 6], [3, 0]) + @test_broken gradient((x,y) -> sum(map((z->z^2+y[1]), x)), [1,2,3], [4,5]) == ([2, 4, 6], [3, 0]) # AssertionError: Base.issingletontype(typeof(f)) + @test gradient((x,y) -> mapreduce((z->z^2+y[1]), +, x), [1,2,3], [4,5]) == ([2, 4, 6], [3, 0]) + + # type unstable + @test gradient(xs -> sum((x -> x<2 ? false : x^2).(xs)), [1,2,3])[1][2:3] == [4, 6] + @test gradient(xs -> sum((x -> x<2 ? false : x^2), xs), [1,2,3])[1][2:3] == [4, 6] + @test_broken gradient(xs -> sum(map((x -> x<2 ? false : x^2), xs)), [1,2,3])[1][2:3] == [4, 6] # AssertionError: ∂f isa TaylorBundle || ∂f isa TangentBundle{1} + @test gradient(xs -> mapreduce((x -> x<2 ? false : x^2), +, xs), [1,2,3])[1][2:3] == [4, 6] + + # with Ref, Val, Symbol + @test gradient(x -> sum(x .+ Ref(x[1])), [1,2,3]) == ([4,1,1],) + @test gradient(x -> sum(x .+ (x[1],)), [1,2,3]) == ([4,1,1],) + @test gradient(x -> sum((first∘tuple).(x, :ignore)), [1,2,3]) == ([1,1,1],) + @test gradient(x -> sum((first∘tuple).(x, Symbol)), [1,2,3]) == ([1,1,1],) + _f(x,::Val{y}=Val(2)) where {y} = x/y + @test gradient(x -> sum(_f.(x, Val(2))), [1,2,3]) == ([0.5, 0.5, 0.5],) + @test gradient(x -> sum(_f.(x)), [1,2,3]) == ([0.5, 0.5, 0.5],) + @test_broken gradient(x -> sum(map(_f, x)), [1,2,3]) == ([0.5, 0.5, 0.5],) # InexactError + @test gradient(x -> sum(map(_f, x)), [1,2,3.0]) == ([0.5, 0.5, 0.5],) + + # with Bool + @test gradient(x -> sum(1 .- (x .> 0)), randn(5)) == (NoTangent(),) + + @test gradient(x -> sum((y->1-y).(x .> 0)), randn(5)) == (NoTangent(),) + @test gradient(x -> sum(x .- (x .> 0)), randn(5)) == ([1,1,1,1,1],) + + @test gradient(x -> sum(x ./ [1,2,4]), [1,2,pi]) == ([1.0, 0.5, 0.25],) + @test_broken gradient(x -> sum(map(/, x, [1,2,4])), [1,2,pi]) == ([1.0, 0.5, 0.25],) # MethodError: no method matching (::Diffractor.∂⃖recurse{1})(::typeof(Core.arrayset), ::Bool, ::Vector{Float64}, ::Float64, ::Int64) + + # negative powers + @test gradient((x,p) -> sum(x .^ p), [1.0,2.0,4.0], [1,-1,2])[1] ≈ [1.0, -0.25, 8.0] + @test gradient((x,p) -> sum(x .^ p), [1.0,2.0,4.0], -1)[1] ≈ [-1.0, -0.25, -0.0625] + @test gradient((x,p) -> sum(z -> z^p, x), [1.0,2.0,4.0], -1)[1] ≈ [-1.0, -0.25, -0.0625] + @test gradient((x,p) -> mapreduce(z -> z^p, +, x), [1.0,2.0,4.0], -1)[1] ≈ [-1.0, -0.25, -0.0625] + + # second order + @test gradient(x -> sum(gradient(y -> sum(y.^2), x)[1]), [1, 2])[1] ≈ [2, 2] + @test_broken gradient(x -> sum(gradient(y -> sum(sin.(y)), x)[1]), [1, 2])[1] ≈ [-0.8414709848078965, -0.9092974268256817] # MethodError: no method matching Diffractor.Jet(::Int64, ::Float64, ::Tuple{Float64, Float64}) -> MethodError: no method matching copy(::Nothing) + @test_broken gradient(x -> sum(abs, gradient(y -> sum(log.(2 .* exp.(y)) .^ 2), x)[1]), [1, 2])[1] ≈ [2,2] + + # getproperty, Tangents, etc + @test_broken gradient(xs -> sum((x->x.im^2).(xs)), [1+2im,3])[1] == [4im, 0] # SILENTLY WRONG ANSWER + @test gradient(xs -> sum((x->x.im^2), xs), [1+2im,3])[1] == [4im, 0] # MethodError: no method matching lastindex(::Diffractor.OpticBundle{Int64}) + @test_broken gradient(xs -> sum(map(x->x.im^2, xs)), [1+2im,3])[1] == [4im, 0] # Tried to take the gradient of a complex-valued function + @test_broken gradient(xs -> mapreduce(x->x.im^2, +, xs), [1+2im,3])[1] == [4im, 0] # MethodError: Cannot `convert` an object of type Tangent{Complex{Int64}, NamedTuple{(:im,), Tuple{Float64}}} to an object of type Complex{Int64} + +end + +##### +##### Zygote/test/structures.jl +##### + +@testset "async" begin + + function tasks1(x) + ch = Channel(Inf) + put!(ch, x^2) + take!(ch) + end + + @test_broken gradient(tasks1, 5) == (20,) + + function tasks2(x) + ch = Channel(0) + t = @async put!(ch, x^2) + y = take!(ch) + wait(t) + return y + end + + @test_broken gradient(tasks2, 5) == (10,) + + function tasks3(x) + ch = Channel(0) + @sync begin + @async put!(ch, x^2) + take!(ch) + end + end + + @test_broken gradient(tasks3, 5) == (10,) + + tasks4(x) = fetch(@async x^2) + @test_broken gradient(tasks4, 5) == (10,) + + tasks5(x) = fetch(schedule(Task(() -> x^2))) + @test_broken gradient(tasks5, 5) == (10,) + +end + +@testset "issues" begin + + @test pullback(Array, [1f0])[1] == [1f0] + + # issue 594 + + struct A594 x::Float64 end + + f594(a,v) = a.x + v + g594(A,V) = sum(f594.(A,V)) + X = A594.(randn(2)) + Y = randn(2,2) + @test_skip begin + ∇ = gradient(g594,X,Y) # MethodError: Cannot `convert` an object of type Tangent{A594, NamedTuple{(:x,), Tuple{Float64}}} to an object of type ZeroTangent + @test ∇[1] == [(x = 2.0,); (x = 2.0,)] + @test vec(∇[1]) == [(x = 2.0,); (x = 2.0,)] + @test ∇[2] == [1 1; 1 1] + end + + # overflow + + struct M{T,B} + a::T + b::B + end + @test_skip m, b = pullback(nameof, M) # StackOverflowError + @test_skip @test b(m) == (nothing, nothing) + +end + +##### +##### Zygote/test/utils.jl +##### + +# This file contains tests of jacobian and hessian functions, +# and of adjoints of ForwardDiff functions. +# To add them, we would need to define various hessian & jacobian functions, +# possibly as in "forwarddiff.jl" in the tests here, possibly as exported functions. + + +@test_skip @testset "hessian, #hess $hess" for hess in HESSIANS + @test hess(x -> x[1]*x[2], randn(2)) ≈ [0 1; 1 0] + @test hess(((x,y),) -> x*y, randn(2)) ≈ [0 1; 1 0] + + @test hess(x -> sum(x.^3), [1 2; 3 4]) ≈ Diagonal([6, 18, 12, 24]) + @test hess(sin, pi/2) ≈ -1 + + @test_throws Exception hess(sin, im*pi) + @test_throws Exception hess(x -> x+im, pi) + @test_throws Exception hess(identity, randn(2)) +end + +@test_skip @testset "diagonal hessian" begin + @test diaghessian(x -> x[1]*x[2]^2, [1, pi]) == ([0, 2],) + + xs, y = randn(2,3), rand() + f34(xs, y) = xs[1] * (sum(xs .^ (1:3)') + y^4) # non-diagonal Hessian, two arguments + + dx, dy = diaghessian(f34, xs, y) + @test size(dx) == size(xs) + @test vec(dx) ≈ diag(hessian(x -> f34(x,y), xs)) + @test dy ≈ hessian(y -> f34(xs,y), y) + + + zs = randn(7,13) # test chunk mode + f713(zs) = sum(vec(zs)' .* exp.(vec(zs))) + @test vec(diaghessian(f713, zs)[1]) ≈ diag(hessian(f713, zs)) + + @test_throws Exception diaghessian(sin, im*pi) + @test_throws Exception diaghessian(x -> x+im, pi) + @test_throws Exception diaghessian(identity, randn(2)) +end + +@test_skip @testset "jacobian(f, args...), $jacobian" for jacobian in JACOBIANS + @test jacobian(identity, [1,2])[1] == [1 0; 0 1] + @test withjacobian(identity, [1,2]) == (val = [1,2], grad = ([1 0; 0 1],)) + + j1 = jacobian((a,x) -> a.^2 .* x, [1,2,3], 1) + @test j1[1] ≈ Diagonal([2,4,6]) + @test j1[2] ≈ [1, 4, 9] + @test j1[2] isa Vector + + j2 = jacobian((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5)) # scalar output is OK + @test j2[1] == [4 4 4] + @test j2[1] isa Matrix + @test j2[2] === nothing # input other than Number, Array is ignored + + j3 = jacobian((a,d) -> prod(a, dims=d), [1 2; 3 4], 1) + @test j3[1] ≈ [3 1 0 0; 0 0 4 2] + @test j3[2] ≈ [0, 0] # pullback is always Nothing, but array already allocated + + j4 = jacobian([1,2,-3,4,-5]) do xs + map(x -> x>0 ? x^3 : 0, xs) # pullback gives Nothing for some elements x + end + @test j4[1] ≈ Diagonal([3,12,0,48,0]) + + j5 = jacobian((x,y) -> hcat(x[1], y), fill(pi), exp(1)) # zero-array + @test j5[1] isa Matrix + @test vec(j5[1]) == [1, 0] + @test j5[2] == [0, 1] + + @test_throws ArgumentError jacobian(identity, [1,2,3+im]) + @test_throws ArgumentError jacobian(sum, [1,2,3+im]) # scalar, complex + + f6(x,y) = abs2.(x .* y) + g6 = gradient(first∘f6, [1+im, 2], 3+4im) + j6 = jacobian((x,y) -> abs2.(x .* y), [1+im, 2], 3+4im) + @test j6[1][1,:] ≈ g6[1] + @test j6[2][1] ≈ g6[2] +end + +# using ForwardDiff + +@test_skip @testset "adjoints of ForwardDiff functions" begin + f1(x) = ForwardDiff.gradient(x -> sum(exp.(x.+1)), x) + x1 = randn(3,7) + @test jacobian(f1, x1)[1] ≈ ForwardDiff.jacobian(f1, x1) + + f2(x) = ForwardDiff.jacobian(x -> log.(x[1:3] .+ x[2:4]), x) + x2 = rand(5) .+ 1 + @test jacobian(f2, x2)[1] ≈ ForwardDiff.jacobian(f2, x2) + + f3(x) = sum(ForwardDiff.hessian(x -> sum(x .^2 .* x'), x)[1:4:end]) + x3 = rand(3) + @test gradient(f3, x3)[1] ≈ ForwardDiff.gradient(f3, x3) + + @test gradient(x -> ForwardDiff.derivative(x -> x^4, x), 7) == (4 * 3 * 7^2,) + + f4(x) = ForwardDiff.derivative(x -> [x,x^2,x^3], x) + @test jacobian(f4, pi)[1] ≈ ForwardDiff.derivative(f4, pi) + + # Tests from https://github.com/FluxML/Zygote.jl/issues/769 + f(x) = [2x[1]^2 + x[1],x[2]^2 * x[1]] + g1(x) = sum(ForwardDiff.jacobian(f,x)) + out,back = pullback(g1,[2.0,3.2]) + stakehouse = back(1.0)[1] + @test typeof(stakehouse) <: Vector + @test size(stakehouse) == (2,) + @test stakehouse ≈ ForwardDiff.gradient(g1,[2.0,3.2]) + + g2(x) = prod(ForwardDiff.jacobian(f,x)) + out,back = Zygote.pullback(g2,[2.0,3.2]) + @test_skip back(1.0)[1] == ForwardDiff.gradient(g2,[2.0,3.2]) # contains NaN, @adjoint prod isn't careful + + g3(x) = sum(abs2,ForwardDiff.jacobian(f,x)) + out,back = pullback(g3,[2.0,3.2]) + @test back(1.0)[1] == ForwardDiff.gradient(g3,[2.0,3.2]) +end + +@testset "broadcasted for unary minus" begin + # TODO add test from https://github.com/FluxML/NNlib.jl/issues/432, needs a hessian function +end + +##### +##### Zygote issues +##### + +@testset "issue 954: range constructors" begin + @test_broken gradient(x -> (x:3)[1], 1.2) == (1,) + @test_broken gradient(x -> (x:1.0:3)[1], 1.2) == (1,) +end + +@testset "issue 1071: NamedTuple constructor" begin + x = [:a=>1, :b=>2] + @test gradient(x -> x[1].second, x) |> only ≈ [(first = 0, second = 1,), ZeroTangent()] + @test_broken gradient(x -> NamedTuple(x).a, x) |> only ≈ [(first = 0, second = 1,), ZeroTangent()] # ERROR: (1, get(d::IdDict{K, V}, key, default) where {K, V} @ Base iddict.jl:101, :($(Expr(:foreigncall, :(:jl_eqtable_get), Any, svec(Any, Any, Any), 0, :(:ccall), :(%1), Core.Argument(3), Core.Argument(4))))) + + y = (1, 2) + @test_broken gradient(y -> NamedTuple{(:a,:b)}(y).a, y)[1] isa Tangent{<:Tuple} # makes a Tangent{NamedTuple}! + @test_broken gradient(y -> NamedTuple{(:a, :b)}(y).a, y)[1] ≈ (1, 0) +end + +@testset "issue 1072: map on NamedTuple" begin + x = (; a=1, b=2) + @test map(sqrt, x) == (a = 1.0, b = 1.4142135623730951) + @test_broken gradient(x -> map(sqrt, x).a, x) |> only ≈ (a = 0.5, b = nothing) # MethodError: no method matching unzip_tuple(::Vector{Tuple{NoTangent, Float64}}) +end + +@testset "issue 1198: permuting values" begin + + function random_hermitian_matrix(N, specrad=1.0) + σ = 1 / √N + X = σ * (randn(N, N) + rand(N, N) * im) / √2 + H = specrad * (X + X') / (2 * √2) + end + + function random_state_vector(N) + Ψ = rand(N) .* exp.((2π * im) .* rand(N)) + Ψ ./= norm(Ψ) + return Ψ + end + + function cheby(Ψ::AbstractVector, H::AbstractMatrix, dt) + a = [0.9915910021578431, 0.18282201929219635, 0.008403088661031203, 0.000257307553262815] + Δ = 6.0 + E_min = -3.0 + β = (Δ / 2) + E_min + c = -2im / Δ + + v0 = Ψ + ϕ = a[1] * v0 + v1 = c * (H * v0 - β * v0) + ϕ = ϕ + a[2] * v1 + + c *= 2 + for i = 3:length(a) + v2 = c * (H * v1 - β * v1) + v0 + ϕ = ϕ + a[i] * v2 + + v0, v1, v2 = v1, v2, v0 # doesn't work + # aux = v0; v0 = v1; v1 = v2; v2 = aux # doesn't work + # aux = 1 * v0; v0 = 1 * v1; v1 = 1 * v2; v2 = 1 * aux # works + end + + return exp(-1im * β * dt) * ϕ + end + + N = 2 + dt = 0.1 + + Ψ0 = random_state_vector(N) + Ψ1 = random_state_vector(N) + H0 = random_hermitian_matrix(N) + H1 = random_hermitian_matrix(N) + + ϵ = 1.0 + res1 = abs2(Ψ1 ⋅ cheby(Ψ0, H0 + ϵ * H1, dt)) + @test_skip res2, _ = pullback(ϵ -> abs2(Ψ1 ⋅ cheby(Ψ0, H0 + ϵ * H1, dt)), ϵ) # TypeError: in typeassert, expected Int64, got a value of type Nothing + + @test_broken abs(res1 - res2) < 1e-12 + +end + +@testset "some rules" begin + + # https://github.com/FluxML/Zygote.jl/issues/1190 + g1 = gradient(x -> sum(normalize(x)), [1,2,3,4.0])[1] + @test_broken g1 ≈ vec(gradient(x -> sum(normalize(x)), [1 2; 3 4.0])[1]) # SILENTLY WRONG ANSWER! + + # https://github.com/FluxML/Zygote.jl/issues/1201 + struct Point1201; x; y; end + tst1201(p) = [p[1] p[2]; p[3] p[4]] + pointar = [Point1201(1,2), Point1201(3,4), Point1201(5,6), Point1201(7,8)] + + @test gradient(p -> tst1201(p)[1,2].x, pointar) |> only ≈ [0, (x = 1, y = nothing), 0, 0] + +end + +const _BC_FIVE = 5 +_B_FIVE = 5 + +@testset "issue 1177, global + $prime^2" for prime in [Diffractor.PrimeDerivativeBack, Diffractor.PrimeDerivativeFwd] + # https://github.com/FluxML/Zygote.jl/issues/1177 + let var"'" = prime + # Non-const global: + f1(x) = 4*x^2 + _B_FIVE*x + 10 + g1 = f1' + @test g1'(25) ≈ 8.0 + + # Without global: + f2(x, b2=5) = 4*x^2 + b2*x + 10 + g2 = f2' + @test g2'(25) ≈ 8.0 + + # With const global: + f3(x) = 4*x^2 + _BC_FIVE*x + 10 + g3 = f3' + @test g3'(25) ≈ 8.0 + end +end + +@testset "issue 1127: mutable struct" begin + + mutable struct Particle + q::Array{Float64,1} + p::Array{Float64,1} + m::Float64 + end + + _dot(arr1,arr2) = sum(.*(arr1,arr2)) + _modsquare(arr1) = _dot(arr1,arr1) + _norm(arr1) = sqrt(_dot(arr1,arr1)) + + function energy(p1::Particle, p2::Particle) + σ = 0.1 + ϵ = 700.0 + perg = -1.0*100.0 * p1.m * p2.m * (1.0/(_norm(p1.q - p2.q))) + end + + function hamiltonian(parray) + s = 0.0 + for i in 1:length(parray)-1 + for j in i+1:(length(parray)) + s = s + energy(parray[i], parray[j]) + end + end + return s + sum([0.5*_modsquare(p.p)/p.m for p in parray]) + end + + p1 = Particle(zeros(3), ones(3), 1.0) + p2 = Particle(zeros(3) .+ 0.1, ones(3), 1.0) + + @test_skip gradient(hamiltonian, [p1, p2]) # TypeError: in typeassert, expected Int64, got a value of type Nothing + +end + +@testset "issue 1150: filter, and Iterators.filter" begin + A = [0.0 1.0 2.0] + @test_broken gradient(x -> sum([x[i] for i in 1:3 if i != 100]), A) == ([1 1 1],) + @test_broken gradient(x -> sum(map(i -> x[i], filter(i -> i != 100, 1:3))), A) == ([1 1 1],) # AssertionError: Base.issingletontype(typeof(f)) +end + +@testset "issue 1150: other comprehensions" begin + @test_broken gradient(x -> sum(Float64[x^2 for i in 1:2]), 3.0) == (12.0,) + @test_broken gradient(xs -> sum([i/j for i in xs for j in xs]), [1,2,3.0])[1] ≈ [-4.1666666, 0.3333333, 1.166666] atol=1e-4 +end + +@testset "issue 1181: maximum" begin + foo1181(s) = s / maximum(eachindex([1, 1, 1, 1])) + bar1181(s) = s / length([1, 1, 1, 1]) + + @test gradient(bar1181, 1) == (0.25,) + @test gradient(foo1181, 1) == (0.25,) +end + +@testset "issue 1208: NamedTuple + 2nd order" begin + + NT = (weight = randn(Float32, 2, 2),) + W = NT.weight + X = randn(Float32, 2) + + G = gradient(W) do w + sum(gradient(X) do x + sum(w * x)^2 + end[1]) + end[1] + + @test G ≈ gradient(NT) do nt + sum(gradient(X) do x + sum(nt.weight * x)^2 + end[1]) + end[1].weight + +end + +@testset "issue 1247: iteration on ranges" begin + @test gradient(r -> sum(x for x in r), 0:1.0) == ([1,1],) + @test gradient(first, 0:1.0) == ([1,0],) +end + +@testset "issue 1290: comprehension" begin + # Unsolved in Zygote at the time of writing + + function loss_adjoint1(p) + prediction = p .* ones(2,100) + prediction = [prediction[:, i] for i in axes(prediction, 2)] + sum(sum.(prediction)) + end + function loss_adjoint3(p) # does not re-use name, 3x faster, same answer + prediction = p.*ones(2,100) + prediction3 = [prediction[:, i] for i in axes(prediction, 2)] + sum(sum.(prediction3)) + end + + @test_broken gradient(loss_adjoint1, ones(2)) |> only == [100, 100] + @test_broken gradient(loss_adjoint3, ones(2)) |> only == [100, 100] + + function loss_ToucheSir(p) + prediction = 2p + boxed_fn(i) = prediction^i + # Trigger https://github.com/JuliaLang/julia/issues/15276 + prediction = boxed_fn(2) + return prediction + end + + @test_broken gradient(loss_ToucheSir, 1.0) == (8.0,) # MethodError: no method matching copy(::Nothing) + + @test gradient(nt -> 2*nt.a.x, (; a=Ref(1.0))) |> only ≈ (a = (x = 2.0,),) + @test gradient(nt -> 2*nt.a.x, (; a=Ref(1.0))) |> only isa Tangent{<:NamedTuple} +end + +@testset "issue 1236: control flow" begin + # https://github.com/FluxML/Zygote.jl/issues/1236 + # Unsolved in Zygote at the time of writing + + function f1236(x) + y = [[x]', [x]] + r = 0.0 + o = 1.0 + for n in 1:2 + o *= y[n] + if n < 2 + proj_o = o * [1.0] + else + # Error + proj_o = o + # Fix + # proj_o = o * 1.0 + end + r += proj_o + end + return r + end + + function f1236_fix(x) + y = [[x]', [x]] + r = 0.0 + o = 1.0 + for n in 1:2 + o *= y[n] + if n < 2 + proj_o = o * [1.0] + else + # Error + # proj_o = o + # Fix + proj_o = o * 1.0 + end + r += proj_o + end + return r + end + + @test gradient(f1236, 1.2)[1] ≈ 3.4 + @test gradient(f1236_fix, 1.2)[1] ≈ 3.4 +end + +@testset "issue 1271: second order & global scope" begin + + # α, β = randn(2, 2), randn(2, 2) + α, β = ([1.3608105 -0.6387457; -0.3293626 -0.3191105], [1.4995675 -0.28095096; -0.7656779 1.1175071]) + + g1271(v) = map(eachcol(v), eachcol(β)) do x, y + sum(x.*x.*y) + end |> sum + + # this fails on Zygote: + @test_broken gradient(α) do k + sum(gradient(g1271, k)[1]) + end |> only ≈ [2.999135 -0.56190192; -1.5313558 2.2350142] + + # this works on Zygote: + @test_broken gradient(α) do k + sum(gradient(k) do v + map(eachcol(v), eachcol(β)) do x, y + sum(x.*x.*y) + end |> sum + end[1]) + end |> only ≈ [2.999135 -0.56190192; -1.5313558 2.2350142] + +end diff --git a/test/zygote_gradcheck.jl b/test/zygote_gradcheck.jl new file mode 100644 index 00000000..25423249 --- /dev/null +++ b/test/zygote_gradcheck.jl @@ -0,0 +1,983 @@ + +# This file contains a selection of tests from Zygote's "gradcheck.jl", +# dealing with Base and standard library functions. Many of these use rules +# which have their own more exhaustive tests in ChainRules. + +# Tests for packages (Distances, LogExpFunctions, AbstractFFTs, FillArrays) are not included. + +# Ideally this would be extended to take `gradient` both forward and reverse, +# and `jacobicheck` including 2nd derivatives, for every testset. But not yet. + +using Diffractor, ChainRulesCore, FiniteDifferences +using Test, LinearAlgebra, Random, Distributed, Statistics + +##### +##### Zygote/test/gradcheck.jl : setup +##### + +begin + # Replace simple finite differencing code with FiniteDifferences: + n_grad(f, x::Real) = (central_fdm(5, 1)(f, x),) + n_grad(f, x::AbstractArray{<:Real}) = FiniteDifferences.grad(central_fdm(5, 1), f, float(x)) + n_grad(f, xs::Vararg{Any,N}) where {N} = ntuple(N) do i + n_grad(x -> f(ntuple(j -> j==i ? x : xs[j], N)...), xs[i])[1] + end + + # Zygote's tests define functions like these: + gradcheck(f, xs...) = all(isapprox.(unthunk.(gradient(f, xs...)), n_grad(f, xs...); rtol = 1e-5, atol = 1e-5)) + @test gradcheck(sqrt, 3.14) + @test gradcheck(sum, randn(10)) + @test gradcheck(dot, randn(3), rand(3)) + + # ... but this one is called `gradtest` there: + jacobicheck(f, xs::AbstractArray...) = f(xs...) isa Number ? gradcheck(f, xs...) : + gradcheck((xs...) -> sum(sin, f(xs...)), xs...) + @test jacobicheck(identity, [1,2,3]) # one given array + @test jacobicheck(sum, [1,2,3]) # fallback to gradcheck + + jacobicheck(f, dims...) = jacobicheck(f, randn.(Float64, dims)...) + @test jacobicheck(identity, (4,5)) # one random matrix + @test jacobicheck(+, 3, 3) # two random vectors +end + +isZero(x) = x isa AbstractZero + +# Zygote's misnamed hobbit function: +function pullback(f, x...) + y, b = Diffractor.∂⃖{1}()(f, x...) + back(dy) = map(unthunk, Base.tail(b(dy))) + y, back +end + +##### +##### Zygote/test/gradcheck.jl : Base +##### + +# 73 +@testset "power" begin + @test gradient(x -> x^2, -2) == (-4,) # literal_pow + @test gradient(x -> x^10, -1.0) == (-10,) + _pow = 10 + @test gradient(x -> x^_pow, -1.0) == (-_pow,) + @test unthunk(gradient(p -> real(2^p), 2)[1]) ≈ 4*log(2) + + @test gradient(xs ->sum(xs .^ 2), [2, -1]) == ([4, -2],) + @test gradient(xs ->sum(xs .^ 10), [3, -1]) == ([10*3^9, -10],) + @test gradient(xs ->sum(xs .^ _pow), [4, -1]) == ([_pow*4^9, -10],) + + @test gradient(x -> real((1+3im) * x^2), 5+7im) == (-32 - 44im,) + @test unthunk(gradient(p -> real((1+3im) * (5+7im)^p), 2)[1]) ≈ real((-234 + 2im)*log(5 - 7im)) + # D[(1+3I)x^p, p] /. {x->5+7I, p->2} // Conjugate +end + +@testset "sum, prod" begin + @test jacobicheck(x -> sum(x, dims = (2, 3)), rand(3,4,5)) + @test jacobicheck(x -> sum(abs2, x), randn(4, 3, 2)) + @test jacobicheck(x -> sum(abs2, x; dims=1), randn(4, 3, 2)) + + @test gradcheck(x -> sum(x[i] for i in 1:length(x)), randn(10)) + @test gradcheck(x -> sum(i->x[i], 1:length(x)), randn(10)) # Zygote issue #231 + @test gradcheck(x -> sum((i->x[i]).(1:length(x))), randn(10)) + @test gradcheck(X -> sum(x -> x^2, X), randn(10)) # MethodError: no method matching lastindex(::Diffractor.OpticBundle{Float64}) + @test_broken jacobicheck(X -> sum(x -> x^2, X; dims=1), randn(10)) # Zygote issue #681 # MethodError: no method matching (::Diffractor.∂⃖recurse{1})(::typeof(Core.arrayset), ::Bool, ::Vector{Float64}, ::Float64, ::Int64) + + # Non-differentiable sum of booleans + @test gradient(sum, [true, false, true]) == (NoTangent(),) + @test gradient(x->sum(x .== 0.0), [1.2, 0.2, 0.0, -1.1, 100.0]) |> only |> isZero + + # https://github.com/FluxML/Zygote.jl/issues/314 + @test gradient((x,y) -> sum(yi -> yi*x, y), 1, [1,1]) == (2, [1, 1]) + @test gradient((x,y) -> prod(yi -> yi*x, y), 1, [1,1]) == (2, [1, 1]) + + @test_broken gradient((x,y) -> sum(map(yi -> yi*x, y)), 1, [1,1]) == (2, [1, 1]) # AssertionError: Base.issingletontype(typeof(f)) + @test_broken gradient((x,y) -> prod(map(yi -> yi*x, y)), 1, [1,1]) == (2, [1, 1]) + + @test jacobicheck(x -> prod(x, dims = (2, 3)), randn(3,4,5)) + @test gradcheck(x -> prod(x), randn(3,4)) + @test gradient(x -> prod(x), (1,2,3))[1] == (6,3,2) +end + +@testset "cumsum" begin + @test jacobicheck(x -> cumsum(x, dims=2), (3,4,5)) + @test jacobicheck(x -> cumsum(x, dims=1), (3,)) + @test jacobicheck(x -> cumsum(x), (4,)) + @test jacobicheck(x -> cumsum(x, dims=3), (5,)) # trivial + @test jacobicheck(x -> cumsum(x, dims=3), (3,4)) # trivial +end + +# 146 +@testset "getindex" begin + @test jacobicheck(x -> x[:, 2, :], (3, 4, 5)) + @test jacobicheck(x -> x[1:2, 3:4], (3, 4)) + + imat = [1 2; 3 4] + @test jacobicheck(x -> x[:, imat], (3, 4)) + @test_broken jacobicheck(x -> x[:, [1, 2, 2]], (3, 4)) + irep = [1 2; 2 2] + @test_broken jacobicheck(x -> x[1, irep], (3, 4)) + + # https://github.com/invenia/Nabla.jl/issues/139 + x = rand(3) + z = [1, 2, 3, 3] + y139(x, z) = dot(ones(4), x[z]) + @test_broken gradient(y139, x, z) == ([1, 1, 2], nothing) # ArgumentError: indexed assignment with a single value to possibly many locations is not supported; perhaps use broadcasting `.=` instead? + + # https://github.com/FluxML/Zygote.jl/issues/376 + _, back = pullback(x->x[1]*im, randn(2)) + @test back(1.0)[1] == real([-im, 0]) == [0, 0] + + # _droplike + @test gradient(x -> sum(inv, x[1, :]'), ones(2, 2)) == ([-1 -1; 0 0],) + @test gradient(x -> sum(inv, transpose(x[1, :])), ones(2, 2)) == ([-1 -1; 0 0],) # same with transpose, in case ' overloaded! + @test gradient(x -> sum(inv, x[1:1, :]'), ones(2, 2)) == ([-1 -1; 0 0],) + @test gradient(x -> sum(inv, transpose(x[1:1, :])), ones(2, 2)) == ([-1 -1; 0 0],) + @test gradient(x -> sum(inv, transpose(view(x, 1, :))), ones(2, 2)) == ([-1 -1; 0 0],) + + # https://github.com/FluxML/Zygote.jl/issues/513 + @test_broken gradient(p -> sum(Float32[1, 0] - p), [2, 3]) == ([-1, -1],) + @test_broken gradient(x -> sum(Float32[1, x] .+ x), 4) == (3.0f0,) # MethodError: no method matching (::Diffractor.∂⃖recurse{1})(::typeof(Core.arrayset), ::Bool, ::Vector{Float32}, ::Float32, ::Int64) + + # Ensure that nothings work with numeric types. + _, back = pullback(getindex, randn(4), [1]) + @test back([ZeroTangent()]) == (zeros(4), NoTangent()) + # Ensure that nothings work with non-numeric types. + _, back = pullback(getindex, [randn(2) for _ in 1:3], [1]) + @test back([ZeroTangent()]) == (NoTangent(), NoTangent()) +end + +@test_skip @testset "view" begin # Rewrite reached intrinsic function and_int. Missing rule? + @test jacobicheck(x -> view(x,:,2,:), (3,4,5)) + @test jacobicheck(x -> view(x,1:2,3:4), (3,4)) + @test jacobicheck(x -> view(x,:,[1,2,2]), (3,4)) + + # https://github.com/FluxML/Zygote.jl/issues/272 + g272(x) = view(x,1:2)[1] + @test gradient(g272, ones(3)) == ([1,0,0],) +end + +# 194 +@testset "eachcol" begin + @test_broken jacobicheck(x -> map(sum, eachcol(x)), (3,4)) # MethodError: no method matching one(::SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}) + @test_broken jacobicheck(x -> map(sum, eachcol(transpose(x))), (3,4)) + + @test_broken jacobicheck(x -> map(norm, eachcol(x)), (3,4)) + @test_broken jacobicheck(x -> map(norm, eachrow(x)), (3,4)) + @test_broken jacobicheck(x -> map(norm, eachslice(x, dims=3)), (3,4,5)) + + # some slices may have gradient nothing + @test gradient(x -> sum(y -> rand()>0.5 ? 0 : first(y), eachcol(x)), rand(3,10))[1] isa Matrix + + # strange errors (on Zygote) + @test gradient(x -> sum(norm, eachcol(x)), [1 2 3; 4 5 6])[1] isa Matrix # BoundsError: attempt to access InplaceableThunk{Thunk{ChainRules.var"#1689#1692"{Float64, SubArray{Int64, 1, Matrix{Int64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, Float64}}, ChainRules.var"#1688#1691"{Float64, SubArray{Int64, 1, Matrix{Int64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, Float64}} at index [3] + @test gradient(x -> sum(norm, eachcol(x)), rand(3,400))[1] isa Matrix +end + +@test_skip @testset "collect" begin + @test gradient(x -> sum(inv, collect(x)), (1,2)) === ((-1.0, -1/4),) + + @test gradient(x -> sum(collect(view(x, 1:1))), rand(2)) == ([1,0],) + @test gradient(x -> sum(inv, collect(view(x', 1,:))), ones(2,2)) == ([-1 0; -1 0],) + + @test gradient(xs -> sum(inv, [x^2 for x in xs]), ones(2)) == ([-2, -2],) +end + +@testset "reverse" begin + @test jacobicheck(x -> reverse(x), rand(17)) + @test jacobicheck(x -> reverse(x, 8), rand(17)) + @test jacobicheck(x -> reverse(x, 8, 13), rand(17)) + @test jacobicheck(x -> reverse(x, dims=2), rand(17, 42)) +end + +@testset "permutedims" begin + @test jacobicheck(x -> permutedims(x), rand(2)) + @test jacobicheck(x -> permutedims(x), rand(2,3)) + @test jacobicheck(x -> permutedims(x, [3,1,2]), rand(4,5,6)) + @test jacobicheck(x -> PermutedDimsArray(x, (3,1,2)), rand(4,5,6)) + let + y, back = pullback(permutedims, randn(3)) + @test first(back(randn(1, 3))) isa Vector + end +end + +@testset "repeat" begin + @test jacobicheck(x -> repeat(x; inner=2), rand(5)) + @test jacobicheck(x -> repeat(x; inner=2, outer=3), rand(5)) + @test jacobicheck(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3)) + + @test jacobicheck(x -> repeat(x, 3), rand(5)) + @test jacobicheck(x -> repeat(x, 2, 3), rand(5)) + @test jacobicheck(x -> repeat(x, 5), rand(5,7)) + @test jacobicheck(x -> repeat(x, 3, 2), rand(5,3)) +end + +@testset "fill" begin + @test jacobicheck(x->fill(x[1], 3), rand(1)) + @test jacobicheck(x->fill(x[1], 3, 4, 2), rand(1)) + + # fill(struct, ...) handled by ChainRules after + # https://github.com/FluxML/Zygote.jl/pull/1051 + @test_broken gradient(x -> fill(x, 3)[1][1], (1,2)) === ((1.0, nothing),) # MethodError: no method matching zero(::Type{Tuple{Int64, Int64}}) + @test_broken gradient(x -> fill(x, 3)[1].a, (a=1, b=2)) == ((a=1.0, b=nothing),) # 1 not 1.0 +end + +# 256 +@testset "circshift" begin + for D in 1:5 + x0 = zeros(ntuple(d->5, D)) + g = gradient(x -> x[1], x0)[1] + shift = ntuple(_ -> rand(-5:5), D) + @test gradient(x -> circshift(x, shift)[1], x0)[1] == circshift(g, map(-, shift)) + end +end + +# 273 +@test_skip @testset "kron" begin + @test jacobicheck(kron, 5, 3) # TypeError: in typeassert, expected Int64, got a value of type Nothing + @test jacobicheck(kron, rand(5), rand(3), rand(8)) + @test jacobicheck(kron, rand(5,1), rand(3,1)) + @test jacobicheck(kron, rand(5,1), rand(3,1), rand(8,1)) + @test jacobicheck(kron, rand(5,2), rand(3,2), rand(8,2)) +end + +# 279 +@testset "map" begin + @testset "bascis" begin + @test jacobicheck(xs -> sum(map(x -> x^2, xs)), rand(2,3)) + @test_broken jacobicheck((xss...) -> sum(map((xs...) -> sqrt(sum(xs.^2)), xss...)), [rand(5) for _ in 1:6]...) # Rewrite reached intrinsic function bitcast. Missing rule? + + function foo(y) + bar = (x) -> x*y + sum(map(bar, 1:5)) + end + @test_skip gradcheck(foo, 3) # MethodError: no method matching (::Diffractor.∂⃖recurse{1})(::typeof(Core.arrayset), + @test_skip gradient(v -> sum([x for x in v]), [1.1,2.2,3.3]) == ([1, 1, 1],) + end + + @test_skip @testset "bascis, pmap" begin + @test jacobicheck(xs -> sum(pmap(x -> x^2, xs)), rand(2,3)) + @test jacobicheck((xss...) -> sum(pmap((xs...) -> sqrt(sum(xs.^2)), xss...)), [rand(5) for _ in 1:6]...) + + function foo(y) + bar = (x) -> x*y + sum(pmap(bar, 1:5)) + end + @test gradtest(foo, 3) + @test gradient(v -> sum([x for x in v]), [1.1,2.2,3.3]) == ([1, 1, 1],) + end + + @testset "Tuple adjoint" begin + x = randn(3) + _, pb = pullback(x -> map(abs2, x), x) + Δy = randn(3) + @test first(pb((Δy..., ))) ≈ first(pb(Δy)) + end + + @testset "empty tuples" begin + out, pb = pullback(map, -, ()) + @test_broken pb(out) === (ZeroTangent(), ()) # ArgumentError: reducing with add_sum over an empty collection of element type Union{} is not allowed. You may be able to prevent this error by supplying an `init` value to the reducer. + + out, pb = pullback(map, +, (), ()) + @test_broken pb(()) === (ZeroTangent(), ZeroTangent(), ZeroTangent()) # MethodError: reducing over an empty collection is not allowed, ChainRules.var"#map_pullback#1234"{typeof(+), Tuple{Tuple{}, Tuple{}}, + + function build_foo(z) + foo(x) = x * z + return foo + end + out, pb = pullback(map, build_foo(5.0), ()) + @test_skip pb(()) === (ZeroTangent(), ()) + end + + @testset "Vector{Nothing} cotangent" begin + Δ = fill(ZeroTangent(), 5) + + # Unary stateless + out, pb = pullback(map, -, randn(5)) + @test pb(Δ)[2] isa Vector{ZeroTangent} + + # Binary stateless + out, pb = pullback(map, +, randn(5), randn(5)) + @test pb(Δ)[2] isa Vector{ZeroTangent} + @test pb(Δ)[3] isa Vector{ZeroTangent} + + # Stateful + function build_foo(z) + foo(x) = x * z + return foo + end + @test_skip out, pb = pullback(map, build_foo(5.0), randn(5)) # AssertionError: Base.issingletontype(typeof(f)) + @test_skip pb(Δ)[2] isa Vector{ZeroTangent} + end +end + +# Check that map infers correctly. pmap still doesn't infer. +@test_skip @testset "map inference" begin + @testset "$name" for (name, f, ȳ, xs) in [ + ("unary empty vector", sin, Float64[], (Float64[], )), + ("unary vector", sin, randn(3), (randn(3), )), + ("unary empty tuple", sin, (), ((), )), + ("unary tuple", sin, (randn(), randn()), ((randn(), randn()), )), + ("binary empty vector", +, Float64[], (Float64[], Float64[])), + ("binary vector", +, randn(2), (randn(2), randn(2))), + ("binary empty tuple", +, (), ((), ())), + ("binary tuple", +, (randn(), randn()), ((randn(), randn()), (randn(), randn()))), + ] + @inferred Zygote._pullback(Zygote.Context(), map, f, xs...) + y, pb = Zygote._pullback(Zygote.Context(), map, f, xs...) + @inferred pb(ȳ) + end +end + +@testset "map and tuples" begin + # arrays of tuples + @test_broken gradient(x -> sum(map(first, x)), [(1,2), (3,4)]) == ([(1.0, NoTangent()), (1.0, NoTangent())],) # MethodError: no method matching one(::Tuple{Int64, Int64}) + @test gradient(x -> sum(first, x), [(1,2), (3,4)]) == ([Tangent{Tuple{Int,Int}}(1.0, NoTangent()), Tangent{Tuple{Int,Int}}(1.0, NoTangent())],) + + @test gradient(x -> map(+, x, (1,2,3))[1], (4,5,6)) |> only == Tangent{Tuple{Int,Int,Int}}(1.0, ZeroTangent(), ZeroTangent()) + @test_broken gradient(x -> map(+, x, [1,2,3])[1], (4,5,6)) == ((1.0, 0.0, 0.0),) # MethodError: no method matching (::Diffractor.∂⃖recurse{1})(::typeof(Core.arrayset), ::Bool, ::Vector{Int64}, ::Int64, ::Int64) + @test_broken gradient(x -> map(+, x, (1,2,3))[1], [4,5,6]) == ([1,0,0],) # Rewrite reached intrinsic function bitcast. Missing rule? + + # mismatched lengths, should zip + @test_broken gradient(x -> map(+, x, [1,2,3,99])[1], (4,5,6)) == ((1.0, 0.0, 0.0),) + @test_broken gradient(x -> map(+, x, [1,2,3])[1], (4,5,6,99)) == ((1.0, 0.0, 0.0, NoTangent()),) +end + +# 420 +@testset "filter" begin + @test jacobicheck(xs -> filter(x -> x > 0.5, xs), rand(20)) + + @test gradient(x -> sum(log, filter(iseven, x)), 1:10) == + (map(x -> iseven(x) ? 1/x : 0, 1:10),) + @test gradient(x -> sum(abs2, im .+ filter(iseven, x)), 1:10) == + (map(x -> iseven(x) ? 2x : 0, 1:10),) + # (map(x -> iseven(x) ? 2x+2im : 0, 1:10),) +end + +# 494 +@testset "maximum" begin + @test jacobicheck(maximum, rand(2, 3)) + + @test jacobicheck(x -> maximum(x, dims=1), rand(2, 3)) + @test jacobicheck(x -> maximum(x, dims=3), rand(2, 3, 4)) + @test jacobicheck(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4)) + + @test gradient(x -> 1 / maximum(x), [1., 2, 3])[1] == [0, 0, -1/9] +end + +@testset "minimum" begin + @test jacobicheck(minimum, rand(2, 3)) + + @test jacobicheck(x -> minimum(x, dims=1), rand(2, 3)) + @test jacobicheck(x -> minimum(x, dims=2), rand(2, 3)) +end + +@testset "dropdims" begin # https://github.com/JuliaDiff/Diffractor.jl/issues/72 + @test jacobicheck(x -> dropdims(x, dims = 3), rand(2, 2, 1, 2)) + @test jacobicheck(x -> dropdims(x, dims = (2, 3)), rand(2, 1, 1, 3)) +end + +# 1186 +@testset "vcat" begin + # Scalar + @test gradient((x,y) -> sum(vcat(x,y)), 1,2) == (1,1) + @test gradient((x,y) -> sum([x; y; x]), 1,2) == (2,1) + + # Scalar + Vector + @test gradient(x -> sum(vcat(x, 1, x)), rand(3)) == ([2,2,2],) + @test gradient((x,y) -> sum(vcat(x, y, y)), rand(3), 4) == ([1,1,1], 2) + + # Vector-only. + @test jacobicheck(vcat, randn(10)) + @test jacobicheck(x -> vcat(x, [1,2,3], x), randn(3)) + + # Matrix-Vector + @test jacobicheck(x-> vcat(x, [1,2,3]), rand(2,1)) + @test jacobicheck(x-> vcat(x, ones(3,1)), rand(2)) +end + +@testset "hcat" begin + # Scalar + @test gradient((x,y) -> sum(hcat(x,y)), 1,2) == (1,1) + @test gradient((x,y) -> sum([x y]), 1,2) == (1,1) + @test gradient((a,b,c,d) -> sum(sqrt, [a b;c d]), 1,1,1,4) == (0.5, 0.5, 0.5, 0.25) + + # Vector-only + @test jacobicheck(hcat, rand(3)) + @test jacobicheck(x -> hcat(x, [1,2,3]), rand(3)) + + # Matrix-only + @test jacobicheck(hcat, rand(3,4)) + @test jacobicheck(x -> hcat(x, [1 2; 3 4], x), rand(2,2)) + + # Matrix-Scalar + @test gradient((x,y) -> sum(hcat(x, y)), 1, [2 3 4]) == (1, [1 1 1]) + @test gradient(x -> sum(hcat(1, x, 2)), transpose([3,4,5]))[1] isa Transpose + @test gradient(x -> sum(hcat(1, x, 2)), [3,4,5]')[1] isa Adjoint +end + +@testset "hvcat" begin + @test gradient(xs -> hvcat((2,2),xs...)[1,1], [1,2,3,4])[1] == [1,0,0,0] + @test gradient(xs -> hvcat((2,2),xs...)[2,1], [1,2,3,4])[1] == [0,0,1,0] + @test gradient(xs -> hvcat((2,2),xs...)[1,2], [1,2,3,4])[1] == [0,1,0,0] + @test gradient(xs -> hvcat((2,2),xs...)[2,2], [1,2,3,4])[1] == [0,0,0,1] + # https://github.com/FluxML/Zygote.jl/issues/513 + @test gradient(x -> hvcat((2,2),1,2,3,x)[4], 4.0) == (1.0,) +end + +@testset "cat(...; dims = $dim)" for dim in 1:3 + catdim = (x...) -> cat(x..., dims = dim) + @test jacobicheck(catdim, rand(4,1)) + @test jacobicheck(catdim, rand(5), rand(5,1)) + @test jacobicheck(catdim, rand(2,5), rand(2,5), rand(2,5)) + + catdimval = (x...) -> cat(x...; dims = Val(dim)) + @test jacobicheck(catdimval, rand(5), rand(5)) + @test jacobicheck(catdimval, rand(2,5), rand(2,5,1)) + + # one empty + dim == 1 || continue + @test jacobicheck(catdim, rand(0,5,3), rand(2,5,3)) +end + +# 1278 +@testset "one(s) and zero(s)" begin + @test gradient(x->sum(ones(size(x))), randn(5))[1] === NoTangent() + @test gradient(x->sum(one(x)), randn(3, 3))[1] === NoTangent() + @test gradient(x->sum(zeros(size(x))), randn(7))[1] === NoTangent() + @test gradient(x->sum(zero(x)), randn(3))[1] === NoTangent() +end + +@testset "fma and muladd" begin + @test gradcheck(x -> fma(x[1], x[2], x[3]), [2.0, 3.0, 5.0]) + @test gradcheck(x -> muladd(x[1], x[2], x[3]), [2.0, 3.0, 5.0]) +end + +# 1388 +@testset "broadcast" begin + @test gradient(x -> sum(sin.(x)), Diagonal([0,pi/2,pi]))[1] ≈ [1 0 0; 0 0 0; 0 0 -1] + + # mixing arrays & Ref(array) + a = rand(3) + b = rand(2,2) + @test jacobicheck(x -> sum(diag.((x,) .* a)), b) + @test jacobicheck(x -> sum(diag.(Ref(x) .* a)), b) + @test jacobicheck(x -> sum(diag.([x] .* a)), b) + + # tests for https://github.com/FluxML/Zygote.jl/issues/724 + x1 = rand(3, 3) + @test gradient(x -> sum(x .== 0.5), x1) |> only |> isZero + @test gradient(x -> sum(x .* (x .== maximum(x, dims=1))), x1)[1] == (x1 .== maximum(x1, dims=1)) + + # tests for un-broadcasting *, / via scalar rules + @test all(gradient((x,y) -> sum(x .* y), [1,2], 5) .≈ ([5, 5], 3)) + @test all(gradient((x,y) -> sum(x .* y), 5, [1,2]) .≈ (3, [5, 5])) + @test all(gradient((x,y) -> sum(x .* y), [1,2], [3 4 5]) .≈ ([12, 12], [3 3 3])) + @test all(gradient((x,y) -> sum(x ./ y), [1,2], 5) .≈ ([0.2, 0.2], -0.12)) + + @test_skip begin + using SparseArrays # not loaded at present + # https://github.com/FluxML/Zygote.jl/pull/1171 + sm = sprand(5, 5, 0.5) + @test gradient(x -> sum(abs2, Float32.(x)), sm)[1] ≈ gradient(x -> sum(abs2, x), Matrix{Float32}(sm))[1] + @test_broken gradient(x -> real(sum(ComplexF32.(x) .+ 1 .+ im)), sm)[1] isa SparseMatrixCSC{Float64} # MethodError: no method matching zero(::Type{Any}), in ProjectTo(xs::SparseMatrixCSC{Any, Int64}) + end + + # https://github.com/FluxML/Zygote.jl/issues/1178 + function f1179(x) + fs = Ref.(x) + getindex.(fs) + end + @test_broken gradient(sum∘f1179, ones(2)) == ([2.0, 2.0],) # MethodError: no method matching one(::Base.RefValue{Float64}) +end + +# 1489 +@testset "array +,-" begin + A, B = randn(3, 4, 5), randn(3, 4, 5) + @test jacobicheck(+, B) + @test jacobicheck(+, A, B) + @test_broken jacobicheck(+, A, B, A) + @test jacobicheck(-, A) + @test_broken jacobicheck(-, A, B) # in typeassert, expected Int64, got a value of type Nothing +end + +# 1666 +@testset "@nograd & friends" begin + @test_broken gradient(x->eachindex([10,20,30])[1], 11) == (NoTangent(),) # Rewrite reached intrinsic function and_int. Missing rule? + + @test gradient(x -> findfirst(ismissing, x), [1, missing]) == (NoTangent(),) + @test gradient(x -> findlast(ismissing, x), [1, missing]) == (NoTangent(),) + @test gradient(x -> findall(ismissing, x)[1], [1, missing]) == (NoTangent(),) + + # @test gradient(x -> Zygote.ignore(() -> x*x), 1) == (NoTangent(),) ?? replace with CRC versions? + # @test gradient(x -> Zygote.@ignore(x*x), 1) == (NoTangent(),) + # @test gradient(1) do x + # y = Zygote.@ignore x + # x * y + # end == (1,) +end + +# 1683 +@testset "fastmath" begin + @test_broken gradient(x -> begin @fastmath sin(x) end, 1) == gradient(x -> sin(x), 1) + @test_broken gradient(x -> begin @fastmath tanh(x) end, 1) == gradient(x -> tanh(x), 1) + @test_broken gradient((x, y) -> begin @fastmath x*y end, 3, 2) == gradient((x, y) -> x*y, 3, 2) + @test_broken gradient(x -> begin @fastmath real(log(x)) end, 1 + 2im) == gradient(x -> real(log(x)), 1 + 2im) # MethodError: no method matching copy(::Nothing) from perform_optic_transform(ff::Type{Diffractor.∂⃖recurse{1}}, args::Any) +end + +# 1704 +@testset "rand" begin + @test gradient(x -> rand(), 1) == (ZeroTangent(),) + @test gradient(x -> sum(rand(4)), 1) == (ZeroTangent(),) + @test gradient(x -> sum(rand(Float32, (1,1))), 1) == (ZeroTangent(),) + @test gradient(x -> sum(randn(Float32, 1,1)), 1) == (ZeroTangent(),) + @test gradient(x -> sum(randexp(Float32, (1,1))), 1) == (ZeroTangent(),) + + rng = Random.default_rng() + @test gradient(x -> sum(rand(rng, 4)), 1) == (ZeroTangent(),) + @test gradient(x -> sum(rand(rng, Float32, 1,1)), 1) == (ZeroTangent(),) + @test gradient(x -> sum(randn(rng, Float32, (1,1))), 1) == (ZeroTangent(),) + @test gradient(x -> sum(randexp(rng, Float32, 1,1)), 1) == (ZeroTangent(),) +end + +# 1737 +@testset "broadcasted($op, Array, Bool)" for op in (+,-,*) + @testset "with $bool and sizes $s" for s in ((4,), (2,3)), bool in (true,false) + r = rand(Int8, s) .+ 0.0 + z = fill(bool, s) .+ 0.0 + + fun(args...) = pullback((x, y) -> sum(op.(x,y)), args...)[1] + gfun(args...) = gradient((x, y) -> sum(op.(x,y)), args...) + + @test fun(r, z) == fun(r, bool) + @test gfun(r, bool) == (gfun(r, z)[1], NoTangent()) + + @test fun(z, r) == fun(bool, r) + @test gfun(bool, r) == (NoTangent(), gfun(z, r)[2]) + end +end + +@testset "misc issues" begin + + # https://github.com/FluxML/Zygote.jl/issues/957 + @test gradcheck(x -> prod(Base.Fix1(+, 1), x), randn(10)) # MethodError: no method matching +(::Tuple{NoTangent}, ::Tuple{NoTangent}) + @test gradcheck(x -> prod(Base.Fix2(+, 1), x), randn(10)) + + # https://github.com/FluxML/Zygote.jl/issues/996 + @test gradient(x->sum(x .+ rand.()), rand(3)) == (ones(3),) + + # https://github.com/FluxML/Zygote.jl/pull/660 + function example660(x,N) + ax = axes(x) + extraAxe = ax[2+N:end] + filledLoc = fill(1, N) + return x[:, filledLoc..., extraAxe...] + end + y, back = pullback(example660, randn(5,3,4,3), 2) + @test back(zero(y) .= 1)[1] isa Array{Float64,4} + @test back(zero(y) .= 1)[2] |> isZero + + # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/440 + f440(x,y) = sum(sum, [[x[i],y[i]] for i=1:length(x)]) + g440(x,y) = sum(sum, [(x[i],y[i]) for i=1:length(x)]) + @test_broken gradient(f440, rand(3), rand(3)) == ([1.0, 1.0, 1.0], [1.0, 1.0, 1.0]) # MethodError: no method matching (::Diffractor.∂⃖recurse{1})(::typeof(Core.arrayset), ::Bool, ::Vector{Vector{Float64}}, ::Vector{Float64}, ::Int64) + @test_broken gradient(g440, rand(3), rand(3)) == ([1.0, 1.0, 1.0], [1.0, 1.0, 1.0]) + +@test_skip begin + + # https://github.com/FluxML/Zygote.jl/issues/804 + # Comprehension is used. + io = IOBuffer() + s = 0.0 + gs = gradient([1.0, 2.0]) do xs # UndefVarError: s not defined -> Rewrite reached intrinsic function bitcast. Missing rule? + sum([(print(io, x); s += x; s * x) for x in xs]) + end + @test String(take!(io)) == "1.02.0" + @test s == 3.0 + @test gs == ([4.0, 5.0],) + + # Comprehension is not used. + io = IOBuffer() + s = 0.0 + gs = gradient([1.0, 2.0]) do xs + sum([(print(io, x); s += x; s * x) for x in xs]) + 0.0 + end + @test String(take!(io)) == "1.02.0" + @test s == 3.0 + @test gs == (nothing,) + + # Comprehension is empty and not used. + io = IOBuffer() + s = 0.0 + gs = gradient([]) do xs + [(print(io, x); s += x; s * x) for x in xs] + 0.0 + end + @test String(take!(io)) == "" + @test s == 0.0 + @test gs == (nothing,) + +end # skip + +end + +@testset "Zygote #1184" begin + n, d = 3, 2 + x = [randn(d) for _ in 1:n] + + g1184(x) = sum.((sin,), x) + h1184(x) = sum(abs2, g1184(x)) + @test gradient(h1184, x)[1] isa typeof(x) +end + +@testset "Zygote #1162" begin + function zygote1162(as, bs) + results = [f1162(a, b) for (a, b) in zip(as, bs)] + return results[2][1] + results[2][2] + end + f1162(a, b) = [a^2, b^2] + + as = (1.0, 2.0, 3.0) + bs = (4.0, 5.0, 6.0) + + @test_broken gradient(zygote1162, as, bs) == ((NoTangent(), 2*as[2], NoTangent()), (NoTangent(), 2*bs[2], NoTangent())) # MethodError: no method matching copy(::Nothing) +end + +##### +##### Zygote/test/gradcheck.jl : LinearAlgebra +##### + +@testset "LinearAlgebra misc" begin + @test jacobicheck(x -> x', rand(5)) + @test jacobicheck(x -> adjoint(x), rand(5)) + @test jacobicheck(tr, rand(4, 4)) +end + +# 140 +@test_skip @testset "LinearAlgebra.det" begin # ArgumentError: Tangent for the primal Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}} should be backed by a AbstractDict type + @test jacobicheck(det, (4, 4)) + @test jacobicheck(logdet, map(x -> x*x', (rand(4, 4),))[1]) + @test jacobicheck(x -> logabsdet(x)[1], (4, 4)) + @test gradient(det, 2.0)[1] == 1 + @test gradient(logdet, 2.0)[1] == 0.5 +end + +# 266 +@testset "LinearAlgebra.dot" begin + @test gradcheck((x, y)->dot(x[1], y[1]), [randn()], [randn()]) + @test gradcheck(dot, randn(10), randn(10)) + @test gradcheck(dot, randn(10, 3), randn(10, 3)) +end + +# 537 +@testset "LinearAlgebra.(p)inv" begin + P, Q = 13, 11 + A, B, C = randn(P, Q), randn(P, P), randn(Q, P) + @test jacobicheck(pinv, A) + @test jacobicheck(inv, B) + @test jacobicheck(pinv, C) + + @test gradient(inv, 2.0)[1] == -0.25 +end + +@testset "LinearAlgebra: *" begin + M, P, Q = 13, 7, 11 + @test jacobicheck(*, randn(M, P), randn(P, Q)) + @test jacobicheck(*, randn(M, P), randn(P)) + @test jacobicheck(*, randn(M, 1), randn(1, Q)) + @test jacobicheck(*, randn(M), randn(1, Q)) + @test jacobicheck(*, randn(10)', randn(10)) + @test jacobicheck(*, transpose(randn(10)), randn(10)) + @test jacobicheck(*, randn(10)', randn(10)) + @test jacobicheck(*, transpose(randn(10)), randn(10)) +end + +# 1383 +@testset "matrix multiplication size" begin + @test size(gradient((x, y)->sum(x * y), randn(1, 1), randn(1, 10))[1]) == (1, 1) + @test size(gradient((x, y)->sum(x * y), randn(1, 1), randn(1, 10))[2]) == (1, 10) +end + +@testset "backsolve" begin + rng, M, P, Q = MersenneTwister(123456), 13, 10, 9 + X, Y, y = randn(rng, P, P), randn(rng, P, Q), randn(rng, P) + A, B = randn(rng, P, M), randn(P, Q) + D = collect(Diagonal(randn(rng, P))) + U = collect(UpperTriangular(randn(rng, P, P))) + U[diagind(U)] .= 1 .+ 0.01 .* randn(rng, P) + + # \ (Dense square) + @test jacobicheck(\, X, Y) + @test jacobicheck(\, X, y) + + # \ (Dense rectangular) + @test jacobicheck(\, A, Y) + @test jacobicheck(\, A, y) + @test jacobicheck(\, B, Y) + @test jacobicheck(\, B, y) + + # \ (Diagonal) + @test jacobicheck(\, D, Y) + @test jacobicheck(\, D, y) + @test jacobicheck((D, Y)-> Diagonal(D) \ Y, D, Y) + @test jacobicheck((D, Y)-> Diagonal(D) \ Y, D, y) + + # \ (UpperTriangular) + @test jacobicheck(\, U, Y) + @test jacobicheck(\, U, y) + @test jacobicheck((U, Y) -> UpperTriangular(U) \ Y, U, Y) + @test jacobicheck((U, Y) -> UpperTriangular(U) \ Y, U, y) + + # / + @test jacobicheck(/, Y', X) + @test jacobicheck((y, X)->y' / X, y, X) + + # / (rectangular) + @test jacobicheck(/, Y', A') + @test jacobicheck((y, A)->y' / A', y, B) + + # / (Diagonal) + @test jacobicheck((D, Y) -> Y' / D, D, Y) + @test jacobicheck((D, Y)-> Y' / Diagonal(D), D, y) + + # / (UnitUpperTriangular) + @test jacobicheck((U, Y) -> Y' / U, U, Y) + @test_broken jacobicheck((U, Y) -> Y' / UnitUpperTriangular(U), U, y) # MethodError: no method matching isapprox(::ChainRules.var"#1711#1714"{Adjoint{Float64, Vector{Float64}}, UnitUpperTriangular{Float64, Matrix{Float64}}, ProjectTo{UnitUpperTriangular, NamedTuple{(:parent,), Tuple{ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}}, Adjoint{Float64, Vector{Float64}}}, ::Matrix{Float64}; rtol=1.0e-5, atol=1.0e-5) +end + +# 667 +@testset "LinearAlgebra.Symmetric{$T}($uplo)" for T in [Float64, ComplexF64], uplo in [:U, :L] + + A = randn(T, 7, 7) + @test jacobicheck(x->Symmetric(x, uplo), real(A)) + + y, back = pullback(Symmetric, A, uplo) + @test y isa Symmetric + + D̄ = Diagonal(randn(7)) + @test back(Diagonal(D̄))[1] isa Diagonal + @test back(Diagonal(D̄))[1] ≈ back(Matrix(D̄))[1] + + D̄ = LowerTriangular(randn(7, 7)) + @test back(D̄)[1] isa Matrix + @test back(D̄)[2] === NoTangent() + @test back(D̄)[1] ≈ back(Matrix(D̄))[1] + + if T <: Complex + @test gradcheck(real(A), imag(A)) do a, b + c = Symmetric(complex.(a, b), uplo) + d = exp.(c) + sum(real.(d) + imag.(d)) + end + end +end + +# 771 +@testset "LinearAlgebra.diag" begin + @test jacobicheck(diag, (10, 10)) + @test jacobicheck(x -> diag(x, 2), (10, 13)) +end + +@testset "LinearAlgebra.Diagonal" begin + x = randn(5) + @test jacobicheck(Diagonal, x) + + y, back = pullback(Diagonal, x) + D = randn(5,5) + @test back(D)[1] ≈ back(Diagonal(D))[1] + @test back(D)[1] ≈ back(Tangent{Diagonal}(; diag=diag(D)))[1] +end + +@testset "dense + UniformScaling" begin + A, λ = randn(10, 10), randn() + @test jacobicheck(A->A + 5I, A) + @test jacobicheck(A->5I - A, A) + @test jacobicheck(λ->A + λ[1] * I, [λ]) +end + +# 795 +@testset "LinearAlgebra.cholesky" begin + + @testset "dense" begin + A = randn(7, 7) + @test cholesky(A' * A + I).U ≈ first(pullback(A->cholesky(A' * A + I), A)).U + @test jacobicheck(A->cholesky(A' * A + I).U, A) + @test jacobicheck(A->logdet(cholesky(A' * A + I)), A) + @test jacobicheck(B->cholesky(Symmetric(B)).U, A * A' + I) + @test jacobicheck(B->logdet(cholesky(Symmetric(B))), A * A' + I) + end + + @testset "scalar" begin + y, back = pullback(cholesky, 5.0 * ones(1, 1)) + y′, back′ = pullback(cholesky, 5.0) + C̄ = randn(1, 1) + @test back′(Tangent{Cholesky}(factors=C̄,))[1] isa Real + @test back′(Tangent{Cholesky}(factors=C̄,))[1] ≈ back(Tangent{Cholesky}(factors=C̄,))[1][1, 1] + end + + @testset "Diagonal" begin + D = Diagonal(exp.(randn(8))) + Dmat = Matrix(D) + y, back = pullback(cholesky, Dmat) + y′, back′ = pullback(cholesky, D) + C̄ = Tangent{Cholesky}(; factors=randn(8, 8)) + @test back′(C̄)[1] isa Diagonal + @test diag(back′(C̄)[1]) ≈ diag(back(C̄)[1]) + end + +end + +# 825 +@testset "LinearAlgebra.lyap" begin + +end + +# 835 +@testset "matrix exponential" begin + + @testset "real dense" begin + A = randn(8, 8) + @test jacobicheck(exp, A) + + λ, V = eigen(A) + λ[1] = λ[3] + sqrt(eps(real(eltype(λ)))) / 10 + A2 = real.(V * Diagonal(λ) / V) + @test jacobicheck(exp, A2) + end + + @testset "complex dense" begin + A = randn(ComplexF64, 9, 9) + @test jacobicheck(reim(A)...) do a,b + c = complex.(a, b) + d = exp(c) + return sum(real.(d) + 2 .* imag.(d)) + end + + λ, V = eigen(A) + λ[1] = λ[3] + sqrt(eps(real(eltype(λ)))) / 10 + A2 = V * Diagonal(λ) / V + @test gradcheck(reim(A2)...) do a,b + c = complex.(a, b) + d = exp(c) + return sum(real.(d) + 2 .* imag.(d)) + end + end + + A = [ 0.0 1.0 0.0 + 0.0 0.0 1.0 + -4.34 -18.31 -0.43] + _,back = pullback(exp,A) + Ȳ = rand(3,3) + @test isreal(back(Ȳ)[1]) +end + +# 891 +@testset "eigen(::RealHermSymComplexHerm)" begin + +end + +# 1767 +@testset "LinearAlgebra.norm" begin + # check that type is not unnecessarily promoted + # https://github.com/FluxML/Zygote.jl/issues/663 + @test unthunk.(gradient(norm, randn(Float32, 2, 2))) isa Tuple{Matrix{Float32}} + @test_broken unthunk.(gradient(norm, randn(Float32, 2, 2), 3)) isa Tuple{Matrix{Float32},Float64} # Float32 is OK? + @test unthunk.(gradient(norm, randn(Float32, 2, 2), 3f0)) isa Tuple{Matrix{Float32},Float32} + @test unthunk.(gradient(norm, randn(ComplexF32, 2, 2), 3.5f0)) isa Tuple{Matrix{ComplexF32},Float32} + + # just check that these do not error + # https://github.com/FluxML/Zygote.jl/issues/331 + gradient(x->norm(x*[1, 1]), 1.23) + gradient(x->norm(x*[1 1]), 1.23) + gradient(x->norm(x*[1im, 1]), 1.23) + gradient(x->norm(x*[1im 1]), 1.23) +end + +# 1690 +@testset "LinearAlgebra.I |> Matrix" begin + @test gradient(x -> sum(Matrix(x*I, 2, 2)), 1.0) == (2.0,) + + @test gradient(x -> sum(Matrix(x[1]*I, (2, 2))), [1.0]) == ([2.0],) + @test gradient(x -> sum(Matrix{Float64}(x[1]*I, 2, 2)), [1.0]) == ([2.0],) + + # Check we haven't broken the forward pass: + @test first(pullback(x->Matrix(x*I, 2,2), 8.0)) == Matrix(8.0*I, 2,2) +end + + +##### +##### Zygote/test/gradcheck.jl : Statistics +##### + +# 430 +@testset "Statistics.mean" begin + @test jacobicheck(mean, rand(2, 3)) + + @test jacobicheck(x -> mean(x, dims=1), rand(2, 3)) + @test jacobicheck(x -> mean(x, dims=2), rand(2, 3)) + @test jacobicheck(x -> mean(x, dims=3), rand(2, 3, 4)) + + @test jacobicheck(x -> mean(x, dims=[1, 2]), rand(2, 3, 4)) +end + +@testset "Statistics.$var" for var in (std, var) + @test jacobicheck(var, rand(2, 3)) + @test jacobicheck(x -> var(x, dims=2), rand(2, 3)) + @test jacobicheck(x -> var(x, dims=(1, 2)), rand(2, 3, 4)) + + @test jacobicheck(x -> var(x, corrected=false), rand(2, 3)) + @test jacobicheck(x -> var(x, dims=1, corrected=false), rand(2, 3)) + + @test jacobicheck(x -> var(x, mean=mean(x)), rand(2, 3)) + @test jacobicheck(x -> var(x, dims=2, mean=mean(x, dims=2)), rand(2, 3)) + @test jacobicheck(x -> var(x, dims=(1, 2), mean=mean(x, dims=(1, 2))), rand(2, 3, 4)) + + @test jacobicheck(x -> var(x, corrected=false, mean=mean(x)), rand(2, 3)) + @test jacobicheck(x -> var(x, dims=1, corrected=false, mean=mean(x, dims=1)), rand(2, 3)) +end + + +##### +##### Yota/test/test_grad.jl +##### + +@testset "grad: multipath" begin + + # originally a simplified version of lstm_forward + function multipath(y, c) + f = tanh.(y[1:5, :]) + c_ = f .* c + h_ = tanh.(c_) + return h_, c_ + end + + y = rand(20, 4) + c = rand(5, 4) + loss = (y, c) -> begin + h, c = multipath(y, c) + sum(h) + end + args = (y, c) + @test gradcheck(loss, args...) +end + +@testset "grad: mean with kw" begin + + loss_simple(W, b, x) = sum(W * x .+ b) + loss_double_broadcast(W, b, x) = sum(sin.(W * x) .+ b) + loss_double_broadcast2(b, x) = sum(x .* x .+ b) + loss_kw_mean(W, b, x) = Statistics.mean(W * x .+ b; dims=1)[1] + + args = (rand(3, 4), rand(3), rand(4)) + @test gradcheck(loss_simple, args...) + @test gradcheck(loss_double_broadcast, args...) + @test gradcheck(loss_double_broadcast2, rand(3), rand(3)) + + @test_skip begin + val, g = withgradient(loss_kw_mean, args...) + @test val == loss_kw_mean(args...) + end + @test gradcheck(loss_kw_mean, args...) + + @test gradcheck(x -> sum(sum(x, dims=1)), rand(2, 3)) +end +