Skip to content

Commit

Permalink
Merge pull request #172 from stanfordnlp/zen/use_cache_fix
Browse files Browse the repository at this point in the history
[Minor] Fix use_cache flag propagation
  • Loading branch information
frankaging authored Jul 17, 2024
2 parents 5b37b24 + 8e75602 commit 8279b5b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,7 @@ def forward(
labels: Optional[torch.LongTensor] = None,
output_original_output: Optional[bool] = False,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = True,
use_cache: Optional[bool] = None,
):
activations_sources = source_representations
if sources is not None and not isinstance(sources, list):
Expand Down Expand Up @@ -1017,7 +1017,7 @@ def forward(
model_kwargs = {}
if labels is not None: # for training
model_kwargs["labels"] = labels
if 'use_cache' in self.model.config.to_dict(): # for transformer models
if use_cache is not None and 'use_cache' in self.model.config.to_dict(): # for transformer models
model_kwargs["use_cache"] = use_cache

if self.mode == "parallel":
Expand Down

0 comments on commit 8279b5b

Please sign in to comment.