-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmetric.py
70 lines (61 loc) · 2.28 KB
/
metric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
import itertools
import pandas as pd
from net.rew import BaseCNN
from utils.eval_utils import distance
from utils.constants import seq_to_one_hot
class OnehotDataset(Dataset):
def __init__(self, seqs):
self.seqs = seqs
def __len__(self):
return len(self.seqs)
def __getitem__(self, index):
return seq_to_one_hot(self.seqs[index])
class Evaluator:
def __init__(self, protein, max_target, min_target, device, batch_size = 16):
self.device = device
self.batch_size = batch_size
self.max_target, self.min_target = max_target, min_target
oracle = BaseCNN(make_one_hot=False)
oracle_ckpt = torch.load(f'ckpt/{protein}/oracle.ckpt', map_location=self.device)
if "state_dict" in oracle_ckpt.keys():
oracle_ckpt = oracle_ckpt["state_dict"]
oracle.load_state_dict({ k.replace('predictor.',''):v for k,v in oracle_ckpt.items() })
oracle.eval()
self.oracle = oracle.to(device)
high = pd.read_csv(f'data/{protein}/all.csv')[['sequence', 'target']]
high = high[high['target'] > high['target'].quantile(q=0.9).item()]
self.high = high['sequence'].tolist()
self.high = self.high[:128]
def evaluate(self, seqs, inits):
dataset = OnehotDataset(seqs)
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
targets = []
with torch.no_grad():
for batch in dataloader:
_, target = self.oracle(batch.to(self.device), get_embed=True)
target = (target - self.min_target) / (self.max_target - self.min_target)
targets.extend(list(target.cpu().flatten()))
fitness = np.median(targets)
distances = []
for s1, s2 in itertools.combinations(seqs, 2):
distances.append(distance(s1, s2))
diversity = np.median(distances)
distances = []
for j in seqs:
dist_j = []
for i in inits:
dist_j.append(distance(i,j))
distances.append(min(dist_j))
novelty = np.median(distances)
distances = []
for j in seqs:
dist_j = []
for i in self.high:
dist_j.append(distance(i,j))
distances.append(min(dist_j))
high = np.median(distances)
return fitness, diversity, novelty, high