diff --git a/src/adjust.jl b/src/adjust.jl index c0ce936..51f6005 100644 --- a/src/adjust.jl +++ b/src/adjust.jl @@ -144,3 +144,33 @@ function _adjust(r::T, nt::NamedTuple) where T <: AbstractRule end T(vals...) # relies on having the default constructor end + +### +#adjust with type control +### + +adjust!(ℓ::Leaf, oT::Type, eta::Real) = (ℓ.rule = adjust(ℓ.rule, oT, eta); nothing) +adjust!(ℓ::Leaf, oT::Type; kw...) = (ℓ.rule = adjust(ℓ.rule, oT; kw...); nothing) + +adjust(ℓ::Leaf, oT::Type, eta::Real) = Leaf(adjust(ℓ.rule, oT, eta), ℓ.state, ℓ.frozen) +adjust(ℓ::Leaf, oT::Type; kw...) = Leaf(adjust(ℓ.rule, oT; kw...), ℓ.state, ℓ.frozen) + +adjust!(tree, oT::Type, eta::Real) = foreach(st -> adjust!(st, oT, eta), tree) +adjust!(tree, oT::Type; kw...) = foreach(st -> adjust!(st, oT; kw...), tree) + +adjust(r::AbstractRule, oT::Type, eta::Real) = ifelse(isa(r, oT), adjust(r, eta), r) +adjust(r::AbstractRule, oT::Type; kw...) = ifelse(isa(r, oT), adjust(r; kw...), r) + +adjust!(r::AbstractRule, oT::Type, eta::Real) = ifelse(isa(r, oT), adjust!(r, eta), r) +adjust!(r::AbstractRule, oT::Type; kw...) = ifelse(isa(r, oT), adjust!(r; kw...), r) + +function adjust(tree, oT::Type, eta::Real) + t′ = fmap(copy, tree; exclude = maywrite) + adjust!(t′, oT, eta) + t′ +end +function adjust(tree, oT::Type; kw...) + t′ = fmap(copy, tree; exclude = maywrite) + adjust!(t′, oT; kw...) + t′ +end \ No newline at end of file diff --git a/src/rules.jl b/src/rules.jl index 0cd8d30..abd653a 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -799,6 +799,9 @@ end adjust(ℓ::OptimiserChain, eta::Real) = OptimiserChain(map(opt -> adjust(opt, eta), ℓ.opts)...) adjust(ℓ::OptimiserChain; kw...) = OptimiserChain(map(opt -> adjust(opt; kw...), ℓ.opts)...) +adjust(ℓ::OptimiserChain, oT::Type, eta::Real) = OptimiserChain(map(opt -> adjust(opt, oT, eta), ℓ.opts)...) +adjust(ℓ::OptimiserChain, oT::Type; kw...) = OptimiserChain(map(opt -> adjust(opt, oT; kw...), ℓ.opts)...) + """ AccumGrad(n::Int)