Skip to content

Commit 41a0236

Browse files
DID switch
1 parent a39e015 commit 41a0236

File tree

2 files changed

+20
-15
lines changed

2 files changed

+20
-15
lines changed

causalpy/experiments/diff_in_diff.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
import numpy as np
2020
import pandas as pd
2121
import seaborn as sns
22+
from formulae import design_matrices
2223
from matplotlib import pyplot as plt
23-
from patsy import build_design_matrices, dmatrices
2424
from sklearn.base import RegressorMixin
2525

2626
from causalpy.custom_exceptions import (
@@ -91,16 +91,18 @@ def __init__(
9191
self.data = data
9292
self.expt_type = "Difference in Differences"
9393
self.formula = formula
94+
self.rhs_formula = formula.split("~", 1)[1].strip()
9495
self.time_variable_name = time_variable_name
9596
self.group_variable_name = group_variable_name
9697
self.input_validation()
9798

98-
y, X = dmatrices(formula, self.data)
99-
self._y_design_info = y.design_info
100-
self._x_design_info = X.design_info
101-
self.labels = X.design_info.column_names
102-
self.y, self.X = np.asarray(y), np.asarray(X)
103-
self.outcome_variable_name = y.design_info.column_names[0]
99+
dm = design_matrices(self.formula, self.data)
100+
self.labels = list(dm.common.terms.keys())
101+
self.y, self.X = (
102+
np.asarray(dm.response.design_matrix).reshape(-1, 1),
103+
np.asarray(dm.common.design_matrix),
104+
)
105+
self.outcome_variable_name = dm.response.name
104106

105107
# fit model
106108
if isinstance(self.model, PyMCModel):
@@ -125,8 +127,8 @@ def __init__(
125127
)
126128
if self.x_pred_control.empty:
127129
raise ValueError("x_pred_control is empty")
128-
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_control)
129-
self.y_pred_control = self.model.predict(np.asarray(new_x))
130+
new_x = np.array(design_matrices(self.rhs_formula, self.x_pred_control).common)
131+
self.y_pred_control = self.model.predict(new_x)
130132

131133
# predicted outcome for treatment group
132134
self.x_pred_treatment = (
@@ -142,8 +144,10 @@ def __init__(
142144
)
143145
if self.x_pred_treatment.empty:
144146
raise ValueError("x_pred_treatment is empty")
145-
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_treatment)
146-
self.y_pred_treatment = self.model.predict(np.asarray(new_x))
147+
new_x = np.array(
148+
design_matrices(self.rhs_formula, self.x_pred_treatment).common
149+
)
150+
self.y_pred_treatment = self.model.predict(new_x)
147151

148152
# predicted outcome for counterfactual. This is given by removing the influence
149153
# of the interaction term between the group and the post_treatment variable
@@ -162,15 +166,15 @@ def __init__(
162166
)
163167
if self.x_pred_counterfactual.empty:
164168
raise ValueError("x_pred_counterfactual is empty")
165-
(new_x,) = build_design_matrices(
166-
[self._x_design_info], self.x_pred_counterfactual, return_type="dataframe"
169+
new_x = np.array(
170+
design_matrices(self.rhs_formula, self.x_pred_counterfactual).common
167171
)
168172
# INTERVENTION: set the interaction term between the group and the
169173
# post_treatment variable to zero. This is the counterfactual.
170174
for i, label in enumerate(self.labels):
171175
if "post_treatment" in label and self.group_variable_name in label:
172-
new_x.iloc[:, i] = 0
173-
self.y_pred_counterfactual = self.model.predict(np.asarray(new_x))
176+
new_x[:, i] = 0
177+
self.y_pred_counterfactual = self.model.predict(new_x)
174178

175179
# calculate causal impact
176180
if isinstance(self.model, PyMCModel):

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ dependencies = [
3434
"numpy",
3535
"pandas",
3636
"patsy",
37+
"formulae",
3738
"pymc>=5.15.1",
3839
"scikit-learn>=1",
3940
"scipy",

0 commit comments

Comments
 (0)