Skip to content

Commit

Permalink
[Minor] Fixed sampling generating counterfactuals bug in causal_model…
Browse files Browse the repository at this point in the history
….py and introduced tests for causal model code;, updated DAS main introduction to match new causal model schematic; created notebook for MQNLI dataset exploring DAS on a nested heirarchical causal structure
  • Loading branch information
AmirZur committed Mar 14, 2024
1 parent 89b8e4f commit 92c00f1
Show file tree
Hide file tree
Showing 9 changed files with 2,317 additions and 209 deletions.
105 changes: 63 additions & 42 deletions pyvene/data_generators/causal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def generate_timesteps(self):
step += 1
for var in self.variables:
assert var in timesteps
return timesteps, step - 1
# return all timesteps and timestep of root
return timesteps, step - 2

def marginalize(self, target):
pass
Expand Down Expand Up @@ -148,9 +149,12 @@ def find_live_paths(self, intervention):
del paths[1]
return paths

def print_setting(self, total_setting):
def print_setting(self, total_setting, display=None):
labeler = lambda var: var + ": " + str(total_setting[var]) \
if display is None or display[var] \
else var
relabeler = {
var: var + ": " + str(total_setting[var]) for var in self.variables
var: labeler(var) for var in self.variables
}
G = nx.DiGraph()
G.add_edges_from(
Expand Down Expand Up @@ -227,21 +231,27 @@ def sample_input(self, mandatory=None):
total = self.run_forward(intervention=input)
return input

def sample_input_tree_balanced(self, output_var=None):
def sample_input_tree_balanced(self, output_var=None, output_var_value=None):
assert output_var is not None or len(self.outputs) == 1
if output_var is None:
output_var = self.outputs[0]
if output_var_value is None:
output_var_value = random.choice(self.values[output_var])

def create_input(var, value, input={}):
parent_values = random.choice(self.equiv_classes[var][value])
for parent in parent_values:
if parent in self.inputs:
input[parent] = parent_values[parent]
else:
create_input(parent, random.choice(self.values[parent]), input)
create_input(parent, parent_values[parent], input)
return input

return create_input(output_var, random.choice(self.values[output_var]))
input_setting = create_input(output_var, output_var_value)
for input_var in self.inputs:
if input_var not in input_setting:
input_setting[input_var] = random.choice(self.values[input_var])
return input_setting

def get_path_maxlen_filter(self, lengths):
def check_path(total_setting):
Expand Down Expand Up @@ -299,24 +309,26 @@ def generate_factual_dataset(
sampler=None,
filter=None,
device="cpu",
inputFunction=None,
outputFunction=None
return_tensors=True,
):
if inputFunction is None:
inputFunction = self.input_to_tensor
if outputFunction is None:
outputFunction = self.output_to_tensor
if sampler is None:
sampler = self.sample_input
X, y = [], []
count = 0
while count < size:

examples = []
while len(examples) < size:
example = dict()
input = sampler()
if filter is None or filter(input):
X.append(inputFunction(input))
y.append(outputFunction(self.run_forward(input)))
count += 1
return torch.stack(X).to(device), torch.stack(y).to(device)
output = self.run_forward(input)
if return_tensors:
example['input_ids'] = self.input_to_tensor(input).to(device)
example['labels'] = self.output_to_tensor(output).to(device)
else:
example['input_ids'] = input
example['labels'] = output
examples.append(example)

return examples

def generate_counterfactual_dataset(
self,
Expand All @@ -327,8 +339,7 @@ def generate_counterfactual_dataset(
intervention_sampler=None,
filter=None,
device="cpu",
inputFunction=None,
outputFunction=None
return_tensors=True,
):
maxlength = len(
[
Expand All @@ -337,17 +348,12 @@ def generate_counterfactual_dataset(
if var not in self.inputs and var not in self.outputs
]
)
if inputFunction is None:
inputFunction = self.input_to_tensor
if outputFunction is None:
outputFunction = self.output_to_tensor
if sampler is None:
sampler = self.sample_input
if intervention_sampler is None:
intervention_sampler = self.sample_intervention
examples = []
count = 0
while count < size:
while len(examples) < size:
intervention = intervention_sampler()
if filter is None or filter(intervention):
for _ in range(batch_size):
Expand All @@ -358,24 +364,39 @@ def generate_counterfactual_dataset(
for var in self.variables:
if var not in intervention:
continue
source = sampler()
sources.append(inputFunction(source))
# sample input to match sampled intervention value
source = sampler(output_var=var, output_var_value=intervention[var])
if return_tensors:
sources.append(self.input_to_tensor(source))
else:
sources.append(source)
source_dic[var] = source
for _ in range(maxlength - len(sources)):
sources.append(torch.zeros(self.input_to_tensor(sampler()).shape))
example["labels"] = outputFunction(
self.run_interchange(base, source_dic)
).to(device)
example["base_labels"] = outputFunction(
self.run_forward(base)
).to(device)
example["input_ids"] = inputFunction(base).to(device)
example["source_input_ids"] = torch.stack(sources).to(device)
example["intervention_id"] = torch.tensor(
[intervention_id(intervention)]
).to(device)
if return_tensors:
sources.append(torch.zeros(self.input_to_tensor(base).shape))
else:
sources.append({})

if return_tensors:
example["labels"] = self.output_to_tensor(
self.run_interchange(base, source_dic)
).to(device)
example["base_labels"] = self.output_to_tensor(
self.run_forward(base)
).to(device)
example["input_ids"] = self.input_to_tensor(base).to(device)
example["source_input_ids"] = torch.stack(sources).to(device)
example["intervention_id"] = torch.tensor(
[intervention_id(intervention)]
).to(device)
else:
example['labels'] = self.run_interchange(base, source_dic)
example['base_labels'] = self.run_forward(base)
example['input_ids'] = base
example['source_input_ids'] = sources
example['intervention_id'] = [intervention_id(intervention)]

examples.append(example)
count += 1
return examples


Expand Down
208 changes: 208 additions & 0 deletions tests/unit_tests/CausalModelTestCase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
import unittest
import random
import torch
from pyvene import CausalModel
random.seed(42)


class CasualModelTestCase(unittest.TestCase):
@classmethod
def setUpClass(self):
print("=== Test Suite: CausalModelTestCase ===")
self.variables = ['A', 'B', 'C']
self.values = {
'A': [False, True],
'B': [False, True],
'C': [False, True]
}

self.parents = {
'A': [],
'B': [],
'C': ['A', 'B']
}

self.functions = {
"A": lambda: True,
"B": lambda: True,
"C": lambda a, b: a and b
}

self.causal_model = CausalModel(
self.variables,
self.values,
self.parents,
self.functions
)

def test_initialization(self):
inputs = ['A', 'B']
outputs = ['C']
timesteps = {
'A': 0,
'B': 0,
'C': 1
}
equivalence_classes = {
'C': {
False: [
{'A': False, 'B': False},
{'A': False, 'B': True},
{'A': True, 'B': False}
],
True: [
{'A': True, 'B': True}
]
}
}

self.assertEqual(set(self.causal_model.inputs), set(inputs))
self.assertEqual(set(self.causal_model.outputs), set(outputs))
self.assertEqual(self.causal_model.timesteps, timesteps)
self.assertEqual(self.causal_model.equiv_classes, equivalence_classes)

def test_run_forward(self):
# test run forward with default values (A and B set to True)
self.assertEqual(
self.causal_model.run_forward(),
{'A': True, 'B': True, 'C': True}
)

# test run forward on all possible input values
for a in [False, True]:
for b in [False, True]:
input_setting = {
'A': a,
'B': b
}
output_setting = {
'A': a,
'B': b,
'C': a and b
}
self.assertEqual(self.causal_model.run_forward(input_setting), output_setting)

# test run forward on fully specified setting
output_setting = {'A': False, 'B': False, 'C': True}
self.assertEqual(self.causal_model.run_forward(output_setting), output_setting)

def test_run_interchange(self):
# interchange intervention on input
base = {'A': True, 'B': False}
source = {'A': False, 'B': True}
self.assertEqual(self.causal_model.run_forward(base)['C'], False)
self.assertEqual(self.causal_model.run_forward(source)['C'], False)
self.assertEqual(
self.causal_model.run_interchange(base, {'B': source})['C'],
True
)

# interchange intervention on output
base = {'A': False, 'B': False}
source = {'A': True, 'B': True}
self.assertEqual(self.causal_model.run_forward(base)['C'], False)
self.assertEqual(
self.causal_model.run_interchange(base, {'B': source})['C'],
False
)
self.assertEqual(
self.causal_model.run_interchange(base, {'C': source})['C'],
True
)

def test_sample_input_tree_balanced(self):
# NOTE: not quite sure how to test a function with random behavior
# right now, fixing seed and assuming approximate behavior
# (taking balanced to be less than 30-70 split)

K = 100
# test sampling by output value
outputs = []
for _ in range(K):
sample = self.causal_model.sample_input_tree_balanced()
output = self.causal_model.run_forward(sample)
outputs.append(output['C'])
self.assertGreaterEqual(sum(outputs), 30)
self.assertLessEqual(sum(outputs), 70)

# test sampling by input value
inputs = []
for _ in range(K):
sample = self.causal_model.sample_input_tree_balanced()
inputs.append(sample['A'])
self.assertGreaterEqual(sum(outputs), 30)
self.assertLessEqual(sum(outputs), 70)

def test_generate_factual_dataset(self):
def sampler():
return {'A': False, 'B': False}

size = 4
factual_dataset = self.causal_model.generate_factual_dataset(
size=size,
sampler=sampler,
return_tensors=False
)
self.assertEqual(len(factual_dataset), size)

self.assertEqual(factual_dataset[0]['input_ids'], {'A': False, 'B': False})
self.assertEqual(factual_dataset[0]['labels']['C'], False)

factual_dataset_tensors = self.causal_model.generate_factual_dataset(
size=size,
sampler=sampler,
return_tensors=True
)
self.assertEqual(len(factual_dataset_tensors), size)
X = torch.stack([example['input_ids'] for example in factual_dataset_tensors])
y = torch.stack([example['labels'] for example in factual_dataset_tensors])
self.assertEqual(X.shape, (size, 2))
self.assertEqual(y.shape, (size, 1))
self.assertTrue(torch.equal(X[0], torch.tensor([0., 0.])))
self.assertTrue(torch.equal(y[0], torch.tensor([0.])))

def test_generate_counterfactual_dataset(self):
def sampler(*args, **kwargs):
if kwargs.get('output_var', None):
return {'A': True, 'B': True}

return {'A': True, 'B': False}

def intervention_sampler(*args, **kwargs):
return {'B': True}

def intervention_id(*args, **kwargs):
return 0

size = 4
counterfactual_dataset = self.causal_model.generate_counterfactual_dataset(
size=size,
batch_size=1,
intervention_id=intervention_id,
sampler=sampler,
intervention_sampler=intervention_sampler,
return_tensors=False
)
self.assertEqual(len(counterfactual_dataset), size)
example = counterfactual_dataset[0]
self.assertEqual(example['input_ids'], {'A': True, 'B': False})
self.assertEqual(example['source_input_ids'][0]['B'], True)
self.assertEqual(example['intervention_id'], [0])
self.assertEqual(example['base_labels']['C'], False) # T and F
self.assertEqual(example['labels']['C'], True) # T and T


def suite():
suite = unittest.TestSuite()
suite.addTest(CasualModelTestCase("test_initialization"))
suite.addTest(CasualModelTestCase("test_run_forward"))
suite.addTest(CasualModelTestCase("test_run_interchange"))
suite.addTest(CasualModelTestCase("test_sample_input_tree_balanced"))
suite.addTest(CasualModelTestCase("test_generate_factual_dataset"))
suite.addTest(CasualModelTestCase("test_generate_counterfactual_dataset"))
return suite


if __name__ == "__main__":
runner = unittest.TextTestRunner()
runner.run(suite())
Loading

0 comments on commit 92c00f1

Please sign in to comment.