-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpretrain-hs-mfgb.py
106 lines (88 loc) · 4.6 KB
/
pretrain-hs-mfgb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import time
import tensorflow as tf
import tensorflow.keras as keras
from model import BertModel
from dataset import NoHM2VGraph_Bert_Dataset, HsM2VGraph_Bert_Dataset
keras.backend.clear_session()
# class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
# def __init__(self, d_model, warmup_steps=4000):
# super(CustomSchedule, self).__init__()
#
# self.d_model = d_model
# self.d_model = tf.cast(self.d_model, tf.float32)
#
# self.warmup_steps = warmup_steps
#
# def __call__(self, step):
# arg1 = tf.math.rsqrt(step)
# arg2 = step * (self.warmup_steps ** -1.5)
#
# return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
#
#
# learning_rate = CustomSchedule(128)
optimizer = tf.keras.optimizers.Adam(1e-4)
small = {'name': 'Small', 'num_layers': 3, 'num_heads': 4, 'd_model': 128, 'path': 'small_weights','addH':True}
medium = {'name': 'Medium', 'num_layers': 6, 'num_heads': 8, 'd_model': 256, 'path': 'medium_weights','addH':True}
medium3 = {'name': 'Medium', 'num_layers': 6, 'num_heads': 4, 'd_model': 256, 'path': 'medium_weights3','addH':True}
large = {'name': 'Large', 'num_layers': 12, 'num_heads': 12, 'd_model': 576, 'path': 'large_weights','addH':True}
mfgb_medium_balanced = {'name':'Medium','num_layers': 6, 'num_heads': 8, 'd_model': 256,'path':'weights_mfgb_hs_balanced_1000w','addH':True}
arch = mfgb_medium_balanced ## small 3 4 128 medium: 6 6 256 large: 12 8 516
num_layers = arch['num_layers']
num_heads = arch['num_heads']
d_model = arch['d_model']
addH = arch['addH']
dff = d_model*2
vocab_size=717
dropout_rate = 0.1
model = BertModel(num_layers=num_layers,d_model=d_model,dff=dff,num_heads=num_heads,vocab_size=vocab_size)
train_dataset, test_dataset = HsM2VGraph_Bert_Dataset(path='data/pubchem_200w.txt',smiles_field='CAN_SMILES',addH=addH).get_data()
train_step_signature = [
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
tf.TensorSpec(shape=(None, None, None), dtype=tf.float32),
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
tf.TensorSpec(shape=(None, None), dtype=tf.float32),
]
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
# test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
loss_function = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
def train_step(x, adjoin_matrix,y, char_weight):
seq = tf.cast(tf.math.equal(x, 0), tf.float32)
mask = seq[:, tf.newaxis, tf.newaxis, :]
with tf.GradientTape() as tape:
predictions = model(x,adjoin_matrix=adjoin_matrix,mask=mask,training=True)
loss = loss_function(y,predictions,sample_weight=char_weight)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss.update_state(loss)
train_accuracy.update_state(y,predictions,sample_weight=char_weight)
# @tf.function(input_signature=train_step_signature)
# def test_step(x, adjoin_matrix,y, char_weight):
# seq = tf.cast(tf.math.equal(x, 0), tf.float32)
# mask = seq[:, tf.newaxis, tf.newaxis, :]
# predictions = model(x,adjoin_matrix=adjoin_matrix,mask=mask,training=False)
# test_accuracy.update_state(y,predictions,sample_weight=char_weight)
for epoch in range(3):
start = time.time()
train_loss.reset_states()
for (batch, (x, adjoin_matrix ,y , char_weight)) in enumerate(train_dataset):
train_step(x, adjoin_matrix, y , char_weight)
if batch % 128 == 0:
print('Epoch {} Batch {} Loss {:.4f}'.format(
epoch + 1, batch, train_loss.result()))
print('Accuracy: {:.4f}'.format(train_accuracy.result()))
#
# for x, adjoin_matrix ,y , char_weight in test_dataset:
# test_step(x, adjoin_matrix, y , char_weight)
# print('Test Accuracy: {:.4f}'.format(test_accuracy.result()))
# test_accuracy.reset_states()
train_accuracy.reset_states()
print(arch['path'] + '/bert_weights_{}_{}.h5'.format(arch['name'], epoch+1))
print('Epoch {} Loss {:.4f}'.format(epoch + 1, train_loss.result()))
print('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))
print('Accuracy: {:.4f}'.format(train_accuracy.result()))
model.save_weights(arch['path']+'/bert_weights_{}_{}.h5'.format(arch['name'],epoch+1))
print('Saving checkpoint')