diff --git a/doubleml/plm/plpr.py b/doubleml/plm/plpr.py index bf860b77..c5507f56 100644 --- a/doubleml/plm/plpr.py +++ b/doubleml/plm/plpr.py @@ -335,9 +335,9 @@ def _transform_data(self): def _set_d_mean(self): if self._approach in ["cre_general", "cre_normal"]: - data = self._original_dml_data.data - d_cols = self._original_dml_data.d_cols - id_col = self._original_dml_data.id_col + data = self._dml_data.data + d_cols = self._dml_data.d_cols + id_col = self._dml_data.id_col help_d_mean = data.loc[:, [id_col] + d_cols] d_mean = help_d_mean.groupby(id_col).transform("mean").values self._d_mean = d_mean diff --git a/doubleml/plm/tests/test_plpr_transformations.py b/doubleml/plm/tests/test_plpr_transformations.py index db162d0a..d6ea2d23 100644 --- a/doubleml/plm/tests/test_plpr_transformations.py +++ b/doubleml/plm/tests/test_plpr_transformations.py @@ -87,6 +87,7 @@ def test_plpr_approach_x_dim(approach, time_type): static_panel=True, ) dml_plpr = dml.DoubleMLPLPR(obj_dml_data, ml_l, ml_m, approach=approach) + dml_plpr.fit() if approach == "wg_approx": assert len(dml_plpr._dml_data.x_cols) == dim_x else: @@ -106,6 +107,7 @@ def test_plpr_approach_d_mean(approach, time_type): static_panel=True, ) dml_plpr = dml.DoubleMLPLPR(obj_dml_data, ml_l, ml_m, approach=approach) + dml_plpr.fit() if approach in ["cre_general", "cre_normal"]: assert dml_plpr.d_mean is not None else: @@ -145,6 +147,7 @@ def test_plpr_fd_exact_unbalanced(time_type): ) with pytest.warns(UserWarning, match=msg_warn): obj_plpr = dml.DoubleMLPLPR(obj_dml_data, ml_l, ml_m, approach="fd_exact", n_folds=2) + obj_plpr.fit() # 4 rows after fd transformation as id 3 has no possible first difference assert obj_plpr.data_transform.data.shape[0] == 4 @@ -168,6 +171,7 @@ def test_plpr_one_id(approach, time_type): ) with pytest.warns(UserWarning, match=msg_warn): obj_plpr = dml.DoubleMLPLPR(obj_dml_data, ml_l, ml_m, approach=approach, n_folds=2) + obj_plpr.fit() # 2 rows after fd transformation, 4 rows else if approach == "fd_exact": assert obj_plpr.data_transform.data.shape[0] == 2 @@ -196,6 +200,7 @@ def test_plpr_fd_exact_one_id_unbalanced(time_type): # capture warnings with pytest.warns(UserWarning) as record: obj_plpr = dml.DoubleMLPLPR(obj_dml_data, ml_l, ml_m, approach="fd_exact", n_folds=2) + obj_plpr.fit() # assert two warnings were raised and content assert len(record) == 2 assert msg_warn_one_id in str(record[0].message) @@ -215,6 +220,7 @@ def test_plpr_time_cre_transformation(cre_approach, data_time_type): static_panel=True, ) dml_cre = dml.DoubleMLPLPR(obj_dml_data, ml_l, ml_m, approach=cre_approach, n_folds=2) + dml_cre.fit() assert dml_cre.transform_cols["y_col"] == "y" assert dml_cre.transform_cols["d_cols"] == ["d"] assert dml_cre.transform_cols["x_cols"] == ["x1", "x2", "x1_mean", "x2_mean"]