Skip to content

Commit

Permalink
Merge pull request #153 from aryamanarora/main
Browse files Browse the repository at this point in the history
[P2] Save/load trainable params in `IntervenableBase` methods
  • Loading branch information
frankaging authored Jul 12, 2024
2 parents e1e8b39 + 185a412 commit 3ec61ea
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,7 +1128,8 @@ def _cleanup_states(self, skip_activation_gc=False):
self._batched_setter_activation_select = {}

def save(
self, save_directory, save_to_hf_hub=False, hf_repo_name="my-awesome-model"
self, save_directory, save_to_hf_hub=False, hf_repo_name="my-awesome-model",
include_model=False
):
"""
Save interventions to disk or hub
Expand Down Expand Up @@ -1205,6 +1206,15 @@ def save(
else:
saving_config.intervention_dimensions += [intervention.interchange_dim.tolist()]
saving_config.intervention_constant_sources += [intervention.is_source_constant]

# save model's trainable parameters as well
if include_model:
model_state_dict = {}
model_binary_filename = "pytorch_model.bin"
for n, p in self.model.named_parameters():
if p.requires_grad:
model_state_dict[n] = p
torch.save(model_state_dict, os.path.join(save_directory, model_binary_filename))

# save metadata config
saving_config.save_pretrained(save_directory)
Expand All @@ -1225,7 +1235,10 @@ def save(
)

@staticmethod
def load(load_directory, model, local_directory=None, from_huggingface_hub=False):
def load(
load_directory, model, local_directory=None, from_huggingface_hub=False,
include_model=False
):
"""
Load interventions from disk or hub
"""
Expand Down Expand Up @@ -1279,6 +1292,12 @@ def load(load_directory, model, local_directory=None, from_huggingface_hub=False
saved_state_dict = torch.load(os.path.join(load_directory, binary_filename))
intervention.load_state_dict(saved_state_dict)

# load model's trainable parameters as well
if include_model:
model_binary_filename = "pytorch_model.bin"
saved_model_state_dict = torch.load(os.path.join(load_directory, model_binary_filename))
intervenable.model.load_state_dict(saved_model_state_dict, strict=False)

return intervenable

def save_intervention(self, save_directory, include_model=True):
Expand Down

0 comments on commit 3ec61ea

Please sign in to comment.