Skip to content

Commit e64f76e

Browse files
committed
add fast paths for arrays of correct numbers, same ndims
1 parent 8218c2c commit e64f76e

File tree

2 files changed

+54
-15
lines changed

2 files changed

+54
-15
lines changed

src/projection.jl

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,17 @@ Base.propertynames(p::ProjectTo) = propertynames(backing(p))
4040
backing(project::ProjectTo) = getfield(project, :info)
4141

4242
project_type(p::ProjectTo{T}) where {T} = T
43+
project_eltype(p::ProjectTo{T}) where {T} = eltype(T)
44+
45+
function project_promote_type(projectors)
46+
T = mapreduce(project_type, promote_type, projectors)
47+
if T <: Number
48+
# The point of this function is to make p.element for arrays. Not in use yet!
49+
return ProjectTo(zero(T))
50+
else
51+
return ProjectTo{Any}()
52+
end
53+
end
4354

4455
function Base.show(io::IO, project::ProjectTo{T}) where {T}
4556
print(io, "ProjectTo{")
@@ -181,9 +192,19 @@ end
181192
# If we don't have a more specialized `ProjectTo` rule, we just assume that there is
182193
# no structure worth re-imposing. Then any array is acceptable as a gradient.
183194

184-
# For arrays of numbers, just store one projector:
185-
function ProjectTo(x::AbstractArray{T}) where {T<:Number}
186-
return ProjectTo{AbstractArray}(; element=_eltype_projectto(T), axes=axes(x))
195+
# For arrays of numbers, just store one projector, and construct it without branches:
196+
ProjectTo(x::AbstractArray{<:Number}) = _array_projectto(x, axes(x))
197+
function _array_projectto(x::AbstractArray{T,N}, axes::NTuple{N,<:Base.OneTo{Int}}) where {T,N}
198+
element = _eltype_projectto(T)
199+
S = project_type(element)
200+
# Fastest path: N means they are OneTo, hence reshape can be skipped
201+
return ProjectTo{AbstractArray{S,N}}(; element=element, axes=axes)
202+
end
203+
function _array_projectto(x::AbstractArray{T,N}, axes::Tuple) where {T,N}
204+
element = _eltype_projectto(T)
205+
S = project_type(element)
206+
# Omitting N means reshape will be called, for OffsetArrays, SArrays, etc.
207+
return ProjectTo{AbstractArray{S}}(; element=element, axes=axes)
187208
end
188209
ProjectTo(x::AbstractArray{Bool}) = ProjectTo{NoTangent}()
189210

@@ -201,7 +222,7 @@ function ProjectTo(xs::AbstractArray)
201222
end
202223
end
203224

204-
function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M}
225+
function (project::ProjectTo{<:AbstractArray})(dx::AbstractArray{S,M}) where {S,M}
205226
# First deal with shape. The rule is that we reshape to add or remove trivial dimensions
206227
# like dx = ones(4,1), where x = ones(4), but throw an error on dx = ones(1,4) etc.
207228
dy = if axes(dx) == project.axes
@@ -225,24 +246,34 @@ function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M}
225246
return dz
226247
end
227248

249+
# Fast paths, for arrays of numbers:
250+
(::ProjectTo{AbstractArray{T,N}})(dx::AbstractArray{S,N}) where {S<:T} where {T,N} = dx
251+
(project::ProjectTo{AbstractArray{T,N}})(dx::AbstractArray{S}) where {S<:T} where {T,N} = reshape(dx, project.axes)
252+
(project::ProjectTo{AbstractArray{T,N}})(dx::AbstractArray{S,N}) where {S,T,N} = map(project.element, dx)
253+
(project::ProjectTo{AbstractArray{T,N}})(dx::AbstractArray) where {T,N} = map(project.element, reshape(dx, project.axes))
254+
228255
# Trivial case, this won't collapse Any[NoTangent(), NoTangent()] but that's OK.
229-
(project::ProjectTo{AbstractArray})(dx::AbstractArray{<:AbstractZero}) = NoTangent()
256+
(project::ProjectTo{AbstractArray{T,N}})(dx::AbstractArray{<:AbstractZero}) where {T,N} = NoTangent()
230257

231258
# Row vectors aren't acceptable as gradients for 1-row matrices:
232-
function (project::ProjectTo{AbstractArray})(dx::LinearAlgebra.AdjOrTransAbsVec)
259+
# function (project::ProjectTo{<:AbstractArray})(dx::LinearAlgebra.AdjOrTransAbsVec)
260+
# return project(reshape(vec(dx), 1, :))
261+
# end
262+
function (project::ProjectTo{AbstractArray{T,N}})(dx::LinearAlgebra.AdjOrTransAbsVec) where {T,N}
233263
return project(reshape(vec(dx), 1, :))
234264
end
235265

236266
# Zero-dimensional arrays -- these have a habit of going missing,
237267
# although really Ref() is probably a better structure.
238-
function (project::ProjectTo{AbstractArray})(dx::Number) # ... so we restore from numbers
239-
if !(project.axes isa Tuple{})
240-
throw(DimensionMismatch(
241-
"array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number",
242-
))
243-
end
244-
return fill(project.element(dx))
245-
end
268+
# function (project::ProjectTo{<:AbstractArray})(dx::Number) # ... so we restore from numbers
269+
# if !(project.axes isa Tuple{})
270+
# throw(DimensionMismatch(
271+
# "array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number",
272+
# ))
273+
# end
274+
# return fill(project.element(dx))
275+
# end
276+
(project::ProjectTo{AbstractArray{<:Number,0}})(dx::Number) = fill(project.element(dx))
246277

247278
function _projection_mismatch(axes_x::Tuple, size_dx::Tuple)
248279
size_x = map(length, axes_x)

test/projection.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
2424
@test ProjectTo(2.0+3.0im)(1+1im) === 1.0+1.0im
2525
@test ProjectTo(2.0)(1+1im) === 1.0
2626

27-
2827
# storage
2928
@test ProjectTo(1)(pi) === pi
3029
@test ProjectTo(1 + im)(pi) === ComplexF64(pi)
@@ -285,6 +284,15 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
285284
##### `OffsetArrays`
286285
#####
287286

287+
# function ProjectTo(x::OffsetArray{T,N}) where {T<:Number,N}
288+
# # As usual:
289+
# element = ChainRulesCore._eltype_projectto(T)
290+
# S = ChainRulesCore.project_type(element)
291+
# # But don't save N? Avoids fast path?
292+
# # Or perhaps the default constructor can check whether axes(x) is NTuple{N,OneTo}?
293+
# return ProjectTo{AbstractArray{S}}(; element=element, axes=axes(x))
294+
# end
295+
288296
@testset "OffsetArrays" begin
289297
# While there is no code for this, the rule that it checks axes(x) == axes(dx) else
290298
# reshape means that it restores offsets. (It throws an error on nontrivial size mismatch.)

0 commit comments

Comments
 (0)