Skip to content

Commit 655ce07

Browse files
committed
.wip
1 parent 04a6259 commit 655ce07

File tree

7 files changed

+319
-106
lines changed

7 files changed

+319
-106
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
import pymc_extras.model.marginal.rewrites # Need import to register rewrites

pymc_extras/model/marginal/distributions.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,37 +24,20 @@
2424
from pymc_extras.distributions import DiscreteMarkovChain
2525

2626

27-
class MarginalRV(OpFromGraph, MeasurableOp):
27+
class MarginalRV(OpFromGraph):
2828
"""Base class for Marginalized RVs"""
2929

3030
def __init__(
3131
self,
3232
*args,
33-
dims_connections: tuple[tuple[int | None], ...],
3433
dims: tuple[Variable, ...],
34+
n_dependent_rvs: int,
3535
**kwargs,
3636
) -> None:
37-
self.dims_connections = dims_connections
3837
self.dims = dims
38+
self.n_dependent_rvs = n_dependent_rvs
3939
super().__init__(*args, **kwargs)
4040

41-
@property
42-
def support_axes(self) -> tuple[tuple[int]]:
43-
"""Dimensions of dependent RVs that belong to the core (non-batched) marginalized variable."""
44-
marginalized_ndim_supp = self.inner_outputs[0].owner.op.ndim_supp
45-
support_axes_vars = []
46-
for dims_connection in self.dims_connections:
47-
ndim = len(dims_connection)
48-
marginalized_supp_axes = ndim - marginalized_ndim_supp
49-
support_axes_vars.append(
50-
tuple(
51-
-i
52-
for i, dim in enumerate(reversed(dims_connection), start=1)
53-
if (dim is None or dim > marginalized_supp_axes)
54-
)
55-
)
56-
return tuple(support_axes_vars)
57-
5841
def __eq__(self, other):
5942
# Just to allow easy testing of equivalent models,
6043
# This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed
@@ -124,11 +107,35 @@ def support_point_marginal_rv(op: MarginalRV, rv, *inputs):
124107
return rv_support_point
125108

126109

127-
class MarginalFiniteDiscreteRV(MarginalRV):
110+
class MarginalEnumerableRV(MarginalRV, MeasurableOp):
111+
112+
def __init__(self, *args, dims_connections: tuple[tuple[int | None], ...], **kwargs):
113+
super().__init__(*args, **kwargs)
114+
self.dims_connections = dims_connections
115+
116+
@property
117+
def support_axes(self) -> tuple[tuple[int]]:
118+
"""Dimensions of dependent RVs that belong to the core (non-batched) marginalized variable."""
119+
marginalized_ndim_supp = self.inner_outputs[0].owner.op.ndim_supp
120+
support_axes_vars = []
121+
for dims_connection in self.dims_connections:
122+
ndim = len(dims_connection)
123+
marginalized_supp_axes = ndim - marginalized_ndim_supp
124+
support_axes_vars.append(
125+
tuple(
126+
-i
127+
for i, dim in enumerate(reversed(dims_connection), start=1)
128+
if (dim is None or dim > marginalized_supp_axes)
129+
)
130+
)
131+
return tuple(support_axes_vars)
132+
133+
134+
class MarginalFiniteDiscreteRV(MarginalEnumerableRV):
128135
"""Base class for Marginalized Finite Discrete RVs"""
129136

130137

131-
class MarginalDiscreteMarkovChainRV(MarginalRV):
138+
class MarginalDiscreteMarkovChainRV(MarginalEnumerableRV):
132139
"""Base class for Marginalized Discrete Markov Chain RVs"""
133140

134141

@@ -239,7 +246,9 @@ def warn_non_separable_logp(values):
239246
@_logprob.register(MarginalFiniteDiscreteRV)
240247
def finite_discrete_marginal_rv_logp(op: MarginalFiniteDiscreteRV, values, *inputs, **kwargs):
241248
# Clone the inner RV graph of the Marginalized RV
242-
marginalized_rv, *inner_rvs = inline_ofg_outputs(op, inputs)
249+
marginalized_rv, *inner_rvs_and_rngs = inline_ofg_outputs(op, inputs)
250+
inner_rvs = inner_rvs_and_rngs[:op.n_dependent_rvs]
251+
assert len(values) == len(inner_rvs)
243252

244253
# Obtain the joint_logp graph of the inner RV graph
245254
inner_rv_values = dict(zip(inner_rvs, values))
@@ -302,7 +311,9 @@ def logp_fn(marginalized_rv_const, *non_sequences):
302311

303312
@_logprob.register(MarginalDiscreteMarkovChainRV)
304313
def marginal_hmm_logp(op, values, *inputs, **kwargs):
305-
chain_rv, *dependent_rvs = inline_ofg_outputs(op, inputs)
314+
chain_rv, *dependent_rvs_and_rngs = inline_ofg_outputs(op, inputs)
315+
dependent_rvs = dependent_rvs_and_rngs[:op.n_dependent_rvs]
316+
assert len(values) == len(dependent_rvs)
306317

307318
P, n_steps_, init_dist_, rng = chain_rv.owner.inputs
308319
domain = pt.arange(P.shape[-1], dtype="int32")

pymc_extras/model/marginal/marginal_model.py

Lines changed: 80 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -62,50 +62,6 @@
6262

6363

6464
class MarginalModel(Model):
65-
"""Subclass of PyMC Model that implements functionality for automatic
66-
marginalization of variables in the logp transformation
67-
68-
After defining the full Model, the `marginalize` method can be used to indicate a
69-
subset of variables that should be marginalized
70-
71-
Notes
72-
-----
73-
Marginalization functionality is still very restricted. Only finite discrete
74-
variables can be marginalized. Deterministics and Potentials cannot be conditionally
75-
dependent on the marginalized variables.
76-
77-
Furthermore, not all instances of such variables can be marginalized. If a variable
78-
has batched dimensions, it is required that any conditionally dependent variables
79-
use information from an individual batched dimension. In other words, the graph
80-
connecting the marginalized variable(s) to the dependent variable(s) must be
81-
composed strictly of Elemwise Operations. This is necessary to ensure an efficient
82-
logprob graph can be generated. If you want to bypass this restriction you can
83-
separate each dimension of the marginalized variable into the scalar components
84-
and then stack them together. Note that such graphs will grow exponentially in the
85-
number of marginalized variables.
86-
87-
For the same reason, it's not possible to marginalize RVs with multivariate
88-
dependent RVs.
89-
90-
Examples
91-
--------
92-
Marginalize over a single variable
93-
94-
.. code-block:: python
95-
96-
import pymc as pm
97-
from pymc_extras import MarginalModel
98-
99-
with MarginalModel() as m:
100-
p = pm.Beta("p", 1, 1)
101-
x = pm.Bernoulli("x", p=p, shape=(3,))
102-
y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10])
103-
104-
m.marginalize([x])
105-
106-
idata = pm.sample()
107-
108-
"""
10965

11066
def __init__(self, *args, **kwargs):
11167
raise TypeError(
@@ -147,10 +103,29 @@ def _unique(seq: Sequence) -> list:
147103
def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
148104
"""Marginalize a subset of variables in a PyMC model.
149105
150-
This creates a class of `MarginalModel` from an existing `Model`, with the specified
151-
variables marginalized.
106+
Notes
107+
-----
108+
Marginalization functionality is still very restricted. Only finite discrete
109+
variables and some closed from graphs can be marginalized.
110+
Deterministics and Potentials cannot be conditionally dependent on the marginalized variables.
152111
153-
See documentation for `MarginalModel` for more information.
112+
113+
Examples
114+
--------
115+
Marginalize over a single variable
116+
117+
.. code-block:: python
118+
119+
import pymc as pm
120+
from pymc_extras import marginalize
121+
122+
with pm.Model() as m:
123+
p = pm.Beta("p", 1, 1)
124+
x = pm.Bernoulli("x", p=p, shape=(3,))
125+
y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10])
126+
127+
with marginalize(m, [x]) as marginal_m:
128+
idata = pm.sample()
154129
155130
Parameters
156131
----------
@@ -161,8 +136,8 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
161136
162137
Returns
163138
-------
164-
marginal_model: MarginalModel
165-
Marginal model with the specified variables marginalized.
139+
marginal_model: Model
140+
PyMC model with the specified variables marginalized.
166141
"""
167142
if isinstance(rvs_to_marginalize, str | Variable):
168143
rvs_to_marginalize = (rvs_to_marginalize,)
@@ -176,20 +151,20 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
176151
if rv_to_marginalize not in model.free_RVs:
177152
raise ValueError(f"Marginalized RV {rv_to_marginalize} is not a free RV in the model")
178153

179-
rv_op = rv_to_marginalize.owner.op
180-
if isinstance(rv_op, DiscreteMarkovChain):
181-
if rv_op.n_lags > 1:
182-
raise NotImplementedError(
183-
"Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported"
184-
)
185-
if rv_to_marginalize.owner.inputs[0].type.ndim > 2:
186-
raise NotImplementedError(
187-
"Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
188-
)
189-
elif not isinstance(rv_op, Bernoulli | Categorical | DiscreteUniform):
190-
raise NotImplementedError(
191-
f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported"
192-
)
154+
# rv_op = rv_to_marginalize.owner.op
155+
# if isinstance(rv_op, DiscreteMarkovChain):
156+
# if rv_op.n_lags > 1:
157+
# raise NotImplementedError(
158+
# "Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported"
159+
# )
160+
# if rv_to_marginalize.owner.inputs[0].type.ndim > 2:
161+
# raise NotImplementedError(
162+
# "Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
163+
# )
164+
# elif not isinstance(rv_op, Bernoulli | Categorical | DiscreteUniform):
165+
# raise NotImplementedError(
166+
# f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported"
167+
# )
193168

194169
fg, memo = fgraph_from_model(model)
195170
rvs_to_marginalize = [memo[rv] for rv in rvs_to_marginalize]
@@ -241,11 +216,52 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
241216
]
242217
input_rvs = _unique((*marginalized_rv_input_rvs, *other_direct_rv_ancestors))
243218

244-
replace_finite_discrete_marginal_subgraph(fg, rv_to_marginalize, dependent_rvs, input_rvs)
219+
marginalize_subgraph(fg, rv_to_marginalize, dependent_rvs, input_rvs)
245220

246221
return model_from_fgraph(fg, mutate_fgraph=True)
247222

248223

224+
def marginalize_subgraph(
225+
fgraph, rv_to_marginalize, dependent_rvs, input_rvs
226+
) -> None:
227+
228+
output_rvs = [rv_to_marginalize, *dependent_rvs]
229+
rng_updates = collect_default_updates(output_rvs, inputs=input_rvs, must_be_shared=False)
230+
outputs = output_rvs + list(rng_updates.values())
231+
inputs = input_rvs + list(rng_updates.keys())
232+
# Add any other shared variable inputs
233+
inputs += collect_shared_vars(output_rvs, blockers=inputs)
234+
235+
inner_inputs = [inp.clone() for inp in inputs]
236+
inner_outputs = clone_replace(outputs, replace=dict(zip(inputs, inner_inputs)))
237+
inner_outputs = remove_model_vars(inner_outputs)
238+
239+
_, _, *dims = rv_to_marginalize.owner.inputs
240+
marginalization_op = MarginalRV(
241+
inputs=inner_inputs,
242+
outputs=inner_outputs,
243+
dims=dims,
244+
n_dependent_rvs=len(dependent_rvs)
245+
)
246+
247+
new_outputs = marginalization_op(*inputs)
248+
assert len(new_outputs) == len(outputs)
249+
for old_output, new_output in zip(outputs, new_outputs):
250+
new_output.name = old_output.name
251+
252+
model_replacements = []
253+
for old_output, new_output in zip(outputs, new_outputs):
254+
if old_output is rv_to_marginalize or not isinstance(old_output.owner.op, ModelValuedVar):
255+
# Replace the marginalized ModelFreeRV (or non model-variables) themselves
256+
var_to_replace = old_output
257+
else:
258+
# Replace the underlying RV, keeping the same value, transform and dims
259+
var_to_replace = old_output.owner.inputs[0]
260+
model_replacements.append((var_to_replace, new_output))
261+
262+
fgraph.replace_all(model_replacements)
263+
264+
249265
@node_rewriter(tracks=[MarginalRV])
250266
def local_unmarginalize(fgraph, node):
251267
unmarginalized_rv, *dependent_rvs_and_rngs = inline_ofg_outputs(node.op, node.inputs)

0 commit comments

Comments
 (0)