diff --git a/src/projection.jl b/src/projection.jl index 7be03e60c..a01774522 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -241,6 +241,12 @@ function (project::ProjectTo{AbstractArray})(dx::Number) # ... so we restore fro return fill(project.element(dx)) end +# Accept the Tangent corresponding to a Tuple -- Zygote's splats produce these +function (project::ProjectTo{AbstractArray})(dx::Tangent{<:Any, <:Tuple}) + dy = reshape(collect(backing(dx)), project.axes) + return project(dy) +end + # Ref -- works like a zero-array, also allows restoration from a number: ProjectTo(x::Ref) = ProjectTo{Ref}(; x=ProjectTo(x[])) (project::ProjectTo{Ref})(dx::Ref) = Ref(project.x(dx[]))