Skip to content

Commit 50fe9f7

Browse files
authored
Merge pull request #52 from tpapp/tp/zygote-tests
Get rid of RealVector, views, test Zygote.
2 parents c411b03 + 3583fd4 commit 50fe9f7

File tree

10 files changed

+169
-124
lines changed

10 files changed

+169
-124
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
- use new LogDensityProblems interface
44

5+
- rewrite internals to work better with AD (especially Zygote)
6+
57
# 0.3.4
68

79
- make `inverse(::ArrayTransform)` accept `AbstractArray`

Manifest.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ uuid = "d96e819e-fc66-5662-9728-84c9c7592b0a"
130130
version = "0.11.0"
131131

132132
[[Pkg]]
133-
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
133+
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
134134
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
135135

136136
[[Printf]]

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1212
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
13+
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1314
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1415

1516
[compat]
@@ -24,6 +25,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2425
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
2526
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2627
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
28+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2729

2830
[targets]
29-
test = ["Flux", "LogDensityProblems", "OffsetArrays", "Random", "ReverseDiff", "StaticArrays", "Test"]
31+
test = ["Flux", "LogDensityProblems", "OffsetArrays", "Random", "ReverseDiff", "StaticArrays", "Test", "Zygote"]

src/aggregation.jl

Lines changed: 37 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -54,37 +54,36 @@ function as(::Type{Matrix}, args...)
5454
t
5555
end
5656

57-
function transform_with(flag::LogJacFlag, t::ArrayTransform, x::RealVector)
57+
function transform_with(flag::LogJacFlag, t::ArrayTransform, x, index::T) where {T}
5858
@unpack transformation, dims = t
59+
# NOTE not using index increments as that somehow breaks type inference
5960
d = dimension(transformation)
60-
I = reshape(range(firstindex(x); length = prod(dims), step = d), dims)
61-
yℓ = map(i -> transform_with(flag, transformation, view_into(x, i, d)), I)
61+
𝐼 = reshape(range(index; length = prod(dims), step = d), dims)
62+
yℓ = map(index -> ((y, ℓ, _) = transform_with(flag, transformation, x, index); (y, ℓ)), 𝐼)
6263
ℓz = logjac_zero(flag, extended_eltype(x))
63-
first.(yℓ), isempty(yℓ) ? ℓz : ℓz + sum(last, yℓ)
64+
first.(yℓ), isempty(yℓ) ? ℓz : ℓz + sum(last, yℓ), index
6465
end
6566

66-
function transform_with(flag::LogJacFlag, t::ArrayTransform{Identity}, x::RealVector)
67+
function transform_with(flag::LogJacFlag, t::ArrayTransform{Identity}, x, index)
6768
# TODO use version below when https://github.com/FluxML/Flux.jl/issues/416 is fixed
6869
# y = reshape(copy(x), t.dims)
69-
y = reshape(map(identity, x), t.dims)
70-
y, logjac_zero(flag, extended_eltype(x))
70+
index′ = index+dimension(t)
71+
y = reshape(map(identity, x[index:(index′-1)]), t.dims)
72+
y, logjac_zero(flag, extended_eltype(x)), index′
7173
end
7274

73-
inverse_eltype(t::ArrayTransform, x::AbstractArray) =
75+
function inverse_eltype(t::ArrayTransform, x::AbstractArray)
7476
inverse_eltype(t.transformation, first(x)) # FIXME shortcut
77+
end
7578

76-
function inverse!(x::RealVector,
77-
transformation_array::ArrayTransform,
78-
y::AbstractArray)
79+
function inverse_at!(x::AbstractVector, index, transformation_array::ArrayTransform,
80+
y::AbstractArray)
7981
@unpack transformation, dims = transformation_array
8082
@argcheck size(y) == dims
81-
index = firstindex(x)
82-
d = dimension(transformation)
8383
for elt in vec(y)
84-
inverse!(view_into(x, index, d), transformation, elt)
85-
index += d;
84+
index = inverse_at!(x, index, transformation, elt)
8685
end
87-
x
86+
index
8887
end
8988

9089
####
@@ -152,24 +151,24 @@ $(SIGNATURES)
152151
Helper function for transforming tuples. Used internally, to help type inference. Use via
153152
`transfom_tuple`.
154153
"""
155-
_transform_tuple(flag::LogJacFlag, x::RealVector, index, ::Tuple{}) =
156-
(), logjac_zero(flag, extended_eltype(x))
154+
_transform_tuple(flag::LogJacFlag, x::AbstractVector, index, ::Tuple{}) =
155+
(), logjac_zero(flag, extended_eltype(x)), index
157156

158-
function _transform_tuple(flag::LogJacFlag, x::RealVector, index, ts)
157+
function _transform_tuple(flag::LogJacFlag, x::AbstractVector, index, ts)
159158
tfirst = first(ts)
160-
d = dimension(tfirst)
161-
yfirst, ℓfirst = transform_with(flag, tfirst, view_into(x, index, d))
162-
yrest, ℓrest = _transform_tuple(flag, x, index + d, Base.tail(ts))
163-
(yfirst, yrest...), ℓfirst + ℓrest
159+
yfirst, ℓfirst, index′ = transform_with(flag, tfirst, x, index)
160+
yrest, ℓrest, index′′ = _transform_tuple(flag, x, index′, Base.tail(ts))
161+
(yfirst, yrest...), ℓfirst + ℓrest, index′′
164162
end
165163

166164
"""
167165
$(SIGNATURES)
168166
169167
Helper function for tuple transformations.
170168
"""
171-
transform_tuple(flag::LogJacFlag, tt::NTransforms, x::RealVector) =
172-
_transform_tuple(flag, x, firstindex(x), tt)
169+
function transform_tuple(flag::LogJacFlag, tt::NTransforms, x, index)
170+
_transform_tuple(flag, x, index, tt)
171+
end
173172

174173
"""
175174
$(SIGNATURES)
@@ -189,39 +188,37 @@ Helper function for inverting tuples of transformations. Used internally.
189188
190189
*Performs no argument validation, caller should do this.*
191190
"""
192-
function _inverse!_tuple(x::RealVector, ts::NTransforms, ys::Tuple)
193-
index = firstindex(x)
191+
function _inverse!_tuple(x::AbstractVector, index, ts::NTransforms, ys::Tuple)
194192
for (t, y) in zip(ts, ys)
195-
d = dimension(t)
196-
inverse!(view_into(x, index, d), t, y)
197-
index += d
193+
index = inverse_at!(x, index, t, y)
198194
end
199-
x
195+
index
200196
end
201197

202-
transform_with(flag::LogJacFlag, tt::TransformTuple{<:Tuple}, x::RealVector) =
203-
transform_tuple(flag, tt.transformations, x)
198+
function transform_with(flag::LogJacFlag, tt::TransformTuple{<:Tuple}, x, index)
199+
transform_tuple(flag, tt.transformations, x, index)
200+
end
204201

205202
function inverse_eltype(tt::TransformTuple{<:Tuple}, y::Tuple)
206203
@unpack transformations = tt
207204
@argcheck length(transformations) == length(y)
208205
_inverse_eltype_tuple(transformations, y)
209206
end
210207

211-
function inverse!(x::RealVector, tt::TransformTuple{<:Tuple}, y::Tuple)
208+
function inverse_at!(x::AbstractVector, index, tt::TransformTuple{<:Tuple}, y::Tuple)
212209
@unpack transformations = tt
213210
@argcheck length(transformations) == length(y)
214211
@argcheck length(x) == dimension(tt)
215-
_inverse!_tuple(x, tt.transformations, y)
212+
_inverse!_tuple(x, index, tt.transformations, y)
216213
end
217214

218215
as(transformations::NamedTuple{N,<:NTransforms}) where N =
219216
TransformTuple(transformations)
220217

221-
function transform_with(flag::LogJacFlag, tt::TransformTuple{<:NamedTuple}, x::RealVector)
218+
function transform_with(flag::LogJacFlag, tt::TransformTuple{<:NamedTuple}, x, index)
222219
@unpack transformations = tt
223-
y, ℓ = transform_tuple(flag, values(transformations), x)
224-
NamedTuple{keys(transformations)}(y), ℓ
220+
y, ℓ, index′ = transform_tuple(flag, values(transformations), x, index)
221+
NamedTuple{keys(transformations)}(y), ℓ, index′
225222
end
226223

227224
function inverse_eltype(tt::TransformTuple{<:NamedTuple}, y::NamedTuple)
@@ -230,9 +227,9 @@ function inverse_eltype(tt::TransformTuple{<:NamedTuple}, y::NamedTuple)
230227
_inverse_eltype_tuple(values(transformations), values(y))
231228
end
232229

233-
function inverse!(x::RealVector, tt::TransformTuple{<:NamedTuple}, y::NamedTuple)
230+
function inverse_at!(x::AbstractVector, index, tt::TransformTuple{<:NamedTuple}, y::NamedTuple)
234231
@unpack transformations = tt
235232
@argcheck keys(transformations) == keys(y)
236233
@argcheck length(x) == dimension(tt)
237-
_inverse!_tuple(x, values(transformations), values(y))
234+
_inverse!_tuple(x, index, values(transformations), values(y))
238235
end

src/custom.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,16 @@ CustomTransform(n::Integer, f, flatten; kwargs...) =
7474

7575
dimension(t::CustomTransform) = dimension(t.g)
7676

77-
function transform_with(flag::NoLogJac, t::CustomTransform, x::RealVector)
77+
function transform_with(flag::NoLogJac, t::CustomTransform, x::AbstractVector, index)
7878
@unpack g, f = t
79-
f(first(transform_with(flag, g, x))), flag
79+
f(first(transform_with(flag, g, x, index))), flag, index + dimension(t)
8080
end
8181

82-
function transform_with(flag::LogJac, t::CustomTransform, x::RealVector)
82+
function transform_with(flag::LogJac, t::CustomTransform, x::AbstractVector, index)
8383
@unpack g, f, flatten, cfg = t
8484
index = firstindex(x)
85-
xv = @view x[index:(index + dimension(g) - 1)]
86-
value_and_logjac_forwarddiff(_custom_f(g, f), xv; flatten = flatten, cfg = cfg)
85+
index′ = index + dimension(g)
86+
y, ℓ = value_and_logjac_forwarddiff(_custom_f(g, f), x[index:(index′ - 1)];
87+
flatten = flatten, cfg = cfg)
88+
y, ℓ, index′
8789
end

src/generic.jl

Lines changed: 61 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
export dimension, transform, transform_and_logjac, transform_logdensity, inverse, inverse!,
22
inverse_eltype, as, random_arg, random_value
33

4-
####
5-
#### log absolute Jacobian determinant
6-
####
4+
###
5+
### log absolute Jacobian determinant
6+
###
77

88
"""
99
$(TYPEDEF)
@@ -48,8 +48,47 @@ logjac_zero(::LogJac, T::Type{<:Real}) = log(one(T))
4848

4949
logjac_zero(::NoLogJac, _) = NOLOGJAC
5050

51+
###
52+
### internal methods that implement transformations
53+
###
54+
55+
"""
56+
transform_with(flag::LogJacFlag, transformation, x::AbstractVector, index)
57+
58+
Transform elements of `x` from `index`, using `transformation`.
59+
60+
Return `(y, logjac), index′`, where
61+
62+
- `y` is the result of the transformation,
63+
64+
- `logjac` is the the log Jacobian determinant or a placeholder, depending on `flag`,
65+
66+
- `index′` is the next index in `x` after the elements used for the transformation
67+
68+
**Internal function**. Implementations
69+
70+
1. can assume that `x` has enough elements for `transformation` (ie `@inbounds` can be
71+
used),
72+
73+
2. should work with generalized indexing on `x`.
74+
"""
75+
function transform_with end
76+
77+
"""
78+
inverse_at!(x, index, transformation, y)
79+
80+
Invert transformation at `y` and put the result in `x` starting at `index`.
81+
82+
**Internal function**. Implementations
83+
84+
1. can assume that `x` has enough elements for the result (ie `@inbounds` can be used),
85+
86+
2. should work with generalized indexing on `x`.
87+
"""
88+
function inverse_at! end
89+
5190
####
52-
#### general
91+
#### API
5392
####
5493

5594
"""
@@ -69,22 +108,6 @@ The user interface consists of
69108
"""
70109
abstract type AbstractTransform end
71110

72-
"""
73-
transform_with(flag::LogJacFlag, t::AbstractTransform, x::RealVector)
74-
75-
Transform elements of `x`, starting using `transformation`.
76-
77-
The first value returned is the transformed value, the second the log Jacobian
78-
determinant or a placeholder, depending on `flag`.
79-
80-
In contrast to [`transform`] and [`transform_and_logjac`], this method always
81-
assumes that `x` is a `RealVector`, for efficient traversal. Some types
82-
implement the latter two via this method.
83-
84-
Implementations should assume generalized indexing on `x`.
85-
"""
86-
function transform_with end
87-
88111
"""
89112
$(TYPEDEF)
90113
@@ -116,15 +139,19 @@ The element type for vector `x` so that `inverse!(x, t, y)` works.
116139
function inverse_eltype end
117140

118141
"""
119-
inverse!(x, t::AbstractTransform, y)
142+
$(SIGNATURES)
120143
121144
Put `inverse(t, y)` into a preallocated vector `x`, returning `x`.
122145
123146
Generalized indexing should be assumed on `x`.
124147
125148
See [`inverse_eltype`](@ref) for determining the type of `x`.
126149
"""
127-
function inverse! end
150+
function inverse!(x::AbstractVector, transformation::AbstractTransform, y)
151+
@argcheck dimension(transformation) == length(x)
152+
inverse_at!(x, firstindex(x), transformation, y)
153+
x
154+
end
128155

129156
"""
130157
$(SIGNATURES)
@@ -169,17 +196,10 @@ function as end
169196
#### vector transformations
170197
####
171198

172-
"""
173-
An `AbstractVector` of `<:Real` elements.
174-
175-
Used internally as a type for transformations from vectors.
176-
"""
177-
const RealVector{T <: Real} = AbstractVector{T}
178-
179199
"""
180200
$(TYPEDEF)
181201
182-
Transformation that transforms `<: RealVector`s to other values.
202+
Transformation that transforms `<: AbstractVector`s to other values.
183203
184204
# Implementation
185205
@@ -193,18 +213,26 @@ $(SIGNATURES)
193213
194214
Transform `x` using `t`.
195215
"""
196-
transform(t::VectorTransform, x::RealVector) = first(transform_with(NOLOGJAC, t, x))
216+
function transform(t::VectorTransform, x::AbstractVector)
217+
@argcheck dimension(t) == length(x)
218+
first(transform_with(NOLOGJAC, t, x, firstindex(x)))
219+
end
197220

198221
"""
199222
$(SIGNATURES)
200223
201224
Transform `x` using `t`; calculating the log Jacobian determinant, returned as
202225
the second value.
203226
"""
204-
transform_and_logjac(t::VectorTransform, x::RealVector) = transform_with(LOGJAC, t, x)
227+
function transform_and_logjac(t::VectorTransform, x::AbstractVector)
228+
@argcheck dimension(t) == length(x)
229+
y, ℓ, _ = transform_with(LOGJAC, t, x, firstindex(x))
230+
y, ℓ
231+
end
205232

206-
inverse(t::VectorTransform, y) =
233+
function inverse(t::VectorTransform, y)
207234
inverse!(Vector{inverse_eltype(t, y)}(undef, dimension(t)), t, y)
235+
end
208236

209237
"""
210238
$(SIGNATURES)

src/scalar.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,17 @@ abstract type ScalarTransform <: AbstractTransform end
1313

1414
dimension(::ScalarTransform) = 1
1515

16-
transform_with(flag::NoLogJac, t::ScalarTransform, x::RealVector) =
17-
transform(t, @inbounds first(x)), flag
16+
function transform_with(flag::NoLogJac, t::ScalarTransform, x::AbstractVector, index)
17+
transform(t, @inbounds x[index]), flag, index + 1
18+
end
1819

19-
transform_with(::LogJac, t::ScalarTransform, x::RealVector) =
20-
transform_and_logjac(t, @inbounds first(x))
20+
function transform_with(::LogJac, t::ScalarTransform, x::AbstractVector, index)
21+
transform_and_logjac(t, @inbounds x[index])..., index + 1
22+
end
2123

22-
function inverse!(x::RealVector, t::ScalarTransform, y::Real)
23-
x[firstindex(x)] = inverse(t, y)
24+
function inverse_at!(x::AbstractVector, index, t::ScalarTransform, y::Real)
25+
x[index] = inverse(t, y)
26+
index + 1
2427
end
2528

2629
inverse_eltype(t::ScalarTransform, y::T) where {T <: Real} = float(T)

0 commit comments

Comments
 (0)