Skip to content

Commit f808ae6

Browse files
committed
efficient-did: make set_params atomic — rollback all attribute mutations on validation failure
Previously, EfficientDiD.set_params iterated over the kwargs and called setattr() before invoking _validate_params(). A rejected batched call such as set_params(vcov_type="classical", alpha=0.1, anticipation=2) raised on validation but left every kwarg in the partially-mutated state — the rejected vcov_type plus the otherwise-valid alpha and anticipation. Callers that catch ValueError and continue using the estimator would then operate on a silently corrupted parameter configuration, defeating the eager-validation contract introduced when _validate_vcov_type was wired into _validate_params. Fix: snapshot original attribute values for every kwarg before applying mutations, run validation, and on exception restore the snapshot before re-raising. The snapshot pass also moves the "Unknown parameter" check ahead of any mutation so that even unknown-name rejections leave the estimator untouched. Regression test added: test_set_params_rollback_on_validation_failure fires the 3-kwarg batched call and pins all three attribute values to their pre-call snapshot after the raise.
1 parent 79f428d commit f808ae6

2 files changed

Lines changed: 40 additions & 5 deletions

File tree

diff_diff/efficient_did.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -351,13 +351,30 @@ def get_params(self) -> Dict[str, Any]:
351351
}
352352

353353
def set_params(self, **params: Any) -> "EfficientDiD":
354-
"""Set estimator parameters (sklearn-compatible)."""
354+
"""Set estimator parameters (sklearn-compatible).
355+
356+
Atomic: snapshots the original attribute values before applying
357+
mutations, validates the new state via ``_validate_params``, and
358+
rolls every attribute back to its pre-call value if validation
359+
raises. Without this, ``set_params(vcov_type="classical",
360+
alpha=0.1)`` would leave ``self.vcov_type`` partially mutated
361+
even though the call raised, defeating the eager-validation
362+
contract for callers that catch ``ValueError`` and keep using
363+
the estimator.
364+
"""
365+
snapshot: Dict[str, Any] = {}
366+
for key in params:
367+
if not hasattr(self, key):
368+
raise ValueError(f"Unknown parameter: {key}")
369+
snapshot[key] = getattr(self, key)
355370
for key, value in params.items():
356-
if hasattr(self, key):
371+
setattr(self, key, value)
372+
try:
373+
self._validate_params()
374+
except Exception:
375+
for key, value in snapshot.items():
357376
setattr(self, key, value)
358-
else:
359-
raise ValueError(f"Unknown parameter: {key}")
360-
self._validate_params()
377+
raise
361378
return self
362379

363380
# -- Main estimation ------------------------------------------------------

tests/test_efficient_did.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2521,6 +2521,24 @@ def test_set_params_unknown_vcov_caught_immediately(self):
25212521
with pytest.raises(ValueError, match="hc4"):
25222522
ed.set_params(vcov_type="hc4")
25232523

2524+
def test_set_params_rollback_on_validation_failure(self):
2525+
# set_params is atomic: when validation rejects a batched call, NO
2526+
# attribute mutation persists. Pre-fix, set_params assigned every
2527+
# kwarg before invoking _validate_params, so a rejected
2528+
# `set_params(vcov_type="classical", alpha=0.1, anticipation=2)`
2529+
# raised but left all three attributes mutated — weakening eager-
2530+
# validation for callers that catch ValueError and keep using the
2531+
# estimator.
2532+
ed = EfficientDiD()
2533+
original_vcov = ed.vcov_type
2534+
original_alpha = ed.alpha
2535+
original_anticipation = ed.anticipation
2536+
with pytest.raises(ValueError, match="influence-function"):
2537+
ed.set_params(vcov_type="classical", alpha=0.1, anticipation=2)
2538+
assert ed.vcov_type == original_vcov
2539+
assert ed.alpha == original_alpha
2540+
assert ed.anticipation == original_anticipation
2541+
25242542
# ---- Surface 7: bootstrap n_psu<2 NaN propagation ---------------------
25252543

25262544
def test_bootstrap_n_psu_less_than_2_returns_nan(self):

0 commit comments

Comments
 (0)