@@ -54,37 +54,36 @@ function as(::Type{Matrix}, args...)
5454 t
5555end
5656
57- function transform_with (flag:: LogJacFlag , t:: ArrayTransform , x:: RealVector )
57+ function transform_with (flag:: LogJacFlag , t:: ArrayTransform , x, index :: T ) where {T}
5858 @unpack transformation, dims = t
59+ # NOTE not using index increments as that somehow breaks type inference
5960 d = dimension (transformation)
60- I = reshape (range (firstindex (x) ; length = prod (dims), step = d), dims)
61- yℓ = map (i -> transform_with (flag, transformation, view_into ( x, i, d )), I )
61+ 𝐼 = reshape (range (index ; length = prod (dims), step = d), dims)
62+ yℓ = map (index -> ((y, ℓ, _) = transform_with (flag, transformation, x, index); (y, ℓ )), 𝐼 )
6263 ℓz = logjac_zero (flag, extended_eltype (x))
63- first .(yℓ), isempty (yℓ) ? ℓz : ℓz + sum (last, yℓ)
64+ first .(yℓ), isempty (yℓ) ? ℓz : ℓz + sum (last, yℓ), index
6465end
6566
66- function transform_with (flag:: LogJacFlag , t:: ArrayTransform{Identity} , x:: RealVector )
67+ function transform_with (flag:: LogJacFlag , t:: ArrayTransform{Identity} , x, index )
6768 # TODO use version below when https://github.com/FluxML/Flux.jl/issues/416 is fixed
6869 # y = reshape(copy(x), t.dims)
69- y = reshape (map (identity, x), t. dims)
70- y, logjac_zero (flag, extended_eltype (x))
70+ index′ = index+ dimension (t)
71+ y = reshape (map (identity, x[index: (index′- 1 )]), t. dims)
72+ y, logjac_zero (flag, extended_eltype (x)), index′
7173end
7274
73- inverse_eltype (t:: ArrayTransform , x:: AbstractArray ) =
75+ function inverse_eltype (t:: ArrayTransform , x:: AbstractArray )
7476 inverse_eltype (t. transformation, first (x)) # FIXME shortcut
77+ end
7578
76- function inverse! (x:: RealVector ,
77- transformation_array:: ArrayTransform ,
78- y:: AbstractArray )
79+ function inverse_at! (x:: AbstractVector , index, transformation_array:: ArrayTransform ,
80+ y:: AbstractArray )
7981 @unpack transformation, dims = transformation_array
8082 @argcheck size (y) == dims
81- index = firstindex (x)
82- d = dimension (transformation)
8383 for elt in vec (y)
84- inverse! (view_into (x, index, d), transformation, elt)
85- index += d;
84+ index = inverse_at! (x, index, transformation, elt)
8685 end
87- x
86+ index
8887end
8988
9089# ###
@@ -152,24 +151,24 @@ $(SIGNATURES)
152151Helper function for transforming tuples. Used internally, to help type inference. Use via
153152`transfom_tuple`.
154153"""
155- _transform_tuple (flag:: LogJacFlag , x:: RealVector , index, :: Tuple{} ) =
156- (), logjac_zero (flag, extended_eltype (x))
154+ _transform_tuple (flag:: LogJacFlag , x:: AbstractVector , index, :: Tuple{} ) =
155+ (), logjac_zero (flag, extended_eltype (x)), index
157156
158- function _transform_tuple (flag:: LogJacFlag , x:: RealVector , index, ts)
157+ function _transform_tuple (flag:: LogJacFlag , x:: AbstractVector , index, ts)
159158 tfirst = first (ts)
160- d = dimension (tfirst)
161- yfirst, ℓfirst = transform_with (flag, tfirst, view_into (x, index, d))
162- yrest, ℓrest = _transform_tuple (flag, x, index + d, Base. tail (ts))
163- (yfirst, yrest... ), ℓfirst + ℓrest
159+ yfirst, ℓfirst, index′ = transform_with (flag, tfirst, x, index)
160+ yrest, ℓrest, index′′ = _transform_tuple (flag, x, index′, Base. tail (ts))
161+ (yfirst, yrest... ), ℓfirst + ℓrest, index′′
164162end
165163
166164"""
167165$(SIGNATURES)
168166
169167Helper function for tuple transformations.
170168"""
171- transform_tuple (flag:: LogJacFlag , tt:: NTransforms , x:: RealVector ) =
172- _transform_tuple (flag, x, firstindex (x), tt)
169+ function transform_tuple (flag:: LogJacFlag , tt:: NTransforms , x, index)
170+ _transform_tuple (flag, x, index, tt)
171+ end
173172
174173"""
175174$(SIGNATURES)
@@ -189,39 +188,37 @@ Helper function for inverting tuples of transformations. Used internally.
189188
190189*Performs no argument validation, caller should do this.*
191190"""
192- function _inverse!_tuple (x:: RealVector , ts:: NTransforms , ys:: Tuple )
193- index = firstindex (x)
191+ function _inverse!_tuple (x:: AbstractVector , index, ts:: NTransforms , ys:: Tuple )
194192 for (t, y) in zip (ts, ys)
195- d = dimension (t)
196- inverse! (view_into (x, index, d), t, y)
197- index += d
193+ index = inverse_at! (x, index, t, y)
198194 end
199- x
195+ index
200196end
201197
202- transform_with (flag:: LogJacFlag , tt:: TransformTuple{<:Tuple} , x:: RealVector ) =
203- transform_tuple (flag, tt. transformations, x)
198+ function transform_with (flag:: LogJacFlag , tt:: TransformTuple{<:Tuple} , x, index)
199+ transform_tuple (flag, tt. transformations, x, index)
200+ end
204201
205202function inverse_eltype (tt:: TransformTuple{<:Tuple} , y:: Tuple )
206203 @unpack transformations = tt
207204 @argcheck length (transformations) == length (y)
208205 _inverse_eltype_tuple (transformations, y)
209206end
210207
211- function inverse ! (x:: RealVector , tt:: TransformTuple{<:Tuple} , y:: Tuple )
208+ function inverse_at ! (x:: AbstractVector , index , tt:: TransformTuple{<:Tuple} , y:: Tuple )
212209 @unpack transformations = tt
213210 @argcheck length (transformations) == length (y)
214211 @argcheck length (x) == dimension (tt)
215- _inverse!_tuple (x, tt. transformations, y)
212+ _inverse!_tuple (x, index, tt. transformations, y)
216213end
217214
218215as (transformations:: NamedTuple{N,<:NTransforms} ) where N =
219216 TransformTuple (transformations)
220217
221- function transform_with (flag:: LogJacFlag , tt:: TransformTuple{<:NamedTuple} , x:: RealVector )
218+ function transform_with (flag:: LogJacFlag , tt:: TransformTuple{<:NamedTuple} , x, index )
222219 @unpack transformations = tt
223- y, ℓ = transform_tuple (flag, values (transformations), x)
224- NamedTuple {keys(transformations)} (y), ℓ
220+ y, ℓ, index′ = transform_tuple (flag, values (transformations), x, index )
221+ NamedTuple {keys(transformations)} (y), ℓ, index′
225222end
226223
227224function inverse_eltype (tt:: TransformTuple{<:NamedTuple} , y:: NamedTuple )
@@ -230,9 +227,9 @@ function inverse_eltype(tt::TransformTuple{<:NamedTuple}, y::NamedTuple)
230227 _inverse_eltype_tuple (values (transformations), values (y))
231228end
232229
233- function inverse ! (x:: RealVector , tt:: TransformTuple{<:NamedTuple} , y:: NamedTuple )
230+ function inverse_at ! (x:: AbstractVector , index , tt:: TransformTuple{<:NamedTuple} , y:: NamedTuple )
234231 @unpack transformations = tt
235232 @argcheck keys (transformations) == keys (y)
236233 @argcheck length (x) == dimension (tt)
237- _inverse!_tuple (x, values (transformations), values (y))
234+ _inverse!_tuple (x, index, values (transformations), values (y))
238235end
0 commit comments