6262
6363
6464class  MarginalModel (Model ):
65-     """Subclass of PyMC Model that implements functionality for automatic 
66-     marginalization of variables in the logp transformation 
67- 
68-     After defining the full Model, the `marginalize` method can be used to indicate a 
69-     subset of variables that should be marginalized 
70- 
71-     Notes 
72-     ----- 
73-     Marginalization functionality is still very restricted. Only finite discrete 
74-     variables can be marginalized. Deterministics and Potentials cannot be conditionally 
75-     dependent on the marginalized variables. 
76- 
77-     Furthermore, not all instances of such variables can be marginalized. If a variable 
78-     has batched dimensions, it is required that any conditionally dependent variables 
79-     use information from an individual batched dimension. In other words, the graph 
80-     connecting the marginalized variable(s) to the dependent variable(s) must be 
81-     composed strictly of Elemwise Operations. This is necessary to ensure an efficient 
82-     logprob graph can be generated. If you want to bypass this restriction you can 
83-     separate each dimension of the marginalized variable into the scalar components 
84-     and then stack them together. Note that such graphs will grow exponentially in the 
85-     number of  marginalized variables. 
86- 
87-     For the same reason, it's not possible to marginalize RVs with multivariate 
88-     dependent RVs. 
89- 
90-     Examples 
91-     -------- 
92-     Marginalize over a single variable 
93- 
94-     .. code-block:: python 
95- 
96-         import pymc as pm 
97-         from pymc_extras import MarginalModel 
98- 
99-         with MarginalModel() as m: 
100-             p = pm.Beta("p", 1, 1) 
101-             x = pm.Bernoulli("x", p=p, shape=(3,)) 
102-             y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10]) 
103- 
104-             m.marginalize([x]) 
105- 
106-             idata = pm.sample() 
107- 
108-     """ 
10965
11066    def  __init__ (self , * args , ** kwargs ):
11167        raise  TypeError (
@@ -147,10 +103,29 @@ def _unique(seq: Sequence) -> list:
147103def  marginalize (model : Model , rvs_to_marginalize : ModelRVs ) ->  MarginalModel :
148104    """Marginalize a subset of variables in a PyMC model. 
149105
150-     This creates a class of `MarginalModel` from an existing `Model`, with the specified 
151-     variables marginalized. 
106+     Notes 
107+     ----- 
108+     Marginalization functionality is still very restricted. Only finite discrete 
109+     variables and some closed from graphs can be marginalized. 
110+     Deterministics and Potentials cannot be conditionally dependent on the marginalized variables. 
152111
153-     See documentation for `MarginalModel` for more information. 
112+ 
113+     Examples 
114+     -------- 
115+     Marginalize over a single variable 
116+ 
117+     .. code-block:: python 
118+ 
119+         import pymc as pm 
120+         from pymc_extras import marginalize 
121+ 
122+         with pm.Model() as m: 
123+             p = pm.Beta("p", 1, 1) 
124+             x = pm.Bernoulli("x", p=p, shape=(3,)) 
125+             y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10]) 
126+ 
127+         with marginalize(m, [x]) as marginal_m: 
128+             idata = pm.sample() 
154129
155130    Parameters 
156131    ---------- 
@@ -161,8 +136,8 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
161136
162137    Returns 
163138    ------- 
164-     marginal_model: MarginalModel  
165-         Marginal  model with the specified variables marginalized. 
139+     marginal_model: Model  
140+         PyMC  model with the specified variables marginalized. 
166141    """ 
167142    if  isinstance (rvs_to_marginalize , str  |  Variable ):
168143        rvs_to_marginalize  =  (rvs_to_marginalize ,)
@@ -176,20 +151,20 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
176151        if  rv_to_marginalize  not  in model .free_RVs :
177152            raise  ValueError (f"Marginalized RV { rv_to_marginalize }  )
178153
179-         rv_op  =  rv_to_marginalize .owner .op 
180-         if  isinstance (rv_op , DiscreteMarkovChain ):
181-             if  rv_op .n_lags  >  1 :
182-                 raise  NotImplementedError (
183-                     "Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported" 
184-                 )
185-             if  rv_to_marginalize .owner .inputs [0 ].type .ndim  >  2 :
186-                 raise  NotImplementedError (
187-                     "Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported" 
188-                 )
189-         elif  not  isinstance (rv_op , Bernoulli  |  Categorical  |  DiscreteUniform ):
190-             raise  NotImplementedError (
191-                 f"Marginalization of RV with distribution { rv_to_marginalize .owner .op }  
192-             )
154+         #  rv_op = rv_to_marginalize.owner.op
155+         #  if isinstance(rv_op, DiscreteMarkovChain):
156+         #      if rv_op.n_lags > 1:
157+         #          raise NotImplementedError(
158+         #              "Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported"
159+         #          )
160+         #      if rv_to_marginalize.owner.inputs[0].type.ndim > 2:
161+         #          raise NotImplementedError(
162+         #              "Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
163+         #          )
164+         #  elif not isinstance(rv_op, Bernoulli | Categorical | DiscreteUniform):
165+         #      raise NotImplementedError(
166+         #          f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported"
167+         #      )
193168
194169    fg , memo  =  fgraph_from_model (model )
195170    rvs_to_marginalize  =  [memo [rv ] for  rv  in  rvs_to_marginalize ]
@@ -241,11 +216,52 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
241216        ]
242217        input_rvs  =  _unique ((* marginalized_rv_input_rvs , * other_direct_rv_ancestors ))
243218
244-         replace_finite_discrete_marginal_subgraph (fg , rv_to_marginalize , dependent_rvs , input_rvs )
219+         marginalize_subgraph (fg , rv_to_marginalize , dependent_rvs , input_rvs )
245220
246221    return  model_from_fgraph (fg , mutate_fgraph = True )
247222
248223
224+ def  marginalize_subgraph (
225+     fgraph , rv_to_marginalize , dependent_rvs , input_rvs 
226+ ) ->  None :
227+ 
228+     output_rvs  =  [rv_to_marginalize , * dependent_rvs ]
229+     rng_updates  =  collect_default_updates (output_rvs , inputs = input_rvs , must_be_shared = False )
230+     outputs  =  output_rvs  +  list (rng_updates .values ())
231+     inputs  =  input_rvs  +  list (rng_updates .keys ())
232+     # Add any other shared variable inputs 
233+     inputs  +=  collect_shared_vars (output_rvs , blockers = inputs )
234+ 
235+     inner_inputs  =  [inp .clone () for  inp  in  inputs ]
236+     inner_outputs  =  clone_replace (outputs , replace = dict (zip (inputs , inner_inputs )))
237+     inner_outputs  =  remove_model_vars (inner_outputs )
238+ 
239+     _ , _ , * dims  =  rv_to_marginalize .owner .inputs 
240+     marginalization_op  =  MarginalRV (
241+         inputs = inner_inputs ,
242+         outputs = inner_outputs ,
243+         dims = dims ,
244+         n_dependent_rvs = len (dependent_rvs )
245+     )
246+ 
247+     new_outputs  =  marginalization_op (* inputs )
248+     assert  len (new_outputs ) ==  len (outputs )
249+     for  old_output , new_output  in  zip (outputs , new_outputs ):
250+         new_output .name  =  old_output .name 
251+ 
252+     model_replacements  =  []
253+     for  old_output , new_output  in  zip (outputs , new_outputs ):
254+         if  old_output  is  rv_to_marginalize  or  not  isinstance (old_output .owner .op , ModelValuedVar ):
255+             # Replace the marginalized ModelFreeRV (or non model-variables) themselves 
256+             var_to_replace  =  old_output 
257+         else :
258+             # Replace the underlying RV, keeping the same value, transform and dims 
259+             var_to_replace  =  old_output .owner .inputs [0 ]
260+         model_replacements .append ((var_to_replace , new_output ))
261+ 
262+     fgraph .replace_all (model_replacements )
263+ 
264+ 
249265@node_rewriter (tracks = [MarginalRV ]) 
250266def  local_unmarginalize (fgraph , node ):
251267    unmarginalized_rv , * dependent_rvs_and_rngs  =  inline_ofg_outputs (node .op , node .inputs )
0 commit comments