-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmain_attn.py
128 lines (110 loc) · 5.02 KB
/
main_attn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
import argparse
from attention_corr_methods import (load_attentions, MaxCorr, MinCorr,
PearsonMaxCorr, PearsonMinCorr,
JSMaxCorr, JSMinCorr, AttnLinCKA,
AttnCCA,
PearsonMaxCorr2, PearsonMinCorr2)
def get_options(opt_fname):
if opt_fname == None:
layerspec_l = [-1] * len(attention_fname_l)
else:
with open(opt_fname, 'r') as f:
l = [line.strip() for line in f]
#opt_l = [line.strip().split(',') for line in f]
#l, f, s = zip(*opt_l)
layerspec_l = []
for ls in l:
if ls == "all":
layerspec_l.append(ls)
else:
layerspec_l.append(int(ls))
return layerspec_l
def get_method_l(methods, num_heads_d, attentions_d, device):
if 'all' in methods:
method_l = [
MaxCorr(num_heads_d, attentions_d, device),
MinCorr(num_heads_d, attentions_d, device),
PearsonMaxCorr(num_heads_d, attentions_d, device),
PearsonMinCorr(num_heads_d, attentions_d, device),
PearsonMaxCorr2(num_heads_d, attentions_d, device),
PearsonMinCorr2(num_heads_d, attentions_d, device),
JSMaxCorr(num_heads_d, attentions_d, device),
JSMinCorr(num_heads_d, attentions_d, device),
AttnLinCKA(num_heads_d, attentions_d, device),
AttnCCA(num_heads_d, attentions_d, device),
]
else:
method_l = []
for method in methods:
if method == 'maxcorr':
method_l.append(MaxCorr(num_heads_d, attentions_d, device))
elif method == 'mincorr':
method_l.append(MinCorr(num_heads_d, attentions_d, device))
elif method == 'pearsonmaxcorr':
method_l.append(PearsonMaxCorr(num_heads_d,
attentions_d, device))
elif method == 'pearsonmincorr':
method_l.append(PearsonMinCorr(num_heads_d,
attentions_d, device))
elif method == 'pearsonmaxcorr2':
method_l.append(PearsonMaxCorr2(num_heads_d,
attentions_d, device))
elif method == 'pearsonmincorr2':
method_l.append(PearsonMinCorr2(num_heads_d,
attentions_d, device))
elif method == 'jsmaxcorr':
method_l.append(JSMaxCorr(num_heads_d, attentions_d,
device))
elif method == 'jsmincorr':
method_l.append(JSMinCorr(num_heads_d, attentions_d,
device))
elif method == 'attn_lincka':
method_l.append(AttnLinCKA(num_heads_d, attentions_d,
device))
elif method == 'attn_cca':
method_l.append(AttnCCA(num_heads_d, attentions_d,
device))
return method_l
def main(methods, attention_files, output_file, opt_fname=None,
limit=None, disable_cuda=False, ar_mask=False):
if not disable_cuda and torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
print("Using device: {0}".format(device))
# Set `attention_fname_l`, and options
with open(attention_files) as f:
attention_fname_l = [line.strip() for line in f]
layerspec_l = get_options(opt_fname)
# Load
print("Loading attentions")
a = load_attentions(attention_fname_l, limit=limit,
layerspec_l=layerspec_l, ar_mask=ar_mask)
num_heads_d, attentions_d = a
# Set `method_l`, list of Method objects
print('\nInitializing methods ' + str(methods))
method_l = get_method_l(methods, num_heads_d, attentions_d, device)
# Run all methods in method_l
print('\nComputing correlations')
for method in method_l:
print('For method: ', str(method))
method.compute_correlations()
print('\nWriting correlations')
for method in method_l:
print('For method: ', str(method))
out_fname = output_file + '_' + str(method)
method.write_correlations(out_fname)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--methods", nargs="+")
parser.add_argument("attention_files")
parser.add_argument("output_file")
parser.add_argument("--opt_fname", default=None)
parser.add_argument("--limit", type=int, default=None)
parser.add_argument("--disable_cuda", action="store_true")
parser.add_argument("--ar_mask", action="store_true")
args = parser.parse_args()
main(args.methods, args.attention_files, args.output_file,
opt_fname=args.opt_fname, limit=args.limit,
disable_cuda=args.disable_cuda, ar_mask=args.ar_mask)