Skip to content

Commit

Permalink
Merge pull request #91 from atticusg/main
Browse files Browse the repository at this point in the history
 Update datagenerators to support tokenizing for LMs
  • Loading branch information
atticusg authored Jan 24, 2024
2 parents 243d7a3 + c7b44b6 commit cdeab77
Showing 1 changed file with 28 additions and 10 deletions.
38 changes: 28 additions & 10 deletions pyvene/data_generators/causal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def check_path(total_setting):

return check_path

def inputToTensor(self, setting):
def input_to_tensor(self, setting):
result = []
for input in self.inputs:
temp = torch.tensor(setting[input]).float()
Expand All @@ -284,7 +284,7 @@ def inputToTensor(self, setting):
result.append(temp)
return torch.cat(result)

def outputToTensor(self, setting):
def output_to_tensor(self, setting):
result = []
for output in self.outputs:
temp = torch.tensor(float(setting[output]))
Expand All @@ -293,16 +293,28 @@ def outputToTensor(self, setting):
result.append(temp)
return torch.cat(result)

def generate_factual_dataset(self, size, sampler=None, filter=None, device="cpu"):
def generate_factual_dataset(
self,
size,
sampler=None,
filter=None,
device="cpu",
inputFunction=None,
outputFunction=None
):
if inputFunction is None:
inputFunction = self.input_to_tensor
if outputFunction is None:
outputFunction = self.output_to_tensor
if sampler is None:
sampler = self.sample_input
X, y = [], []
count = 0
while count < size:
input = sampler()
if filter is None or filter(input):
X.append(self.inputToTensor(input))
y.append(self.outputToTensor(self.run_forward(input)))
X.append(inputFunction(input))
y.append(outputFunction(self.run_forward(input)))
count += 1
return torch.stack(X).to(device), torch.stack(y).to(device)

Expand All @@ -315,6 +327,8 @@ def generate_counterfactual_dataset(
intervention_sampler=None,
filter=None,
device="cpu",
inputFunction=None,
outputFunction=None
):
maxlength = len(
[
Expand All @@ -323,6 +337,10 @@ def generate_counterfactual_dataset(
if var not in self.inputs and var not in self.outputs
]
)
if inputFunction is None:
inputFunction = self.input_to_tensor
if outputFunction is None:
outputFunction = self.output_to_tensor
if sampler is None:
sampler = self.sample_input
if intervention_sampler is None:
Expand All @@ -341,17 +359,17 @@ def generate_counterfactual_dataset(
if var not in intervention:
continue
source = sampler()
sources.append(self.inputToTensor(source))
sources.append(inputFunction(source))
source_dic[var] = source
for _ in range(maxlength - len(sources)):
sources.append(torch.zeros(self.inputToTensor(sampler()).shape))
example["labels"] = self.outputToTensor(
sources.append(torch.zeros(self.input_to_tensor(sampler()).shape))
example["labels"] = outputFunction(
self.run_interchange(base, source_dic)
).to(device)
example["base_labels"] = self.outputToTensor(
example["base_labels"] = outputFunction(
self.run_forward(base)
).to(device)
example["input_ids"] = self.inputToTensor(base).to(device)
example["input_ids"] = inputFunction(base).to(device)
example["source_input_ids"] = torch.stack(sources).to(device)
example["intervention_id"] = torch.tensor(
[intervention_id(intervention)]
Expand Down

0 comments on commit cdeab77

Please sign in to comment.