Skip to content

Commit 2b166cb

Browse files
igerberclaude
andcommitted
callaway-santanna: make per-(g,t) analytical SE cluster-aware
CI codex R3 P0: the cluster wiring contract documented in REGISTRY.md ("cluster=X means CR1 Liang-Zeger on the IF") was honored at the aggregate inference surface (overall_se, event-study, group, bootstrap) but the per-cell public surface results.group_time_effects[(g,t)]["se"] remained unit-level. Users inspecting per-cell ATT(g,t) inference under cluster= got silently misleading SE/t/p/CI even though overall inference was correctly cluster-robust. Fix: new module-level _cluster_robust_se_from_per_gt_if helper that aggregates the per-(g,t) IF by PSU and returns CR1 Liang-Zeger SE. Applied at all 4 ATT(g,t) computation sites identified by the codex: 1. _compute_all_att_gt_vectorized (no-covariate vectorized batch) — recompute se after building inf_info, overwrite group_time_effects [(g,t)]["se"] which was set with the unit-level value 2. _compute_all_att_gt_covariate_reg (covariate-reg batch) — same pattern 3. Main panel single-cell loop (after _compute_att_gt_fast) — local se_gt update flows into gte_entry["se"] 4. RC fit loop (after _compute_att_gt_rc) — uses resolved_survey.psu (per-obs) instead of resolved_survey_unit.psu (per-unit) The recompute is gated by `if psu is not None`, so cluster=None remains bit-equal to pre-PR. For cluster=unit (each unit its own cluster), the CR1 formula coincides with the unit-level IF formula (modulo ddof conventions in the underlying OR path) — methodologically consistent with Williams (2000) CR1-on-IF for IF-based estimators. Tests: - test_per_gt_analytical_se_changes_with_cluster: asserts at least one (g,t) cell shows measurable SE divergence between cluster=None and cluster="state" on a panel with intra-cluster correlation - test_per_gt_se_matches_explicit_survey_design: asserts per-(g,t) SE agrees (rel=1e-10) between bare cluster="state" and explicit SurveyDesign(psu="state") — both activate the same CR1 aggregation All 414 tests (test_staggered + test_staggered_rc + test_triple_diff + test_honest_did + test_two_stage) pass. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent 2449842 commit 2b166cb

2 files changed

Lines changed: 211 additions & 2 deletions

File tree

diff_diff/staggered.py

Lines changed: 119 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,65 @@ def _linear_regression(
9292
return beta, residuals
9393

9494

95+
def _cluster_robust_se_from_per_gt_if(
96+
inf_info: Dict[str, Any],
97+
psu_array: np.ndarray,
98+
) -> Optional[float]:
99+
"""CR1 Liang-Zeger cluster-robust SE for a single (g,t) ATT.
100+
101+
Computes the cluster-aggregated IF variance for one per-(g,t) cell:
102+
103+
psi_per_index[i] = sum of IFs at this index for this (g, t)
104+
psi_per_cluster[c] = sum_{i in c} psi_per_index[i]
105+
se = sqrt(sum(psi_per_cluster ** 2))
106+
107+
For the panel path, ``psu_array`` is ``resolved_survey_unit.psu``
108+
(length n_units) and the IF index space is per-unit. For the RCS
109+
path, ``psu_array`` is ``resolved_survey.psu`` (length n_obs) and
110+
the IF index space is per-obs. The helper is index-space agnostic
111+
— it just requires ``treated_idx`` / ``control_idx`` in ``inf_info``
112+
to be valid offsets into ``psu_array``.
113+
114+
Returns ``None`` when ``inf_info`` lacks the required IF fields or
115+
when index alignment cannot be verified (caller falls back to the
116+
unit-level SE returned by the underlying estimation method).
117+
"""
118+
if (
119+
inf_info is None
120+
or "treated_inf" not in inf_info
121+
or "control_inf" not in inf_info
122+
or "treated_idx" not in inf_info
123+
or "control_idx" not in inf_info
124+
):
125+
return None
126+
treated_idx = np.asarray(inf_info["treated_idx"])
127+
control_idx = np.asarray(inf_info["control_idx"])
128+
treated_inf = np.asarray(inf_info["treated_inf"])
129+
control_inf = np.asarray(inf_info["control_inf"])
130+
n = len(psu_array)
131+
if (
132+
treated_idx.size > 0
133+
and (treated_idx.max(initial=-1) >= n or treated_idx.min(initial=0) < 0)
134+
) or (
135+
control_idx.size > 0
136+
and (control_idx.max(initial=-1) >= n or control_idx.min(initial=0) < 0)
137+
):
138+
return None
139+
psi_per_index = np.zeros(n)
140+
if treated_idx.size:
141+
np.add.at(psi_per_index, treated_idx, treated_inf)
142+
if control_idx.size:
143+
np.add.at(psi_per_index, control_idx, control_inf)
144+
# Factorize PSU labels for index-friendly aggregation
145+
_, psu_codes = np.unique(psu_array, return_inverse=True)
146+
n_clusters = int(psu_codes.max() + 1) if psu_codes.size else 0
147+
if n_clusters == 0:
148+
return None
149+
psi_per_cluster = np.zeros(n_clusters)
150+
np.add.at(psi_per_cluster, psu_codes, psi_per_index)
151+
return float(np.sqrt(np.sum(psi_per_cluster**2)))
152+
153+
95154
def _safe_inv(
96155
A: np.ndarray,
97156
tracker: Optional[list] = None,
@@ -1035,14 +1094,30 @@ def _compute_all_att_gt_vectorized(
10351094
all_units = precomputed["all_units"]
10361095
treated_positions = np.where(treated_valid)[0]
10371096
control_positions = np.where(control_valid)[0]
1038-
influence_func_info[(g, t)] = {
1097+
inf_info_gt = {
10391098
"treated_idx": treated_positions,
10401099
"control_idx": control_positions,
10411100
"treated_units": all_units[treated_positions],
10421101
"control_units": all_units[control_positions],
10431102
"treated_inf": inf_treated,
10441103
"control_inf": inf_control,
10451104
}
1105+
influence_func_info[(g, t)] = inf_info_gt
1106+
1107+
# Cluster-aware per-(g,t) SE: aggregate the per-(g,t) IF by
1108+
# PSU when a survey design (explicit OR synthesized from bare
1109+
# cluster=) provides one. Bit-equal to pre-PR when psu is None.
1110+
rsu_for_gt = precomputed.get("resolved_survey_unit")
1111+
if rsu_for_gt is not None and getattr(rsu_for_gt, "psu", None) is not None:
1112+
se_cluster = _cluster_robust_se_from_per_gt_if(inf_info_gt, rsu_for_gt.psu)
1113+
if se_cluster is not None and np.isfinite(se_cluster):
1114+
se = se_cluster
1115+
# gte_entry["se"] was set with the unit-level value
1116+
# at the gte_entry construction above; overwrite with
1117+
# the cluster-aware value so the public surface
1118+
# group_time_effects[(g,t)]["se"] reflects the
1119+
# documented CR1 contract.
1120+
group_time_effects[(g, t)]["se"] = se
10461121

10471122
atts.append(att)
10481123
ses.append(se)
@@ -1379,14 +1454,24 @@ def _compute_all_att_gt_covariate_reg(
13791454
all_units = precomputed["all_units"]
13801455
treated_positions = np.where(treated_valid)[0]
13811456
control_positions = np.where(control_valid)[0]
1382-
influence_func_info[(g, t)] = {
1457+
inf_info_gt = {
13831458
"treated_idx": treated_positions,
13841459
"control_idx": control_positions,
13851460
"treated_units": all_units[treated_positions],
13861461
"control_units": all_units[control_positions],
13871462
"treated_inf": inf_treated,
13881463
"control_inf": inf_control,
13891464
}
1465+
influence_func_info[(g, t)] = inf_info_gt
1466+
1467+
# Cluster-aware per-(g,t) SE — see same pattern in
1468+
# _compute_all_att_gt_vectorized.
1469+
rsu_for_gt = precomputed.get("resolved_survey_unit")
1470+
if rsu_for_gt is not None and getattr(rsu_for_gt, "psu", None) is not None:
1471+
se_cluster = _cluster_robust_se_from_per_gt_if(inf_info_gt, rsu_for_gt.psu)
1472+
if se_cluster is not None and np.isfinite(se_cluster):
1473+
se = se_cluster
1474+
group_time_effects[(g, t)]["se"] = se
13901475

13911476
atts.append(att)
13921477
ses.append(se)
@@ -1820,6 +1905,22 @@ def fit(
18201905
agg_w = rc_result[6] if len(rc_result) > 6 else n_treat
18211906

18221907
if att_gt is not None:
1908+
# Cluster-aware per-(g,t) SE on the RCS path. RC
1909+
# IF indices are per-obs (vs per-unit on the panel
1910+
# path); the corresponding PSU array is
1911+
# ``resolved_survey.psu`` (length n_obs), not
1912+
# ``resolved_survey_unit.psu``. Bit-equal to pre-PR
1913+
# when psu is None.
1914+
rs_for_gt = precomputed.get("resolved_survey") if precomputed else None
1915+
if (
1916+
rs_for_gt is not None
1917+
and getattr(rs_for_gt, "psu", None) is not None
1918+
and inf_info is not None
1919+
):
1920+
se_cluster = _cluster_robust_se_from_per_gt_if(inf_info, rs_for_gt.psu)
1921+
if se_cluster is not None and np.isfinite(se_cluster):
1922+
se_gt = se_cluster
1923+
18231924
t_stat, p_val, ci = safe_inference(
18241925
att_gt,
18251926
se_gt,
@@ -1912,6 +2013,22 @@ def fit(
19122013
)
19132014

19142015
if att_gt is not None:
2016+
# Cluster-aware per-(g,t) SE: when a survey PSU is
2017+
# in play (explicit OR synthesized from bare
2018+
# cluster=), aggregate the per-(g,t) IF by PSU
2019+
# and use CR1 Liang-Zeger SE instead of the
2020+
# unit-level diff-of-means SE returned by OR/IPW/DR.
2021+
# Preserves bit-equality when psu is None.
2022+
rsu_for_gt = precomputed.get("resolved_survey_unit")
2023+
if (
2024+
rsu_for_gt is not None
2025+
and getattr(rsu_for_gt, "psu", None) is not None
2026+
and inf_info is not None
2027+
):
2028+
se_cluster = _cluster_robust_se_from_per_gt_if(inf_info, rsu_for_gt.psu)
2029+
if se_cluster is not None and np.isfinite(se_cluster):
2030+
se_gt = se_cluster
2031+
19152032
t_stat, p_val, ci = safe_inference(
19162033
att_gt,
19172034
se_gt,

tests/test_staggered.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4829,6 +4829,98 @@ def test_bare_cluster_bootstrap_se_differs_from_unit_level(self):
48294829
"not be reaching the bootstrap multiplier-weights routing."
48304830
)
48314831

4832+
def test_per_gt_analytical_se_changes_with_cluster(self):
4833+
"""Per-(g,t) analytical SE at results.group_time_effects[(g,t)]
4834+
["se"] must change when cluster= is set (mirrors the overall_se
4835+
contract). Pre-fix, per-(g,t) SEs were unit-level even with
4836+
cluster=, only the aggregate path + bootstrap honored cluster=.
4837+
Per CI codex R3 P0 finding."""
4838+
data = _generate_clustered_staggered_data(seed=97)
4839+
4840+
cs_unit = CallawaySantAnna()
4841+
res_unit = cs_unit.fit(
4842+
data,
4843+
outcome="outcome",
4844+
unit="unit",
4845+
time="time",
4846+
first_treat="first_treat",
4847+
)
4848+
cs_cluster = CallawaySantAnna(cluster="state")
4849+
res_cluster = cs_cluster.fit(
4850+
data,
4851+
outcome="outcome",
4852+
unit="unit",
4853+
time="time",
4854+
first_treat="first_treat",
4855+
)
4856+
4857+
# Pick a representative (g, t) cell that exists in both fits
4858+
gt_keys = sorted(
4859+
set(res_unit.group_time_effects.keys()) & set(res_cluster.group_time_effects.keys())
4860+
)
4861+
assert len(gt_keys) > 0, "expected overlapping (g, t) keys"
4862+
4863+
# At least one (g, t) cell must show measurable SE divergence —
4864+
# cluster-aware aggregation should differ from unit-level for at
4865+
# least one cell on a panel with intra-cluster correlation.
4866+
diffs = []
4867+
for gt in gt_keys:
4868+
se_unit = res_unit.group_time_effects[gt]["se"]
4869+
se_cluster = res_cluster.group_time_effects[gt]["se"]
4870+
if np.isfinite(se_unit) and np.isfinite(se_cluster):
4871+
diffs.append(abs(se_unit - se_cluster))
4872+
max_diff = max(diffs) if diffs else 0.0
4873+
assert max_diff > 1e-6, (
4874+
f"Per-(g,t) SEs did not change with cluster= (max diff "
4875+
f"across {len(diffs)} cells: {max_diff:.6g}). The cluster= "
4876+
"parameter may not be reaching the per-(g,t) analytical SE "
4877+
"computation."
4878+
)
4879+
4880+
def test_per_gt_se_matches_explicit_survey_design(self):
4881+
"""When bare cluster=X and explicit SurveyDesign(psu=X) produce
4882+
equivalent variance contracts, the per-(g,t) SE surface must
4883+
also agree (modulo the deterministic synthesis path). Per CI
4884+
codex R3 P0 finding."""
4885+
from diff_diff import SurveyDesign
4886+
4887+
data = _generate_clustered_staggered_data(seed=101)
4888+
4889+
cs_bare = CallawaySantAnna(cluster="state")
4890+
res_bare = cs_bare.fit(
4891+
data,
4892+
outcome="outcome",
4893+
unit="unit",
4894+
time="time",
4895+
first_treat="first_treat",
4896+
)
4897+
4898+
cs_explicit = CallawaySantAnna()
4899+
res_explicit = cs_explicit.fit(
4900+
data,
4901+
outcome="outcome",
4902+
unit="unit",
4903+
time="time",
4904+
first_treat="first_treat",
4905+
survey_design=SurveyDesign(psu="state"),
4906+
)
4907+
4908+
gt_keys = sorted(
4909+
set(res_bare.group_time_effects.keys()) & set(res_explicit.group_time_effects.keys())
4910+
)
4911+
assert len(gt_keys) > 0
4912+
4913+
for gt in gt_keys:
4914+
se_bare = res_bare.group_time_effects[gt]["se"]
4915+
se_explicit = res_explicit.group_time_effects[gt]["se"]
4916+
if np.isfinite(se_bare) and np.isfinite(se_explicit):
4917+
assert se_bare == pytest.approx(se_explicit, rel=1e-10, abs=1e-12), (
4918+
f"Per-(g,t) SE divergence at {gt}: bare cluster=state "
4919+
f"({se_bare}) vs explicit SurveyDesign(psu=state) "
4920+
f"({se_explicit}). Both should activate the same CR1 "
4921+
"aggregation."
4922+
)
4923+
48324924
def test_survey_design_psu_wins_under_bootstrap(self):
48334925
"""Bootstrap path: when survey_design=SurveyDesign(psu=Y) is
48344926
explicit AND cluster=X is also set with a different partition,

0 commit comments

Comments
 (0)