forked from Robert-K/MAChINE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathml_gnns.py
60 lines (50 loc) · 2.61 KB
/
ml_gnns.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import tensorflow as tf
from backend.utils.molecule_formats import smiles_to_mol_graph
from backend.machine_learning.models.schnet import make_schnet
def create_schnet_with_dataset(parameters, dataset, labels, loss, optimizer, metrics, batch_size):
"""
Creates a Schrödinger Network and a dataset for it to train on using tensorflow
:param parameters: dict containing keys depth, readoutSize and embeddingDimension
:param dataset: dataset to use
:param labels: array of string labels to train on. Currently, only one label is supported.
:param loss: keras loss function
:param optimizer: keras optimizer
:param metrics: array of keras metrics
:param batch_size: int size of data batches
:return: the tf model and created dataset
"""
label = labels[0] # SchNets do not support multiple labels
# Gets Data for the first label from the dataset
x, y = zip(*[(mol["x"]['mol_graph'], mol["y"][label]) for mol in dataset])
# Splits the Data into its 3 parts
nodes, edges, edges_i = zip(*x)
# Needed to properly set dimension of model input
node_dim = nodes[0].shape[-1]
edge_dim = edges[0].shape[-1]
# Converts dataset data to model input
nodes = tf.ragged.constant(nodes, dtype="float32", ragged_rank=1, inner_shape=(node_dim,))
edges = tf.ragged.constant(edges, dtype="float32", ragged_rank=1, inner_shape=(edge_dim,))
edges_i = tf.ragged.constant(edges_i, dtype="int32", ragged_rank=1, inner_shape=(2,))
y = tf.constant(y)
# Creates the actual Dataset
ds = tf.data.Dataset.from_tensor_slices(((nodes, edges, edges_i), y)).batch(batch_size)
# Creates a new SchNet Model
model = make_schnet(
input_node_shape=[None, node_dim],
input_edge_shape=[None, edge_dim],
embedding_dim=int(parameters.get('embeddingDimension')),
readout_size=int(parameters.get('readoutSize')),
depth=int(parameters.get('depth')),
)
model.compile(loss=loss, optimizer=optimizer, metrics=metrics)
return model, ds
def smiles_to_schnet_input(smiles):
# Converts our molecule to a mol graph
(nodes, edges, edges_i) = smiles_to_mol_graph(smiles)
node_dim = nodes.shape[-1]
edge_dim = edges.shape[-1]
# Converts that mol graph into proper model input
nodes = tf.RaggedTensor.from_tensor(tf.constant(nodes, dtype="float32", shape=[1, len(nodes), node_dim]))
edges = tf.RaggedTensor.from_tensor(tf.constant(edges, dtype="float32", shape=[1, len(edges), edge_dim]))
edges_i = tf.RaggedTensor.from_tensor(tf.constant(edges_i, dtype="int32", shape=[1, len(edges_i), 2]))
return [nodes, edges, edges_i]