@@ -40,6 +40,17 @@ Base.propertynames(p::ProjectTo) = propertynames(backing(p))
40
40
backing (project:: ProjectTo ) = getfield (project, :info )
41
41
42
42
project_type (p:: ProjectTo{T} ) where {T} = T
43
+ project_eltype (p:: ProjectTo{T} ) where {T} = eltype (T)
44
+
45
+ function project_promote_type (projectors)
46
+ T = mapreduce (project_type, promote_type, projectors)
47
+ if T <: Number
48
+ # The point of this function is to make p.element for arrays. Not in use yet!
49
+ return ProjectTo (zero (T))
50
+ else
51
+ return ProjectTo {Any} ()
52
+ end
53
+ end
43
54
44
55
function Base. show (io:: IO , project:: ProjectTo{T} ) where {T}
45
56
print (io, " ProjectTo{" )
181
192
# If we don't have a more specialized `ProjectTo` rule, we just assume that there is
182
193
# no structure worth re-imposing. Then any array is acceptable as a gradient.
183
194
184
- # For arrays of numbers, just store one projector:
185
- function ProjectTo (x:: AbstractArray{T} ) where {T<: Number }
186
- return ProjectTo {AbstractArray} (; element= _eltype_projectto (T), axes= axes (x))
195
+ # For arrays of numbers, just store one projector, and construct it without branches:
196
+ ProjectTo (x:: AbstractArray{<:Number} ) = _array_projectto (x, axes (x))
197
+ function _array_projectto (x:: AbstractArray{T,N} , axes:: NTuple{N,<:Base.OneTo{Int}} ) where {T,N}
198
+ element = _eltype_projectto (T)
199
+ S = project_type (element)
200
+ # Fastest path: N means they are OneTo, hence reshape can be skipped
201
+ return ProjectTo {AbstractArray{S,N}} (; element= element, axes= axes)
202
+ end
203
+ function _array_projectto (x:: AbstractArray{T,N} , axes:: Tuple ) where {T,N}
204
+ element = _eltype_projectto (T)
205
+ S = project_type (element)
206
+ # Omitting N means reshape will be called, for OffsetArrays, SArrays, etc.
207
+ return ProjectTo {AbstractArray{S}} (; element= element, axes= axes)
187
208
end
188
209
ProjectTo (x:: AbstractArray{Bool} ) = ProjectTo {NoTangent} ()
189
210
@@ -201,7 +222,7 @@ function ProjectTo(xs::AbstractArray)
201
222
end
202
223
end
203
224
204
- function (project:: ProjectTo{AbstractArray} )(dx:: AbstractArray{S,M} ) where {S,M}
225
+ function (project:: ProjectTo{<: AbstractArray} )(dx:: AbstractArray{S,M} ) where {S,M}
205
226
# First deal with shape. The rule is that we reshape to add or remove trivial dimensions
206
227
# like dx = ones(4,1), where x = ones(4), but throw an error on dx = ones(1,4) etc.
207
228
dy = if axes (dx) == project. axes
@@ -225,24 +246,34 @@ function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M}
225
246
return dz
226
247
end
227
248
249
+ # Fast paths, for arrays of numbers:
250
+ (:: ProjectTo{AbstractArray{T,N}} )(dx:: AbstractArray{S,N} ) where {S<: T } where {T,N} = dx
251
+ (project:: ProjectTo{AbstractArray{T,N}} )(dx:: AbstractArray{S} ) where {S<: T } where {T,N} = reshape (dx, project. axes)
252
+ (project:: ProjectTo{AbstractArray{T,N}} )(dx:: AbstractArray{S,N} ) where {S,T,N} = map (project. element, dx)
253
+ (project:: ProjectTo{AbstractArray{T,N}} )(dx:: AbstractArray ) where {T,N} = map (project. element, reshape (dx, project. axes))
254
+
228
255
# Trivial case, this won't collapse Any[NoTangent(), NoTangent()] but that's OK.
229
- (project:: ProjectTo{AbstractArray} )(dx:: AbstractArray{<:AbstractZero} ) = NoTangent ()
256
+ (project:: ProjectTo{AbstractArray{T,N}} )(dx:: AbstractArray{<:AbstractZero} ) where {T,N} = NoTangent ()
230
257
231
258
# Row vectors aren't acceptable as gradients for 1-row matrices:
232
- function (project:: ProjectTo{AbstractArray} )(dx:: LinearAlgebra.AdjOrTransAbsVec )
259
+ # function (project::ProjectTo{<:AbstractArray})(dx::LinearAlgebra.AdjOrTransAbsVec)
260
+ # return project(reshape(vec(dx), 1, :))
261
+ # end
262
+ function (project:: ProjectTo{AbstractArray{T,N}} )(dx:: LinearAlgebra.AdjOrTransAbsVec ) where {T,N}
233
263
return project (reshape (vec (dx), 1 , :))
234
264
end
235
265
236
266
# Zero-dimensional arrays -- these have a habit of going missing,
237
267
# although really Ref() is probably a better structure.
238
- function (project:: ProjectTo{AbstractArray} )(dx:: Number ) # ... so we restore from numbers
239
- if ! (project. axes isa Tuple{})
240
- throw (DimensionMismatch (
241
- " array with ndims(x) == $(length (project. axes)) > 0 cannot have dx::Number" ,
242
- ))
243
- end
244
- return fill (project. element (dx))
245
- end
268
+ # function (project::ProjectTo{<:AbstractArray})(dx::Number) # ... so we restore from numbers
269
+ # if !(project.axes isa Tuple{})
270
+ # throw(DimensionMismatch(
271
+ # "array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number",
272
+ # ))
273
+ # end
274
+ # return fill(project.element(dx))
275
+ # end
276
+ (project:: ProjectTo{AbstractArray{<:Number,0}} )(dx:: Number ) = fill (project. element (dx))
246
277
247
278
function _projection_mismatch (axes_x:: Tuple , size_dx:: Tuple )
248
279
size_x = map (length, axes_x)
0 commit comments