Skip to content

Commit 2252a9b

Browse files
torfjeldesunxd3github-actions[bot]yebaipenelopeysm
authored
Depreciate@submodel l ~ m in favour of l ~ to_submodel(m); rename generated_quantities to returned (#696)
* Added `@returned_quantities` macro * Added `@returned_quantities` to the docs * Fixed names of doctests for `@returned_quantities` * Update src/submodel_macro.jl Co-authored-by: Xianda Sun <[email protected]> * Added `@prefix` macro which calls `prefix` with a `Val` argument to make things easier to basic users * Convert the result of `prefix_expr` in `@prefix` into a `Sybmol` before wrapping in `Val` * Export `prefix` and `@prefix` * Updated docstring for `@returned_quantities` * Fixed bug in `rand` for `Model` where it would duplicate the non-leaf contexts in `model.context` * Update src/contexts.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Added `prefix` and `@prefix` to docs * removed the prefix=... syntax for `@returned_quantities` * added deprecation.jl + deprecated `generated_quantities` in favour of `returned_quantities` * removed export of `prefix` and `generated_quantities` (the latter is exported by the deprecation macro) * updated `DynamicPPLMCMCChainsExt` to define `returned_quantities` * updated docs * Update docs/src/api.md Co-authored-by: Hong Ge <[email protected]> * improved docstring for `prefix` and `@prefix` * added `@returned_quantities` macro taking two arguments + removed `returned_quantities` from exported functions * updated docs to reflect the new two-argument `@returned_quantities` * added depwarn to `@submodel` macro * fixed reference * fixed reference to `@prefix` in `@returned_quantities` macro * actually fixed doc references * updated doctests for `@submodel` to include the depwarn + added warning regarding deprecation of `@submodel` * added `to_sampleable` and limited `~` handling for submodels * added docs to `to_sampleable` + removed the unnecessary macro exports that we no longer need * updated more docstrings * added testing of deprecation warning of `@submodel` + replaced some usages in tests (though we don't support some of these so we cant' do that yet) * Update test/compiler.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * renamed `returned_quantities` to `returned` as requested * removed redundant `SampleableModelWrapper` in favour of `ReturnedModelWrapper` + introduced `rand_like!!` to hide explicit calls to `_evaluate!!` * updated tests + docstrings + warnings to use `returned` * updated docs * formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix docs * export `to_sampleable` and add to docs * fixed typo in warning * removed unnecessary import in docstring * added docstring to `rand_like!!` * fixed docstring for `returned(model)` * improvements to docstrings thanks to @penelopesym Co-authored-by: Penelope Yong <[email protected]> * added abstract type `Distributional` and concrete type `Sampleable`, in addition to method `to_submodel` * replaced usages of `returned` with `to_submodel` * formatting * Update docs/src/api.md Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * removed export of `to_sampleable` since it currently has no purpose + fixed docs for `returned` * formatting * updated docstring for `condition` and `fix` to not use `@submdoel` * added `check_tilde_rhs` for `Sampleable` * let the field of sampleable determine whether it works or not * add automatic prefixing of submodels + remove support for dot-tilde since this is ambigious in this case * added automatic prefixing for sub-models involved in `~` statements * updated depwarn for `@submodel` and tests * formatting * updated docstrings * updated docs * added more depwarns to the doctests to see if that helps (though I don't understand why this is needed for Documenter.jl) * forgot one * replaced usage of `generated_quantities` with `returned` * foxed docstring for `to_submodel` * patch version bump --------- Co-authored-by: Xianda Sun <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Hong Ge <[email protected]> Co-authored-by: Penelope Yong <[email protected]>
1 parent 82842bc commit 2252a9b

13 files changed

+456
-120
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.31.0"
3+
version = "0.31.1"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

docs/src/api.md

+30-8
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,6 @@ These statements are rewritten by `@model` as calls of [internal functions](@ref
1414
@model
1515
```
1616

17-
One can nest models and call another model inside the model function with [`@submodel`](@ref).
18-
19-
```@docs
20-
@submodel
21-
```
22-
2317
### Type
2418

2519
A [`Model`](@ref) can be created by calling the model function, as defined by [`@model`](@ref).
@@ -110,6 +104,34 @@ Similarly, we can [`unfix`](@ref) variables, i.e. return them to their original
110104
unfix
111105
```
112106

107+
## Models within models
108+
109+
One can include models and call another model inside the model function with `left ~ to_submodel(model)`.
110+
111+
```@docs
112+
to_submodel
113+
```
114+
115+
Note that a `[to_submodel](@ref)` is only sampleable; one cannot compute `logpdf` for its realizations.
116+
117+
In the past, one would instead embed sub-models using [`@submodel`](@ref), which has been deprecated since the introduction of [`to_submodel(model)`](@ref)
118+
119+
```@docs
120+
@submodel
121+
```
122+
123+
In the context of including models within models, it's also useful to prefix the variables in sub-models to avoid variable names clashing:
124+
125+
```@docs
126+
prefix
127+
```
128+
129+
Under the hood, [`to_submodel`](@ref) makes use of the following method to indicate that the model it's wrapping is a model over its return-values rather than something else
130+
131+
```@docs
132+
returned(::Model)
133+
```
134+
113135
## Utilities
114136

115137
It is possible to manually increase (or decrease) the accumulated log density from within a model function.
@@ -118,10 +140,10 @@ It is possible to manually increase (or decrease) the accumulated log density fr
118140
@addlogprob!
119141
```
120142

121-
Return values of the model function for a collection of samples can be obtained with [`generated_quantities`](@ref).
143+
Return values of the model function for a collection of samples can be obtained with [`returned(model, chain)`](@ref).
122144

123145
```@docs
124-
generated_quantities
146+
returned(::DynamicPPL.Model, ::NamedTuple)
125147
```
126148

127149
For a chain of samples, one can compute the pointwise log-likelihoods of each observed random variable with [`pointwise_loglikelihoods`](@ref). Similarly, the log-densities of the priors using

ext/DynamicPPLMCMCChainsExt.jl

+5-7
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
4343
end
4444

4545
"""
46-
generated_quantities(model::Model, chain::MCMCChains.Chains)
46+
returned(model::Model, chain::MCMCChains.Chains)
4747
4848
Execute `model` for each of the samples in `chain` and return an array of the values
4949
returned by the `model` for each sample.
@@ -63,12 +63,12 @@ m = demo(data)
6363
chain = sample(m, alg, n)
6464
# To inspect the `interesting_quantity(θ, x)` where `θ` is replaced by samples
6565
# from the posterior/`chain`:
66-
generated_quantities(m, chain) # <= results in a `Vector` of returned values
66+
returned(m, chain) # <= results in a `Vector` of returned values
6767
# from `interesting_quantity(θ, x)`
6868
```
6969
## Concrete (and simple)
7070
```julia
71-
julia> using DynamicPPL, Turing
71+
julia> using Turing
7272
7373
julia> @model function demo(xs)
7474
s ~ InverseGamma(2, 3)
@@ -87,7 +87,7 @@ julia> model = demo(randn(10));
8787
8888
julia> chain = sample(model, MH(), 10);
8989
90-
julia> generated_quantities(model, chain)
90+
julia> returned(model, chain)
9191
10×1 Array{Tuple{Float64},2}:
9292
(2.1964758025119338,)
9393
(2.1964758025119338,)
@@ -101,9 +101,7 @@ julia> generated_quantities(model, chain)
101101
(-0.16489786710222099,)
102102
```
103103
"""
104-
function DynamicPPL.generated_quantities(
105-
model::DynamicPPL.Model, chain_full::MCMCChains.Chains
106-
)
104+
function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Chains)
107105
chain = MCMCChains.get_sections(chain_full, :parameters)
108106
varinfo = DynamicPPL.VarInfo(model)
109107
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))

src/DynamicPPL.jl

+7-2
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ export AbstractVarInfo,
8686
Model,
8787
getmissings,
8888
getargnames,
89-
generated_quantities,
9089
extract_priors,
9190
values_as_in_model,
9291
# Samplers
@@ -122,6 +121,9 @@ export AbstractVarInfo,
122121
decondition,
123122
fix,
124123
unfix,
124+
prefix,
125+
returned,
126+
to_submodel,
125127
# Convenience macros
126128
@addlogprob!,
127129
@submodel,
@@ -130,7 +132,8 @@ export AbstractVarInfo,
130132
check_model_and_trace,
131133
# Deprecated.
132134
@logprob_str,
133-
@prob_str
135+
@prob_str,
136+
generated_quantities
134137

135138
# Reexport
136139
using Distributions: loglikelihood
@@ -196,6 +199,8 @@ include("values_as_in_model.jl")
196199
include("debug_utils.jl")
197200
using .DebugUtils
198201

202+
include("deprecated.jl")
203+
199204
if !isdefined(Base, :get_extension)
200205
using Requires
201206
end

src/compiler.jl

+5
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,11 @@ function check_tilde_rhs(@nospecialize(x))
178178
end
179179
check_tilde_rhs(x::Distribution) = x
180180
check_tilde_rhs(x::AbstractArray{<:Distribution}) = x
181+
check_tilde_rhs(x::ReturnedModelWrapper) = x
182+
function check_tilde_rhs(x::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix}
183+
model = check_tilde_rhs(x.model)
184+
return Sampleable{typeof(model),AutoPrefix}(model)
185+
end
181186

182187
"""
183188
unwrap_right_vn(right, vn)

src/context_implementations.jl

+37-3
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,17 @@ By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log
103103
probability of `vi` with the returned value.
104104
"""
105105
function tilde_assume!!(context, right, vn, vi)
106-
value, logp, vi = tilde_assume(context, right, vn, vi)
107-
return value, acclogp_assume!!(context, vi, logp)
106+
return if is_rhs_model(right)
107+
# Prefix the variables using the `vn`.
108+
rand_like!!(
109+
right,
110+
should_auto_prefix(right) ? PrefixContext{Symbol(vn)}(context) : context,
111+
vi,
112+
)
113+
else
114+
value, logp, vi = tilde_assume(context, right, vn, vi)
115+
value, acclogp_assume!!(context, vi, logp)
116+
end
108117
end
109118

110119
# observe
@@ -159,6 +168,11 @@ Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the informati
159168
and indices; if needed, these can be accessed through this function, though.
160169
"""
161170
function tilde_observe!!(context, right, left, vname, vi)
171+
is_rhs_model(right) && throw(
172+
ArgumentError(
173+
"`~` with a model on the right-hand side of an observe statement is not supported",
174+
),
175+
)
162176
return tilde_observe!!(context, right, left, vi)
163177
end
164178

@@ -172,6 +186,11 @@ By default, calls `tilde_observe(context, right, left, vi)` and accumulates the
172186
probability of `vi` with the returned value.
173187
"""
174188
function tilde_observe!!(context, right, left, vi)
189+
is_rhs_model(right) && throw(
190+
ArgumentError(
191+
"`~` with a model on the right-hand side of an observe statement is not supported",
192+
),
193+
)
175194
logp, vi = tilde_observe(context, right, left, vi)
176195
return left, acclogp_observe!!(context, vi, logp)
177196
end
@@ -321,8 +340,13 @@ model inputs), accumulate the log probability, and return the sampled value and
321340
Falls back to `dot_tilde_assume(context, right, left, vn, vi)`.
322341
"""
323342
function dot_tilde_assume!!(context, right, left, vn, vi)
343+
is_rhs_model(right) && throw(
344+
ArgumentError(
345+
"`.~` with a model on the right-hand side is not supported; please use `~`"
346+
),
347+
)
324348
value, logp, vi = dot_tilde_assume(context, right, left, vn, vi)
325-
return value, acclogp_assume!!(context, vi, logp), vi
349+
return value, acclogp_assume!!(context, vi, logp)
326350
end
327351

328352
# `dot_assume`
@@ -573,6 +597,11 @@ Falls back to `dot_tilde_observe!!(context, right, left, vi)` ignoring the infor
573597
name and indices; if needed, these can be accessed through this function, though.
574598
"""
575599
function dot_tilde_observe!!(context, right, left, vn, vi)
600+
is_rhs_model(right) && throw(
601+
ArgumentError(
602+
"`~` with a model on the right-hand side of an observe statement is not supported",
603+
),
604+
)
576605
return dot_tilde_observe!!(context, right, left, vi)
577606
end
578607

@@ -585,6 +614,11 @@ probability, and return the observed value and updated `vi`.
585614
Falls back to `dot_tilde_observe(context, right, left, vi)`.
586615
"""
587616
function dot_tilde_observe!!(context, right, left, vi)
617+
is_rhs_model(right) && throw(
618+
ArgumentError(
619+
"`~` with a model on the right-hand side of an observe statement is not supported",
620+
),
621+
)
588622
logp, vi = dot_tilde_observe(context, right, left, vi)
589623
return left, acclogp_observe!!(context, vi, logp)
590624
end

src/contexts.jl

+28
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,34 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym}
281281
end
282282
end
283283

284+
"""
285+
prefix(model::Model, x)
286+
287+
Return `model` but with all random variables prefixed by `x`.
288+
289+
If `x` is known at compile-time, use `Val{x}()` to avoid runtime overheads for prefixing.
290+
291+
# Examples
292+
293+
```jldoctest
294+
julia> using DynamicPPL: prefix
295+
296+
julia> @model demo() = x ~ Dirac(1)
297+
demo (generic function with 2 methods)
298+
299+
julia> rand(prefix(demo(), :my_prefix))
300+
(var"my_prefix.x" = 1,)
301+
302+
julia> # One can also use `Val` to avoid runtime overheads.
303+
rand(prefix(demo(), Val(:my_prefix)))
304+
(var"my_prefix.x" = 1,)
305+
```
306+
"""
307+
prefix(model::Model, x) = contextualize(model, PrefixContext{Symbol(x)}(model.context))
308+
function prefix(model::Model, ::Val{x}) where {x}
309+
return contextualize(model, PrefixContext{Symbol(x)}(model.context))
310+
end
311+
284312
struct ConditionContext{Values,Ctx<:AbstractContext} <: AbstractContext
285313
values::Values
286314
context::Ctx

src/deprecated.jl

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
@deprecate generated_quantities(model, params) returned(model, params)

0 commit comments

Comments
 (0)