Skip to content

Commit a729f00

Browse files
authored
Merge pull request #253 from igerber/replicate-weight-expansion
Add replicate weight support to 7 estimators
2 parents 8ffd344 + 692dab4 commit a729f00

11 files changed

Lines changed: 1391 additions & 235 deletions

File tree

TODO.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ Deferred items from PR reviews that were not addressed before merge.
6969
| StaggeredTripleDifference R parity: benchmark only tests no-covariate path (xformla=~1). Add covariate-adjusted scenarios and aggregation SE parity assertions. | `benchmarks/R/benchmark_staggered_triplediff.R` | #245 | Medium |
7070
| StaggeredTripleDifference: per-cohort group-effect SEs include WIF (conservative vs R's wif=NULL). Documented in REGISTRY. Could override mixin for exact R match. | `staggered_triple_diff.py` | #245 | Low |
7171
| HonestDiD Delta^RM: uses naive FLCI instead of paper's ARP conditional/hybrid confidence sets (Sections 3.2.1-3.2.2). ARP infrastructure exists but moment inequality transformation needs calibration. CIs are conservative (wider, valid coverage). | `honest_did.py` | #248 | Medium |
72+
| Replicate weight tests use Fay-like BRR perturbations (0.5/1.5), not true half-sample BRR. Add true BRR regressions per estimator family. Existing `test_survey_phase6.py` covers true BRR at the helper level. | `tests/test_replicate_weight_expansion.py` | #253 | Low |
7273

7374
#### Performance
7475

diff_diff/estimators.py

Lines changed: 165 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -240,14 +240,14 @@ def fit(
240240
resolved_survey, survey_weights, survey_weight_type, survey_metadata = (
241241
_resolve_survey_for_fit(survey_design, data, self.inference)
242242
)
243-
# Reject replicate-weight designs — base DiD uses compute_survey_vcov
244-
# (TSL) directly, not LinearRegression's replicate dispatch.
245-
if resolved_survey is not None and resolved_survey.uses_replicate_variance:
246-
raise NotImplementedError(
247-
"DifferenceInDifferences does not yet support replicate-weight "
248-
"survey designs. Use CallawaySantAnna, EfficientDiD, "
249-
"ContinuousDiD, or TripleDifference for replicate-weight "
250-
"inference, or use a TSL-based survey design (strata/psu/fpc)."
243+
_uses_replicate = (
244+
resolved_survey is not None and resolved_survey.uses_replicate_variance
245+
)
246+
if _uses_replicate and self.inference == "wild_bootstrap":
247+
raise ValueError(
248+
"Cannot use inference='wild_bootstrap' with replicate-weight "
249+
"survey designs. Replicate weights provide their own variance "
250+
"estimation."
251251
)
252252

253253
# Handle absorbed fixed effects (within-transformation)
@@ -358,6 +358,13 @@ def fit(
358358
)
359359
survey_metadata = compute_survey_metadata(resolved_survey, raw_w)
360360

361+
# When absorb + replicate: pass survey_design=None to prevent
362+
# LinearRegression from computing replicate vcov on already-demeaned
363+
# data (demeaning depends on weights, so replicate refits must re-demean).
364+
_lr_survey = resolved_survey
365+
if _uses_replicate and absorbed_vars:
366+
_lr_survey = None
367+
361368
reg = LinearRegression(
362369
include_intercept=False, # Intercept already in X
363370
robust=self.robust,
@@ -366,7 +373,7 @@ def fit(
366373
rank_deficient_action=self.rank_deficient_action,
367374
weights=survey_weights,
368375
weight_type=survey_weight_type,
369-
survey_design=resolved_survey,
376+
survey_design=_lr_survey,
370377
).fit(X, y, df_adjustment=n_absorbed_effects)
371378

372379
coefficients = reg.coefficients_
@@ -375,14 +382,69 @@ def fit(
375382
assert coefficients is not None
376383
att = coefficients[att_idx]
377384

378-
# Get inference - either from bootstrap or analytical
379-
if self.inference == "wild_bootstrap" and self.cluster is not None:
385+
# Get inference - replicate absorb override, bootstrap, or analytical
386+
if _uses_replicate and absorbed_vars:
387+
# Estimator-level replicate variance: re-demean + re-solve per replicate
388+
from diff_diff.survey import compute_replicate_refit_variance
389+
from diff_diff.utils import safe_inference
390+
391+
_absorb_list = list(absorbed_vars) # capture for closure
392+
393+
# Handle rank-deficient nuisance: refit only identified columns
394+
_id_mask = ~np.isnan(coefficients)
395+
_id_cols = np.where(_id_mask)[0]
396+
_att_idx_reduced = int(np.searchsorted(_id_cols, att_idx))
397+
398+
def _refit_did_absorb(w_r):
399+
nz = w_r > 0
400+
wd = data[nz].copy()
401+
w_nz = w_r[nz]
402+
wd["_treat_time"] = (
403+
wd[treatment].values.astype(float) * wd[time].values.astype(float)
404+
)
405+
vars_dm = [outcome, treatment, time, "_treat_time"] + (covariates or [])
406+
for ab_var in _absorb_list:
407+
wd, _ = demean_by_group(wd, vars_dm, ab_var, inplace=True, weights=w_nz)
408+
y_r = wd[outcome].values.astype(float)
409+
d_r = wd[treatment].values.astype(float)
410+
t_r = wd[time].values.astype(float)
411+
dt_r = wd["_treat_time"].values.astype(float)
412+
X_r = np.column_stack([np.ones(len(y_r)), d_r, t_r, dt_r])
413+
if covariates:
414+
for cov in covariates:
415+
X_r = np.column_stack([X_r, wd[cov].values.astype(float)])
416+
coef_r, _, _ = solve_ols(
417+
X_r[:, _id_cols], y_r,
418+
weights=w_nz, weight_type=survey_weight_type,
419+
rank_deficient_action="silent", return_vcov=False,
420+
)
421+
return coef_r
422+
423+
vcov_reduced, _n_valid_rep = compute_replicate_refit_variance(
424+
_refit_did_absorb, coefficients[_id_mask], resolved_survey
425+
)
426+
vcov = _expand_vcov_with_nan(vcov_reduced, len(coefficients), _id_cols)
427+
se = float(np.sqrt(max(vcov[att_idx, att_idx], 0.0)))
428+
_df_rep = (
429+
survey_metadata.df_survey
430+
if survey_metadata and survey_metadata.df_survey
431+
else 0 # rank-deficient replicate → NaN inference
432+
)
433+
if _n_valid_rep < resolved_survey.n_replicates:
434+
_df_rep = _n_valid_rep - 1 if _n_valid_rep > 1 else 0
435+
if survey_metadata is not None:
436+
survey_metadata.df_survey = _df_rep if _df_rep > 0 else None
437+
t_stat, p_value, conf_int = safe_inference(
438+
att, se, alpha=self.alpha, df=_df_rep
439+
)
440+
elif self.inference == "wild_bootstrap" and self.cluster is not None:
380441
# Override with wild cluster bootstrap inference
381442
se, p_value, conf_int, t_stat, vcov, _ = self._run_wild_bootstrap_inference(
382443
X, y, residuals, cluster_ids, att_idx
383444
)
384445
else:
385446
# Use analytical inference from LinearRegression
447+
# (handles replicate vcov for no-absorb path automatically)
386448
vcov = reg.vcov_
387449
inference = reg.get_inference(att_idx)
388450
se = inference.se
@@ -1017,14 +1079,14 @@ def fit( # type: ignore[override]
10171079
resolved_survey, survey_weights, survey_weight_type, survey_metadata = (
10181080
_resolve_survey_for_fit(survey_design, data, effective_inference)
10191081
)
1020-
# Reject replicate-weight designs — MultiPeriodDiD uses
1021-
# compute_survey_vcov (TSL) directly without replicate dispatch.
1022-
if resolved_survey is not None and resolved_survey.uses_replicate_variance:
1023-
raise NotImplementedError(
1024-
"MultiPeriodDiD does not yet support replicate-weight survey "
1025-
"designs. Use CallawaySantAnna for staggered adoption with "
1026-
"replicate weights, or use a TSL-based survey design "
1027-
"(strata/psu/fpc)."
1082+
_uses_replicate_mp = (
1083+
resolved_survey is not None and resolved_survey.uses_replicate_variance
1084+
)
1085+
if _uses_replicate_mp and effective_inference == "wild_bootstrap":
1086+
raise ValueError(
1087+
"Cannot use inference='wild_bootstrap' with replicate-weight "
1088+
"survey designs. Replicate weights provide their own variance "
1089+
"estimation."
10281090
)
10291091

10301092
# Handle absorbed fixed effects (within-transformation)
@@ -1177,7 +1239,80 @@ def fit( # type: ignore[override]
11771239
)
11781240

11791241
# Compute survey vcov if applicable
1180-
if _use_survey_vcov:
1242+
_n_valid_rep_mp = None
1243+
if _use_survey_vcov and _uses_replicate_mp and absorb:
1244+
# Absorb + replicate: estimator-level refit (demeaning depends on weights)
1245+
from diff_diff.survey import compute_replicate_refit_variance
1246+
1247+
_absorb_list_mp = list(absorb)
1248+
# Handle rank-deficient nuisance: refit only identified columns
1249+
_id_mask_mp = ~np.isnan(coefficients)
1250+
_id_cols_mp = np.where(_id_mask_mp)[0]
1251+
1252+
def _refit_mp_absorb(w_r):
1253+
nz = w_r > 0
1254+
wd = data[nz].copy()
1255+
w_nz = w_r[nz]
1256+
d_raw_ = wd[treatment].values.astype(float)
1257+
t_raw_ = wd[time].values
1258+
wd["_did_treatment"] = d_raw_
1259+
for period_ in non_ref_periods:
1260+
wd[f"_did_period_{period_}"] = (t_raw_ == period_).astype(float)
1261+
wd[f"_did_interact_{period_}"] = d_raw_ * (t_raw_ == period_).astype(float)
1262+
vars_dm_ = (
1263+
[outcome, "_did_treatment"]
1264+
+ [f"_did_period_{p}" for p in non_ref_periods]
1265+
+ [f"_did_interact_{p}" for p in non_ref_periods]
1266+
+ (covariates or [])
1267+
)
1268+
for ab_var_ in _absorb_list_mp:
1269+
wd, _ = demean_by_group(wd, vars_dm_, ab_var_, inplace=True, weights=w_nz)
1270+
y_r = wd[outcome].values.astype(float)
1271+
d_r = wd["_did_treatment"].values.astype(float)
1272+
X_r = np.column_stack([np.ones(len(y_r)), d_r])
1273+
for period_ in non_ref_periods:
1274+
X_r = np.column_stack(
1275+
[X_r, wd[f"_did_period_{period_}"].values.astype(float)]
1276+
)
1277+
for period_ in non_ref_periods:
1278+
X_r = np.column_stack(
1279+
[X_r, wd[f"_did_interact_{period_}"].values.astype(float)]
1280+
)
1281+
if covariates:
1282+
for cov_ in covariates:
1283+
X_r = np.column_stack([X_r, wd[cov_].values.astype(float)])
1284+
coef_r, _, _ = solve_ols(
1285+
X_r[:, _id_cols_mp], y_r,
1286+
weights=w_nz, weight_type=survey_weight_type,
1287+
rank_deficient_action="silent", return_vcov=False,
1288+
)
1289+
return coef_r
1290+
1291+
vcov_reduced_mp, _n_valid_rep_mp = compute_replicate_refit_variance(
1292+
_refit_mp_absorb, coefficients[_id_mask_mp], resolved_survey
1293+
)
1294+
vcov = _expand_vcov_with_nan(vcov_reduced_mp, len(coefficients), _id_cols_mp)
1295+
elif _use_survey_vcov and _uses_replicate_mp:
1296+
# No absorb + replicate: X is fixed, use compute_replicate_vcov directly
1297+
from diff_diff.survey import compute_replicate_vcov
1298+
1299+
nan_mask = np.isnan(coefficients)
1300+
if np.any(nan_mask):
1301+
kept_cols = np.where(~nan_mask)[0]
1302+
if len(kept_cols) > 0:
1303+
vcov_reduced, _n_valid_rep_mp = compute_replicate_vcov(
1304+
X[:, kept_cols], y, coefficients[kept_cols], resolved_survey,
1305+
weight_type=survey_weight_type,
1306+
)
1307+
vcov = _expand_vcov_with_nan(vcov_reduced, X.shape[1], kept_cols)
1308+
else:
1309+
vcov = np.full((X.shape[1], X.shape[1]), np.nan)
1310+
_n_valid_rep_mp = 0
1311+
else:
1312+
vcov, _n_valid_rep_mp = compute_replicate_vcov(
1313+
X, y, coefficients, resolved_survey, weight_type=survey_weight_type,
1314+
)
1315+
elif _use_survey_vcov:
11811316
from diff_diff.survey import compute_survey_vcov
11821317

11831318
nan_mask = np.isnan(coefficients)
@@ -1201,9 +1336,18 @@ def fit( # type: ignore[override]
12011336
df = n_eff_df - k_effective - n_absorbed_effects
12021337
if resolved_survey is not None and resolved_survey.df_survey is not None:
12031338
df = resolved_survey.df_survey
1339+
# Replicate df: rank-deficient → NaN inference; dropped replicates → n_valid-1
1340+
if _uses_replicate_mp:
1341+
if resolved_survey.df_survey is None:
1342+
df = 0 # rank-deficient replicate → NaN inference
1343+
if _n_valid_rep_mp is not None and _n_valid_rep_mp < resolved_survey.n_replicates:
1344+
df = _n_valid_rep_mp - 1 if _n_valid_rep_mp > 1 else 0
1345+
if survey_metadata is not None:
1346+
survey_metadata.df_survey = df if df > 0 else None
12041347

12051348
# Guard: fall back to normal distribution if df is non-positive
1206-
if df is not None and df <= 0:
1349+
# Skip for replicate designs — df=0 is intentional for NaN inference
1350+
if df is not None and df <= 0 and not _uses_replicate_mp:
12071351
warnings.warn(
12081352
f"Degrees of freedom is non-positive (df={df}). "
12091353
"Using normal distribution instead of t-distribution for inference.",

0 commit comments

Comments
 (0)