Skip to content

Commit f2a8a77

Browse files
committed
move all old tests
1 parent 0840a82 commit f2a8a77

File tree

2 files changed

+299
-286
lines changed

2 files changed

+299
-286
lines changed

test/diffractor_01.jl

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
# The rest of this file is unchanged, except the very end,
2+
# but IMO we should move these tests to a new file.
3+
4+
# Loading Diffractor: var"'" globally will break many tests above, which use it for adjoint.
5+
6+
using Diffractor: var"'", ∂⃖, DiffractorRuleConfig
7+
using ChainRules
8+
using ChainRulesCore
9+
using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad
10+
using Symbolics
11+
using LinearAlgebra
12+
13+
14+
const fwd = Diffractor.PrimeDerivativeFwd
15+
const bwd = Diffractor.PrimeDerivativeBack
16+
17+
# Unit tests
18+
function tup2(f)
19+
a, b = ∂⃖{2}()(f, 1)
20+
c, d = b((2,))
21+
e, f = d(ZeroTangent(), 3)
22+
f((4,))
23+
end
24+
25+
@test tup2(tuple) == (NoTangent(), 4)
26+
27+
my_tuple(args...) = args
28+
ChainRules.rrule(::typeof(my_tuple), args...) = args, Δ->Core.tuple(NoTangent(), Δ...)
29+
@test tup2(my_tuple) == (ZeroTangent(), 4)
30+
31+
# Check characteristic of exp rule
32+
@variables ω α β γ δ ϵ ζ η
33+
(x1, c1) = ∂⃖{3}()(exp, ω)
34+
@test isequal(simplify(x1), simplify(exp(ω)))
35+
((_, x2), c2) = c1(α)
36+
@test isequal(simplify(x2), simplify*exp(ω)))
37+
(x3, c3) = c2(ZeroTangent(), β)
38+
@test isequal(simplify(x3), simplify*exp(ω)))
39+
((_, x4), c4) = c3(γ)
40+
@test isequal(simplify(x4), simplify(exp(ω)*+*β))))
41+
(x5, c5) = c4(ZeroTangent(), δ)
42+
@test isequal(simplify(x5), simplify*exp(ω)))
43+
((_, x6), c6) = c5(ϵ)
44+
@test isequal(simplify(x6), simplify*exp(ω) + α*δ*exp(ω)))
45+
(x7, c7) = c6(ZeroTangent(), ζ)
46+
@test isequal(simplify(x7), simplify*exp(ω) + β*δ*exp(ω)))
47+
(_, x8) = c7(η)
48+
@test isequal(simplify(x8), simplify((η +*ζ) +*ϵ) +*+*β))))*exp(ω)))
49+
50+
# Minimal 2-nd order forward smoke test
51+
@test Diffractor.∂☆{2}()(Diffractor.ZeroBundle{2}(sin),
52+
Diffractor.TangentBundle{2}(1.0, (1.0, 1.0, 0.0)))[Diffractor.CanonicalTangentIndex(1)] == sin'(1.0)
53+
54+
function simple_control_flow(b, x)
55+
if b
56+
return sin(x)
57+
else
58+
return cos(x)
59+
end
60+
end
61+
62+
function myprod(xs)
63+
s = 1
64+
for x in xs
65+
s *= x
66+
end
67+
return s
68+
end
69+
70+
function mypow(x, n)
71+
r = one(x)
72+
while n > 0
73+
n -= 1
74+
r *= x
75+
end
76+
return r
77+
end
78+
79+
function times_three_while(x)
80+
z = x
81+
i = 3
82+
while i > 1
83+
z += x
84+
i -= 1
85+
end
86+
z
87+
end
88+
89+
isa_control_flow(::Type{T}, x) where {T} = isa(x, T) ? x : T(x)
90+
91+
# Simple Reverse Mode tests
92+
let var"'" = Diffractor.PrimeDerivativeBack
93+
# Integration tests
94+
@test @inferred(sin'(1.0)) == cos(1.0)
95+
@test @inferred(sin''(1.0)) == -sin(1.0)
96+
@test sin'''(1.0) == -cos(1.0)
97+
@test sin''''(1.0) == sin(1.0)
98+
@test sin'''''(1.0) == cos(1.0) # broken = VERSION >= v"1.8"
99+
@test sin''''''(1.0) == -sin(1.0) # broken = VERSION >= v"1.8"
100+
101+
f_getfield(x) = getfield((x,), 1)
102+
@test f_getfield'(1) == 1
103+
@test f_getfield''(1) == 0
104+
@test f_getfield'''(1) == 0
105+
106+
# Higher order mixed mode tests
107+
108+
complicated_2sin(x) = (x = map(sin, Diffractor.xfill(x, 2)); x[1] + x[2])
109+
@test @inferred(complicated_2sin'(1.0)) == 2sin'(1.0)
110+
@test @inferred(complicated_2sin''(1.0)) == 2sin''(1.0) broken=true
111+
@test @inferred(complicated_2sin'''(1.0)) == 2sin'''(1.0) broken=true
112+
@test @inferred(complicated_2sin''''(1.0)) == 2sin''''(1.0) broken=true
113+
114+
# Control flow cases
115+
@test @inferred((x->simple_control_flow(true, x))'(1.0)) == sin'(1.0)
116+
@test @inferred((x->simple_control_flow(false, x))'(1.0)) == cos'(1.0)
117+
@test (x->sum(isa_control_flow(Matrix{Float64}, x)))'(Float32[1 2;]) == [1.0 1.0;]
118+
@test times_three_while'(1.0) == 3.0
119+
120+
pow5p(x) = (x->mypow(x, 5))'(x)
121+
@test pow5p(1.0) == 5.0
122+
end
123+
124+
# Simple Forward Mode tests
125+
let var"'" = Diffractor.PrimeDerivativeFwd
126+
recursive_sin(x) = sin(x)
127+
ChainRulesCore.frule(∂, ::typeof(recursive_sin), x) = frule(∂, sin, x)
128+
129+
# Integration tests
130+
@test recursive_sin'(1.0) == cos(1.0)
131+
@test recursive_sin''(1.0) == -sin(1.0)
132+
# Error: ArgumentError: Tangent for the primal Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}
133+
# should be backed by a NamedTuple type, not by Tuple{Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}}.
134+
@test_broken recursive_sin'''(1.0) == -cos(1.0)
135+
@test_broken recursive_sin''''(1.0) == sin(1.0)
136+
@test_broken recursive_sin'''''(1.0) == cos(1.0)
137+
@test_broken recursive_sin''''''(1.0) == -sin(1.0)
138+
139+
# Test the special rules for sin/cos/exp
140+
@test sin''''''(1.0) == -sin(1.0)
141+
@test cos''''''(1.0) == -cos(1.0)
142+
@test exp''''''(1.0) == exp(1.0)
143+
end
144+
145+
# Some Basic Mixed Mode tests
146+
function sin_twice_fwd(x)
147+
let var"'" = Diffractor.PrimeDerivativeFwd
148+
sin''(x)
149+
end
150+
end
151+
let var"'" = Diffractor.PrimeDerivativeFwd
152+
@test sin_twice_fwd'(1.0) == sin'''(1.0)
153+
end
154+
155+
# Regression tests
156+
@test gradient(x -> sum(abs2, x .+ 1.0), zeros(3))[1] == [2.0, 2.0, 2.0]
157+
158+
function f_broadcast(a)
159+
l = a / 2.0 * [[0. 1. 1.]; [1. 0. 1.]; [1. 1. 0.]]
160+
return sum(l)
161+
end
162+
@test fwd(f_broadcast)(1.0) == bwd(f_broadcast)(1.0)
163+
164+
# Make sure that there's no infinite recursion in kwarg calls
165+
g_kw(;x=1.0) = sin(x)
166+
f_kw(x) = g_kw(;x)
167+
@test bwd(f_kw)(1.0) == bwd(sin)(1.0)
168+
169+
function f_crit_edge(a, b, c, x)
170+
# A function with two critical edges. This used to trigger an issue where
171+
# Diffractor would fail to insert edges for the second split critical edge.
172+
y = 1x
173+
if a && b
174+
y = 2x
175+
end
176+
if b && c
177+
y = 3x
178+
end
179+
180+
if c
181+
y = 4y
182+
end
183+
184+
return y
185+
end
186+
@test bwd(x->f_crit_edge(false, false, false, x))(1.0) == 1.0
187+
@test bwd(x->f_crit_edge(true, true, false, x))(1.0) == 2.0
188+
@test bwd(x->f_crit_edge(false, true, true, x))(1.0) == 12.0
189+
@test bwd(x->f_crit_edge(false, false, true, x))(1.0) == 4.0
190+
@test bwd(bwd(x->5))(1.0) == ZeroTangent()
191+
@test fwd(fwd(x->5))(1.0) == ZeroTangent()
192+
193+
# Issue #27 - Mixup in lifting of getfield
194+
let var"'" = bwd
195+
@test (x->x^5)''(1.0) == 20.
196+
@test (x->(x*x)*(x*x)*x)'''(1.0) == 60.
197+
# Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24)
198+
@test_broken (x->x^5)'''(1.0) == 60.
199+
end
200+
201+
# Issue #38 - Splatting arrays
202+
@test gradient(x -> max(x...), (1,2,3))[1] == (0.0, 0.0, 1.0)
203+
@test gradient(x -> max(x...), [1,2,3])[1] == [0.0, 0.0, 1.0]
204+
205+
# Issue #40 - Symbol type parameters not properly quoted
206+
@test Diffractor.∂⃖recurse{1}()(Val{:transformations})[1] === Val{:transformations}()
207+
208+
# PR #43
209+
loss(res, z, w) = sum(res.U * Diagonal(res.S) * res.V) + sum(res.S .* w)
210+
x43 = rand(10, 10)
211+
@test Diffractor.gradient(x->loss(svd(x), x[:,1], x[:,2]), x43) isa Tuple{Matrix{Float64}}
212+
213+
# PR # 45 - Calling back into AD from ChainRules
214+
r45 = rrule_via_ad(DiffractorRuleConfig(), x -> log(exp(x)), 2)
215+
@test r45 isa Tuple
216+
y45, back45 = r45
217+
@test y45 2.0
218+
@test back45(1) == (ZeroTangent(), 1.0)
219+
220+
z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2)
221+
@test z45 2.0
222+
@test delta45 1.0
223+
224+
# PR #82 - getindex on non-numeric arrays
225+
@test gradient(ls -> ls[1](1.), [Base.Fix1(*, 1.)])[1][1] isa Tangent{<:Base.Fix1}
226+
227+
@testset "broadcast" begin
228+
@test gradient(x -> sum(x ./ x), [1,2,3]) == ([0,0,0],) # derivatives_given_output
229+
@test gradient(x -> sum(sqrt.(atan.(x, transpose(x)))), [1,2,3])[1] [0.2338, -0.0177, -0.0661] atol=1e-3
230+
@test gradient(x -> sum(exp.(log.(x))), [1,2,3]) == ([1,1,1],)
231+
232+
@test gradient(x -> sum((explog).(x)), [1,2,3]) == ([1,1,1],) # frule_via_ad
233+
exp_log(x) = exp(log(x))
234+
@test gradient(x -> sum(exp_log.(x)), [1,2,3]) == ([1,1,1],)
235+
@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], [1,2]) == ([1 1; 0.5 0.5], [-3, -1.75])
236+
@test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], 5) == ([0.2 0.2; 0.2 0.2], -0.4)
237+
@test gradient(x -> sum((y -> y/x).([1,2,3])), 4) == (-0.375,) # closure
238+
239+
@test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3 # array of arrays
240+
@test gradient(x -> sum(sum, Ref(x) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3
241+
@test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] [-4.1666, 0.3333, 1.1666] atol=1e-3
242+
@test gradient(x -> sum(sum, (x,) .* transpose(x)), [1,2,3])[1] [12, 12, 12] # must not take the * fast path
243+
244+
@test gradient(x -> sum(x ./ 4), [1,2,3]) == ([0.25, 0.25, 0.25],)
245+
@test gradient(x -> sum([1,2,3] ./ x), 4) == (-0.375,) # x/y rule
246+
@test gradient(x -> sum(x.^2), [1,2,3]) == ([2.0, 4.0, 6.0],) # x.^2 rule
247+
@test gradient(x -> sum([1,2,3] ./ x.^2), 4) == (-0.1875,) # scalar^2 rule
248+
249+
@test gradient(x -> sum((1,2,3) .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-1.0, -1.0, -1.0),)
250+
@test gradient(x -> sum(transpose([1,2,3]) .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-3.0, -3.0, -3.0),)
251+
@test gradient(x -> sum([1 2 3] .+ x .^ 2), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(6.0, 12.0, 18.0),)
252+
253+
@test gradient(x -> sum(x .> 2), [1,2,3]) |> only |> iszero # Bool output
254+
@test gradient(x -> sum(1 .+ iseven.(x)), [1,2,3]) |> only |> iszero
255+
@test gradient((x,y) -> sum(x .== y), [1,2,3], [1 2 3]) == (NoTangent(), NoTangent())
256+
@test gradient(x -> sum(x .+ [1,2,3]), true) |> only |> iszero # Bool input
257+
@test gradient(x -> sum(x ./ [1,2,3]), [true false]) |> only |> iszero
258+
@test gradient(x -> sum(x .* transpose([1,2,3])), (true, false)) |> only |> iszero
259+
260+
tup_adj = gradient((x,y) -> sum(2 .* x .+ log.(y)), (1,2), transpose([3,4,5]))
261+
@test tup_adj[1] == Tangent{Tuple{Int64, Int64}}(6.0, 6.0)
262+
@test tup_adj[2] [0.6666666666666666 0.5 0.4]
263+
@test tup_adj[2] isa Transpose
264+
@test gradient(x -> sum(atan.(x, (1,2,3))), Diagonal([4,5,6]))[1] isa Diagonal
265+
266+
@test gradient(x -> sum((y -> (x*y)).([1,2,3])), 4.0) == (6.0,) # closure
267+
end
268+
269+
@testset "broadcast, 2nd order" begin
270+
@test gradient(x -> gradient(y -> sum(y .* y), x)[1] |> sum, [1,2,3.0])[1] == [2,2,2] # calls "split broadcasting generic" with f = unthunk
271+
@test gradient(x -> gradient(y -> sum(y .* x), x)[1].^3 |> sum, [1,2,3.0])[1] == [3,12,27]
272+
@test_broken gradient(x -> gradient(y -> sum(y .* 2 .* y'), x)[1] |> sum, [1,2,3.0])[1] == [12, 12, 12] # Control flow support not fully implemented yet for higher-order
273+
274+
@test_broken gradient(x -> sum(gradient(x -> sum(x .^ 2 .+ x'), x)[1]), [1,2,3.0])[1] == [6,6,6] # BoundsError: attempt to access 18-element Vector{Core.Compiler.BasicBlock} at index [0]
275+
@test_broken gradient(x -> sum(gradient(x -> sum((x .+ 1) .* x .- x), x)[1]), [1,2,3.0])[1] == [2,2,2]
276+
@test_broken gradient(x -> sum(gradient(x -> sum(x .* x ./ 2), x)[1]), [1,2,3.0])[1] == [1,1,1]
277+
278+
@test_broken gradient(x -> sum(gradient(x -> sum(exp.(x)), x)[1]), [1,2,3])[1] exp.(1:3) # MethodError: no method matching copy(::Nothing)
279+
@test_broken gradient(x -> sum(gradient(x -> sum(atan.(x, x')), x)[1]), [1,2,3.0])[1] [0,0,0]
280+
@test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) .* x), x)[1]), [1,2,3]) == ([6,6,6],) # accum(a::Transpose{Float64, Vector{Float64}}, b::ChainRulesCore.Tangent{Transpose{Int64, Vector{Int64}}, NamedTuple{(:parent,), Tuple{ChainRulesCore.NoTangent}}})
281+
@test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) ./ x.^2), x)[1]), [1,2,3])[1] [27.675925925925927, -0.824074074074074, -2.1018518518518516]
282+
283+
@test_broken gradient(z -> gradient(x -> sum((y -> (x^2*y)).([1,2,3])), z)[1], 5.0) == (12.0,)
284+
end
285+

0 commit comments

Comments
 (0)