Skip to content

Commit

Permalink
[P0] Prohibit SDPA Attention Use in Pyvene
Browse files Browse the repository at this point in the history
Until we find a permanent solution
  • Loading branch information
PinetreePantry committed Jan 28, 2025
1 parent 7a7a96a commit 22d4d60
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 2 deletions.
10 changes: 10 additions & 0 deletions pyvene/models/gpt2/modelings_intervenable_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def create_gpt2(name="gpt2", cache_dir=None):
from transformers import GPT2Model, GPT2Tokenizer, GPT2Config

config = GPT2Config.from_pretrained(name)
if hasattr(config, '_attn_implementation'):
config._attn_implementation = "eager"
tokenizer = GPT2Tokenizer.from_pretrained(name)
gpt = GPT2Model.from_pretrained(name, config=config, cache_dir=cache_dir)
print("loaded model")
Expand All @@ -90,8 +92,12 @@ def create_gpt2_lm(name="gpt2", config=None, cache_dir=None):
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
if config is None:
config = GPT2Config.from_pretrained(name)
if hasattr(config, '_attn_implementation'):
config._attn_implementation = "eager"
gpt = GPT2LMHeadModel.from_pretrained(name, config=config, cache_dir=cache_dir)
else:
if hasattr(config, '_attn_implementation'):
config._attn_implementation = "eager"
gpt = GPT2LMHeadModel(config=config)
print("loaded model")
return config, tokenizer, gpt
Expand All @@ -103,8 +109,12 @@ def create_gpt2_classifier(name="gpt2", config=None, cache_dir=None):
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
if config is None:
config = GPT2Config.from_pretrained(name)
if hasattr(config, '_attn_implementation'):
config._attn_implementation = "eager"
gpt = GPT2LMForSequenceClassification.from_pretrained(name, config=config, cache_dir=cache_dir)
else:
if hasattr(config, '_attn_implementation'):
config._attn_implementation = "eager"
gpt = GPT2LMForSequenceClassification(config=config)
print("loaded model")
return config, tokenizer, gpt
Empty file added pyvene/models/qwen2/__init__.py
Empty file.
5 changes: 3 additions & 2 deletions pyvene_101.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
"\n",
"model_name = \"gpt2\"\n",
"gpt2 = AutoModelForCausalLM.from_pretrained(model_name)\n",
"# Do not use SDPA attention because we cannot hook to attn_dropout\n",
"gpt2 = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation=\"eager\")\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"\n",
"pv_gpt2 = pv.IntervenableModel({\n",
Expand Down Expand Up @@ -3032,7 +3033,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.11.9"
},
"toc-autonumbering": true,
"toc-showcode": false,
Expand Down

0 comments on commit 22d4d60

Please sign in to comment.