Skip to content

Commit

Permalink
noop guards
Browse files Browse the repository at this point in the history
  • Loading branch information
nathankim7 committed Apr 22, 2024
1 parent 78ba2f6 commit 0cdd268
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 9 deletions.
8 changes: 3 additions & 5 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,6 @@ def _cleanup_states(self, skip_activation_gc=False):
"""
Clean up all old in memo states of interventions
"""
self._skip_forward = False
self._remove_forward_hooks()
self._reset_hook_count()
if not skip_activation_gc:
Expand Down Expand Up @@ -857,8 +856,7 @@ def _intervention_setter(

def hook_callback(model, args, kwargs, output=None):
if (
not self.is_model_stateless
and self._skip_forward
self._skip_forward
and state.setter_timestep <= 0
):
state.setter_timestep += 1
Expand All @@ -878,11 +876,11 @@ def hook_callback(model, args, kwargs, output=None):
# in this code we assume that output is batched along its first axis.
int_unit_loc = (
unit_locations_base[key_i]
if state.setter_timestep <= 0 or not timestep_selector
if state.setter_timestep <= 0
else [
(
[0]
if timestep_selector[key_i](
if timestep_selector != None and timestep_selector[key_i](
state.setter_timestep, output[i]
)
else None
Expand Down
93 changes: 90 additions & 3 deletions tests/integration_tests/GenerationInterventionTestCase.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def setUpClass(cls):
cls.device = DEVICE

cls.config, cls.tokenizer, cls.tinystory = pv.create_gpt_neo()
cls.tinystory.to(cls.device)

@classmethod
def tearDownClass(cls):
Expand Down Expand Up @@ -65,7 +66,7 @@ def test_lm_generation(self):

prompt = tokenizer("Once upon a time there was", return_tensors="pt")
_, intervened_story = pv_tinystory.generate(
prompt, source_representations=emb_happy, max_length=32
prompt, source_representations=emb_happy, unit_locations={"sources->base": (0, [0, 1, 2])}, max_length=32
)
print(tokenizer.decode(intervened_story[0], skip_special_tokens=True))

Expand All @@ -81,7 +82,7 @@ def test_generation_with_source_intervened_prompt(self):
}
for l in range(self.config.num_layers)
],
model=self.tinystory.to(self.device),
model=self.tinystory,
)

prompt = self.tokenizer("Once upon a time there was", return_tensors="pt").to(
Expand Down Expand Up @@ -118,7 +119,7 @@ def test_dynamic_static_generation_intervention_parity(self):
}
for l in range(self.config.num_layers)
],
model=self.tinystory.to(self.device),
model=self.tinystory,
)

prompt = self.tokenizer("Once upon a time there was", return_tensors="pt").to(
Expand All @@ -144,6 +145,92 @@ def test_dynamic_static_generation_intervention_parity(self):
orig_text != intervened_text
), "Aggressive intervention did not change the output. Probably something wrong."

def test_generation_noops(self):
torch.manual_seed(0)

# No-op intervention
pv_model = pv.IntervenableModel(
[
{
"layer": l,
"component": "mlp_output",
"intervention": lambda b, s: b,
}
for l in range(self.config.num_layers)
],
model=self.tinystory,
)

prompt = self.tokenizer("Once upon a time there was", return_tensors="pt").to(
self.device
)
sources = self.tokenizer(" love", return_tensors="pt").to(self.device)

orig, intervened = pv_model.generate(
prompt,
max_length=20,
sources=sources,
intervene_on_prompt=True,
unit_locations={"sources->base": (0, [0, 1, 2])},
output_original_output=True,
)
orig_text, intervened_text = (
self.tokenizer.decode(orig[0], skip_special_tokens=True),
self.tokenizer.decode(intervened[0], skip_special_tokens=True),
)

print(intervened_text)
assert (
orig_text == intervened_text
), "No-op intervention changed the output. Probably something wrong."

# Aggressive intervention with intervene_on_prompt=False
aggressive_model = pv.IntervenableModel(
[
{
"layer": l,
"component": "mlp_output",
"intervention": lambda b, s: s * 1000,
}
for l in range(self.config.num_layers)
],
model=self.tinystory,
)

orig, intervened = aggressive_model.generate(
prompt,
max_length=20,
sources=sources,
intervene_on_prompt=False,
output_original_output=True,
)

orig_text, intervened_text = (
self.tokenizer.decode(orig[0], skip_special_tokens=True),
self.tokenizer.decode(intervened[0], skip_special_tokens=True),
)
print(orig_text)
print(intervened_text)
assert (
orig_text == intervened_text
), "Aggressive intervention changed the output. Probably something wrong."

# Aggressive intervention with no prompt intervention, disabled selectors
orig, intervened = aggressive_model.generate(
prompt,
max_length=20,
sources=sources,
intervene_on_prompt=False,
output_original_output=True,
timestep_selector=[lambda idx, o: False] * self.config.num_layers,
)
orig_text, intervened_text = (
self.tokenizer.decode(orig[0], skip_special_tokens=True),
self.tokenizer.decode(intervened[0], skip_special_tokens=True),
)
assert (
orig_text == intervened_text
), "Aggressive intervention changed the output. Probably something wrong."

if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion tests/unit_tests/InterventionUtilsTestCase.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def test_low_rank_gradient_positive(self):
loss = F.mse_loss(output, golden)
loss.backward()
optimizer.step()
print(output)

self.assertTrue(torch.allclose(golden, output, rtol=1e-02, atol=1e-02))
except:
pass # retry
Expand Down

0 comments on commit 0cdd268

Please sign in to comment.