-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcausal_model.py
343 lines (313 loc) · 13.8 KB
/
causal_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
import random
import copy
import inspect
import itertools
import torch
from collections import defaultdict
import networkx as nx
import matplotlib.pyplot as plt
class CausalModel:
def __init__(self,
variables,
values,
parents,
functions,
timesteps = None,
equiv_classes=None,
pos = {}):
self.variables = variables
self.variables.sort()
self.values= values
self.parents = parents
self.children = {var:[] for var in variables}
for variable in variables:
assert variable in self.parents
for parent in self.parents[variable]:
self.children[parent].append(variable)
self.functions = functions
self.start_variables = []
self.timesteps = timesteps
for variable in self.variables:
assert variable in self.values
assert variable in self.children
assert variable in self.functions
assert len(inspect.getfullargspec(self.functions[variable])[0]) == len(self.parents[variable])
if timesteps is not None:
assert variable in timesteps
for variable2 in copy.copy(self.variables):
if variable2 in self.parents[variable]:
assert variable in self.children[variable2]
if timesteps is not None:
assert timesteps[variable2] < timesteps[variable]
if variable2 in self.children[variable]:
assert variable in parents[variable2]
if timesteps is not None:
assert timesteps[variable2] > timesteps[variable]
if len(self.parents) == 0:
self.start_variables.add(variable)
self.inputs = [ var for var in self.variables if len(parents[var])==0]
self.outputs = copy.deepcopy(variables)
for child in variables:
for parent in parents[child]:
if parent in self.outputs:
self.outputs.remove(parent)
if self.timesteps is not None:
self.timesteps = timesteps
else:
self.timesteps,self.end_time = self.generate_timesteps()
for output in self.outputs:
self.timesteps[output] = self.end_time
self.variables.sort(key=lambda x: self.timesteps[x])
self.run_forward()
self.pos = pos
width = {_:0 for _ in range(len(self.variables))}
if self.pos == None:
self.pos = dict()
for var in self.variables:
if var not in pos:
pos[var] = (width[self.timesteps[var]], self.timesteps[var])
width[self.timesteps[var]] += 1
if equiv_classes is not None:
self.equiv_classes = equiv_classes
else:
self.equiv_classes = {}
for var in self.variables:
if var in self.inputs or var in self.equiv_classes:
continue
self.equiv_classes[var] = {val:[] for val in self.values[var]}
for parent_values in itertools.product(*[self.values[par] for par in self.parents[var]]):
value = self.functions[var](*parent_values)
self.equiv_classes[var][value].append({par:parent_values[i] for i,par in enumerate(self.parents[var])})
def generate_timesteps(self):
timesteps = {input:0 for input in self.inputs}
step = 1
change = True
while change:
change = False
copytimesteps = copy.deepcopy(timesteps)
for parent in timesteps:
if timesteps[parent] == step-1:
for child in self.children[parent]:
copytimesteps[child] = step
change = True
timesteps = copytimesteps
step += 1
for var in self.variables:
assert var in timesteps
return timesteps, step-1
def marginalize(self, target):
#TODO
for var in tumor:
return None
def print_structure(self,pos=None):
G = nx.DiGraph()
G.add_edges_from([(parent,child) for child in self.variables for parent in self.parents[child]])
plt.figure(figsize=(10,10))
edges = nx.draw_networkx(G, with_labels = True, node_color ='green', pos = self.pos)
plt.show()
def find_live_paths(self, intervention):
actual_setting = self.run_forward(intervention)
paths = {1:[[variable] for variable in self.variables]}
step = 2
while True:
paths[step] = []
for path in paths[step-1]:
for child in self.children[path[-1]]:
actual_cause = False
for value in self.values[path[-1]]:
newintervention = copy.deepcopy(intervention)
newintervention[path[-1]] = value
counterfactual_setting = self.run_forward(newintervention)
if counterfactual_setting[child] != actual_setting[child]:
actual_cause = True
if actual_cause:
paths[step].append(copy.deepcopy(path)+[child])
if len(paths[step]) == 0:
break
step += 1
del paths[1]
return paths
def print_setting(self,total_setting):
relabeler = {var: var + ": " + str(total_setting[var]) for var in self.variables}
G = nx.DiGraph()
G.add_edges_from([(parent,child) for child in self.variables for parent in self.parents[child]])
plt.figure(figsize=(10,10))
G = nx.relabel_nodes(G, relabeler)
newpos = dict()
if self.pos is not None:
for var in self.pos:
newpos[relabeler[var]] = self.pos[var]
edges = nx.draw_networkx(G, with_labels = True, node_color ='green', pos = newpos)
plt.show()
def run_forward(self, intervention = None):
total_setting = defaultdict(None)
length = len(list(total_setting.keys()))
step = 0
while length != len(self.variables):
for variable in self.variables:
for variable2 in self.parents[variable]:
if variable2 not in total_setting:
continue
if intervention is not None and variable in intervention:
total_setting[variable] = intervention[variable]
else:
total_setting[variable] = self.functions[variable](*[total_setting[parent] for parent in self.parents[variable]])
length = len(list(total_setting.keys()))
return total_setting
def run_interchange(self, input, source_interventions):
interchange_intervention = input
for var in source_interventions:
setting = self.run_forward(source_interventions[var])
interchange_intervention[var] = setting[var]
return self.run_forward(interchange_intervention)
def add_variable(self, variable, values, parents, children, function, timestep=None):
if timestep is not None:
assert self.timesteps is not None
self.timesteps[variable] = timestep
for parent in parents:
assert parent in self.variables
for child in children:
assert child in self.variables
self.parents[variable] = parents
self.children[variable] = children
self.values[variable] = values
self.functions[variable] = function
def sample_intervention(self, mandatory=None):
intervention = {}
while len(intervention.keys()) == 0:
for var in self.variables:
if var in self.inputs or var in self.outputs:
continue
if random.choice([0,1]) == 0:
intervention[var] = random.choice(self.values[var])
return intervention
def sample_input(self, mandatory=None):
input = {var: random.sample(self.values[var],1)[0] for var in self.inputs}
total = self.run_forward(intervention=input)
while mandatory is not None and not mandatory(total):
input = {var: random.sample(self.values[var],1)[0] for var in self.inputs}
total = self.run_forward(intervention=input)
return input
def sample_input_tree_balanced(self, output_var = None):
assert output_var is not None or len(self.outputs) == 1
if output_var is None:
output_var = self.outputs[0]
def create_input(var,value, input = {}):
parent_values = random.choice(self.equiv_classes[var][value])
for parent in parent_values:
if parent in self.inputs:
input[parent] = parent_values[parent]
else:
create_input(parent, random.choice(self.values[parent]), input)
return input
return create_input(output_var, random.choice(self.values[output_var]))
def get_path_maxlen_filter(self,lengths):
def check_path(total_setting):
input = {var:total_setting[var] for var in self.inputs}
paths = self.find_live_paths(input)
m = max([l for l in paths.keys() if len(paths[l]) != 0])
if m in lengths:
return True
return False
return check_path
def get_partial_filter(self,partial_setting):
def compare(total_setting):
for var in partial_setting:
if total_setting[var] != partial_setting[var]:
return False
return True
return compare
def get_specific_path_filter(self,start,end):
def check_path(total_setting):
input = {var:total_setting[var] for var in self.inputs}
paths = self.find_live_paths(input)
for k in paths:
for path in paths[k]:
if path[0]==start and path[-1]==end:
return True
return False
return check_path
def inputToTensor(self, setting):
result = []
for input in self.inputs:
temp = torch.tensor(setting[input]).float()
if len(temp.size()) == 0:
temp = torch.reshape(temp,(1,))
result.append(temp)
return torch.cat(result)
def outputToTensor(self, setting):
result = []
for output in self.outputs:
temp = torch.tensor(float(setting[output]))
if len(temp.size()) == 0:
temp = torch.reshape(temp,(1,))
result.append(temp)
return torch.cat(result)
def generate_factual_dataset(self, size, sampler=None, filter=None):
if sampler is None:
sampler = self.sample_input
X,y = [],[]
count = 0
while count < size:
input = sampler()
if filter is None or filter(input):
X.append(self.inputToTensor(input))
y.append(self.outputToTensor(self.run_forward(input)))
count += 1
return torch.stack(X), torch.stack(y)
def generate_counterfactual_dataset(self,
size,
intervention_id,
batch_size,
sampler=None,
intervention_sampler=None,
filter=None):
maxlength = len([var for var in self.variables if var not in self.inputs and var not in self.outputs ])
if sampler is None:
sampler = self.sample_input
if intervention_sampler is None:
intervention_sampler = self.sample_intervention
bases, y, sourceses, yii,interventions = [], [], [], [], []
count = 0
while count < size:
intervention = intervention_sampler()
if filter is None or filter(intervention):
for _ in range(batch_size):
base = sampler()
sources = []
source_dic = {}
for var in self.variables:
if var not in intervention:
continue
source = sampler()
sources.append(self.inputToTensor(source))
source_dic[var] = source
for _ in range(maxlength - len(sources)):
sources.append(torch.zeros(self.inputToTensor(self.sample_input()).shape))
y.append(self.outputToTensor(self.run_forward(base)))
yii.append(self.outputToTensor(self.run_interchange(base, source_dic)))
bases.append(self.inputToTensor(base))
sources = torch.stack(sources)
sourceses.append(sources)
interventions.append(torch.tensor([intervention_id(intervention)]))
count += 1
return torch.stack(bases), torch.stack(y), torch.stack(sourceses), torch.stack(yii), torch.stack(interventions)
def simple_example():
variables = ["A", "B", "C"]
values= {variable:[True, False] for variable in variables}
parents = {"A":[], "B":[], "C":["A", "B"]}
def A():
return True
def B():
return False
def C(a,b):
return a and b
functions = {"A": A, "B": B, "C": C}
model = CausalModel(variables, values, parents, functions)
model.print_structure()
print("No intervention:\n", model.run_forward(), "\n")
model.print_setting(model.run_forward())
print("Intervention setting A and B to TRUE:\n", model.run_forward({"A":True, "B":True}))
print("Timesteps:", model.timesteps)
if __name__ == '__main__':
simple_example()