diff --git a/pyvene/models/intervenable_base.py b/pyvene/models/intervenable_base.py index d5712765..6bbd2c1c 100644 --- a/pyvene/models/intervenable_base.py +++ b/pyvene/models/intervenable_base.py @@ -310,7 +310,7 @@ def named_parameters(self, recurse=True): if isinstance(v[0], TrainableIntervention): ret_params += [(k + '.' + n, p) for n, p in v[0].named_parameters()] return ret_params - + def get_cached_activations(self): """ Return the cached activations with keys @@ -399,6 +399,14 @@ def set_zero_grad(self): if isinstance(v[0], TrainableIntervention): v[0].zero_grad() + def zero_grad(self): + """ + The above, but for HuggingFace. + """ + for k, v in self.interventions.items(): + if isinstance(v[0], TrainableIntervention): + v[0].zero_grad() + def save( self, save_directory, save_to_hf_hub=False, hf_repo_name="my-awesome-model" ):