diff --git a/pymc/initial_point.py b/pymc/initial_point.py index 241409f683..f7c7bd5114 100644 --- a/pymc/initial_point.py +++ b/pymc/initial_point.py @@ -66,12 +66,12 @@ def make_initial_point_fns_per_chain( model, overrides: StartDict | Sequence[StartDict | None] | None, jitter_rvs: set[TensorVariable] | None = None, + jitter_scale: float = 1.0, chains: int, ) -> list[Callable]: """Create an initial point function for each chain, as defined by initvals. - If a single initval dictionary is passed, the function is replicated for each - chain, otherwise a unique function is compiled for each entry in the dictionary. + If a single initval dictionary is passed, the function is replicated for each chain, otherwise a unique function is compiled for each entry in the dictionary. Parameters ---------- @@ -81,6 +81,8 @@ def make_initial_point_fns_per_chain( jitter_rvs : set, optional Random variable tensors for which U(-1, 1) jitter shall be applied. (To the transformed space if applicable.) + jitter_scale : float, optional + The scale of the jitter in the jitter_rvs set. Defaults to 1.0. Raises ------ @@ -96,6 +98,7 @@ def make_initial_point_fns_per_chain( model=model, overrides=overrides, jitter_rvs=jitter_rvs, + jitter_scale=jitter_scale, return_transformed=True, ) ] * chains @@ -104,6 +107,7 @@ def make_initial_point_fns_per_chain( make_initial_point_fn( model=model, jitter_rvs=jitter_rvs, + jitter_scale=jitter_scale, overrides=chain_overrides, return_transformed=True, ) @@ -122,6 +126,7 @@ def make_initial_point_fn( model, overrides: StartDict | None = None, jitter_rvs: set[TensorVariable] | None = None, + jitter_scale: float = 1.0, default_strategy: str = "support_point", return_transformed: bool = True, ) -> Callable: @@ -130,8 +135,9 @@ def make_initial_point_fn( Parameters ---------- jitter_rvs : set - The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be - added to the initial value. Only available for variables that have a transform or real-valued support. + The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be added to the initial value. Only available for variables that have a transform or real-valued support. + jitter_scale : float, optional + The scale of the jitter in the jitter_rvs set. Defaults to 1.0. default_strategy : str Which of { "support_point", "prior" } to prefer if the initval setting for an RV is None. overrides : dict @@ -150,6 +156,7 @@ def make_initial_point_fn( rvs_to_transforms=model.rvs_to_transforms, initval_strategies=initval_strats, jitter_rvs=jitter_rvs, + jitter_scale=jitter_scale, default_strategy=default_strategy, return_transformed=return_transformed, ) @@ -188,6 +195,7 @@ def make_initial_point_expression( rvs_to_transforms: dict[TensorVariable, Transform], initval_strategies: dict[TensorVariable, np.ndarray | Variable | str | None], jitter_rvs: set[TensorVariable] | None = None, + jitter_scale: float = 1.0, default_strategy: str = "support_point", return_transformed: bool = False, ) -> list[TensorVariable]: @@ -203,8 +211,10 @@ def make_initial_point_expression( Mapping of free random variable tensors to initial value strategies. For example the `Model.initial_values` dictionary. jitter_rvs : set - The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be + The set (or list or tuple) of random variables for which a U(-1, 1) jitter should be added to the initial value. Only available for variables that have a transform or real-valued support. + jitter_scale : float, optional + The scale of the jitter in the jitter_rvs set. Defaults to 1.0. default_strategy : str Which of { "support_point", "prior" } to prefer if the initval strategy setting for an RV is None. return_transformed : bool @@ -265,7 +275,7 @@ def make_initial_point_expression( value = transform.forward(value, *variable.owner.inputs) if variable in jitter_rvs: - jitter = pt.random.uniform(-1, 1, size=value.shape) + jitter = pt.random.uniform(-jitter_scale, jitter_scale, size=value.shape) jitter.name = f"{variable.name}_jitter" value = value + jitter diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 64d6829fc8..ed78efbb98 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1427,10 +1427,11 @@ def _init_jitter( initvals: StartDict | Sequence[StartDict | None] | None, seeds: Sequence[int] | np.ndarray, jitter: bool, + jitter_scale: float, jitter_max_retries: int, logp_dlogp_func=None, ) -> list[PointType]: - """Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain. + """Apply a uniform jitter in [-jitter_scale, jitter_scale] to the test value as starting point in each chain. ``model.check_start_vals`` is used to test whether the jittered starting values produce a finite log probability. Invalid values are resampled @@ -1441,6 +1442,8 @@ def _init_jitter( ---------- jitter: bool Whether to apply jitter or not. + jitter_scale : float, optional + The scale of the jitter in set(model.free_RVs). Defaults to 1.0. jitter_max_retries : int Maximum number of repeated attempts at initializing values (per chain). @@ -1453,6 +1456,7 @@ def _init_jitter( model=model, overrides=initvals, jitter_rvs=set(model.free_RVs) if jitter else set(), + jitter_scale=jitter_scale if jitter else 1.0, chains=len(seeds), ) diff --git a/tests/test_initial_point.py b/tests/test_initial_point.py index 9138f37b3e..ddc087c14e 100644 --- a/tests/test_initial_point.py +++ b/tests/test_initial_point.py @@ -152,6 +152,26 @@ def test_adds_jitter(self): assert fn(0) == fn(0) assert fn(0) != fn(1) + def test_jitter_scale(self): + with pm.Model() as pmodel: + A = pm.HalfFlat("A", initval="support_point") + + fn_default = make_initial_point_fn( + model=pmodel, + jitter_rvs=set(pmodel.free_RVs), + return_transformed=True, + ) + + fn_large = make_initial_point_fn( + model=pmodel, + jitter_rvs=set(pmodel.free_RVs), + jitter_scale=1000.0, + return_transformed=True, + ) + + assert fn_large(0)["A_log__"] > 10 + assert fn_default(0)["A_log__"] < 1 + def test_respects_overrides(self): with pm.Model() as pmodel: A = pm.Flat("A", initval="support_point")