|
18 | 18 | class TestDowhy(unittest.TestCase):
|
19 | 19 |
|
20 | 20 | def _get_data(self):
|
21 |
| - X = np.random.normal(0, 1, size=(500, 5)) |
22 |
| - T = np.random.binomial(1, .5, size=(500,)) |
23 |
| - Y = np.random.normal(0, 1, size=(500,)) |
24 |
| - Z = np.random.normal(0, 1, size=(500,)) |
| 21 | + X = np.random.normal(0, 1, size=(250, 5)) |
| 22 | + T = np.random.binomial(1, .5, size=(250,)) |
| 23 | + Y = np.random.normal(0, 1, size=(250,)) |
| 24 | + Z = np.random.normal(0, 1, size=(250,)) |
25 | 25 | return Y, T, X[:, [0]], X[:, 1:], Z
|
26 | 26 |
|
27 | 27 | def test_dowhy(self):
|
@@ -65,7 +65,7 @@ def clf():
|
65 | 65 | # test causal graph
|
66 | 66 | est_dowhy.view_model()
|
67 | 67 | # test refutation estimate
|
68 |
| - est_dowhy.refute_estimate(method_name="random_common_cause") |
| 68 | + est_dowhy.refute_estimate(method_name="random_common_cause", num_simulations=3) |
69 | 69 | if name != "orf":
|
70 | 70 | est_dowhy.refute_estimate(method_name="add_unobserved_common_cause",
|
71 | 71 | confounders_effect_on_treatment="binary_flip",
|
|
0 commit comments