diff --git a/.github/workflows/code-cov.yaml b/.github/workflows/code-cov.yaml new file mode 100644 index 0000000..8182928 --- /dev/null +++ b/.github/workflows/code-cov.yaml @@ -0,0 +1,64 @@ +name: CI-code-cov + +on: + push: + branches: [ main ] + pull_request: + branches: + - main + - develop + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.11' # Specify the Python version you want to use + + - name: Install Package in Editable Mode with Python Dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[dev]" + + - name: Setup R + uses: r-lib/actions/setup-r@v2 + with: + r-version: '4.3.2' # Use the R version you prefer + + - name: Install R packages + uses: r-lib/actions/setup-r-dependencies@v2 + with: + cache: true + cache-version: 1 + dependencies: 'NA' + install-pandoc: false + packages: | + grf + causalweight + mediation + + - name: Install plmed package + run: | + R -e "pak::pkg_install('ohines/plmed')" + + - name: Install Pytest and Coverage + run: | + pip install pytest pytest-cov + + - name: Run tests with coverage + run: | + pytest --cov=med_bench --cov-report=xml + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4.3.0 + with: + # token: ${{ secrets.CODECOV_TOKEN }} + token: 'e4829e41-01da-4d08-9e04-04443da957e3' + slug: judithabk6/med_bench + diff --git a/.github/workflows/tests-with-R.yaml b/.github/workflows/tests-with-R.yaml index 93cc116..086ddb5 100644 --- a/.github/workflows/tests-with-R.yaml +++ b/.github/workflows/tests-with-R.yaml @@ -4,7 +4,9 @@ on: push: branches: [ main ] pull_request: - branches: [ main ] + branches: + - main + - develop jobs: test: diff --git a/.github/workflows/tests-without-R.yaml b/.github/workflows/tests-without-R.yaml index afe51cb..e8fb624 100644 --- a/.github/workflows/tests-without-R.yaml +++ b/.github/workflows/tests-without-R.yaml @@ -4,7 +4,9 @@ on: push: branches: [ main ] pull_request: - branches: [ main ] + branches: + - main + - develop jobs: test: diff --git a/.gitignore b/.gitignore index 7f6b817..834d0a8 100644 --- a/.gitignore +++ b/.gitignore @@ -128,4 +128,6 @@ dmypy.json # Pyre type checker .pyre/ +# DS_STORE files +src/.DS_Store .DS_Store diff --git a/README.md b/README.md index 8d7191c..b95d123 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +[![codecov](https://codecov.io/gh/judithabk6/med_bench/graph/badge.svg?token=PASB71N41D)](https://codecov.io/gh/judithabk6/med_bench) + # med_bench **med_bench** is a Python package designed to wrap the most common estimators for causal mediation analysis in a single framework. We additionally allow for some flexibility in the choice of nuisance parameters models. diff --git a/setup.py b/setup.py index ca61aa1..618e661 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,8 @@ 'rpy2>=2.9.4', 'scipy>=1.5.2', 'seaborn>=0.11.1', - 'matplotlib>=3.3.2' + 'matplotlib>=3.3.2', + "pytest" ], classifiers=[ 'Programming Language :: Python :: 3', diff --git a/src/med_bench/get_estimation.py b/src/med_bench/get_estimation.py index 45578cf..4ee3ae6 100644 --- a/src/med_bench/get_estimation.py +++ b/src/med_bench/get_estimation.py @@ -2,17 +2,19 @@ # -*- coding:utf-8 -*- import numpy as np + from .mediation import ( mediation_IPW, mediation_coefficient_product, mediation_g_formula, mediation_multiply_robust, - mediation_DML, + mediation_dml, r_mediation_g_estimator, - r_mediation_DML, + r_mediation_dml, r_mediate, ) + def get_estimation(x, t, m, y, estimator, config): """Wrapper estimator fonction ; calls an estimator given mediation data in order to estimate total, direct, and indirect effects. @@ -70,7 +72,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=False, forest=False, crossfit=0, - clip=0.0, + clip=1e-6, calibration=None, ) elif estimator == "mediation_ipw_noreg_cf": @@ -83,7 +85,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=False, forest=False, crossfit=2, - clip=0.0, + clip=1e-6, calibration=None, ) elif estimator == "mediation_ipw_reg": @@ -96,7 +98,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, forest=False, crossfit=0, - clip=0.0, + clip=1e-6, calibration=None, ) elif estimator == "mediation_ipw_reg_cf": @@ -109,7 +111,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, forest=False, crossfit=2, - clip=0.0, + clip=1e-6, calibration=None, ) elif estimator == "mediation_ipw_reg_calibration": @@ -122,7 +124,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, forest=False, crossfit=0, - clip=0.0, + clip=1e-6, calibration=None, ) elif estimator == "mediation_ipw_reg_calibration_iso": @@ -135,7 +137,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, forest=False, crossfit=0, - clip=0.0, + clip=1e-6, calibration="isotonic", ) elif estimator == "mediation_ipw_reg_calibration_cf": @@ -148,7 +150,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, forest=False, crossfit=2, - clip=0.0, + clip=1e-6, calibration='sigmoid', ) elif estimator == "mediation_ipw_reg_calibration_iso_cf": @@ -161,7 +163,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, forest=False, crossfit=2, - clip=0.0, + clip=1e-6, calibration="isotonic", ) elif estimator == "mediation_ipw_forest": @@ -174,7 +176,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, forest=True, crossfit=0, - clip=0.0, + clip=1e-6, calibration=None, ) elif estimator == "mediation_ipw_forest_cf": @@ -187,7 +189,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, forest=True, crossfit=2, - clip=0.0, + clip=1e-6, calibration=None, ) elif estimator == "mediation_ipw_forest_calibration": @@ -200,7 +202,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, forest=True, crossfit=0, - clip=0.0, + clip=1e-6, calibration=None, ) elif estimator == "mediation_ipw_forest_calibration_iso": @@ -213,7 +215,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, forest=True, crossfit=0, - clip=0.0, + clip=1e-6, calibration="isotonic", ) elif estimator == "mediation_ipw_forest_calibration_cf": @@ -226,7 +228,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, forest=True, crossfit=2, - clip=0.0, + clip=1e-6, calibration='sigmoid', ) elif estimator == "mediation_ipw_forest_calibration_iso_cf": @@ -239,7 +241,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, forest=True, crossfit=2, - clip=0.0, + clip=1e-6, calibration="isotonic", ) elif estimator == "mediation_g_computation_noreg": @@ -434,7 +436,7 @@ def get_estimation(x, t, m, y, estimator, config): interaction=False, forest=False, crossfit=0, - clip=0.0, + clip=1e-6, regularization=False, calibration=None, ) @@ -448,7 +450,7 @@ def get_estimation(x, t, m, y, estimator, config): interaction=False, forest=False, crossfit=2, - clip=0.0, + clip=1e-6, regularization=False, calibration=None, ) @@ -462,7 +464,7 @@ def get_estimation(x, t, m, y, estimator, config): interaction=False, forest=False, crossfit=0, - clip=0.0, + clip=1e-6, regularization=True, calibration=None, ) @@ -476,7 +478,7 @@ def get_estimation(x, t, m, y, estimator, config): interaction=False, forest=False, crossfit=2, - clip=0.0, + clip=1e-6, regularization=True, calibration=None, ) @@ -490,7 +492,7 @@ def get_estimation(x, t, m, y, estimator, config): interaction=False, forest=False, crossfit=0, - clip=0.0, + clip=1e-6, regularization=True, calibration='sigmoid', ) @@ -504,7 +506,7 @@ def get_estimation(x, t, m, y, estimator, config): interaction=False, forest=False, crossfit=0, - clip=0.0, + clip=1e-6, regularization=True, calibration="isotonic", ) @@ -518,7 +520,7 @@ def get_estimation(x, t, m, y, estimator, config): interaction=False, forest=False, crossfit=2, - clip=0.0, + clip=1e-6, regularization=True, calibration='sigmoid', ) @@ -532,7 +534,7 @@ def get_estimation(x, t, m, y, estimator, config): interaction=False, forest=False, crossfit=2, - clip=0.0, + clip=1e-6, regularization=True, calibration="isotonic", ) @@ -546,7 +548,7 @@ def get_estimation(x, t, m, y, estimator, config): interaction=False, forest=True, crossfit=0, - clip=0.0, + clip=1e-6, regularization=True, calibration=None, ) @@ -560,7 +562,7 @@ def get_estimation(x, t, m, y, estimator, config): interaction=False, forest=True, crossfit=2, - clip=0.0, + clip=1e-6, regularization=True, calibration=None, ) @@ -574,7 +576,7 @@ def get_estimation(x, t, m, y, estimator, config): interaction=False, forest=True, crossfit=0, - clip=0.0, + clip=1e-6, regularization=True, calibration='sigmoid', ) @@ -588,7 +590,7 @@ def get_estimation(x, t, m, y, estimator, config): interaction=False, forest=True, crossfit=0, - clip=0.0, + clip=1e-6, regularization=True, calibration="isotonic", ) @@ -602,7 +604,7 @@ def get_estimation(x, t, m, y, estimator, config): interaction=False, forest=True, crossfit=2, - clip=0.0, + clip=1e-6, regularization=True, calibration='sigmoid', ) @@ -616,55 +618,104 @@ def get_estimation(x, t, m, y, estimator, config): interaction=False, forest=True, crossfit=2, - clip=0.0, + clip=1e-6, regularization=True, calibration="isotonic", ) elif estimator == "simulation_based": if config in (0, 1, 2): effects = r_mediate(y, t, m, x, interaction=False) - elif estimator == "mediation_DML": + elif estimator == "mediation_dml": if config > 0: - effects = r_mediation_DML(y, t, m, x, trim=0.0, order=1) - elif estimator == "mediation_DML_noreg": - effects = mediation_DML( - y, t, m, x, trim=0, regularization=False, calibration=None) - elif estimator == "mediation_DML_reg": - effects = mediation_DML(y, t, m, x, trim=0, calibration=None) - elif estimator == "mediation_DML_reg_fixed_seed": - effects = mediation_DML( - y, t, m, x, trim=0, random_state=321, calibration=None) - elif estimator == "mediation_DML_noreg_cf": - effects = mediation_DML( + effects = r_mediation_dml(y, t, m, x, trim=0.0, order=1) + elif estimator == "mediation_dml_noreg": + effects = mediation_dml( y, t, m, x, trim=0, + clip=1e-6, + regularization=False, + calibration=None) + elif estimator == "mediation_dml_reg": + effects = mediation_dml( + y, t, m, x, trim=0, clip=1e-6, calibration=None) + elif estimator == "mediation_dml_reg_fixed_seed": + effects = mediation_dml( + y, t, m, x, trim=0, clip=1e-6, random_state=321, calibration=None) + elif estimator == "mediation_dml_noreg_cf": + effects = mediation_dml( + y, + t, + m, + x, + trim=0, + clip=1e-6, crossfit=2, regularization=False, calibration=None) - elif estimator == "mediation_DML_reg_cf": - effects = mediation_DML( - y, t, m, x, trim=0, crossfit=2, calibration=None) - elif estimator == "mediation_DML_reg_calibration": - effects = mediation_DML( - y, t, m, x, trim=0, crossfit=0, calibration='sigmoid') - elif estimator == "mediation_DML_forest": - effects = mediation_DML( - y, t, m, x, trim=0, crossfit=0, calibration=None, forest=True) - elif estimator == "mediation_DML_forest_calibration": - effects = mediation_DML( - y, t, m, x, trim=0, crossfit=0, calibration='sigmoid', forest=True) - elif estimator == "mediation_DML_reg_calibration_cf": - effects = mediation_DML( - y, t, m, x, trim=0, crossfit=2, calibration='sigmoid', forest=False) - elif estimator == "mediation_DML_forest_cf": - effects = mediation_DML( - y, t, m, x, trim=0, crossfit=2, calibration=None, forest=True) - elif estimator == "mediation_DML_forest_calibration_cf": - effects = mediation_DML( - y, t, m, x, trim=0, crossfit=2, calibration='sigmoid', forest=True) + elif estimator == "mediation_dml_reg_cf": + effects = mediation_dml( + y, t, m, x, trim=0, clip=1e-6, crossfit=2, calibration=None) + elif estimator == "mediation_dml_reg_calibration": + effects = mediation_dml( + y, t, m, x, trim=0, clip=1e-6, crossfit=0, calibration='sigmoid') + elif estimator == "mediation_dml_forest": + effects = mediation_dml( + y, + t, + m, + x, + trim=0, + clip=1e-6, + crossfit=0, + calibration=None, + forest=True) + elif estimator == "mediation_dml_forest_calibration": + effects = mediation_dml( + y, + t, + m, + x, + trim=0, + clip=1e-6, + crossfit=0, + calibration='sigmoid', + forest=True) + elif estimator == "mediation_dml_reg_calibration_cf": + effects = mediation_dml( + y, + t, + m, + x, + trim=0, + clip=1e-6, + crossfit=2, + calibration='sigmoid', + forest=False) + elif estimator == "mediation_dml_forest_cf": + effects = mediation_dml( + y, + t, + m, + x, + trim=0, + clip=1e-6, + crossfit=2, + calibration=None, + forest=True) + elif estimator == "mediation_dml_forest_calibration_cf": + effects = mediation_dml( + y, + t, + m, + x, + trim=0, + clip=1e-6, + crossfit=2, + calibration='sigmoid', + forest=True) elif estimator == "mediation_g_estimator": if config in (0, 1, 2): effects = r_mediation_g_estimator(y, t, m, x) diff --git a/src/med_bench/get_simulated_data.py b/src/med_bench/get_simulated_data.py index 6b8eefa..b646390 100644 --- a/src/med_bench/get_simulated_data.py +++ b/src/med_bench/get_simulated_data.py @@ -1,13 +1,7 @@ import numpy as np from numpy.random import default_rng from scipy import stats -import pandas as pd -from pathlib import Path -from scipy.stats import bernoulli from scipy.special import expit -import matplotlib.pyplot as plt -import pathlib -import seaborn as sns def simulate_data(n, @@ -23,16 +17,16 @@ def simulate_data(n, beta_t_factor=1, beta_m_factor=1): """Simulate data for mediation analysis - + Parameters ---------- n: :obj:`int`, Number of samples to generate. - + rg: RandomState instance, Controls the pseudo random number generator used to generate the data at fit time. - + mis_spec_m: obj:`bool`, Whether the mediator generation is misspecified or not defaults to False @@ -40,7 +34,7 @@ def simulate_data(n, mis_spec_y: obj:`bool`, Whether the output model is misspecified or not defaults to False - + dim_x: :obj:`int`, optional, Number of covariates in the input. Defaults to 1 @@ -48,13 +42,13 @@ def simulate_data(n, dim_m: :obj:`int`, optional, Number of mediatiors to generate. Defaults to 1 - + seed: :obj:`int` or None, optional, Controls the pseudo random number generator used to generate the coefficients of the model. Pass an int for reproducible output across multiple function calls. Defaults to None - + type_m: :obj:`str`, Whether the mediator is binary or continuous Defaults to 'binary', @@ -66,7 +60,7 @@ def simulate_data(n, sigma_m :obj:`float`, noise variance on mediator Defaults to 0.5, - + beta_t_factor: :obj:`float`, scaling factor on treatment effect, Defaults to 1, @@ -74,18 +68,18 @@ def simulate_data(n, beta_m_factor: :obj:`float`, scaling factor on mediator, Defaults to 1, - + returns ------- x: ndarray of shape (n, dim_x) the simulated covariates - + t: ndarray of shape (n, 1) the simulated treatment - + m: ndarray of shape (n, dim_m) the simulated mediators - + y: ndarray of shape (n, 1) the simulated outcome @@ -137,9 +131,11 @@ def simulate_data(n, m = m_2d[np.arange(n), t[:, 0]].reshape(-1, 1) else: random_noise = sigma_m * rg.standard_normal((n, dim_m)) - m0 = x.dot(beta_x) + t0.dot(beta_t) + t0 * (x.dot(beta_xt)) + random_noise - m1 = x.dot(beta_x) + t1.dot(beta_t) + t1 * (x.dot(beta_xt)) + random_noise - m = x.dot(beta_x) + t.dot(beta_t) + t * (x.dot(beta_xt)) + random_noise + m0 = x.dot(beta_x) + t0.dot(beta_t) + t0 * \ + (x.dot(beta_xt)) + random_noise + m1 = x.dot(beta_x) + t1.dot(beta_t) + t1 * \ + (x.dot(beta_xt)) + random_noise + m = x.dot(beta_x) + t.dot(beta_t) + t * (x.dot(beta_xt)) + random_noise # generate the outcome Y gamma_m = np.ones((dim_m, 1)) * 0.5 / dim_m * beta_m_factor @@ -150,31 +146,38 @@ def simulate_data(n, else: gamma_t_m = np.zeros((dim_m, 1)) - y = x.dot(gamma_x) + gamma_t * t + m.dot(gamma_m) + m.dot(gamma_t_m) * t + sigma_y * rg.standard_normal((n, 1)) + y = x.dot(gamma_x) + gamma_t * t + m.dot(gamma_m) + \ + m.dot(gamma_t_m) * t + sigma_y * rg.standard_normal((n, 1)) # Compute differents types of effects if type_m == 'binary': theta_1 = gamma_t + gamma_t_m * np.mean(p_m1) theta_0 = gamma_t + gamma_t_m * np.mean(p_m0) - delta_1 = np.mean((p_m1 - p_m0) * (gamma_m.flatten() + gamma_t_m.dot(t1.T))) - delta_0 = np.mean((p_m1 - p_m0) * (gamma_m.flatten() + gamma_t_m.dot(t0.T))) + delta_1 = np.mean( + (p_m1 - p_m0) * (gamma_m.flatten() + gamma_t_m.dot(t1.T))) + delta_0 = np.mean( + (p_m1 - p_m0) * (gamma_m.flatten() + gamma_t_m.dot(t0.T))) else: - theta_1 = gamma_t + gamma_t_m.T.dot(np.mean(m1, axis=0)) # to do mean(m1) pour avoir un vecteur de taille dim_m + # to do mean(m1) pour avoir un vecteur de taille dim_m + theta_1 = gamma_t + gamma_t_m.T.dot(np.mean(m1, axis=0)) theta_0 = gamma_t + gamma_t_m.T.dot(np.mean(m0, axis=0)) - delta_1 = (gamma_t * t1 + m1.dot(gamma_m) + m1.dot(gamma_t_m) * t1 - (gamma_t * t1 + m0.dot(gamma_m) + m0.dot(gamma_t_m) * t1)).mean() - delta_0 = (gamma_t * t0 + m1.dot(gamma_m) + m1.dot(gamma_t_m) * t0 - (gamma_t * t0 + m0.dot(gamma_m) + m0.dot(gamma_t_m) * t0)).mean() + delta_1 = (gamma_t * t1 + m1.dot(gamma_m) + m1.dot(gamma_t_m) * t1 - + (gamma_t * t1 + m0.dot(gamma_m) + m0.dot(gamma_t_m) * t1)).mean() + delta_0 = (gamma_t * t0 + m1.dot(gamma_m) + m1.dot(gamma_t_m) * t0 - + (gamma_t * t0 + m0.dot(gamma_m) + m0.dot(gamma_t_m) * t0)).mean() if type_m == 'binary': pre_pm = np.hstack((p_m0.reshape(-1, 1), p_m1.reshape(-1, 1))) - pre_pm[m.ravel()==0, :] = 1 - pre_pm[m.ravel()==0, :] + pre_pm[m.ravel() == 0, :] = 1 - pre_pm[m.ravel() == 0, :] pm = pre_pm[:, 1].reshape(-1, 1) else: - p_m0 = np.prod(stats.norm.pdf((m - x.dot(beta_x)) - t0.dot(beta_t) - t0 * (x.dot(beta_xt)) / sigma_m), axis=1) - p_m1 = np.prod(stats.norm.pdf((m - x.dot(beta_x)) - t1.dot(beta_t) - t1 * (x.dot(beta_xt)) / sigma_m), axis=1) + p_m0 = np.prod(stats.norm.pdf((m - x.dot(beta_x)) - + t0.dot(beta_t) - t0 * (x.dot(beta_xt)) / sigma_m), axis=1) + p_m1 = np.prod(stats.norm.pdf((m - x.dot(beta_x)) - + t1.dot(beta_t) - t1 * (x.dot(beta_xt)) / sigma_m), axis=1) pre_pm = np.hstack((p_m0.reshape(-1, 1), p_m1.reshape(-1, 1))) pm = pre_pm[:, 1].reshape(-1, 1) - px = np.prod(stats.norm.pdf(x), axis=1) pre_pt = np.hstack(((1-p_t).reshape(-1, 1), p_t.reshape(-1, 1))) @@ -182,15 +185,15 @@ def simulate_data(n, denom = np.sum(pre_pm * pre_pt * double_px, axis=1) num = pm.ravel() * p_t.ravel() * px.ravel() th_p_t_mx = num.ravel() / denom - - return (x, - t, - m, - y, + + return (x, + t, + m, + y, theta_1.flatten()[0] + delta_0.flatten()[0], - theta_1.flatten()[0], - theta_0.flatten()[0], + theta_1.flatten()[0], + theta_0.flatten()[0], delta_1.flatten()[0], - delta_0.flatten()[0], - p_t, - th_p_t_mx) \ No newline at end of file + delta_0.flatten()[0], + p_t, + th_p_t_mx) diff --git a/src/med_bench/mediation.py b/src/med_bench/mediation.py index f19f93c..5d2071b 100644 --- a/src/med_bench/mediation.py +++ b/src/med_bench/mediation.py @@ -8,16 +8,9 @@ import numpy as np import pandas as pd -from numpy.random import default_rng -from scipy import stats -from scipy.special import expit -from scipy.stats import bernoulli from sklearn.base import clone -from sklearn.calibration import CalibratedClassifierCV -from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor -from sklearn.linear_model import LassoCV, LogisticRegressionCV, RidgeCV -from sklearn.model_selection import KFold -from sklearn.preprocessing import PolynomialFeatures +from sklearn.linear_model import RidgeCV + from .utils.nuisances import (_estimate_conditional_mean_outcome, _estimate_cross_conditional_mean_outcome, @@ -33,7 +26,7 @@ def mediation_IPW(y, t, m, x, trim, regularization=True, forest=False, - crossfit=0, clip=0.01, calibration='sigmoid'): + crossfit=0, clip=1e-6, calibration='sigmoid'): """ IPW estimator presented in HUBER, Martin. Identifying causal mechanisms (primarily) based on inverse @@ -76,7 +69,7 @@ def mediation_IPW(y, t, m, x, trim, regularization=True, forest=False, crossfit : integer, default=0 number of folds for cross-fitting - clip : float, default=0.01 + clip : float, default=1e-6 limit to clip for numerical stability (min=clip, max=1-clip) calibration : str, default=sigmoid @@ -123,12 +116,12 @@ def mediation_IPW(y, t, m, x, trim, regularization=True, forest=False, y0m1 = np.sum(y * (1 - t) * p_xm / ((1 - p_xm) * p_x)) /\ np.sum((1 - t) * p_xm / ((1 - p_xm) * p_x)) - return(y1m1 - y0m0, - y1m1 - y0m1, - y1m0 - y0m0, - y1m1 - y1m0, - y0m1 - y0m0, - np.sum(ind)) + return (y1m1 - y0m0, + y1m1 - y0m1, + y1m0 - y0m0, + y1m1 - y1m0, + y0m1 - y0m0, + np.sum(ind)) def mediation_coefficient_product(y, t, m, x, interaction=False, @@ -203,12 +196,12 @@ def mediation_coefficient_product(y, t, m, x, interaction=False, # return total, direct and indirect effect direct_effect = y_reg.coef_[x.shape[1]] indirect_effect = sum(y_reg.coef_[x.shape[1] + 1:] * coef_t_m) - return [direct_effect + indirect_effect, + return (direct_effect + indirect_effect, direct_effect, direct_effect, indirect_effect, indirect_effect, - None] + None) def mediation_g_formula(y, t, m, x, interaction=False, forest=False, @@ -284,12 +277,12 @@ def mediation_g_formula(y, t, m, x, interaction=False, forest=False, + indirect_effect_i0 * mu_00x).sum() / n total_effect = direct_effect_control + indirect_effect_treated - return [total_effect, + return (total_effect, direct_effect_treated, direct_effect_control, indirect_effect_treated, indirect_effect_control, - None] + None) def alternative_estimator(y, t, m, x, regularization=True): @@ -352,16 +345,16 @@ def alternative_estimator(y, t, m, x, regularization=True): # computation of indirect effect indirect_effect = total_effect - direct_effect - return [total_effect, + return (total_effect, direct_effect, direct_effect, indirect_effect, indirect_effect, - None] + None) def mediation_multiply_robust(y, t, m, x, interaction=False, forest=False, - crossfit=0, clip=0.01, normalized=True, + crossfit=0, clip=1e-6, normalized=True, regularization=True, calibration="sigmoid"): """ Presented in Eric J. Tchetgen Tchetgen. Ilya Shpitser. @@ -397,7 +390,7 @@ def mediation_multiply_robust(y, t, m, x, interaction=False, forest=False, Number of folds for cross-fitting. If crossfit<2, no cross-fitting is applied - clip : float, default=0.01 + clip : float, default=1e-6 Limit to clip p_x and f_mtx for numerical stability (min=clip, max=1-clip) @@ -510,28 +503,28 @@ def mediation_multiply_robust(y, t, m, x, interaction=False, forest=False, y0m0 = (((1 - t) / (1 - p_x) * (y - E_mu_t0_t0)) / sum_score_m0 + E_mu_t0_t0) y1m0 = ( - ((t / p_x) * (f_m0x / f_m1x) * (y - mu_1mx)) / sum_score_t1m0 - + ((1 - t) / (1 - p_x) * (mu_1mx - E_mu_t1_t0)) / sum_score_m0 - + E_mu_t1_t0 + ((t / p_x) * (f_m0x / f_m1x) * (y - mu_1mx)) / sum_score_t1m0 + + ((1 - t) / (1 - p_x) * (mu_1mx - E_mu_t1_t0)) / sum_score_m0 + + E_mu_t1_t0 ) y0m1 = ( - ((1 - t) / (1 - p_x) * (f_m1x / f_m0x) * (y - mu_0mx)) - / sum_score_t0m1 + t / p_x * ( - mu_0mx - E_mu_t0_t1) / sum_score_m1 - + E_mu_t0_t1 + ((1 - t) / (1 - p_x) * (f_m1x / f_m0x) * (y - mu_0mx)) + / sum_score_t0m1 + t / p_x * ( + mu_0mx - E_mu_t0_t1) / sum_score_m1 + + E_mu_t0_t1 ) else: y1m1 = t / p_x * (y - E_mu_t1_t1) + E_mu_t1_t1 y0m0 = (1 - t) / (1 - p_x) * (y - E_mu_t0_t0) + E_mu_t0_t0 y1m0 = ( - (t / p_x) * (f_m0x / f_m1x) * (y - mu_1mx) - + (1 - t) / (1 - p_x) * (mu_1mx - E_mu_t1_t0) - + E_mu_t1_t0 + (t / p_x) * (f_m0x / f_m1x) * (y - mu_1mx) + + (1 - t) / (1 - p_x) * (mu_1mx - E_mu_t1_t0) + + E_mu_t1_t0 ) y0m1 = ( - (1 - t) / (1 - p_x) * (f_m1x / f_m0x) * (y - mu_0mx) - + t / p_x * (mu_0mx - E_mu_t0_t1) - + E_mu_t0_t1 + (1 - t) / (1 - p_x) * (f_m1x / f_m0x) * (y - mu_0mx) + + t / p_x * (mu_0mx - E_mu_t0_t1) + + E_mu_t0_t1 ) # effects computing @@ -665,16 +658,16 @@ def r_mediation_g_estimator(y, t, m, x): data=base.as_symbol('df')) direct_effect = res.rx2('coef')[0] indirect_effect = res.rx2('coef')[1] - return [direct_effect + indirect_effect, + return (direct_effect + indirect_effect, direct_effect, direct_effect, indirect_effect, indirect_effect, - None] + None) @r_dependency_required(['causalweight', 'base']) -def r_mediation_DML(y, t, m, x, trim=0.05, order=1): +def r_mediation_dml(y, t, m, x, trim=0.05, order=1): """ This function calls the R Double Machine Learning estimator from the package causalweight (https://cran.r-project.org/web/packages/causalweight) @@ -709,7 +702,7 @@ def r_mediation_DML(y, t, m, x, trim=0.05, order=1): Polynomials/interactions are created using the Generate. Powers command of the LARF package. """ - + import rpy2.robjects.packages as rpackages from rpy2.robjects import numpy2ri, pandas2ri from .utils.utils import _convert_array_to_R @@ -728,7 +721,7 @@ def r_mediation_DML(y, t, m, x, trim=0.05, order=1): return list(raw_res_R[0, :5]) + [ntrimmed] -def mediation_DML(y, t, m, x, forest=False, crossfit=0, trim=0.05, +def mediation_dml(y, t, m, x, forest=False, crossfit=0, trim=0.05, clip=1e-6, normalized=True, regularization=True, random_state=None, calibration=None): """ @@ -765,6 +758,9 @@ def mediation_DML(y, t, m, x, forest=False, crossfit=0, trim=0.05, trim : float, default=0.05 Trimming treshold for discarding observations with extreme probability. + clip : float, default=1e-6 + limit to clip for numerical stability (min=clip, max=1-clip) + normalized : boolean, default=True Normalizes the inverse probability-based weights so they add up to 1, as described in "Identifying causal mechanisms (primarily) based on @@ -831,7 +827,6 @@ def mediation_DML(y, t, m, x, forest=False, crossfit=0, trim=0.05, nobs = 0 - var_name = [ "p_x", "p_xm", @@ -870,6 +865,10 @@ def mediation_DML(y, t, m, x, forest=False, crossfit=0, trim=0.05, exec(f"{var} = {var}[not_trimmed]") nobs = np.sum(not_trimmed) + # clipping + p_x = np.clip(p_x, clip, 1 - clip) + p_xm = np.clip(p_xm, clip, 1 - clip) + # score computing if normalized: sum_score_m1 = np.mean(t / p_x) @@ -916,4 +915,4 @@ def mediation_DML(y, t, m, x, forest=False, crossfit=0, trim=0.05, direct0 = my1m0 - my0m0 indirect1 = my1m1 - my1m0 indirect0 = my0m1 - my0m0 - return total, direct1, direct0, indirect1, indirect0, n - nobs \ No newline at end of file + return total, direct1, direct0, indirect1, indirect0, n - nobs diff --git a/src/med_bench/utils/constants.py b/src/med_bench/utils/constants.py new file mode 100644 index 0000000..417f993 --- /dev/null +++ b/src/med_bench/utils/constants.py @@ -0,0 +1,158 @@ +import itertools +import numpy as np +from numpy.random import default_rng + +# CONSTANTS USED FOR TESTS + +# TOLERANCE THRESHOLDS + +TOLERANCE_THRESHOLDS = { + "SMALL": { + "ATE": 0.05, + "DIRECT": 0.05, + "INDIRECT": 0.2, + }, + "MEDIUM": { + "ATE": 0.10, + "DIRECT": 0.10, + "INDIRECT": 0.4, + }, + "LARGE": { + "ATE": 0.15, + "DIRECT": 0.15, + "INDIRECT": 0.9, + }, + "INFINITE": { + "ATE": np.inf, + "DIRECT": np.inf, + "INDIRECT": np.inf, + }, +} + + +def get_tolerance_array(tolerance_size: str) -> np.array: + """Get tolerance array for tolerance tests + + Parameters + ---------- + tolerance_size : str + tolerance size, can be "SMALL", "MEDIUM", "LARGE" or "INFINITE" + + Returns + ------- + np.array + array of size 5 containing the ATE, DIRECT (*2) and INDIRECT (*2) effects tolerance + """ + + return np.array( + [ + TOLERANCE_THRESHOLDS[tolerance_size]["ATE"], + TOLERANCE_THRESHOLDS[tolerance_size]["DIRECT"], + TOLERANCE_THRESHOLDS[tolerance_size]["DIRECT"], + TOLERANCE_THRESHOLDS[tolerance_size]["INDIRECT"], + TOLERANCE_THRESHOLDS[tolerance_size]["INDIRECT"], + ] + ) + + +SMALL_TOLERANCE = get_tolerance_array("SMALL") +MEDIUM_TOLERANCE = get_tolerance_array("MEDIUM") +LARGE_TOLERANCE = get_tolerance_array("LARGE") +INFINITE_TOLERANCE = get_tolerance_array("INFINITE") + +TOLERANCE_DICT = { + "coefficient_product": LARGE_TOLERANCE, + "mediation_ipw_noreg": INFINITE_TOLERANCE, + "mediation_ipw_reg": INFINITE_TOLERANCE, + "mediation_ipw_reg_calibration": INFINITE_TOLERANCE, + "mediation_ipw_forest": INFINITE_TOLERANCE, + "mediation_ipw_forest_calibration": INFINITE_TOLERANCE, + "mediation_g_computation_noreg": LARGE_TOLERANCE, + "mediation_g_computation_reg": MEDIUM_TOLERANCE, + "mediation_g_computation_reg_calibration": LARGE_TOLERANCE, + "mediation_g_computation_forest": LARGE_TOLERANCE, + "mediation_g_computation_forest_calibration": INFINITE_TOLERANCE, + "mediation_multiply_robust_noreg": INFINITE_TOLERANCE, + "mediation_multiply_robust_reg": LARGE_TOLERANCE, + "mediation_multiply_robust_reg_calibration": LARGE_TOLERANCE, + "mediation_multiply_robust_forest": INFINITE_TOLERANCE, + "mediation_multiply_robust_forest_calibration": LARGE_TOLERANCE, + "simulation_based": LARGE_TOLERANCE, + "mediation_dml": INFINITE_TOLERANCE, + "mediation_dml_reg_fixed_seed": INFINITE_TOLERANCE, + "mediation_g_estimator": LARGE_TOLERANCE, + "mediation_ipw_noreg_cf": INFINITE_TOLERANCE, + "mediation_ipw_reg_cf": INFINITE_TOLERANCE, + "mediation_ipw_reg_calibration_cf": INFINITE_TOLERANCE, + "mediation_ipw_forest_cf": INFINITE_TOLERANCE, + "mediation_ipw_forest_calibration_cf": INFINITE_TOLERANCE, + "mediation_g_computation_noreg_cf": SMALL_TOLERANCE, + "mediation_g_computation_reg_cf": LARGE_TOLERANCE, + "mediation_g_computation_reg_calibration_cf": LARGE_TOLERANCE, + "mediation_g_computation_forest_cf": INFINITE_TOLERANCE, + "mediation_g_computation_forest_calibration_cf": LARGE_TOLERANCE, + "mediation_multiply_robust_noreg_cf": MEDIUM_TOLERANCE, + "mediation_multiply_robust_reg_cf": LARGE_TOLERANCE, + "mediation_multiply_robust_reg_calibration_cf": MEDIUM_TOLERANCE, + "mediation_multiply_robust_forest_cf": INFINITE_TOLERANCE, + "mediation_multiply_robust_forest_calibration_cf": INFINITE_TOLERANCE, +} + +ESTIMATORS = list(TOLERANCE_DICT.keys()) + +R_DEPENDENT_ESTIMATORS = [ + "mediation_IPW_R", "simulation_based", "mediation_dml", "mediation_g_estimator" +] + +# PARAMETERS VALUES FOR DATA GENERATION + +PARAMETER_NAME = [ + "n", + "rg", + "mis_spec_m", + "mis_spec_y", + "dim_x", + "dim_m", + "seed", + "type_m", + "sigma_y", + "sigma_m", + "beta_t_factor", + "beta_m_factor", +] + +PARAMETER_LIST = list( + itertools.product( + [1000], + [default_rng(321)], + [False], + [False], + [1, 5], + [1], + [123], + ["binary"], + [0.5], + [0.5], + [0.5], + [0.5], + ) +) + +PARAMETER_LIST.extend( + list( + itertools.product( + [1000], + [default_rng(321)], + [False], + [False], + [1, 5], + [1, 5], + [123], + ["continuous"], + [0.5], + [0.5], + [0.5], + [0.5], + ) + ) +) diff --git a/src/med_bench/utils/nuisances.py b/src/med_bench/utils/nuisances.py index 2d13449..ded68f0 100644 --- a/src/med_bench/utils/nuisances.py +++ b/src/med_bench/utils/nuisances.py @@ -3,22 +3,16 @@ used in mediation estimators in causal inference """ import numpy as np -import pandas as pd -from numpy.random import default_rng -from scipy import stats -from scipy.special import expit -from scipy.stats import bernoulli from sklearn.base import clone from sklearn.calibration import CalibratedClassifierCV from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor -from sklearn.linear_model import LassoCV, LogisticRegressionCV, RidgeCV +from sklearn.linear_model import LogisticRegressionCV, RidgeCV from sklearn.model_selection import KFold -from sklearn.preprocessing import PolynomialFeatures -from .utils import check_r_dependencies +from .utils import check_r_dependencies, _get_interactions if check_r_dependencies(): - from .utils import _convert_array_to_R, _get_interactions + from .utils import _convert_array_to_R ALPHAS = np.logspace(-5, 5, 8) @@ -103,7 +97,8 @@ def _get_regressor(regularization, forest, random_state=42): if not forest: reg = RidgeCV(alphas=alphas, cv=CV_FOLDS) else: - reg = RandomForestRegressor(n_estimators=100, min_samples_leaf=10) + reg = RandomForestRegressor( + n_estimators=100, min_samples_leaf=10, random_state=random_state) return reg @@ -195,7 +190,6 @@ def _estimate_mediator_density(t, m, x, y, crossfit, clf_m, interaction): # f_mtx model fitting clf_m = clf_m.fit(t_x[train_index, :], m[train_index]) - #clf_m = clf_m.fit(t_x[train_index, :], m.ravel()[train_index]) # predict f(M=m|T=t,X) fm_0 = clf_m.predict_proba(t0_x[test_index, :]) @@ -368,12 +362,12 @@ def _estimate_cross_conditional_mean_outcome(t, m, x, y, crossfit, reg_y, # predict E[E[Y|T=1,M=m,X]|T=t,X] E_mu_t1_t0[test_index] = ( - reg_y_t1m0_t0.predict(x[test_index, :]) * f_00x[test_index] - + reg_y_t1m1_t0.predict(x[test_index, :]) * f_01x[test_index] + reg_y_t1m0_t0.predict(x[test_index, :]) * f_00x[test_index] + + reg_y_t1m1_t0.predict(x[test_index, :]) * f_01x[test_index] ) E_mu_t1_t1[test_index] = ( - reg_y_t1m0_t1.predict(x[test_index, :]) * f_10x[test_index] - + reg_y_t1m1_t1.predict(x[test_index, :]) * f_11x[test_index] + reg_y_t1m0_t1.predict(x[test_index, :]) * f_10x[test_index] + + reg_y_t1m1_t1.predict(x[test_index, :]) * f_11x[test_index] ) # E[E[Y|T=0,M=m,X]|T=t,X] model fitting @@ -392,12 +386,12 @@ def _estimate_cross_conditional_mean_outcome(t, m, x, y, crossfit, reg_y, # predict E[E[Y|T=0,M=m,X]|T=t,X] E_mu_t0_t0[test_index] = ( - reg_y_t0m0_t0.predict(x[test_index, :]) * f_00x[test_index] - + reg_y_t0m1_t0.predict(x[test_index, :]) * f_01x[test_index] + reg_y_t0m0_t0.predict(x[test_index, :]) * f_00x[test_index] + + reg_y_t0m1_t0.predict(x[test_index, :]) * f_01x[test_index] ) E_mu_t0_t1[test_index] = ( - reg_y_t0m0_t1.predict(x[test_index, :]) * f_10x[test_index] - + reg_y_t0m1_t1.predict(x[test_index, :]) * f_11x[test_index] + reg_y_t0m0_t1.predict(x[test_index, :]) * f_10x[test_index] + + reg_y_t0m1_t1.predict(x[test_index, :]) * f_11x[test_index] ) return mu_0mx, mu_1mx, E_mu_t0_t0, E_mu_t0_t1, E_mu_t1_t0, E_mu_t1_t1 diff --git a/src/med_bench/utils/utils.py b/src/med_bench/utils/utils.py index 99d7a67..ab7f44d 100644 --- a/src/med_bench/utils/utils.py +++ b/src/med_bench/utils/utils.py @@ -1,4 +1,6 @@ import numpy as np +import pandas as pd + import subprocess import warnings @@ -7,14 +9,14 @@ def check_r_dependencies(): try: # Check if R is accessible by trying to get its version subprocess.check_output(["R", "--version"]) - + # If the above command fails, it will raise a subprocess.CalledProcessError and won't reach here - + # Assuming reaching here means R is accessible, now try importing rpy2 packages import rpy2.robjects.packages as rpackages required_packages = [ 'causalweight', 'mediation', 'stats', 'base', 'grf', 'plmed' - ] + ] for package in required_packages: rpackages.importr(package) @@ -33,6 +35,7 @@ def is_r_installed(): except: return False + def check_r_package(package_name): try: import rpy2.robjects.packages as rpackages @@ -40,31 +43,38 @@ def check_r_package(package_name): return True except: return False - + + +class DependencyNotInstalledError(Exception): + pass + def r_dependency_required(required_packages): def decorator(func): def wrapper(*args, **kwargs): if not is_r_installed(): - warnings.warn( + raise DependencyNotInstalledError( "R is not installed or not found. " "Please install R and set it up correctly in your system." ) - return None - + + # To get rid of the 'DataFrame' object has no attribute 'iteritems' error due to pandas version mismatch in rpy2 + # https://stackoverflow.com/a/76404841 + pd.DataFrame.iteritems = pd.DataFrame.items + for package in required_packages: if not check_r_package(package): - if package!='plmed': - warnings.warn( + if package != 'plmed': + raise DependencyNotInstalledError( f"The '{package}' R package is not installed. " "Please install it using R by running:\n" "import rpy2.robjects.packages as rpackages\n" "utils = rpackages.importr('utils')\n" "utils.chooseCRANmirror(ind=33)\n" f"utils.install_packages('{package}')" - ) + ) else: - warnings.warn( + raise DependencyNotInstalledError( "The 'plmed' R package is not installed. " "Please install it using R by running:\n" "import rpy2.robjects.packages as rpackages\n" @@ -73,12 +83,12 @@ def wrapper(*args, **kwargs): "utils.install_packages('devtools')\n" "devtools = rpackages.importr('devtools')\n" "devtools.install_github('ohines/plmed')" - ) + ) return None return func(*args, **kwargs) return wrapper return decorator - + if is_r_installed(): import rpy2.robjects as robjects @@ -121,7 +131,7 @@ def _get_interactions(interaction, *args): variables = list(args) for index, var in enumerate(variables): if len(var.shape) == 1: - variables[index] = var.reshape(-1,1) + variables[index] = var.reshape(-1, 1) pre_inter_variables = np.hstack(variables) if not interaction: return pre_inter_variables @@ -135,6 +145,7 @@ def _get_interactions(interaction, *args): result = np.hstack((pre_inter_variables, new_vars)) return result + def _convert_array_to_R(x): """ converts a numpy array to a R matrix or vector diff --git a/src/tests/estimation/generate_tests_results.py b/src/tests/estimation/generate_tests_results.py new file mode 100644 index 0000000..b698a05 --- /dev/null +++ b/src/tests/estimation/generate_tests_results.py @@ -0,0 +1,63 @@ +import numpy as np + +from med_bench.get_simulated_data import simulate_data +from med_bench.get_estimation import get_estimation + +from med_bench.utils.constants import ESTIMATORS, PARAMETER_LIST, PARAMETER_NAME + + +def _get_data_from_list(data): + """Get x, t, m, y from simulated data + """ + x = data[0] + t = data[1].ravel() + m = data[2] + y = data[3].ravel() + + return x, t, m, y + + +def _get_config_from_dict(dict_params): + """Get config parameter used for estimators parametrisation + """ + if dict_params["dim_m"] == 1 and dict_params["type_m"] == "binary": + config = 0 + else: + config = 5 + return config + + +def _get_estimators_results(x, t, m, y, config, estimator): + """Get estimation result from specified input parameters and estimator name + """ + + try: + res = get_estimation(x, t, m, y, estimator, config)[0:5] + return res + + except Exception as e: + print(f"{e}") + return str(e) + + +if __name__ == "__main__": + + results = [] + + for param_list in PARAMETER_LIST: + + # Get synthetic input data from parameters list defined above + dict_params = dict(zip(PARAMETER_NAME, param_list)) + data = simulate_data(**dict_params) + x, t, m, y = _get_data_from_list(data) + config = _get_config_from_dict(dict_params=dict_params) + + for estimator in ESTIMATORS: + + # Get results from synthetic inputs + result = _get_estimators_results(x, t, m, y, config, estimator) + row = [estimator, x, t, m, y, config, result] + results.append(row) + + # Store the results in a npy file + np.save("tests_results.npy", np.array(results, dtype="object")) diff --git a/src/tests/estimation/test_exact_estimation.py b/src/tests/estimation/test_exact_estimation.py new file mode 100644 index 0000000..98d2d96 --- /dev/null +++ b/src/tests/estimation/test_exact_estimation.py @@ -0,0 +1,109 @@ +""" +Pytest file for get_estimation.py + +It tests all the benchmark_mediation estimators : +- for a certain tolerance +- whether their effects satisfy "total = direct + indirect" +- whether they support (n,1) and (n,) inputs + +To be robust to future updates, tests are adjusted with a smaller tolerance when possible. +The test is skipped if estimator has not been implemented yet, i.e. if ValueError is raised. +The test fails for any other unwanted behavior. +""" + +from pprint import pprint +import pytest +import os +import numpy as np + +from med_bench.get_estimation import get_estimation +from med_bench.utils.constants import R_DEPENDENT_ESTIMATORS +from med_bench.utils.utils import DependencyNotInstalledError, check_r_dependencies + +current_dir = os.path.dirname(__file__) +true_estimations_file_path = os.path.join(current_dir, 'tests_results.npy') +TRUE_ESTIMATIONS = np.load(true_estimations_file_path, allow_pickle=True) + + +@pytest.fixture(params=range(TRUE_ESTIMATIONS.shape[0])) +def tests_results_idx(request): + return request.param + + +@pytest.fixture +def data(tests_results_idx): + return TRUE_ESTIMATIONS[tests_results_idx] + + +@pytest.fixture +def estimator(data): + return data[0] + + +@pytest.fixture +def x(data): + return data[1] + + +# t is raveled because some estimators fail with (n,1) inputs +@pytest.fixture +def t(data): + return data[2].ravel() + + +@pytest.fixture +def m(data): + return data[3] + + +@pytest.fixture +def y(data): + return data[4].ravel() # same reason as t + + +@pytest.fixture +def config(data): + return data[5] + + +@pytest.fixture +def result(data): + return data[6] + + +@pytest.fixture +def effects_chap(x, t, m, y, estimator, config): + # try whether estimator is implemented or not + + try: + res = get_estimation(x, t, m, y, estimator, config)[0:5] + + # NaN situations + if np.all(np.isnan(res)): + pytest.xfail("all effects are NaN") + elif np.any(np.isnan(res)): + pprint("NaN found") + + except Exception as e: + if str(e) in ( + "Estimator only supports 1D binary mediator.", + "Estimator does not support 1D binary mediator.", + ): + pytest.skip(f"{e}") + + # We skip the test if an error with function from glmet rpy2 package occurs + elif "glmnet::glmnet" in str(e): + pytest.skip(f"{e}") + + elif estimator in R_DEPENDENT_ESTIMATORS and not check_r_dependencies(): + assert isinstance(e, DependencyNotInstalledError) == True + pytest.skip(f"{e}") + + else: + pytest.fail(f"{e}") + + return res + + +def test_estimation_exactness(result, effects_chap): + assert np.all(effects_chap == pytest.approx(result, abs=0.01)) diff --git a/src/tests/estimation/test_get_estimation.py b/src/tests/estimation/test_get_estimation.py index 70de0a9..2e1adae 100644 --- a/src/tests/estimation/test_get_estimation.py +++ b/src/tests/estimation/test_get_estimation.py @@ -12,159 +12,14 @@ """ from pprint import pprint -import itertools import pytest import numpy as np -from numpy.random import default_rng + from med_bench.get_simulated_data import simulate_data from med_bench.get_estimation import get_estimation -from med_bench.utils.utils import check_r_dependencies - - -SMALL_ATE_TOLERANCE = 0.05 -SMALL_DIRECT_TOLERANCE = 0.05 -SMALL_INDIRECT_TOLERANCE = 0.2 - -MEDIUM_ATE_TOLERANCE = 0.10 -MEDIUM_DIRECT_TOLERANCE = 0.10 -MEDIUM_INDIRECT_TOLERANCE = 0.4 - -LARGE_ATE_TOLERANCE = 0.15 -LARGE_DIRECT_TOLERANCE = 0.15 -LARGE_INDIRECT_TOLERANCE = 0.8 -# indirect effect is weak, leading to a large relative error - -SMALL_TOLERANCE = np.array( - [ - SMALL_ATE_TOLERANCE, - SMALL_DIRECT_TOLERANCE, - SMALL_DIRECT_TOLERANCE, - SMALL_INDIRECT_TOLERANCE, - SMALL_INDIRECT_TOLERANCE, - ] -) - -MEDIUM_TOLERANCE = np.array( - [ - MEDIUM_ATE_TOLERANCE, - MEDIUM_DIRECT_TOLERANCE, - MEDIUM_DIRECT_TOLERANCE, - MEDIUM_INDIRECT_TOLERANCE, - MEDIUM_INDIRECT_TOLERANCE, - ] -) - -LARGE_TOLERANCE = np.array( - [ - LARGE_ATE_TOLERANCE, - LARGE_DIRECT_TOLERANCE, - LARGE_DIRECT_TOLERANCE, - LARGE_INDIRECT_TOLERANCE, - LARGE_INDIRECT_TOLERANCE, - ] -) - -INFINITE_TOLERANCE = np.array( - [ - np.inf, - np.inf, - np.inf, - np.inf, - np.inf, - ] -) - - -TOLERANCE_DICT = { - "coefficient_product": LARGE_TOLERANCE, - "mediation_ipw_noreg": INFINITE_TOLERANCE, - "mediation_ipw_reg": INFINITE_TOLERANCE, - "mediation_ipw_reg_calibration": INFINITE_TOLERANCE, - "mediation_ipw_forest": INFINITE_TOLERANCE, - "mediation_ipw_forest_calibration": INFINITE_TOLERANCE, - "mediation_g_computation_noreg": LARGE_TOLERANCE, - "mediation_g_computation_reg": MEDIUM_TOLERANCE, - "mediation_g_computation_reg_calibration": LARGE_TOLERANCE, - "mediation_g_computation_forest": LARGE_TOLERANCE, - "mediation_g_computation_forest_calibration": INFINITE_TOLERANCE, - "mediation_multiply_robust_noreg": INFINITE_TOLERANCE, - "mediation_multiply_robust_reg": LARGE_TOLERANCE, - "mediation_multiply_robust_reg_calibration": LARGE_TOLERANCE, - "mediation_multiply_robust_forest": INFINITE_TOLERANCE, - "mediation_multiply_robust_forest_calibration": LARGE_TOLERANCE, - "simulation_based": LARGE_TOLERANCE, - "mediation_DML": INFINITE_TOLERANCE, - "mediation_DML_reg_fixed_seed": INFINITE_TOLERANCE, - "mediation_g_estimator": SMALL_TOLERANCE, - "mediation_ipw_noreg_cf": INFINITE_TOLERANCE, - "mediation_ipw_reg_cf": INFINITE_TOLERANCE, - "mediation_ipw_reg_calibration_cf": INFINITE_TOLERANCE, - "mediation_ipw_forest_cf": INFINITE_TOLERANCE, - "mediation_ipw_forest_calibration_cf": INFINITE_TOLERANCE, - "mediation_g_computation_noreg_cf": SMALL_TOLERANCE, - "mediation_g_computation_reg_cf": LARGE_TOLERANCE, - "mediation_g_computation_reg_calibration_cf": LARGE_TOLERANCE, - "mediation_g_computation_forest_cf": INFINITE_TOLERANCE, - "mediation_g_computation_forest_calibration_cf": LARGE_TOLERANCE, - "mediation_multiply_robust_noreg_cf": MEDIUM_TOLERANCE, - "mediation_multiply_robust_reg_cf": LARGE_TOLERANCE, - "mediation_multiply_robust_reg_calibration_cf": MEDIUM_TOLERANCE, - "mediation_multiply_robust_forest_cf": INFINITE_TOLERANCE, - "mediation_multiply_robust_forest_calibration_cf": INFINITE_TOLERANCE, -} - - -PARAMETER_NAME = [ - "n", - "rg", - "mis_spec_m", - "mis_spec_y", - "dim_x", - "dim_m", - "seed", - "type_m", - "sigma_y", - "sigma_m", - "beta_t_factor", - "beta_m_factor", -] - -PARAMETER_LIST = list( - itertools.product( - [1000], - [default_rng(321)], - [False], - [False], - [1, 5], - [1], - [123], - ["binary"], - [0.5], - [0.5], - [0.5], - [0.5], - ) -) - -PARAMETER_LIST.extend( - list( - itertools.product( - [1000], - [default_rng(321)], - [False], - [False], - [1, 5], - [1, 5], - [123], - ["continuous"], - [0.5], - [0.5], - [0.5], - [0.5], - ) - ) -) +from med_bench.utils.utils import DependencyNotInstalledError, check_r_dependencies +from med_bench.utils.constants import PARAMETER_LIST, PARAMETER_NAME, R_DEPENDENT_ESTIMATORS, TOLERANCE_DICT @pytest.fixture(params=PARAMETER_LIST) @@ -224,30 +79,25 @@ def config(dict_param): def effects_chap(x, t, m, y, estimator, config): # try whether estimator is implemented or not - r_dependent_estimators = [ - "mediation_IPW_R", "simulation_based", "mediation_DML", "mediation_g_estimator" - ] - - if estimator in r_dependent_estimators and not check_r_dependencies(): - warning_message = ( - "R or some required R packages ('causalweight', 'mediation', 'stats', 'base', " - "'grf', 'plmed') not available" - ) - print(warning_message) - pytest.skip( - f"Skipping {estimator} as the required R environment/packages are not available." - ) - try: res = get_estimation(x, t, m, y, estimator, config)[0:5] - except ValueError as message_error: - if message_error.args[0] in ( + except Exception as e: + if str(e) in ( "Estimator only supports 1D binary mediator.", "Estimator does not support 1D binary mediator.", ): - pytest.skip(f"{message_error}") + pytest.skip(f"{e}") + + # We skip the test if an error with function from glmet rpy2 package occurs + elif "glmnet::glmnet" in str(e): + pytest.skip(f"{e}") + + elif estimator in R_DEPENDENT_ESTIMATORS and not check_r_dependencies(): + assert isinstance(e, DependencyNotInstalledError) == True + pytest.skip(f"{e}") + else: - pytest.fail(f"{message_error}") + pytest.fail(f"{e}") # NaN situations if np.all(np.isnan(res)): @@ -265,9 +115,11 @@ def test_tolerance(effects, effects_chap, tolerance): def test_total_is_direct_plus_indirect(effects_chap): if not np.isnan(effects_chap[1]): - assert effects_chap[0] == pytest.approx(effects_chap[1] + effects_chap[4]) + assert effects_chap[0] == pytest.approx( + effects_chap[1] + effects_chap[4]) if not np.isnan(effects_chap[2]): - assert effects_chap[0] == pytest.approx(effects_chap[2] + effects_chap[3]) + assert effects_chap[0] == pytest.approx( + effects_chap[2] + effects_chap[3]) @pytest.mark.xfail @@ -275,7 +127,8 @@ def test_robustness_to_ravel_format(data, estimator, config, effects_chap): if "forest" in estimator: pytest.skip("Forest estimator skipped") assert np.all( - get_estimation(data[0], data[1], data[2], data[3], estimator, config)[0:5] + get_estimation(data[0], data[1], data[2], + data[3], estimator, config)[0:5] == pytest.approx( effects_chap, nan_ok=True ) # effects_chap is obtained with data[1].ravel() and data[3].ravel() diff --git a/src/tests/estimation/tests_results.npy b/src/tests/estimation/tests_results.npy new file mode 100644 index 0000000..b468ed3 Binary files /dev/null and b/src/tests/estimation/tests_results.npy differ diff --git a/src/tests/simulate_data/test_get_simulated_data.py b/src/tests/simulate_data/test_get_simulated_data.py index e13f4d8..b7e066f 100644 --- a/src/tests/simulate_data/test_get_simulated_data.py +++ b/src/tests/simulate_data/test_get_simulated_data.py @@ -14,45 +14,11 @@ """ from pprint import pprint -import itertools import pytest import numpy as np from numpy.random import default_rng from med_bench.get_simulated_data import simulate_data - - -PARAMETER_NAME = [ - "n", - "rg", - "mis_spec_m", - "mis_spec_y", - "dim_x", - "dim_m", - "seed", - "type_m", - "sigma_y", - "sigma_m", - "beta_t_factor", - "beta_m_factor", -] - - -PARAMETER_LIST = list( - itertools.product( - [1, 500, 1000], - [default_rng(321)], - [False, True], - [False, True], - [1, 5], - [1], - [123], - ["binary", "continuous"], - [0.5], - [0.5], - [0.5], - [0.5], - ) -) +from med_bench.utils.constants import PARAMETER_LIST, PARAMETER_NAME @pytest.fixture(params=PARAMETER_LIST)