From 7ab3cf7bde51d07deb89a87664d1b34755a3ddd9 Mon Sep 17 00:00:00 2001 From: frankaging Date: Mon, 4 Mar 2024 05:42:16 -0800 Subject: [PATCH] [Minor] Support zero_grad as a module --- pyvene/models/intervenable_base.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/pyvene/models/intervenable_base.py b/pyvene/models/intervenable_base.py index d5712765..6bbd2c1c 100644 --- a/pyvene/models/intervenable_base.py +++ b/pyvene/models/intervenable_base.py @@ -310,7 +310,7 @@ def named_parameters(self, recurse=True): if isinstance(v[0], TrainableIntervention): ret_params += [(k + '.' + n, p) for n, p in v[0].named_parameters()] return ret_params - + def get_cached_activations(self): """ Return the cached activations with keys @@ -399,6 +399,14 @@ def set_zero_grad(self): if isinstance(v[0], TrainableIntervention): v[0].zero_grad() + def zero_grad(self): + """ + The above, but for HuggingFace. + """ + for k, v in self.interventions.items(): + if isinstance(v[0], TrainableIntervention): + v[0].zero_grad() + def save( self, save_directory, save_to_hf_hub=False, hf_repo_name="my-awesome-model" ):