19
19
import numpy as np
20
20
import pandas as pd
21
21
import seaborn as sns
22
+ from formulae import design_matrices
22
23
from matplotlib import pyplot as plt
23
- from patsy import build_design_matrices , dmatrices
24
24
from sklearn .base import RegressorMixin
25
25
26
26
from causalpy .custom_exceptions import (
@@ -91,16 +91,18 @@ def __init__(
91
91
self .data = data
92
92
self .expt_type = "Difference in Differences"
93
93
self .formula = formula
94
+ self .rhs_formula = formula .split ("~" , 1 )[1 ].strip ()
94
95
self .time_variable_name = time_variable_name
95
96
self .group_variable_name = group_variable_name
96
97
self .input_validation ()
97
98
98
- y , X = dmatrices (formula , self .data )
99
- self ._y_design_info = y .design_info
100
- self ._x_design_info = X .design_info
101
- self .labels = X .design_info .column_names
102
- self .y , self .X = np .asarray (y ), np .asarray (X )
103
- self .outcome_variable_name = y .design_info .column_names [0 ]
99
+ dm = design_matrices (self .formula , self .data )
100
+ self .labels = list (dm .common .terms .keys ())
101
+ self .y , self .X = (
102
+ np .asarray (dm .response .design_matrix ).reshape (- 1 , 1 ),
103
+ np .asarray (dm .common .design_matrix ),
104
+ )
105
+ self .outcome_variable_name = dm .response .name
104
106
105
107
# fit model
106
108
if isinstance (self .model , PyMCModel ):
@@ -125,8 +127,8 @@ def __init__(
125
127
)
126
128
if self .x_pred_control .empty :
127
129
raise ValueError ("x_pred_control is empty" )
128
- ( new_x ,) = build_design_matrices ([ self ._x_design_info ] , self .x_pred_control )
129
- self .y_pred_control = self .model .predict (np . asarray ( new_x ) )
130
+ new_x = np . array ( design_matrices ( self .rhs_formula , self .x_pred_control ). common )
131
+ self .y_pred_control = self .model .predict (new_x )
130
132
131
133
# predicted outcome for treatment group
132
134
self .x_pred_treatment = (
@@ -142,8 +144,10 @@ def __init__(
142
144
)
143
145
if self .x_pred_treatment .empty :
144
146
raise ValueError ("x_pred_treatment is empty" )
145
- (new_x ,) = build_design_matrices ([self ._x_design_info ], self .x_pred_treatment )
146
- self .y_pred_treatment = self .model .predict (np .asarray (new_x ))
147
+ new_x = np .array (
148
+ design_matrices (self .rhs_formula , self .x_pred_treatment ).common
149
+ )
150
+ self .y_pred_treatment = self .model .predict (new_x )
147
151
148
152
# predicted outcome for counterfactual. This is given by removing the influence
149
153
# of the interaction term between the group and the post_treatment variable
@@ -162,15 +166,15 @@ def __init__(
162
166
)
163
167
if self .x_pred_counterfactual .empty :
164
168
raise ValueError ("x_pred_counterfactual is empty" )
165
- ( new_x ,) = build_design_matrices (
166
- [ self ._x_design_info ] , self .x_pred_counterfactual , return_type = "dataframe"
169
+ new_x = np . array (
170
+ design_matrices ( self .rhs_formula , self .x_pred_counterfactual ). common
167
171
)
168
172
# INTERVENTION: set the interaction term between the group and the
169
173
# post_treatment variable to zero. This is the counterfactual.
170
174
for i , label in enumerate (self .labels ):
171
175
if "post_treatment" in label and self .group_variable_name in label :
172
- new_x . iloc [:, i ] = 0
173
- self .y_pred_counterfactual = self .model .predict (np . asarray ( new_x ) )
176
+ new_x [:, i ] = 0
177
+ self .y_pred_counterfactual = self .model .predict (new_x )
174
178
175
179
# calculate causal impact
176
180
if isinstance (self .model , PyMCModel ):
0 commit comments