Skip to content

Commit

Permalink
save/load model params too (for classifier heads)
Browse files Browse the repository at this point in the history
  • Loading branch information
aryamanarora committed May 2, 2024
1 parent 1dc9243 commit 2219d82
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,8 @@ def zero_grad(self):
v[0].zero_grad()

def save(
self, save_directory, save_to_hf_hub=False, hf_repo_name="my-awesome-model"
self, save_directory, save_to_hf_hub=False, hf_repo_name="my-awesome-model",
include_model=True
):
"""
Save interventions to disk or hub
Expand Down Expand Up @@ -493,6 +494,15 @@ def save(
else:
saving_config.intervention_dimensions += [intervention.interchange_dim.tolist()]
saving_config.intervention_constant_sources += [intervention.is_source_constant]

# save model's trainable parameters as well
if include_model:
model_state_dict = {}
model_binary_filename = "pytorch_model.bin"
for n, p in self.model.named_parameters():
if p.requires_grad:
model_state_dict[n] = p
torch.save(model_state_dict, os.path.join(save_directory, model_binary_filename))

# save metadata config
saving_config.save_pretrained(save_directory)
Expand All @@ -513,7 +523,10 @@ def save(
)

@staticmethod
def load(load_directory, model, local_directory=None, from_huggingface_hub=False):
def load(
load_directory, model, local_directory=None, from_huggingface_hub=False,
include_model=True
):
"""
Load interventions from disk or hub
"""
Expand Down Expand Up @@ -567,6 +580,12 @@ def load(load_directory, model, local_directory=None, from_huggingface_hub=False
saved_state_dict = torch.load(os.path.join(load_directory, binary_filename))
intervention.load_state_dict(saved_state_dict)

# load model's trainable parameters as well
if include_model:
model_binary_filename = "pytorch_model.bin"
saved_model_state_dict = torch.load(os.path.join(load_directory, model_binary_filename))
intervenable.model.load_state_dict(saved_model_state_dict, strict=False)

return intervenable

def save_intervention(self, save_directory, include_model=True):
Expand Down

0 comments on commit 2219d82

Please sign in to comment.