-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdata_utils.py
160 lines (120 loc) · 5.1 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from mat import make_model
def get_samples(pos_path, neg_path):
pos_samples = torch.load(pos_path)
neg_samples = torch.load(neg_path)
pos_mols = []
pos_seqs = []
for unis, values in pos_samples.items():
mol = values[0]
seq = values[1]
pos_mols.append(mol.replace('*', 'C'))
pos_seqs.append(seq)
neg_mols = []
neg_seqs = []
for unis, values in neg_samples.items():
mol = values[0]
seq = values[1]
neg_mols.append(mol.replace('*', 'C'))
neg_seqs.append(seq)
assert len(pos_mols) == len(pos_seqs)
assert len(neg_mols) == len(neg_seqs)
return pos_mols, pos_seqs, neg_mols, neg_seqs
def collate_fn(batch):
mols, seqs, labels = zip(*batch)
batch_mols = pad_sequence(mols, batch_first=True, padding_value=0)
batch_seqs = pad_sequence(seqs, batch_first=True, padding_value=1)
batch_labels = torch.stack(labels)
return batch_mols, batch_seqs, batch_labels
class EnzymeDataset(Dataset):
def __init__(self, molecules, sequences, mol_tokenizer, seq_tokenizer, positive_sample=True, max_len=7000):
assert len(molecules) == len(sequences)
self.len = len(sequences)
self.mols = molecules
self.seqs = sequences
self.mol_tokenizer = mol_tokenizer
self.seq_tokenizer = seq_tokenizer
self.max_len = max_len
self.labels = torch.ones(self.len) if positive_sample else torch.zeros(self.len)
def __len__(self):
return self.len
def __getitem__(self, item):
mols = self.mols[item]
seqs = self.seqs[item]
labels = self.labels[item]
mol_tok = self.mol_tokenizer(mols, padding=True, truncation=True, max_length=self.max_len)['input_ids']
seq_tok = self.seq_tokenizer(seqs, padding=True, truncation=True, max_length=self.max_len)['input_ids']
return torch.tensor(mol_tok), torch.tensor(seq_tok), labels
def collate_fn_pretrained(batch):
mols, seqs, labels = zip(*batch)
batch_mols = torch.stack(mols)
batch_seqs = torch.stack(seqs)
batch_labels = torch.stack(labels)
return batch_mols, batch_seqs, batch_labels
class EnzymeDatasetPretrained(Dataset):
def __init__(self, molecules, sequences, mol_embedding, seq_embedding, positive_sample=True, max_len=5000):
assert len(molecules) == len(sequences)
self.len = len(sequences)
self.mols = molecules
self.seqs = sequences
self.mol_embedding = mol_embedding
self.seq_embedding = seq_embedding
self.max_len = max_len
self.labels = torch.ones(self.len) if positive_sample else torch.zeros(self.len)
def __len__(self):
return self.len
def __getitem__(self, item):
mols = self.mols[item]
seqs = self.seqs[item]
labels = self.labels[item]
seqs = seqs[:self.max_len] if len(seqs) > self.max_len else seqs
mol_tok = self.mol_embedding[mols]
seq_tok = self.seq_embedding[seqs]
if mol_tok.dim() == 2: mol_tok = mol_tok.sum(0)
return mol_tok, seq_tok, labels
def collate_fn_pretrained_single(batch):
data, labels = zip(*batch)
batch_data = torch.stack(data)
batch_labels = torch.stack(labels)
return batch_data, batch_labels
class EnzymeDatasetPretrainedSingle(Dataset):
def __init__(self, data, embedding, positive_sample=True, max_len=5000):
self.len = len(data)
self.data = data
self.embedding = embedding
self.max_len = max_len
self.labels = torch.ones(self.len) if positive_sample else torch.zeros(self.len)
def __len__(self):
return self.len
def __getitem__(self, item):
data = self.data[item]
labels = self.labels[item]
data = data[:self.max_len] if len(data) > self.max_len else data
emb = self.embedding[data]
if emb.dim() == 2: emb = emb.sum(0)
return emb, labels
def graph_collate_fn(batch):
mols, seqs, labels = zip(*batch)
return mols, seqs, labels
class GraphEnzymeDataset(Dataset):
def __init__(self, molecules, sequences, alphabet, mol_graphs_dict, positive_sample=True, max_len=7000):
assert len(molecules) == len(sequences)
self.len = len(sequences)
self.mols = molecules
self.seqs = sequences
self.alphabet = alphabet
self.mol_graphs_dict = mol_graphs_dict
self.max_len = max_len
self.labels = torch.ones(self.len) if positive_sample else torch.zeros(self.len)
def __len__(self):
return self.len
def __getitem__(self, item):
mols = self.mols[item]
seqs = self.seqs[item]
labels = self.labels[item]
split_mols = mols.replace('*', 'C').split('.')
seq_tok = self.alphabet.encode(seqs[:self.max_len])
graph_mols = [(self.mol_graphs_dict[mol]) for mol in split_mols]
return graph_mols, torch.tensor(seq_tok).view(1, -1), labels