From adf1116dba1e523746f28f809a98cdf8aea1f9d7 Mon Sep 17 00:00:00 2001 From: Zen Date: Wed, 17 Jul 2024 16:02:36 -0700 Subject: [PATCH 1/2] [Minor] Fix use_cache flag propagation --- pyvene/models/intervenable_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyvene/models/intervenable_base.py b/pyvene/models/intervenable_base.py index e157c29e..1922b233 100644 --- a/pyvene/models/intervenable_base.py +++ b/pyvene/models/intervenable_base.py @@ -974,7 +974,7 @@ def forward( labels: Optional[torch.LongTensor] = None, output_original_output: Optional[bool] = False, return_dict: Optional[bool] = None, - use_cache: Optional[bool] = True, + use_cache: Optional[bool] = None, ): activations_sources = source_representations if sources is not None and not isinstance(sources, list): @@ -1017,7 +1017,7 @@ def forward( model_kwargs = {} if labels is not None: # for training model_kwargs["labels"] = labels - if 'use_cache' in self.model.config.to_dict(): # for transformer models + if 'use_cache' is not None and 'use_cache' in self.model.config.to_dict(): # for transformer models model_kwargs["use_cache"] = use_cache if self.mode == "parallel": From 8e75602a1207749d3652f49d540acbabdab83c65 Mon Sep 17 00:00:00 2001 From: Zen Date: Wed, 17 Jul 2024 16:12:53 -0700 Subject: [PATCH 2/2] Update intervenable_base.py --- pyvene/models/intervenable_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyvene/models/intervenable_base.py b/pyvene/models/intervenable_base.py index 1922b233..b1a579f6 100644 --- a/pyvene/models/intervenable_base.py +++ b/pyvene/models/intervenable_base.py @@ -1017,7 +1017,7 @@ def forward( model_kwargs = {} if labels is not None: # for training model_kwargs["labels"] = labels - if 'use_cache' is not None and 'use_cache' in self.model.config.to_dict(): # for transformer models + if use_cache is not None and 'use_cache' in self.model.config.to_dict(): # for transformer models model_kwargs["use_cache"] = use_cache if self.mode == "parallel":