-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathacdc.py
89 lines (71 loc) · 2.64 KB
/
acdc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# %%
import torch as t
from torch import nn
def get_hook(layer, patch, clean, threshold):
def hook(grad):
effects = t.where(
t.gt(grad * (patch.value - clean.value), threshold),
grad,
t.zeros_like(grad)
)
print(f"Layer {layer}")
for idx in t.nonzero(effects):
value = effects[tuple(idx)] * (patch.value - clean.value)[tuple(idx)]
print(f"Multindex: {tuple(idx.tolist())}, Value: {value}")
print()
return effects
return hook
def find_circuit(
clean,
patch,
model,
submodules,
autoencoders,
threshold=0.5,
):
clean_input, clean_answer = clean
patch_input, patch_answer = patch
clean_answer_idx = model.tokenizer(clean_answer)['input_ids'][0]
patch_answer_idx = model.tokenizer(patch_answer)['input_ids'][0]
patched_features = []
with model.invoke(patch_input) as invoker:
for submodule, ae in zip(submodules, autoencoders):
f = ae.encode(submodule.output)
patched_features.append(f.save())
submodule.output = ae.decode(f)
logits = invoker.output.logits
patch_logit_diff = logits[0, -1, patch_answer_idx] - logits[0, -1, clean_answer_idx]
clean_features = []
with model.invoke(clean_input, fwd_args={'inference' : False}) as invoker:
for i, (submodule, ae) in enumerate(zip(submodules, autoencoders)):
f = ae.encode(submodule.output)
clean_features.append(f.save())
patch, clean = patched_features[i], clean_features[i]
hook = get_hook(i, patch, clean, threshold)
f.register_hook(hook)
submodule.output = ae.decode(f)
logits = invoker.output.logits
clean_logit_diff = logits[0, -1, patch_answer_idx] - logits[0, -1, clean_answer_idx]
clean_logit_diff.backward()
print(f'Total change: {patch_logit_diff.item() - clean_logit_diff.item()}')
# %%
from nnsight import LanguageModel
from dictionary_learning.dictionary import AutoEncoder
model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map='cuda:0')
layers = len(model.gpt_neox.layers)
submodules = [
model.gpt_neox.layers[i].mlp.dense_4h_to_h for i in range(layers)
]
autoencoders = []
for i in range(layers):
ae = AutoEncoder(512, 16 * 512)
ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/mlp_out_layer{i}/0_8192/ae.pt'))
autoencoders.append(ae)
clean = (
"The man", " is"
)
patch = (
"The men", " are"
)
grads = find_circuit(clean, patch, model, submodules, autoencoders, threshold=0.5)
# %%