Skip to content

Commit

Permalink
Merge pull request #97 from stanfordnlp/zen/gradientfix
Browse files Browse the repository at this point in the history
[Minor] Fix gradient backprop trainables with upstream interventions
  • Loading branch information
frankaging authored Jan 29, 2024
2 parents 1b261df + 6e8bdfe commit 8c0e544
Show file tree
Hide file tree
Showing 7 changed files with 377 additions and 67 deletions.
6 changes: 3 additions & 3 deletions .github/pull_request_template.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
## 📝 Description
## Description

[Describe Your Changes Here ...]

## Testing Done

[Describe Your Changes Here ...]

# Checklist:
## Checklist:

- [ ] My PR title strictly follows the format: `[Your Priority] Your Title`
- [ ] I have attached the testing log above
- [ ] I provide enough comments to my code
- [ ] I have changed documentations
- [ ] I have added tests for my changes
- [ ] I have added tests for my changes
2 changes: 1 addition & 1 deletion pyvene/models/gpt2/modelings_intervenable_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"attention_value_output": ("h[%s].attn.c_proj", CONST_INPUT_HOOK),
"head_attention_value_output": ("h[%s].attn.c_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")),
"attention_weight": ("h[%s].attn.attn_dropout", CONST_INPUT_HOOK),
"attention_output": ("h[%s].attn", CONST_OUTPUT_HOOK),
"attention_output": ("h[%s].attn.resid_dropout", CONST_OUTPUT_HOOK),
"attention_input": ("h[%s].attn", CONST_INPUT_HOOK),
"query_output": ("h[%s].attn.c_attn", CONST_OUTPUT_HOOK, (split_three, 0)),
"key_output": ("h[%s].attn.c_attn", CONST_OUTPUT_HOOK, (split_three, 1)),
Expand Down
79 changes: 31 additions & 48 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,11 +555,11 @@ def _gather_intervention_output(
self._intervention_reverse_link[representations_key]
]
else:
# cold gather
original_output = output
# data structure casting
if isinstance(output, tuple):
original_output = output[0]
original_output = output[0].clone()
else:
original_output = output.clone()
# gather subcomponent
original_output = output_to_subcomponent(
original_output,
Expand Down Expand Up @@ -788,7 +788,7 @@ def hook_callback(model, args, kwargs, output=None):
output = kwargs[list(kwargs.keys())[0]]
else:
output = args

selected_output = self._gather_intervention_output(
output, key, unit_locations_base[key_i]
)
Expand Down Expand Up @@ -838,26 +838,15 @@ def hook_callback(model, args, kwargs, output=None):
self._intervention_reverse_link[key]
] = intervened_representation.clone()

# very buggy due to tensor version
if self.model_has_grad:
# TODO: figure out how to allow this!
if isinstance(output, tuple):
raise ValueError(
"Model grad is not allowed when "
"intervening output is tuple type."
)
output_c = output.clone()
# patched in the intervned activations in-place
if isinstance(output, tuple):
_ = self._scatter_intervention_output(
output_c, intervened_representation, key, unit_locations_base[key_i]
output[0], intervened_representation, key, unit_locations_base[key_i]
)
output = output_c.clone()
else:
# patched in the intervned activations in-place
_ = self._scatter_intervention_output(
output, intervened_representation, key, unit_locations_base[key_i]
)

self._intervention_state[key].inc_setter_version()

handlers.append(module_hook(hook_callback, with_kwargs=True))
Expand Down Expand Up @@ -1073,7 +1062,6 @@ def _wait_for_forward_with_serial_intervention(
def _broadcast_unit_locations(
self,
batch_size,
intervention_group_size,
unit_locations
):
if self.mode == "parallel":
Expand All @@ -1086,33 +1074,33 @@ def _broadcast_unit_locations(
k = "sources->base"
if isinstance(v, int):
if is_base_only:
_unit_locations[k] = (None, [[[v]]*batch_size]*intervention_group_size)
_unit_locations[k] = (None, [[[v]]*batch_size]*len(self.interventions))
else:
_unit_locations[k] = (
[[[v]]*batch_size]*intervention_group_size,
[[[v]]*batch_size]*intervention_group_size
[[[v]]*batch_size]*len(self.interventions),
[[[v]]*batch_size]*len(self.interventions)
)
self.use_fast = True
elif len(v) == 2 and isinstance(v[0], int) and isinstance(v[1], int):
_unit_locations[k] = (
[[[v[0]]]*batch_size]*intervention_group_size,
[[[v[1]]]*batch_size]*intervention_group_size
[[[v[0]]]*batch_size]*len(self.interventions),
[[[v[1]]]*batch_size]*len(self.interventions)
)
self.use_fast = True
elif len(v) == 2 and v[0] == None and isinstance(v[1], int):
_unit_locations[k] = (None, [[[v[1]]]*batch_size]*intervention_group_size)
_unit_locations[k] = (None, [[[v[1]]]*batch_size]*len(self.interventions))
self.use_fast = True
elif len(v) == 2 and isinstance(v[0], int) and v[1] == None:
_unit_locations[k] = ([[[v[0]]]*batch_size]*intervention_group_size, None)
_unit_locations[k] = ([[[v[0]]]*batch_size]*len(self.interventions), None)
self.use_fast = True
elif isinstance(v, list) and get_list_depth(v) == 1:
# [0,1,2,3] -> [[[0,1,2,3]]], ...
if is_base_only:
_unit_locations[k] = (None, [[v]*batch_size]*intervention_group_size)
_unit_locations[k] = (None, [[v]*batch_size]*len(self.interventions))
else:
_unit_locations[k] = (
[[v]*batch_size]*intervention_group_size,
[[v]*batch_size]*intervention_group_size
[[v]*batch_size]*len(self.interventions),
[[v]*batch_size]*len(self.interventions)
)
self.use_fast = True
else:
Expand All @@ -1125,27 +1113,27 @@ def _broadcast_unit_locations(
for k, v in unit_locations.items():
if isinstance(v, int):
_unit_locations[k] = (
[[[v]]*batch_size]*intervention_group_size,
[[[v]]*batch_size]*intervention_group_size
[[[v]]*batch_size]*len(self.interventions),
[[[v]]*batch_size]*len(self.interventions)
)
self.use_fast = True
elif len(v) == 2 and isinstance(v[0], int) and isinstance(v[1], int):
_unit_locations[k] = (
[[[v[0]]]*batch_size]*intervention_group_size,
[[[v[1]]]*batch_size]*intervention_group_size
[[[v[0]]]*batch_size]*len(self.interventions),
[[[v[1]]]*batch_size]*len(self.interventions)
)
self.use_fast = True
elif len(v) == 2 and v[0] == None and isinstance(v[1], int):
_unit_locations[k] = (None, [[[v[1]]]*batch_size]*intervention_group_size)
_unit_locations[k] = (None, [[[v[1]]]*batch_size]*len(self.interventions))
self.use_fast = True
elif len(v) == 2 and isinstance(v[0], int) and v[1] == None:
_unit_locations[k] = ([[[v[0]]]*batch_size]*intervention_group_size, None)
_unit_locations[k] = ([[[v[0]]]*batch_size]*len(self.interventions), None)
self.use_fast = True
elif isinstance(v, list) and get_list_depth(v) == 1:
# [0,1,2,3] -> [[[0,1,2,3]]], ...
_unit_locations[k] = (
[[v]*batch_size]*intervention_group_size,
[[v]*batch_size]*intervention_group_size
[[v]*batch_size]*len(self.interventions),
[[v]*batch_size]*len(self.interventions)
)
self.use_fast = True
else:
Expand Down Expand Up @@ -1191,16 +1179,15 @@ def _broadcast_sources(
def _broadcast_subspaces(
self,
batch_size,
intervention_group_size,
subspaces
):
"""Broadcast simple subspaces input"""
_subspaces = subspaces
if isinstance(subspaces, int):
_subspaces = [[[subspaces]]*batch_size]*intervention_group_size
_subspaces = [[[subspaces]]*batch_size]*len(self.interventions)

elif isinstance(subspaces, list) and isinstance(subspaces[0], int):
_subspaces = [[subspaces]*batch_size]*intervention_group_size
_subspaces = [[subspaces]*batch_size]*len(self.interventions)
else:
# TODO: subspaces is easier to add more broadcast majic.
pass
Expand Down Expand Up @@ -1292,13 +1279,11 @@ def forward(
return self.model(**base), None

# broadcast
unit_locations = self._broadcast_unit_locations(
get_batch_size(base), len(self._intervention_group), unit_locations)
unit_locations = self._broadcast_unit_locations(get_batch_size(base), unit_locations)
sources = [None]*len(self._intervention_group) if sources is None else sources
sources = self._broadcast_sources(sources)
activations_sources = self._broadcast_source_representations(activations_sources)
subspaces = self._broadcast_subspaces(
get_batch_size(base), len(self._intervention_group), subspaces)
subspaces = self._broadcast_subspaces(get_batch_size(base), subspaces)

self._input_validation(
base,
Expand Down Expand Up @@ -1412,13 +1397,11 @@ def generate(
unit_locations = {"base": 0}

# broadcast
unit_locations = self._broadcast_unit_locations(
get_batch_size(base), len(self._intervention_group), unit_locations)
unit_locations = self._broadcast_unit_locations(get_batch_size(base), unit_locations)
sources = [None]*len(self._intervention_group) if sources is None else sources
sources = self._broadcast_sources(sources)
activations_sources = self._broadcast_source_representations(activations_sources)
subspaces = self._broadcast_subspaces(
get_batch_size(base), len(self._intervention_group), subspaces)
subspaces = self._broadcast_subspaces(get_batch_size(base), subspaces)

self._input_validation(
base,
Expand Down
8 changes: 4 additions & 4 deletions pyvene/models/interventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Intervention(torch.nn.Module):

def __init__(self, **kwargs):
super().__init__()
self.trainble = False
self.trainable = False
self.is_source_constant = False

self.use_fast = kwargs["use_fast"] if "use_fast" in kwargs else False
Expand Down Expand Up @@ -87,7 +87,7 @@ class TrainableIntervention(Intervention):

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.trainble = True
self.trainable = True
self.is_source_constant = False

def tie_weight(self, linked_intervention):
Expand Down Expand Up @@ -204,7 +204,7 @@ class VanillaIntervention(Intervention, LocalistRepresentationIntervention):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def forward(self, base, source, subspaces=None):
def forward(self, base, source, subspaces=None):
return _do_intervention_by_swap(
base,
source if self.source_representation is None else self.source_representation,
Expand Down Expand Up @@ -478,7 +478,7 @@ def __init__(self, **kwargs):
self.pca_std = torch.nn.Parameter(
torch.tensor(pca_std, dtype=torch.float32), requires_grad=False
)
self.trainble = False
self.trainable = False

def forward(self, base, source, subspaces=None):
base_norm = (base - self.pca_mean) / self.pca_std
Expand Down
Loading

0 comments on commit 8c0e544

Please sign in to comment.