From 254fa958df8fb430709684f9342206f7d4422f6a Mon Sep 17 00:00:00 2001 From: Akihiro Shimoda Date: Fri, 10 Apr 2026 08:38:53 +0100 Subject: [PATCH 1/2] Add plot_overlap_common_support() method to DoubleMLIRM --- doubleml/irm/irm.py | 110 ++++++++++++++ doubleml/irm/tests/test_irm_overlap_plot.py | 150 ++++++++++++++++++++ doubleml/utils/_plots.py | 135 ++++++++++++++++++ 3 files changed, 395 insertions(+) create mode 100644 doubleml/irm/tests/test_irm_overlap_plot.py diff --git a/doubleml/irm/irm.py b/doubleml/irm/irm.py index cb8e093f4..b268326cc 100644 --- a/doubleml/irm/irm.py +++ b/doubleml/irm/irm.py @@ -674,3 +674,113 @@ def policy_tree(self, features, depth=2, **tree_params): model = DoubleMLPolicyTree(orth_signal, depth=depth, features=features, **tree_params).fit() return model + + def plot_overlap_common_support( + self, + idx_treatment: int = 0, + i_rep: int = 0, + threshold: float = 0.05, + show_warning: bool = True, + ) -> "go.Figure": + """Plot the propensity score overlap (common support) for treatment and control groups. + + Visualizes the distribution of estimated propensity scores :math:`\\hat{m}_0(X) = \\hat{E}[D|X]` + split by treatment status using kernel density estimation. Highlights regions near 0 and 1 + where the positivity assumption :math:`\\eta < m_0(X) < 1 - \\eta` may be violated. + + Parameters + ---------- + idx_treatment : int + Index of the treatment variable (for multi-treatment settings). + Default is ``0``. + + i_rep : int + Index of the repetition to use for the propensity score predictions. + Default is ``0``. + + threshold : float + Threshold for positivity violation warning zones. Vertical lines are drawn at ``threshold`` + and ``1 - threshold`` to highlight the danger zones. Must be in ``(0, 0.5)``. + Default is ``0.05``. + + show_warning : bool + If ``True``, a warning is issued when the share of observations in the positivity violation + zones exceeds 5%. This indicates that IPW-based estimators may suffer from high variance. + Default is ``True``. + + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` + Plotly figure with the propensity score overlap plot. + + Raises + ------ + ValueError + If ``fit()`` has not been called or predictions are not stored. + """ + import plotly.graph_objects as go + + from doubleml.utils._plots import _propensity_score_overlap_plot + + # Input validation + if self._framework is None: + raise ValueError("Apply fit() before plot_overlap_common_support().") + + if self.predictions is None: + raise ValueError( + "Predictions are not stored. Call fit() with store_predictions=True " + "before plot_overlap_common_support()." + ) + + if "ml_m" not in self.predictions: + raise ValueError( + "Propensity score predictions ('ml_m') are not available. " + "Ensure fit() was called with store_predictions=True." + ) + + if not isinstance(idx_treatment, int): + raise TypeError(f"idx_treatment must be an integer. Got {type(idx_treatment)}.") + if idx_treatment < 0 or idx_treatment >= self._dml_data.n_treat: + raise ValueError( + f"idx_treatment must be in [0, {self._dml_data.n_treat - 1}]. Got {idx_treatment}." + ) + + if not isinstance(i_rep, int): + raise TypeError(f"i_rep must be an integer. Got {type(i_rep)}.") + if i_rep < 0 or i_rep >= self.n_rep: + raise ValueError(f"i_rep must be in [0, {self.n_rep - 1}]. Got {i_rep}.") + + if not isinstance(threshold, (int, float)): + raise TypeError(f"threshold must be a float. Got {type(threshold)}.") + if threshold <= 0 or threshold >= 0.5: + raise ValueError(f"threshold must be in (0, 0.5). Got {threshold}.") + + if not isinstance(show_warning, bool): + raise TypeError(f"show_warning must be a boolean. Got {type(show_warning)}.") + + # Extract propensity scores and treatment indicator + ps_scores = self.predictions["ml_m"][:, i_rep, idx_treatment] + treatment = self._dml_data.d + + # Generate plot + fig = _propensity_score_overlap_plot(ps_scores, treatment, threshold) + + # Positivity violation warning + # When propensity scores cluster near 0 or 1, IPW estimators suffer from extreme weights, + # leading to inflated variance of the treatment effect estimate. + if show_warning: + n_total = len(ps_scores) + n_violations = np.sum((ps_scores < threshold) | (ps_scores > 1 - threshold)) + pct_violations = n_violations / n_total * 100 + if pct_violations > 5.0: + warnings.warn( + f"Potential positivity violation detected: {pct_violations:.1f}% of observations have " + f"propensity scores outside [{threshold}, {1 - threshold}]. " + f"This may lead to high variance in IPW-based estimators. " + f"Consider using trimming (ps_processor_config) or checking the covariate balance.", + UserWarning, + stacklevel=2, + ) + + return fig + diff --git a/doubleml/irm/tests/test_irm_overlap_plot.py b/doubleml/irm/tests/test_irm_overlap_plot.py new file mode 100644 index 000000000..880251d8e --- /dev/null +++ b/doubleml/irm/tests/test_irm_overlap_plot.py @@ -0,0 +1,150 @@ +import warnings + +import numpy as np +import plotly.graph_objects as go +import pytest +from sklearn.linear_model import Lasso, LogisticRegression + +from doubleml import DoubleMLIRM +from doubleml.irm.datasets import make_irm_data + +np.random.seed(3141) +n_obs = 200 +dml_data = make_irm_data(n_obs=n_obs) + + +@pytest.fixture(scope="module") +def fitted_irm(): + """IRM model fitted with stored predictions.""" + irm = DoubleMLIRM(dml_data, Lasso(), LogisticRegression(), n_rep=2, n_folds=3) + irm.fit(store_predictions=True) + return irm + + +@pytest.fixture(scope="module") +def fitted_irm_no_preds(): + """IRM model fitted without stored predictions.""" + irm = DoubleMLIRM(dml_data, Lasso(), LogisticRegression()) + irm.fit(store_predictions=False) + return irm + + +@pytest.mark.ci +class TestOverlapPlotReturnType: + """Test that plot_overlap_common_support returns a plotly Figure.""" + + def test_returns_plotly_figure(self, fitted_irm): + fig = fitted_irm.plot_overlap_common_support() + assert isinstance(fig, go.Figure) + + def test_returns_plotly_figure_custom_params(self, fitted_irm): + fig = fitted_irm.plot_overlap_common_support( + idx_treatment=0, i_rep=1, threshold=0.1, show_warning=False + ) + assert isinstance(fig, go.Figure) + + +@pytest.mark.ci +class TestOverlapPlotErrors: + """Test error handling for plot_overlap_common_support.""" + + def test_error_before_fit(self): + """Calling before fit() should raise ValueError.""" + irm = DoubleMLIRM(dml_data, Lasso(), LogisticRegression()) + with pytest.raises(ValueError, match="Apply fit"): + irm.plot_overlap_common_support() + + def test_error_no_predictions(self, fitted_irm_no_preds): + """Calling after fit(store_predictions=False) should raise ValueError.""" + with pytest.raises(ValueError, match="Predictions are not stored"): + fitted_irm_no_preds.plot_overlap_common_support() + + def test_error_invalid_idx_treatment_type(self, fitted_irm): + with pytest.raises(TypeError, match="idx_treatment must be an integer"): + fitted_irm.plot_overlap_common_support(idx_treatment=0.5) + + def test_error_invalid_idx_treatment_range(self, fitted_irm): + with pytest.raises(ValueError, match="idx_treatment must be in"): + fitted_irm.plot_overlap_common_support(idx_treatment=5) + + def test_error_negative_idx_treatment(self, fitted_irm): + with pytest.raises(ValueError, match="idx_treatment must be in"): + fitted_irm.plot_overlap_common_support(idx_treatment=-1) + + def test_error_invalid_i_rep_type(self, fitted_irm): + with pytest.raises(TypeError, match="i_rep must be an integer"): + fitted_irm.plot_overlap_common_support(i_rep=0.5) + + def test_error_invalid_i_rep_range(self, fitted_irm): + with pytest.raises(ValueError, match="i_rep must be in"): + fitted_irm.plot_overlap_common_support(i_rep=10) + + def test_error_invalid_threshold_type(self, fitted_irm): + with pytest.raises(TypeError, match="threshold must be a float"): + fitted_irm.plot_overlap_common_support(threshold="0.05") + + def test_error_invalid_threshold_range_low(self, fitted_irm): + with pytest.raises(ValueError, match="threshold must be in"): + fitted_irm.plot_overlap_common_support(threshold=0.0) + + def test_error_invalid_threshold_range_high(self, fitted_irm): + with pytest.raises(ValueError, match="threshold must be in"): + fitted_irm.plot_overlap_common_support(threshold=0.5) + + def test_error_invalid_show_warning_type(self, fitted_irm): + with pytest.raises(TypeError, match="show_warning must be a boolean"): + fitted_irm.plot_overlap_common_support(show_warning="True") + + +@pytest.mark.ci +class TestOverlapPlotContent: + """Test plot content and structure.""" + + def test_figure_has_traces(self, fitted_irm): + fig = fitted_irm.plot_overlap_common_support() + # Should have at least 2 traces (treated and control KDE) + assert len(fig.data) >= 2 + + def test_figure_trace_names(self, fitted_irm): + fig = fitted_irm.plot_overlap_common_support() + trace_names = [trace.name for trace in fig.data] + assert "Treated" in trace_names + assert "Control" in trace_names + + def test_figure_layout(self, fitted_irm): + fig = fitted_irm.plot_overlap_common_support() + assert fig.layout.title.text == "Propensity Score Overlap (Common Support)" + assert fig.layout.xaxis.title.text == "Estimated Propensity Score" + assert fig.layout.yaxis.title.text == "Density" + + def test_custom_threshold_in_annotations(self, fitted_irm): + fig = fitted_irm.plot_overlap_common_support(threshold=0.1) + # The annotation text should reference the threshold + annotations = fig.layout.annotations + found = any("0.1" in str(ann.text) for ann in annotations if ann.text) + assert found + + def test_different_repetitions(self, fitted_irm): + """Results should differ between repetitions.""" + fig0 = fitted_irm.plot_overlap_common_support(i_rep=0) + fig1 = fitted_irm.plot_overlap_common_support(i_rep=1) + # Traces should exist for both + assert len(fig0.data) >= 2 + assert len(fig1.data) >= 2 + + +@pytest.mark.ci +class TestOverlapPlotWarning: + """Test positivity violation warning behavior.""" + + def test_no_warning_when_disabled(self, fitted_irm): + with warnings.catch_warnings(): + warnings.simplefilter("error") + # Should not raise any warning + fitted_irm.plot_overlap_common_support(show_warning=False) + + def test_warning_with_extreme_threshold(self, fitted_irm): + """With a very wide threshold (e.g., 0.49), almost all observations + should be flagged, triggering the warning.""" + with pytest.warns(UserWarning, match="Potential positivity violation"): + fitted_irm.plot_overlap_common_support(threshold=0.49, show_warning=True) diff --git a/doubleml/utils/_plots.py b/doubleml/utils/_plots.py index 67b449b38..0c50ff9a0 100644 --- a/doubleml/utils/_plots.py +++ b/doubleml/utils/_plots.py @@ -1,5 +1,6 @@ import numpy as np import plotly.graph_objects as go +from scipy.stats import gaussian_kde def _sensitivity_contour_plot( @@ -92,3 +93,137 @@ def _sensitivity_contour_plot( fig.update_yaxes(range=[0, np.max(y)]) return fig + + +def _propensity_score_overlap_plot( + ps_scores: np.ndarray, + treatment: np.ndarray, + threshold: float = 0.05, +) -> go.Figure: + """Create an interactive propensity score overlap (common support) plot. + + Visualizes the distribution of estimated propensity scores split by treatment status + using kernel density estimation. Highlights regions near 0 and 1 where the positivity + assumption may be violated. + + Parameters + ---------- + ps_scores : :class:`numpy.ndarray` + Array of estimated propensity scores. + treatment : :class:`numpy.ndarray` + Binary treatment indicator array. + threshold : float + Threshold for positivity violation warning zones. Lines are drawn at ``threshold`` + and ``1 - threshold``. Default is ``0.05``. + + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` + Plotly figure with the propensity score overlap plot. + """ + ps_treated = ps_scores[treatment == 1] + ps_control = ps_scores[treatment == 0] + + # Compute KDE for both groups + x_grid = np.linspace(0, 1, 200) + + kde_treated = gaussian_kde(ps_treated) + kde_control = gaussian_kde(ps_control) + density_treated = kde_treated(x_grid) + density_control = kde_control(x_grid) + + fig = go.Figure() + + # Add danger zone shading (near 0 and near 1) + fig.add_vrect( + x0=0, + x1=threshold, + fillcolor="red", + opacity=0.08, + line_width=0, + annotation_text="Positivity
concern", + annotation_position="top left", + annotation_font_size=10, + annotation_font_color="red", + ) + fig.add_vrect( + x0=1 - threshold, + x1=1, + fillcolor="red", + opacity=0.08, + line_width=0, + annotation_text="Positivity
concern", + annotation_position="top right", + annotation_font_size=10, + annotation_font_color="red", + ) + + # KDE curves for treated and control + fig.add_trace( + go.Scatter( + x=x_grid, + y=density_treated, + mode="lines", + name="Treated", + fill="tozeroy", + line=dict(color="rgba(31, 119, 180, 0.9)", width=2), + fillcolor="rgba(31, 119, 180, 0.25)", + hovertemplate="PS: %{x:.3f}
Density: %{y:.3f}Treated", + ) + ) + fig.add_trace( + go.Scatter( + x=x_grid, + y=density_control, + mode="lines", + name="Control", + fill="tozeroy", + line=dict(color="rgba(255, 127, 14, 0.9)", width=2), + fillcolor="rgba(255, 127, 14, 0.25)", + hovertemplate="PS: %{x:.3f}
Density: %{y:.3f}Control", + ) + ) + + # Threshold boundary lines + fig.add_vline(x=threshold, line_dash="dash", line_color="red", line_width=1.5, opacity=0.7) + fig.add_vline(x=1 - threshold, line_dash="dash", line_color="red", line_width=1.5, opacity=0.7) + + # Summary annotation with positivity diagnostics + n_total = len(ps_scores) + n_below = np.sum(ps_scores < threshold) + n_above = np.sum(ps_scores > 1 - threshold) + pct_violation = (n_below + n_above) / n_total * 100 + + annotation_text = ( + f"Positivity diagnostics
" + f"PS < {threshold}: {n_below} ({n_below / n_total * 100:.1f}%)
" + f"PS > {1 - threshold}: {n_above} ({n_above / n_total * 100:.1f}%)
" + f"Total violations: {pct_violation:.1f}%" + ) + fig.add_annotation( + text=annotation_text, + xref="paper", + yref="paper", + x=0.98, + y=0.98, + showarrow=False, + font=dict(size=11), + align="left", + bordercolor="gray", + borderwidth=1, + borderpad=6, + bgcolor="rgba(255, 255, 255, 0.85)", + ) + + fig.update_layout( + title="Propensity Score Overlap (Common Support)", + xaxis_title="Estimated Propensity Score", + yaxis_title="Density", + xaxis=dict(range=[0, 1]), + template="plotly_white", + legend=dict(yanchor="top", y=0.85, xanchor="left", x=0.02), + hovermode="x unified", + ) + + return fig + From 7362ed727f86667bebdc0cf7e1599ffe7b2e4b9a Mon Sep 17 00:00:00 2001 From: Akihiro Shimoda Date: Sat, 18 Apr 2026 09:09:41 +0100 Subject: [PATCH 2/2] REFACTOR: Replace KDE overlap plot with histogram-based calibration plot - Replace KDE overlap plot with histogram-based calibration plot - Generic array-based API: propensity_score, treatment, bins, density, palette - 2x2 matplotlib figure: histograms + binned calibration curves - Move tests to doubleml/utils/tests/ - Address review feedback from PR #389 --- doubleml/irm/irm.py | 77 ++++---- doubleml/irm/tests/test_irm_overlap_plot.py | 150 ---------------- doubleml/utils/__init__.py | 2 + doubleml/utils/_plots.py | 132 -------------- doubleml/utils/plots.py | 132 ++++++++++++++ .../test_propensity_score_calibration.py | 165 ++++++++++++++++++ 6 files changed, 330 insertions(+), 328 deletions(-) delete mode 100644 doubleml/irm/tests/test_irm_overlap_plot.py create mode 100644 doubleml/utils/plots.py create mode 100644 doubleml/utils/tests/test_propensity_score_calibration.py diff --git a/doubleml/irm/irm.py b/doubleml/irm/irm.py index b268326cc..500e41b8a 100644 --- a/doubleml/irm/irm.py +++ b/doubleml/irm/irm.py @@ -679,14 +679,15 @@ def plot_overlap_common_support( self, idx_treatment: int = 0, i_rep: int = 0, - threshold: float = 0.05, - show_warning: bool = True, - ) -> "go.Figure": - """Plot the propensity score overlap (common support) for treatment and control groups. + bins=10, + density: bool = False, + palette: str = "colorblind", + ): + """Plot propensity score distributions and binned calibration curves. Visualizes the distribution of estimated propensity scores :math:`\\hat{m}_0(X) = \\hat{E}[D|X]` - split by treatment status using kernel density estimation. Highlights regions near 0 and 1 - where the positivity assumption :math:`\\eta < m_0(X) < 1 - \\eta` may be violated. + split by treatment status using histograms, together with calibration curves comparing + predicted propensity scores against observed treatment fractions. Parameters ---------- @@ -698,29 +699,31 @@ def plot_overlap_common_support( Index of the repetition to use for the propensity score predictions. Default is ``0``. - threshold : float - Threshold for positivity violation warning zones. Vertical lines are drawn at ``threshold`` - and ``1 - threshold`` to highlight the danger zones. Must be in ``(0, 0.5)``. - Default is ``0.05``. + bins : int or array-like + Number of bins or explicit bin edges for the histograms and calibration curves. + Default is ``10``. + + density : bool + If ``True``, histogram heights are normalized to density. + Default is ``False``. - show_warning : bool - If ``True``, a warning is issued when the share of observations in the positivity violation - zones exceeds 5%. This indicates that IPW-based estimators may suffer from high variance. - Default is ``True``. + palette : str or sequence + Seaborn palette name or explicit colors. + Default is ``"colorblind"``. Returns ------- - fig : :class:`plotly.graph_objects.Figure` - Plotly figure with the propensity score overlap plot. + fig : :class:`matplotlib.figure.Figure` + Matplotlib figure. + axes : :class:`numpy.ndarray` + 2x2 axes array. Raises ------ ValueError If ``fit()`` has not been called or predictions are not stored. """ - import plotly.graph_objects as go - - from doubleml.utils._plots import _propensity_score_overlap_plot + from doubleml.utils.plots import plot_propensity_score_calibration # Input validation if self._framework is None: @@ -750,37 +753,19 @@ def plot_overlap_common_support( if i_rep < 0 or i_rep >= self.n_rep: raise ValueError(f"i_rep must be in [0, {self.n_rep - 1}]. Got {i_rep}.") - if not isinstance(threshold, (int, float)): - raise TypeError(f"threshold must be a float. Got {type(threshold)}.") - if threshold <= 0 or threshold >= 0.5: - raise ValueError(f"threshold must be in (0, 0.5). Got {threshold}.") - - if not isinstance(show_warning, bool): - raise TypeError(f"show_warning must be a boolean. Got {type(show_warning)}.") - # Extract propensity scores and treatment indicator ps_scores = self.predictions["ml_m"][:, i_rep, idx_treatment] treatment = self._dml_data.d # Generate plot - fig = _propensity_score_overlap_plot(ps_scores, treatment, threshold) - - # Positivity violation warning - # When propensity scores cluster near 0 or 1, IPW estimators suffer from extreme weights, - # leading to inflated variance of the treatment effect estimate. - if show_warning: - n_total = len(ps_scores) - n_violations = np.sum((ps_scores < threshold) | (ps_scores > 1 - threshold)) - pct_violations = n_violations / n_total * 100 - if pct_violations > 5.0: - warnings.warn( - f"Potential positivity violation detected: {pct_violations:.1f}% of observations have " - f"propensity scores outside [{threshold}, {1 - threshold}]. " - f"This may lead to high variance in IPW-based estimators. " - f"Consider using trimming (ps_processor_config) or checking the covariate balance.", - UserWarning, - stacklevel=2, - ) + fig, axes = plot_propensity_score_calibration( + propensity_score=ps_scores, + treatment=treatment, + bins=bins, + density=density, + palette=palette, + ) + + return fig, axes - return fig diff --git a/doubleml/irm/tests/test_irm_overlap_plot.py b/doubleml/irm/tests/test_irm_overlap_plot.py deleted file mode 100644 index 880251d8e..000000000 --- a/doubleml/irm/tests/test_irm_overlap_plot.py +++ /dev/null @@ -1,150 +0,0 @@ -import warnings - -import numpy as np -import plotly.graph_objects as go -import pytest -from sklearn.linear_model import Lasso, LogisticRegression - -from doubleml import DoubleMLIRM -from doubleml.irm.datasets import make_irm_data - -np.random.seed(3141) -n_obs = 200 -dml_data = make_irm_data(n_obs=n_obs) - - -@pytest.fixture(scope="module") -def fitted_irm(): - """IRM model fitted with stored predictions.""" - irm = DoubleMLIRM(dml_data, Lasso(), LogisticRegression(), n_rep=2, n_folds=3) - irm.fit(store_predictions=True) - return irm - - -@pytest.fixture(scope="module") -def fitted_irm_no_preds(): - """IRM model fitted without stored predictions.""" - irm = DoubleMLIRM(dml_data, Lasso(), LogisticRegression()) - irm.fit(store_predictions=False) - return irm - - -@pytest.mark.ci -class TestOverlapPlotReturnType: - """Test that plot_overlap_common_support returns a plotly Figure.""" - - def test_returns_plotly_figure(self, fitted_irm): - fig = fitted_irm.plot_overlap_common_support() - assert isinstance(fig, go.Figure) - - def test_returns_plotly_figure_custom_params(self, fitted_irm): - fig = fitted_irm.plot_overlap_common_support( - idx_treatment=0, i_rep=1, threshold=0.1, show_warning=False - ) - assert isinstance(fig, go.Figure) - - -@pytest.mark.ci -class TestOverlapPlotErrors: - """Test error handling for plot_overlap_common_support.""" - - def test_error_before_fit(self): - """Calling before fit() should raise ValueError.""" - irm = DoubleMLIRM(dml_data, Lasso(), LogisticRegression()) - with pytest.raises(ValueError, match="Apply fit"): - irm.plot_overlap_common_support() - - def test_error_no_predictions(self, fitted_irm_no_preds): - """Calling after fit(store_predictions=False) should raise ValueError.""" - with pytest.raises(ValueError, match="Predictions are not stored"): - fitted_irm_no_preds.plot_overlap_common_support() - - def test_error_invalid_idx_treatment_type(self, fitted_irm): - with pytest.raises(TypeError, match="idx_treatment must be an integer"): - fitted_irm.plot_overlap_common_support(idx_treatment=0.5) - - def test_error_invalid_idx_treatment_range(self, fitted_irm): - with pytest.raises(ValueError, match="idx_treatment must be in"): - fitted_irm.plot_overlap_common_support(idx_treatment=5) - - def test_error_negative_idx_treatment(self, fitted_irm): - with pytest.raises(ValueError, match="idx_treatment must be in"): - fitted_irm.plot_overlap_common_support(idx_treatment=-1) - - def test_error_invalid_i_rep_type(self, fitted_irm): - with pytest.raises(TypeError, match="i_rep must be an integer"): - fitted_irm.plot_overlap_common_support(i_rep=0.5) - - def test_error_invalid_i_rep_range(self, fitted_irm): - with pytest.raises(ValueError, match="i_rep must be in"): - fitted_irm.plot_overlap_common_support(i_rep=10) - - def test_error_invalid_threshold_type(self, fitted_irm): - with pytest.raises(TypeError, match="threshold must be a float"): - fitted_irm.plot_overlap_common_support(threshold="0.05") - - def test_error_invalid_threshold_range_low(self, fitted_irm): - with pytest.raises(ValueError, match="threshold must be in"): - fitted_irm.plot_overlap_common_support(threshold=0.0) - - def test_error_invalid_threshold_range_high(self, fitted_irm): - with pytest.raises(ValueError, match="threshold must be in"): - fitted_irm.plot_overlap_common_support(threshold=0.5) - - def test_error_invalid_show_warning_type(self, fitted_irm): - with pytest.raises(TypeError, match="show_warning must be a boolean"): - fitted_irm.plot_overlap_common_support(show_warning="True") - - -@pytest.mark.ci -class TestOverlapPlotContent: - """Test plot content and structure.""" - - def test_figure_has_traces(self, fitted_irm): - fig = fitted_irm.plot_overlap_common_support() - # Should have at least 2 traces (treated and control KDE) - assert len(fig.data) >= 2 - - def test_figure_trace_names(self, fitted_irm): - fig = fitted_irm.plot_overlap_common_support() - trace_names = [trace.name for trace in fig.data] - assert "Treated" in trace_names - assert "Control" in trace_names - - def test_figure_layout(self, fitted_irm): - fig = fitted_irm.plot_overlap_common_support() - assert fig.layout.title.text == "Propensity Score Overlap (Common Support)" - assert fig.layout.xaxis.title.text == "Estimated Propensity Score" - assert fig.layout.yaxis.title.text == "Density" - - def test_custom_threshold_in_annotations(self, fitted_irm): - fig = fitted_irm.plot_overlap_common_support(threshold=0.1) - # The annotation text should reference the threshold - annotations = fig.layout.annotations - found = any("0.1" in str(ann.text) for ann in annotations if ann.text) - assert found - - def test_different_repetitions(self, fitted_irm): - """Results should differ between repetitions.""" - fig0 = fitted_irm.plot_overlap_common_support(i_rep=0) - fig1 = fitted_irm.plot_overlap_common_support(i_rep=1) - # Traces should exist for both - assert len(fig0.data) >= 2 - assert len(fig1.data) >= 2 - - -@pytest.mark.ci -class TestOverlapPlotWarning: - """Test positivity violation warning behavior.""" - - def test_no_warning_when_disabled(self, fitted_irm): - with warnings.catch_warnings(): - warnings.simplefilter("error") - # Should not raise any warning - fitted_irm.plot_overlap_common_support(show_warning=False) - - def test_warning_with_extreme_threshold(self, fitted_irm): - """With a very wide threshold (e.g., 0.49), almost all observations - should be flagged, triggering the warning.""" - with pytest.warns(UserWarning, match="Potential positivity violation"): - fitted_irm.plot_overlap_common_support(threshold=0.49, show_warning=True) diff --git a/doubleml/utils/__init__.py b/doubleml/utils/__init__.py index 868429dad..960988269 100644 --- a/doubleml/utils/__init__.py +++ b/doubleml/utils/__init__.py @@ -7,6 +7,7 @@ from .dummy_learners import DMLDummyClassifier, DMLDummyRegressor from .gain_statistics import gain_statistics from .global_learner import GlobalClassifier, GlobalRegressor +from .plots import plot_propensity_score_calibration from .policytree import DoubleMLPolicyTree from .propensity_score_processing import PSProcessor, PSProcessorConfig from .resampling import DoubleMLClusterResampling, DoubleMLResampling @@ -22,6 +23,7 @@ "gain_statistics", "GlobalClassifier", "GlobalRegressor", + "plot_propensity_score_calibration", "PSProcessor", "PSProcessorConfig", ] diff --git a/doubleml/utils/_plots.py b/doubleml/utils/_plots.py index 0c50ff9a0..f21121f5c 100644 --- a/doubleml/utils/_plots.py +++ b/doubleml/utils/_plots.py @@ -1,6 +1,5 @@ import numpy as np import plotly.graph_objects as go -from scipy.stats import gaussian_kde def _sensitivity_contour_plot( @@ -95,135 +94,4 @@ def _sensitivity_contour_plot( return fig -def _propensity_score_overlap_plot( - ps_scores: np.ndarray, - treatment: np.ndarray, - threshold: float = 0.05, -) -> go.Figure: - """Create an interactive propensity score overlap (common support) plot. - - Visualizes the distribution of estimated propensity scores split by treatment status - using kernel density estimation. Highlights regions near 0 and 1 where the positivity - assumption may be violated. - - Parameters - ---------- - ps_scores : :class:`numpy.ndarray` - Array of estimated propensity scores. - treatment : :class:`numpy.ndarray` - Binary treatment indicator array. - threshold : float - Threshold for positivity violation warning zones. Lines are drawn at ``threshold`` - and ``1 - threshold``. Default is ``0.05``. - - Returns - ------- - fig : :class:`plotly.graph_objects.Figure` - Plotly figure with the propensity score overlap plot. - """ - ps_treated = ps_scores[treatment == 1] - ps_control = ps_scores[treatment == 0] - - # Compute KDE for both groups - x_grid = np.linspace(0, 1, 200) - - kde_treated = gaussian_kde(ps_treated) - kde_control = gaussian_kde(ps_control) - density_treated = kde_treated(x_grid) - density_control = kde_control(x_grid) - - fig = go.Figure() - - # Add danger zone shading (near 0 and near 1) - fig.add_vrect( - x0=0, - x1=threshold, - fillcolor="red", - opacity=0.08, - line_width=0, - annotation_text="Positivity
concern", - annotation_position="top left", - annotation_font_size=10, - annotation_font_color="red", - ) - fig.add_vrect( - x0=1 - threshold, - x1=1, - fillcolor="red", - opacity=0.08, - line_width=0, - annotation_text="Positivity
concern", - annotation_position="top right", - annotation_font_size=10, - annotation_font_color="red", - ) - - # KDE curves for treated and control - fig.add_trace( - go.Scatter( - x=x_grid, - y=density_treated, - mode="lines", - name="Treated", - fill="tozeroy", - line=dict(color="rgba(31, 119, 180, 0.9)", width=2), - fillcolor="rgba(31, 119, 180, 0.25)", - hovertemplate="PS: %{x:.3f}
Density: %{y:.3f}Treated", - ) - ) - fig.add_trace( - go.Scatter( - x=x_grid, - y=density_control, - mode="lines", - name="Control", - fill="tozeroy", - line=dict(color="rgba(255, 127, 14, 0.9)", width=2), - fillcolor="rgba(255, 127, 14, 0.25)", - hovertemplate="PS: %{x:.3f}
Density: %{y:.3f}Control", - ) - ) - - # Threshold boundary lines - fig.add_vline(x=threshold, line_dash="dash", line_color="red", line_width=1.5, opacity=0.7) - fig.add_vline(x=1 - threshold, line_dash="dash", line_color="red", line_width=1.5, opacity=0.7) - - # Summary annotation with positivity diagnostics - n_total = len(ps_scores) - n_below = np.sum(ps_scores < threshold) - n_above = np.sum(ps_scores > 1 - threshold) - pct_violation = (n_below + n_above) / n_total * 100 - - annotation_text = ( - f"Positivity diagnostics
" - f"PS < {threshold}: {n_below} ({n_below / n_total * 100:.1f}%)
" - f"PS > {1 - threshold}: {n_above} ({n_above / n_total * 100:.1f}%)
" - f"Total violations: {pct_violation:.1f}%" - ) - fig.add_annotation( - text=annotation_text, - xref="paper", - yref="paper", - x=0.98, - y=0.98, - showarrow=False, - font=dict(size=11), - align="left", - bordercolor="gray", - borderwidth=1, - borderpad=6, - bgcolor="rgba(255, 255, 255, 0.85)", - ) - - fig.update_layout( - title="Propensity Score Overlap (Common Support)", - xaxis_title="Estimated Propensity Score", - yaxis_title="Density", - xaxis=dict(range=[0, 1]), - template="plotly_white", - legend=dict(yanchor="top", y=0.85, xanchor="left", x=0.02), - hovermode="x unified", - ) - - return fig diff --git a/doubleml/utils/plots.py b/doubleml/utils/plots.py new file mode 100644 index 000000000..75b31d85f --- /dev/null +++ b/doubleml/utils/plots.py @@ -0,0 +1,132 @@ +import matplotlib.pyplot as plt +import numpy as np +import seaborn as sns + + +def plot_propensity_score_calibration( + propensity_score, + treatment, + bins=10, + density=False, + palette="colorblind", +): + """ + Plot propensity score distributions and binned calibration curves. + + Parameters + ---------- + propensity_score : array-like + Predicted propensity scores of shape (n_samples,). + treatment : array-like + Binary treatment indicator of shape (n_samples,). + bins : int or array-like + Number of bins or explicit bin edges. + density : bool + If True, histogram heights are normalized. + palette : str or sequence + Seaborn palette name or explicit colors. + + Returns + ------- + fig : :class:`matplotlib.figure.Figure` + Matplotlib figure. + axes : :class:`numpy.ndarray` + 2x2 axes array. + """ + ps = np.asarray(propensity_score, dtype=float).reshape(-1) + tr = np.asarray(treatment).reshape(-1) + + if ps.shape != tr.shape: + raise ValueError("propensity_score and treatment must have the same shape.") + if ps.ndim != 1: + raise ValueError("propensity_score and treatment must be one-dimensional.") + if not np.isin(tr, [0, 1]).all(): + raise ValueError("treatment must be binary with values 0 and 1.") + if np.any((ps < 0) | (ps > 1)): + raise ValueError("propensity_score must lie in [0, 1].") + + tr = tr.astype(int) + + if isinstance(bins, int): + if bins < 2: + raise ValueError("bins must be at least 2.") + bins = np.linspace(0.0, 1.0, bins + 1) + else: + bins = np.asarray(bins, dtype=float) + if bins.ndim != 1 or len(bins) < 2: + raise ValueError("bins must contain at least two edges.") + if np.any(np.diff(bins) <= 0): + raise ValueError("bins must be strictly increasing.") + + x_min, x_max = float(bins[0]), float(bins[-1]) + centers = 0.5 * (bins[:-1] + bins[1:]) + widths = np.diff(bins) + + treated_frac = [] + control_frac = [] + for i in range(len(bins) - 1): + if i < len(bins) - 2: + mask = (ps >= bins[i]) & (ps < bins[i + 1]) + else: + mask = (ps >= bins[i]) & (ps <= bins[i + 1]) + if np.sum(mask) == 0: + treated_frac.append(np.nan) + control_frac.append(np.nan) + else: + p_treated = np.mean(tr[mask] == 1) + treated_frac.append(p_treated) + control_frac.append(1.0 - p_treated) + + colors = sns.color_palette(palette, n_colors=2) + fig, axes = plt.subplots(2, 2, figsize=(12, 10), gridspec_kw={"height_ratios": [2, 1]}) + + sns.histplot( + ps[tr == 1], + bins=bins, + stat="density" if density else "count", + kde=False, + color=colors[0], + ax=axes[0, 0], + label="Treated", + ) + axes[0, 0].set_title("Treated: Propensity Score Distribution") + axes[0, 0].set_xlim(x_min, x_max) + axes[0, 0].set_ylabel("Density" if density else "Count") + axes[0, 0].legend() + + sns.histplot( + ps[tr == 0], + bins=bins, + stat="density" if density else "count", + kde=False, + color=colors[1], + ax=axes[0, 1], + label="Control", + ) + axes[0, 1].set_title("Control: Propensity Score Distribution") + axes[0, 1].set_xlim(x_min, x_max) + axes[0, 1].set_ylabel("Density" if density else "Count") + axes[0, 1].legend() + + axes[1, 0].bar(centers, treated_frac, width=widths, color=colors[0], alpha=0.7) + axes[1, 0].plot([x_min, x_max], [x_min, x_max], "k--", label="Ideal calibration") + axes[1, 0].set_title("Treated: Calibration") + axes[1, 0].set_xlabel("Predicted propensity score") + axes[1, 0].set_ylabel("Observed treatment fraction") + axes[1, 0].set_xlim(x_min, x_max) + axes[1, 0].set_ylim(0, 1) + axes[1, 0].legend() + + axes[1, 1].bar(centers, control_frac, width=widths, color=colors[1], alpha=0.7) + axes[1, 1].plot([x_min, x_max], [1 - x_min, 1 - x_max], "k--", label="Ideal calibration") + axes[1, 1].set_title("Control: Calibration") + axes[1, 1].set_xlabel("Predicted propensity score") + axes[1, 1].set_ylabel("Observed control fraction") + axes[1, 1].set_xlim(x_min, x_max) + axes[1, 1].set_ylim(0, 1) + axes[1, 1].legend() + + fig.suptitle("Propensity Score Calibration") + plt.tight_layout() + + return fig, axes diff --git a/doubleml/utils/tests/test_propensity_score_calibration.py b/doubleml/utils/tests/test_propensity_score_calibration.py new file mode 100644 index 000000000..7680cc2d4 --- /dev/null +++ b/doubleml/utils/tests/test_propensity_score_calibration.py @@ -0,0 +1,165 @@ +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pytest + +from doubleml.utils.plots import plot_propensity_score_calibration + + +@pytest.mark.ci +class TestInputValidation: + """Test input validation for plot_propensity_score_calibration.""" + + def test_shape_mismatch(self): + with pytest.raises(ValueError, match="same shape"): + plot_propensity_score_calibration(np.array([0.5, 0.3]), np.array([0, 1, 0])) + + def test_non_binary_treatment(self): + with pytest.raises(ValueError, match="binary with values 0 and 1"): + plot_propensity_score_calibration(np.array([0.5, 0.3, 0.7]), np.array([0, 1, 2])) + + def test_scores_below_zero(self): + with pytest.raises(ValueError, match="must lie in"): + plot_propensity_score_calibration(np.array([-0.1, 0.5]), np.array([0, 1])) + + def test_scores_above_one(self): + with pytest.raises(ValueError, match="must lie in"): + plot_propensity_score_calibration(np.array([0.5, 1.1]), np.array([0, 1])) + + def test_bins_too_few(self): + with pytest.raises(ValueError, match="bins must be at least 2"): + plot_propensity_score_calibration(np.array([0.5, 0.3]), np.array([0, 1]), bins=1) + + def test_bins_array_too_short(self): + with pytest.raises(ValueError, match="at least two edges"): + plot_propensity_score_calibration(np.array([0.5, 0.3]), np.array([0, 1]), bins=np.array([0.5])) + + def test_bins_not_increasing(self): + with pytest.raises(ValueError, match="strictly increasing"): + plot_propensity_score_calibration(np.array([0.5, 0.3]), np.array([0, 1]), bins=np.array([0.5, 0.3, 0.8])) + + +@pytest.mark.ci +class TestReturnType: + """Test return type and basic plot structure.""" + + def test_returns_figure_and_axes(self): + ps = np.array([0.1, 0.3, 0.5, 0.7, 0.9, 0.2, 0.4, 0.6, 0.8, 0.95]) + tr = np.array([0, 0, 0, 1, 1, 0, 0, 1, 1, 1]) + fig, axes = plot_propensity_score_calibration(ps, tr) + assert isinstance(fig, matplotlib.figure.Figure) + assert isinstance(axes, np.ndarray) + assert axes.shape == (2, 2) + plt.close(fig) + + def test_density_mode(self): + ps = np.array([0.1, 0.3, 0.5, 0.7, 0.9, 0.2, 0.4, 0.6, 0.8, 0.95]) + tr = np.array([0, 0, 0, 1, 1, 0, 0, 1, 1, 1]) + fig, axes = plot_propensity_score_calibration(ps, tr, density=True) + assert axes[0, 0].get_ylabel() == "Density" + plt.close(fig) + + def test_count_mode(self): + ps = np.array([0.1, 0.3, 0.5, 0.7, 0.9, 0.2, 0.4, 0.6, 0.8, 0.95]) + tr = np.array([0, 0, 0, 1, 1, 0, 0, 1, 1, 1]) + fig, axes = plot_propensity_score_calibration(ps, tr, density=False) + assert axes[0, 0].get_ylabel() == "Count" + plt.close(fig) + + +@pytest.mark.ci +class TestBinHandling: + """Test bin handling with int and array bins.""" + + def test_int_bins(self): + np.random.seed(42) + ps = np.random.uniform(0, 1, 100) + tr = (ps > 0.5).astype(int) + fig, axes = plot_propensity_score_calibration(ps, tr, bins=5) + assert isinstance(fig, matplotlib.figure.Figure) + plt.close(fig) + + def test_explicit_bins(self): + np.random.seed(42) + ps = np.random.uniform(0, 1, 100) + tr = (ps > 0.5).astype(int) + custom_bins = np.array([0.0, 0.25, 0.5, 0.75, 1.0]) + fig, axes = plot_propensity_score_calibration(ps, tr, bins=custom_bins) + assert isinstance(fig, matplotlib.figure.Figure) + plt.close(fig) + + +@pytest.mark.ci +class TestBoundaryValues: + """Test boundary values at 0 and 1.""" + + def test_all_scores_at_zero(self): + ps = np.zeros(10) + tr = np.array([0, 1, 0, 1, 0, 1, 0, 1, 0, 1]) + fig, axes = plot_propensity_score_calibration(ps, tr) + assert isinstance(fig, matplotlib.figure.Figure) + plt.close(fig) + + def test_all_scores_at_one(self): + ps = np.ones(10) + tr = np.array([0, 1, 0, 1, 0, 1, 0, 1, 0, 1]) + fig, axes = plot_propensity_score_calibration(ps, tr) + assert isinstance(fig, matplotlib.figure.Figure) + plt.close(fig) + + def test_scores_at_bin_edges(self): + ps = np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]) + tr = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]) + fig, axes = plot_propensity_score_calibration(ps, tr, bins=10) + assert isinstance(fig, matplotlib.figure.Figure) + plt.close(fig) + + +@pytest.mark.ci +class TestEmptyBinBehavior: + """Test empty-bin behavior (NaN in calibration).""" + + def test_empty_bins_do_not_crash(self): + # Only scores in [0.4, 0.6], so bins outside that range are empty + ps = np.array([0.45, 0.5, 0.55, 0.5, 0.45, 0.55]) + tr = np.array([0, 1, 1, 0, 0, 1]) + fig, axes = plot_propensity_score_calibration(ps, tr, bins=10) + assert isinstance(fig, matplotlib.figure.Figure) + plt.close(fig) + + +@pytest.mark.ci +class TestCalibrationContent: + """Test that calibration subplots have expected properties.""" + + def test_calibration_axes_labels(self): + np.random.seed(42) + ps = np.random.uniform(0, 1, 200) + tr = (np.random.uniform(0, 1, 200) < ps).astype(int) + fig, axes = plot_propensity_score_calibration(ps, tr) + + assert axes[1, 0].get_xlabel() == "Predicted propensity score" + assert axes[1, 0].get_ylabel() == "Observed treatment fraction" + assert axes[1, 1].get_xlabel() == "Predicted propensity score" + assert axes[1, 1].get_ylabel() == "Observed control fraction" + plt.close(fig) + + def test_titles(self): + np.random.seed(42) + ps = np.random.uniform(0, 1, 200) + tr = (np.random.uniform(0, 1, 200) < ps).astype(int) + fig, axes = plot_propensity_score_calibration(ps, tr) + + assert axes[0, 0].get_title() == "Treated: Propensity Score Distribution" + assert axes[0, 1].get_title() == "Control: Propensity Score Distribution" + assert axes[1, 0].get_title() == "Treated: Calibration" + assert axes[1, 1].get_title() == "Control: Calibration" + plt.close(fig) + + def test_suptitle(self): + np.random.seed(42) + ps = np.random.uniform(0, 1, 50) + tr = (np.random.uniform(0, 1, 50) < ps).astype(int) + fig, axes = plot_propensity_score_calibration(ps, tr) + assert fig._suptitle.get_text() == "Propensity Score Calibration" + plt.close(fig)