From b35001b346a0db15d48241d0c4cc5d1c102f1b89 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Mon, 16 Jun 2025 13:18:19 -0400 Subject: [PATCH 1/6] add pymc-extras to environment --- environment.yml | 1 + pyproject.toml | 1 + 2 files changed, 2 insertions(+) diff --git a/environment.yml b/environment.yml index 02b7f920..2bc8ed20 100644 --- a/environment.yml +++ b/environment.yml @@ -15,3 +15,4 @@ dependencies: - seaborn>=0.11.2 - statsmodels - xarray>=v2022.11.0 + - pymc-extras>=0.2.7 diff --git a/pyproject.toml b/pyproject.toml index 29f86277..bcc4bc7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "seaborn>=0.11.2", "statsmodels", "xarray>=v2022.11.0", + "pymc-extras>=0.2.7", ] # List additional groups of dependencies here (e.g. development dependencies). Users From b7300e79db32abca5ef88eb08095f39f764db590 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Mon, 16 Jun 2025 13:20:06 -0400 Subject: [PATCH 2/6] add default_priors and support for custom priors --- causalpy/pymc_models.py | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/causalpy/pymc_models.py b/causalpy/pymc_models.py index ea380c1a..3ed4cac9 100644 --- a/causalpy/pymc_models.py +++ b/causalpy/pymc_models.py @@ -22,6 +22,7 @@ import pytensor.tensor as pt import xarray as xr from arviz import r2_score +from pymc_extras.prior import Prior from causalpy.utils import round_num @@ -68,7 +69,13 @@ class PyMCModel(pm.Model): Inference data... """ - def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None): + default_priors: dict[str, Any] + + def __init__( + self, + sample_kwargs: Optional[Dict[str, Any]] = None, + priors: dict[str, Any] | None = None, + ): """ :param sample_kwargs: A dictionary of kwargs that get unpacked and passed to the :func:`pymc.sample` function. Defaults to an empty dictionary. @@ -77,6 +84,8 @@ def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None): self.idata = None self.sample_kwargs = sample_kwargs if sample_kwargs is not None else {} + self.priors = {**self.default_priors, **(priors or {})} + def build_model(self, X, y, coords) -> None: """Build the model, must be implemented by subclass.""" raise NotImplementedError("This method must be implemented by a subclass") @@ -237,6 +246,11 @@ class LinearRegression(PyMCModel): Inference data... """ # noqa: W605 + default_priors = { + "beta": Prior("Normal", mu=0, sigma=50, dims="coeffs"), + "y_hat": Prior("Normal", sigma=Prior("HalfNormal", sigma=1)), + } + def build_model(self, X, y, coords): """ Defines the PyMC model @@ -245,10 +259,9 @@ def build_model(self, X, y, coords): self.add_coords(coords) X = pm.Data("X", X, dims=["obs_ind", "coeffs"]) y = pm.Data("y", y, dims="obs_ind") - beta = pm.Normal("beta", 0, 50, dims="coeffs") - sigma = pm.HalfNormal("sigma", 1) + beta = self.priors["beta"].create_variable("beta") mu = pm.Deterministic("mu", pm.math.dot(X, beta), dims="obs_ind") - pm.Normal("y_hat", mu, sigma, observed=y, dims="obs_ind") + self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y) class WeightedSumFitter(PyMCModel): @@ -276,6 +289,10 @@ class WeightedSumFitter(PyMCModel): Inference data... """ # noqa: W605 + default_priors = { + "y_hat": Prior("Normal", sigma=Prior("HalfNormal", sigma=1)), + } + def build_model(self, X, y, coords): """ Defines the PyMC model @@ -286,9 +303,8 @@ def build_model(self, X, y, coords): X = pm.Data("X", X, dims=["obs_ind", "coeffs"]) y = pm.Data("y", y[:, 0], dims="obs_ind") beta = pm.Dirichlet("beta", a=np.ones(n_predictors), dims="coeffs") - sigma = pm.HalfNormal("sigma", 1) mu = pm.Deterministic("mu", pm.math.dot(X, beta), dims="obs_ind") - pm.Normal("y_hat", mu, sigma, observed=y, dims="obs_ind") + self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y) class InstrumentalVariableRegression(PyMCModel): @@ -477,13 +493,17 @@ class PropensityScore(PyMCModel): Inference... """ # noqa: W605 + default_priors = { + "b": Prior("Normal", mu=0, sigma=1, dims="coeffs"), + } + def build_model(self, X, t, coords): "Defines the PyMC propensity model" with self: self.add_coords(coords) X_data = pm.Data("X", X, dims=["obs_ind", "coeffs"]) t_data = pm.Data("t", t.flatten(), dims="obs_ind") - b = pm.Normal("b", mu=0, sigma=1, dims="coeffs") + b = self.priors["b"].create_variable("b") mu = pm.math.dot(X_data, b) p = pm.Deterministic("p", pm.math.invlogit(mu)) pm.Bernoulli("t_pred", p=p, observed=t_data, dims="obs_ind") From a60035e61d0b7653515721d83dad958d4d368ee2 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Mon, 16 Jun 2025 13:52:22 -0400 Subject: [PATCH 3/6] get pymc_models tests to pass --- causalpy/pymc_models.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/causalpy/pymc_models.py b/causalpy/pymc_models.py index 3ed4cac9..6f15f0cf 100644 --- a/causalpy/pymc_models.py +++ b/causalpy/pymc_models.py @@ -69,7 +69,9 @@ class PyMCModel(pm.Model): Inference data... """ - default_priors: dict[str, Any] + @property + def default_priors(self): + return {} def __init__( self, @@ -248,7 +250,7 @@ class LinearRegression(PyMCModel): default_priors = { "beta": Prior("Normal", mu=0, sigma=50, dims="coeffs"), - "y_hat": Prior("Normal", sigma=Prior("HalfNormal", sigma=1)), + "y_hat": Prior("Normal", sigma=Prior("HalfNormal", sigma=1), dims="obs_ind"), } def build_model(self, X, y, coords): From 367c9220b835e35b32c1b48c77c4e76f578f91e9 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Mon, 16 Jun 2025 17:05:02 -0400 Subject: [PATCH 4/6] add dim to y_hat --- causalpy/pymc_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/causalpy/pymc_models.py b/causalpy/pymc_models.py index 6f15f0cf..812e9d70 100644 --- a/causalpy/pymc_models.py +++ b/causalpy/pymc_models.py @@ -292,7 +292,7 @@ class WeightedSumFitter(PyMCModel): """ # noqa: W605 default_priors = { - "y_hat": Prior("Normal", sigma=Prior("HalfNormal", sigma=1)), + "y_hat": Prior("Normal", sigma=Prior("HalfNormal", sigma=1), dims="obs_ind"), } def build_model(self, X, y, coords): From a9f821c8e84a5edaea2615a0c60582a205056318 Mon Sep 17 00:00:00 2001 From: "Benjamin T. Vincent" Date: Fri, 20 Jun 2025 11:31:53 +0100 Subject: [PATCH 5/6] fix for sigma -> y_hat_sigma --- causalpy/pymc_models.py | 6 +++--- docs/source/_static/interrogate_badge.svg | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/causalpy/pymc_models.py b/causalpy/pymc_models.py index 812e9d70..f95b6371 100644 --- a/causalpy/pymc_models.py +++ b/causalpy/pymc_models.py @@ -199,15 +199,15 @@ def print_row( coeffs = az.extract(self.idata.posterior, var_names="beta") # Determine the width of the longest label - max_label_length = max(len(name) for name in labels + ["sigma"]) + max_label_length = max(len(name) for name in labels + ["y_hat_sigma"]) for name in labels: coeff_samples = coeffs.sel(coeffs=name) print_row(max_label_length, name, coeff_samples, round_to) # Add coefficient for measurement std - coeff_samples = az.extract(self.idata.posterior, var_names="sigma") - name = "sigma" + coeff_samples = az.extract(self.idata.posterior, var_names="y_hat_sigma") + name = "y_hat_sigma" print_row(max_label_length, name, coeff_samples, round_to) diff --git a/docs/source/_static/interrogate_badge.svg b/docs/source/_static/interrogate_badge.svg index 9975f47a..4a908d60 100644 --- a/docs/source/_static/interrogate_badge.svg +++ b/docs/source/_static/interrogate_badge.svg @@ -1,5 +1,5 @@ - interrogate: 94.9% + interrogate: 94.5% @@ -12,8 +12,8 @@ interrogate interrogate - 94.9% - 94.9% + 94.5% + 94.5% From 91aee009f60506bdb2af55faef256b864d8bb483 Mon Sep 17 00:00:00 2001 From: "Benjamin T. Vincent" Date: Fri, 20 Jun 2025 12:02:06 +0100 Subject: [PATCH 6/6] fix failing doctest --- causalpy/experiments/prepostnegd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/causalpy/experiments/prepostnegd.py b/causalpy/experiments/prepostnegd.py index beec847e..3ab18968 100644 --- a/causalpy/experiments/prepostnegd.py +++ b/causalpy/experiments/prepostnegd.py @@ -82,7 +82,7 @@ class PrePostNEGD(BaseExperiment): Intercept -0.5, 94% HDI [-1, 0.2] C(group)[T.1] 2, 94% HDI [2, 2] pre 1, 94% HDI [1, 1] - sigma 0.5, 94% HDI [0.5, 0.6] + y_hat_sigma 0.5, 94% HDI [0.5, 0.6] """ supports_ols = False