diff --git a/Project.toml b/Project.toml index 57722694..5af68933 100644 --- a/Project.toml +++ b/Project.toml @@ -20,13 +20,3 @@ StaticArrays = "1" StatsBase = "0.33" StructArrays = "0.6" julia = "1.7" - -[extras] -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["Test", "ForwardDiff", "LinearAlgebra", "Random", "Symbolics"] diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 00000000..0f1e29bb --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,22 @@ +[deps] +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" +Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[compat] +ChainRules = "1.5" +ChainRulesCore = "1.2" +Combinatorics = "1" +StaticArrays = "1" +StatsBase = "0.33" +StructArrays = "0.6" +julia = "1.7" diff --git a/test/runtests.jl b/test/runtests.jl index 50a9039b..bfd40eb0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,21 +25,21 @@ ChainRules.rrule(::typeof(my_tuple), args...) = args, Δ->Core.tuple(NoTangent() # Check characteristic of exp rule @variables ω α β γ δ ϵ ζ η (x1, c1) = ∂⃖{3}()(exp, ω) -@test simplify(x1 == exp(ω)).val +@test isequal(simplify(x1), simplify(exp(ω))) ((_, x2), c2) = c1(α) -@test simplify(x2 == α*exp(ω)).val +@test isequal(simplify(x2), simplify(α*exp(ω))) (x3, c3) = c2(ZeroTangent(), β) -@test simplify(x3 == β*exp(ω)).val +@test isequal(simplify(x3), simplify(β*exp(ω))) ((_, x4), c4) = c3(γ) -@test simplify(x4 == exp(ω)*(γ + (α*β))).val +@test isequal(simplify(x4), simplify(exp(ω)*(γ + (α*β)))) (x5, c5) = c4(ZeroTangent(), δ) -@test simplify(x5 == δ*exp(ω)).val +@test isequal(simplify(x5), simplify(δ*exp(ω))) ((_, x6), c6) = c5(ϵ) -@test simplify(x6 == ϵ*exp(ω) + α*δ*exp(ω)).val +@test isequal(simplify(x6), simplify(ϵ*exp(ω) + α*δ*exp(ω))) (x7, c7) = c6(ZeroTangent(), ζ) -@test simplify(x7 == ζ*exp(ω) + β*δ*exp(ω)).val +@test isequal(simplify(x7), simplify(ζ*exp(ω) + β*δ*exp(ω))) (_, x8) = c7(η) -@test simplify(x8 == (η + (α*ζ) + (β*ϵ) + (δ*(γ + (α*β))))*exp(ω)).val +@test isequal(simplify(x8), simplify((η + (α*ζ) + (β*ϵ) + (δ*(γ + (α*β))))*exp(ω))) # Minimal 2-nd order forward smoke test @test Diffractor.∂☆{2}()(Diffractor.ZeroBundle{2}(sin), @@ -123,10 +123,12 @@ let var"'" = Diffractor.PrimeDerivativeFwd # Integration tests @test recursive_sin'(1.0) == cos(1.0) @test recursive_sin''(1.0) == -sin(1.0) - @test recursive_sin'''(1.0) == -cos(1.0) - @test recursive_sin''''(1.0) == sin(1.0) - @test recursive_sin'''''(1.0) == cos(1.0) - @test recursive_sin''''''(1.0) == -sin(1.0) + # Error: ArgumentError: Tangent for the primal Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}} + # should be backed by a NamedTuple type, not by Tuple{Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}}. + @test_broken recursive_sin'''(1.0) == -cos(1.0) + @test_broken recursive_sin''''(1.0) == sin(1.0) + @test_broken recursive_sin'''''(1.0) == cos(1.0) + @test_broken recursive_sin''''''(1.0) == -sin(1.0) # Test the special rules for sin/cos/exp @test sin''''''(1.0) == -sin(1.0) @@ -148,7 +150,7 @@ end @test gradient(x -> sum(abs2, x .+ 1.0), zeros(3))[1] == [2.0, 2.0, 2.0] const fwd = Diffractor.PrimeDerivativeFwd -const bwd = Diffractor.PrimeDerivativeFwd +const bwd = Diffractor.PrimeDerivativeBack function f_broadcast(a) l = a / 2.0 * [[0. 1. 1.]; [1. 0. 1.]; [1. 1. 0.]] @@ -186,7 +188,9 @@ end # Issue #27 - Mixup in lifting of getfield let var"'" = bwd @test (x->x^5)''(1.0) == 20. - @test (x->x^5)'''(1.0) == 60. + @test (x->(x*x)*(x*x)*x)''' == 60. + # Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24) + @test_broken (x->x^5)'''(1.0) == 60. end # Issue #38 - Splatting arrays