Skip to content

Commit 178d411

Browse files
authored
Merge pull request #53 from tpapp/tp/fix-broadcasting
fix broadcasting
2 parents 50fe9f7 + ec7b7ac commit 178d411

File tree

3 files changed

+45
-28
lines changed

3 files changed

+45
-28
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
- rewrite internals to work better with AD (especially Zygote)
66

7+
- fix broadcasting (`Ref(transformation)` no longer necessary)
8+
79
# 0.3.4
810

911
- make `inverse(::ArrayTransform)` accept `AbstractArray`

src/generic.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ The user interface consists of
108108
"""
109109
abstract type AbstractTransform end
110110

111+
Base.broadcastable(t::AbstractTransform) = Ref(t)
112+
111113
"""
112114
$(TYPEDEF)
113115

test/runtests.jl

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -378,34 +378,34 @@ end
378378

379379
end
380380

381-
if VERSION v"1.1"
382-
if CIENV
383-
@info "installing Zygote#master"
384-
import Pkg
385-
Pkg.API.add(Pkg.PackageSpec(; name = "Zygote", rev = "master"))
386-
end
387-
388-
import Zygote
389-
390-
@testset "Zygote AD" begin
391-
# Zygote
392-
# NOTE @inferred removed as it currently fails
393-
# NOTE tests simplified disabled as they currently fail
394-
t = as((μ = asℝ, ))
395-
function f(θ)
396-
@unpack μ = θ
397-
-(abs2(μ))
398-
end
399-
P = TransformedLogDensity(t, f)
400-
x = zeros(dimension(t))
401-
PF = ADgradient(:ForwardDiff, P)
402-
PZ = ADgradient(:Zygote, P)
403-
@test @inferred(logdensity(PZ, x)) == logdensity(P, x)
404-
vZ, gZ = logdensity_and_gradient(PZ, x)
405-
@test vZ == logdensity(P, x)
406-
@test gZ last(logdensity_and_gradient(PF, x))
407-
end
408-
end
381+
# if VERSION ≥ v"1.1"
382+
# if CIENV
383+
# @info "installing Zygote"
384+
# import Pkg
385+
# Pkg.API.add(Pkg.PackageSpec(; name = "Zygote"))
386+
# end
387+
388+
# import Zygote
389+
390+
# @testset "Zygote AD" begin
391+
# # Zygote
392+
# # NOTE @inferred removed as it currently fails
393+
# # NOTE tests simplified disabled as they currently fail
394+
# t = as((μ = asℝ, ))
395+
# function f(θ)
396+
# @unpack μ = θ
397+
# -(abs2(μ))
398+
# end
399+
# P = TransformedLogDensity(t, f)
400+
# x = zeros(dimension(t))
401+
# PF = ADgradient(:ForwardDiff, P)
402+
# PZ = ADgradient(:Zygote, P)
403+
# @test @inferred(logdensity(PZ, x)) == logdensity(P, x)
404+
# vZ, gZ = logdensity_and_gradient(PZ, x)
405+
# @test vZ == logdensity(P, x)
406+
# @test gZ ≈ last(logdensity_and_gradient(PF, x))
407+
# end
408+
# end
409409

410410
@testset "inverse_and_logjac" begin
411411
# WIP, test separately until integrated
@@ -449,3 +449,16 @@ end
449449
t = as(Array, 2, 3)
450450
@test inverse(t, ones(SMatrix{2,3})) == ones(6)
451451
end
452+
453+
####
454+
#### broadcasting
455+
####
456+
457+
@testset "broadcasting" begin
458+
@test as𝕀.([0, 0]) == [0.5, 0.5]
459+
460+
t = UnitVector(3)
461+
d = dimension(t)
462+
x = [zeros(d), zeros(d)]
463+
@test t.(x) == map(t, x)
464+
end

0 commit comments

Comments
 (0)