-
Notifications
You must be signed in to change notification settings - Fork 0
Enzyme throws lots of errors when using MLDatasets
#24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
module MWE
using ChangesOfVariables: ChangesOfVariables
import Enzyme: Enzyme, Reverse, Const, set_runtime_activity
Enzyme.Compiler.VERBOSE_ERRORS[] = true
struct PStruct
p
end
Base.getproperty(::Type{PStruct}, s::Symbol) = getfield(PStruct, s)
struct QStruct
q
end
Base.getproperty(::Type{QStruct}, s::Symbol) = getfield(QStruct, s)
struct RStruct
r
end
Base.getproperty(::Type{RStruct}, s::Symbol) = getfield(RStruct, s)
# The below minimised from:
# @model function f()
# s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])
# 1.5 ~ Normal(s[1])
# end
function wladj(mapped_f::Base.Fix1{typeof(broadcast)}, X)
f = mapped_f.x
y_with_ladj = broadcast(Base.Fix1(ChangesOfVariables.with_logabsdet_jacobian, f), X)
ChangesOfVariables._with_ladj_on_mapped(broadcast, y_with_ladj)
end
function f(x::AbstractVector)
logp = Ref(0.0)
g = Base.Fix1(broadcast, Base.Fix1(broadcast, exp))
s, _ = wladj(g, x)
logp[] += sum(s)
return logp[]
end
x = [0.5, 1.0]
@show f(x)
Enzyme.gradient(Reverse, f, x)
end |
module MWE
import Enzyme: Enzyme, Reverse, Const, set_runtime_activity
Enzyme.Compiler.VERBOSE_ERRORS[] = true
struct PStruct
p
end
Base.getproperty(::Type{PStruct}, s::Symbol) = getfield(PStruct, s)
struct QStruct
q
end
Base.getproperty(::Type{QStruct}, s::Symbol) = getfield(QStruct, s)
struct RStruct
r
end
Base.getproperty(::Type{RStruct}, s::Symbol) = getfield(RStruct, s)
# The below minimised from:
# @model function f()
# s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])
# 1.5 ~ Normal(s[1])
# end
function g(xs)
return broadcast(x -> (exp(x), x), xs)
end
function f(x::AbstractVector)
logp = Ref(0.0)
ts = broadcast(g, x)
logp[] += (ts[1][1] + ts[2][1])
return logp[]
end
x = [0.5, 1.0]
@show f(x)
Enzyme.gradient(Reverse, f, x)
end |
Upstreamed EnzymeAD/Enzyme.jl#2408 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
That's why, for example, https://turinglang.org/ADTests/pr/ (built from #23) currently shows a lot of errors.
Currently minimised to:
The text was updated successfully, but these errors were encountered: