Skip to content

Commit

Permalink
[Minor] Allow other field passing in generate besides input_ids
Browse files Browse the repository at this point in the history
  • Loading branch information
frankaging committed Mar 5, 2024
1 parent fcf4870 commit e32bc0a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1496,7 +1496,7 @@ def generate(
base_outputs = None
if output_original_output:
# returning un-intervened output
base_outputs = self.model.generate(inputs=base["input_ids"], **kwargs)
base_outputs = self.model.generate(**base, **kwargs)

set_handlers_to_remove = None
try:
Expand All @@ -1522,7 +1522,7 @@ def generate(

# run intervened generate
counterfactual_outputs = self.model.generate(
inputs=base["input_ids"], **kwargs
**base, **kwargs
)

collected_activations = []
Expand Down
10 changes: 5 additions & 5 deletions pyvene_101.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
}
],
"source": [
"import torch\n",
"import pyvene as pv\n",
"\n",
"_, tokenizer, gpt2 = pv.create_gpt2()\n",
Expand Down Expand Up @@ -222,6 +223,7 @@
}
],
"source": [
"import torch\n",
"import pyvene as pv\n",
"\n",
"_, tokenizer, gpt2 = pv.create_gpt2()\n",
Expand Down Expand Up @@ -262,6 +264,7 @@
}
],
"source": [
"import torch\n",
"import copy\n",
"import pyvene as pv\n",
"\n",
Expand Down Expand Up @@ -1068,17 +1071,14 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 4,
"id": "f718e2d6",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/sailhome/wuzhengx/.local/lib/python3.8/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
" return self.fget.__get__(instance, owner)()\n",
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
]
},
Expand Down Expand Up @@ -2647,7 +2647,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
"version": "3.10.13"
},
"toc-autonumbering": true,
"toc-showcode": false,
Expand Down

0 comments on commit e32bc0a

Please sign in to comment.