Skip to content

Commit

Permalink
Merge pull request #126 from stanfordnlp/zen/interventionloading
Browse files Browse the repository at this point in the history
[Minor] Loading interventions as non-static method
  • Loading branch information
frankaging authored Mar 7, 2024
2 parents 83fe20f + cbdc459 commit 586ace8
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,19 @@ def load(load_directory, model, local_directory=None, from_huggingface_hub=False

return intervenable

def load_intervention(self, load_directory):
"""
Instead of creating an new object, this function loads existing weights onto
the current object. This is not a static method, and returns nothing.
"""
# load binary files
for i, (k, v) in enumerate(self.interventions.items()):
intervention = v[0]
binary_filename = f"intkey_{k}.bin"
if isinstance(intervention, TrainableIntervention):
saved_state_dict = torch.load(os.path.join(load_directory, binary_filename))
intervention.load_state_dict(saved_state_dict)

def _gather_intervention_output(
self, output, representations_key, unit_locations
) -> torch.Tensor:
Expand Down

0 comments on commit 586ace8

Please sign in to comment.