Skip to content

Commit

Permalink
Foward compile_kwargs to ADVI when init = "advi+..." (#7640)
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski authored Jan 9, 2025
1 parent 82716fb commit 52f7fe1
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 42 deletions.
26 changes: 0 additions & 26 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import pytensor.tensor as pt
import scipy.sparse as sps

from pytensor import scalar
from pytensor.compile import Function, Mode, get_mode
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import grad
Expand Down Expand Up @@ -415,31 +414,6 @@ def hessian_diag(f, vars=None, negate_output=True):
return empty_gradient


class IdentityOp(scalar.UnaryScalarOp):
@staticmethod
def st_impl(x):
return x

def impl(self, x):
return x

def grad(self, inp, grads):
return grads

def c_code(self, node, name, inp, out, sub):
return f"{out[0]} = {inp[0]};"

def __eq__(self, other):
return isinstance(self, type(other))

def __hash__(self):
return hash(type(self))


scalar_identity = IdentityOp(scalar.upgrade_to_float, name="scalar_identity")
identity = Elemwise(scalar_identity, name="identity")


def make_shared_replacements(point, vars, model):
"""
Make shared replacements for all *other* variables than the ones passed.
Expand Down
4 changes: 4 additions & 0 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,7 @@ def init_nuts(
callbacks=cb,
progressbar=progressbar,
obj_optimizer=pm.adagrad_window,
compile_kwargs=compile_kwargs,
)
approx_sample = approx.sample(
draws=chains, random_seed=random_seed_list[0], return_inferencedata=False
Expand All @@ -1566,6 +1567,7 @@ def init_nuts(
potential = quadpotential.QuadPotentialDiagAdapt(
n, mean, cov, weight, rng=random_seed_list[0]
)

elif init == "advi":
approx = pm.fit(
random_seed=random_seed_list[0],
Expand All @@ -1575,6 +1577,7 @@ def init_nuts(
callbacks=cb,
progressbar=progressbar,
obj_optimizer=pm.adagrad_window,
compile_kwargs=compile_kwargs,
)
approx_sample = approx.sample(
draws=chains, random_seed=random_seed_list[0], return_inferencedata=False
Expand All @@ -1592,6 +1595,7 @@ def init_nuts(
callbacks=cb,
progressbar=progressbar,
obj_optimizer=pm.adagrad_window,
compile_kwargs=compile_kwargs,
)
approx_sample = approx.sample(
draws=chains, random_seed=random_seed_list[0], return_inferencedata=False
Expand Down
19 changes: 14 additions & 5 deletions pymc/variational/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,18 @@ def _maybe_score(self, score):

def run_profiling(self, n=1000, score=None, **kwargs):
score = self._maybe_score(score)
fn_kwargs = kwargs.pop("fn_kwargs", {})
fn_kwargs["profile"] = True
step_func = self.objective.step_function(score=score, fn_kwargs=fn_kwargs, **kwargs)
if "fn_kwargs" in kwargs:
warnings.warn(
"fn_kwargs is deprecated, please use compile_kwargs instead", DeprecationWarning
)
compile_kwargs = kwargs.pop("fn_kwargs")
else:
compile_kwargs = kwargs.pop("compile_kwargs", {})

compile_kwargs["profile"] = True
step_func = self.objective.step_function(
score=score, compile_kwargs=compile_kwargs, **kwargs
)
try:
for _ in track(range(n)):
step_func()
Expand Down Expand Up @@ -134,7 +143,7 @@ def fit(
Add custom updates to resulting updates
total_grad_norm_constraint: `float`
Bounds gradient norm, prevents exploding gradient problem
fn_kwargs: `dict`
compile_kwargs: `dict`
Add kwargs to pytensor.function (e.g. `{'profile': True}`)
more_replacements: `dict`
Apply custom replacements before calculating gradients
Expand Down Expand Up @@ -729,7 +738,7 @@ def fit(
Add custom updates to resulting updates
total_grad_norm_constraint: `float`
Bounds gradient norm, prevents exploding gradient problem
fn_kwargs: `dict`
compile_kwargs: `dict`
Add kwargs to pytensor.function (e.g. `{'profile': True}`)
more_replacements: `dict`
Apply custom replacements before calculating gradients
Expand Down
49 changes: 38 additions & 11 deletions pymc/variational/opvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@

from pytensor.graph.basic import Variable
from pytensor.graph.replace import graph_replace
from pytensor.scalar.basic import identity as scalar_identity
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.shape import unbroadcast

import pymc as pm
Expand All @@ -74,7 +76,6 @@
SeedSequenceSeed,
compile,
find_rng_nodes,
identity,
reseed_rngs,
)
from pymc.util import (
Expand Down Expand Up @@ -332,6 +333,7 @@ def step_function(
more_replacements=None,
total_grad_norm_constraint=None,
score=False,
compile_kwargs=None,
fn_kwargs=None,
):
R"""Step function that should be called on each optimization step.
Expand Down Expand Up @@ -362,17 +364,30 @@ def step_function(
Bounds gradient norm, prevents exploding gradient problem
score: `bool`
calculate loss on each step? Defaults to False for speed
fn_kwargs: `dict`
compile_kwargs: `dict`
Add kwargs to pytensor.function (e.g. `{'profile': True}`)
fn_kwargs: dict
arbitrary kwargs passed to `pytensor.function`
.. warning:: `fn_kwargs` is deprecated and will be removed in future versions
more_replacements: `dict`
Apply custom replacements before calculating gradients
Returns
-------
`pytensor.function`
"""
if fn_kwargs is None:
fn_kwargs = {}
if fn_kwargs is not None:
warnings.warn(
"`fn_kwargs` is deprecated and will be removed in future versions. Use "
"`compile_kwargs` instead.",
DeprecationWarning,
)
compile_kwargs = fn_kwargs

if compile_kwargs is None:
compile_kwargs = {}
if score and not self.op.returns_loss:
raise NotImplementedError(f"{self.op} does not have loss")
updates = self.updates(
Expand All @@ -388,14 +403,14 @@ def step_function(
)
seed = self.approx.rng.randint(2**30, dtype=np.int64)
if score:
step_fn = compile([], updates.loss, updates=updates, random_seed=seed, **fn_kwargs)
step_fn = compile([], updates.loss, updates=updates, random_seed=seed, **compile_kwargs)
else:
step_fn = compile([], [], updates=updates, random_seed=seed, **fn_kwargs)
step_fn = compile([], [], updates=updates, random_seed=seed, **compile_kwargs)
return step_fn

@pytensor.config.change_flags(compute_test_value="off")
def score_function(
self, sc_n_mc=None, more_replacements=None, fn_kwargs=None
self, sc_n_mc=None, more_replacements=None, compile_kwargs=None, fn_kwargs=None
): # pragma: no cover
R"""Compile scoring function that operates which takes no inputs and returns Loss.
Expand All @@ -405,22 +420,34 @@ def score_function(
number of scoring MC samples
more_replacements:
Apply custom replacements before compiling a function
compile_kwargs: `dict`
arbitrary kwargs passed to `pytensor.function`
fn_kwargs: `dict`
arbitrary kwargs passed to `pytensor.function`
.. warning:: `fn_kwargs` is deprecated and will be removed in future versions
Returns
-------
pytensor.function
"""
if fn_kwargs is None:
fn_kwargs = {}
if fn_kwargs is not None:
warnings.warn(
"`fn_kwargs` is deprecated and will be removed in future versions. Use "
"`compile_kwargs` instead",
DeprecationWarning,
)
compile_kwargs = fn_kwargs

if compile_kwargs is None:
compile_kwargs = {}
if not self.op.returns_loss:
raise NotImplementedError(f"{self.op} does not have loss")
if more_replacements is None:
more_replacements = {}
loss = self(sc_n_mc, more_replacements=more_replacements)
seed = self.approx.rng.randint(2**30, dtype=np.int64)
return compile([], loss, random_seed=seed, **fn_kwargs)
return compile([], loss, random_seed=seed, **compile_kwargs)

@pytensor.config.change_flags(compute_test_value="off")
def __call__(self, nmc, **kwargs):
Expand Down Expand Up @@ -451,7 +478,7 @@ class Operator:
require_logq = True
objective_class = ObjectiveFunction
supports_aevb = property(lambda self: not self.approx.any_histograms)
T = identity
T = Elemwise(scalar_identity)

def __init__(self, approx):
self.approx = approx
Expand Down

0 comments on commit 52f7fe1

Please sign in to comment.