Skip to content

Commit fef6aa8

Browse files
mcabbottoxinabox
andauthored
Remove NTuple from unbroadcast (#661)
* remove NTuple * spaces Co-authored-by: Frames Catherine White <[email protected]> Co-authored-by: Frames Catherine White <[email protected]>
1 parent 63cc4e0 commit fef6aa8

File tree

3 files changed

+7
-3
lines changed

3 files changed

+7
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.44.1"
3+
version = "1.44.2"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/rulesets/Base/broadcast.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,13 +328,13 @@ end
328328
unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx::AbstractZero) = dx
329329

330330
function unbroadcast(x::T, dx) where {T<:Tuple{Vararg{Any,N}}} where {N}
331-
val = if length(x) == length(dx)
331+
val = if N == length(dx)
332332
dx
333333
else
334334
sum(dx; dims=2:ndims(dx))
335335
end
336336
eltype(val) <: AbstractZero && return NoTangent()
337-
return ProjectTo(x)(NTuple{length(x)}(val)) # Tangent
337+
return ProjectTo(x)(Tuple{Vararg{Any,N}}(val)) # Tangent
338338
end
339339
unbroadcast(x::Tuple, dx::AbstractZero) = dx
340340

test/rulesets/Base/broadcast.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,4 +173,8 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
173173
test_rrule(copybroadcasted, complex, rand())
174174
end
175175
end
176+
177+
@testset "bugs" begin
178+
@test ChainRules.unbroadcast((1, 2, [3]), [4, 5, [6]]) isa Tangent # earlier, NTuple demanded same type
179+
end
176180
end

0 commit comments

Comments
 (0)