Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions doubleml/plm/plpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,9 +335,9 @@ def _transform_data(self):

def _set_d_mean(self):
if self._approach in ["cre_general", "cre_normal"]:
data = self._original_dml_data.data
d_cols = self._original_dml_data.d_cols
id_col = self._original_dml_data.id_col
data = self._dml_data.data
d_cols = self._dml_data.d_cols
id_col = self._dml_data.id_col
help_d_mean = data.loc[:, [id_col] + d_cols]
d_mean = help_d_mean.groupby(id_col).transform("mean").values
self._d_mean = d_mean
Expand Down
6 changes: 6 additions & 0 deletions doubleml/plm/tests/test_plpr_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def test_plpr_approach_x_dim(approach, time_type):
static_panel=True,
)
dml_plpr = dml.DoubleMLPLPR(obj_dml_data, ml_l, ml_m, approach=approach)
dml_plpr.fit()
if approach == "wg_approx":
assert len(dml_plpr._dml_data.x_cols) == dim_x
else:
Expand All @@ -106,6 +107,7 @@ def test_plpr_approach_d_mean(approach, time_type):
static_panel=True,
)
dml_plpr = dml.DoubleMLPLPR(obj_dml_data, ml_l, ml_m, approach=approach)
dml_plpr.fit()
if approach in ["cre_general", "cre_normal"]:
assert dml_plpr.d_mean is not None
else:
Expand Down Expand Up @@ -145,6 +147,7 @@ def test_plpr_fd_exact_unbalanced(time_type):
)
with pytest.warns(UserWarning, match=msg_warn):
obj_plpr = dml.DoubleMLPLPR(obj_dml_data, ml_l, ml_m, approach="fd_exact", n_folds=2)
obj_plpr.fit()
# 4 rows after fd transformation as id 3 has no possible first difference
assert obj_plpr.data_transform.data.shape[0] == 4

Expand All @@ -168,6 +171,7 @@ def test_plpr_one_id(approach, time_type):
)
with pytest.warns(UserWarning, match=msg_warn):
obj_plpr = dml.DoubleMLPLPR(obj_dml_data, ml_l, ml_m, approach=approach, n_folds=2)
obj_plpr.fit()
# 2 rows after fd transformation, 4 rows else
if approach == "fd_exact":
assert obj_plpr.data_transform.data.shape[0] == 2
Expand Down Expand Up @@ -196,6 +200,7 @@ def test_plpr_fd_exact_one_id_unbalanced(time_type):
# capture warnings
with pytest.warns(UserWarning) as record:
obj_plpr = dml.DoubleMLPLPR(obj_dml_data, ml_l, ml_m, approach="fd_exact", n_folds=2)
obj_plpr.fit()
# assert two warnings were raised and content
assert len(record) == 2
assert msg_warn_one_id in str(record[0].message)
Expand All @@ -215,6 +220,7 @@ def test_plpr_time_cre_transformation(cre_approach, data_time_type):
static_panel=True,
)
dml_cre = dml.DoubleMLPLPR(obj_dml_data, ml_l, ml_m, approach=cre_approach, n_folds=2)
dml_cre.fit()
assert dml_cre.transform_cols["y_col"] == "y"
assert dml_cre.transform_cols["d_cols"] == ["d"]
assert dml_cre.transform_cols["x_cols"] == ["x1", "x2", "x1_mean", "x2_mean"]
Expand Down
Loading