Skip to content

Commit

Permalink
updated notebooks fix GNNExplain changes for pytorch in latest keras …
Browse files Browse the repository at this point in the history
…release. Added jax for forcemodel.
  • Loading branch information
PatReis committed Feb 21, 2024
1 parent 74902a5 commit 58e43b5
Show file tree
Hide file tree
Showing 22 changed files with 1,249 additions and 2,006 deletions.
68 changes: 34 additions & 34 deletions docs/source/forces.ipynb

Large diffs are not rendered by default.

16 changes: 8 additions & 8 deletions docs/source/layers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -330,10 +330,10 @@
"text": [
"tensor([[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
" 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
" [ 0.1194, -0.0455, 0.2886, -0.2412, 0.5782, -0.0314, 0.0691, -0.0883,\n",
" -0.0427, 0.4910, 0.5271, 0.1340, 0.1813, -0.1867, -0.0742, 0.1329],\n",
" [ 0.3357, 0.2820, 0.3638, -0.1271, -0.1052, -0.3881, 0.3368, -0.5367,\n",
" -0.0495, 0.7888, 0.6429, -0.6687, 0.0486, -0.1593, -0.5402, 0.4580]],\n",
" [ 0.0927, -0.1268, -0.1730, -0.0280, 0.2832, 0.3245, -0.1873, -0.2191,\n",
" -0.0106, 0.2115, -0.0539, -0.0450, 0.0180, 0.0578, 0.0857, 0.1617],\n",
" [-0.0629, -0.4794, -0.3772, 2.1843, -0.5275, 0.6220, -0.2883, -0.6898,\n",
" 0.5452, -0.4738, 0.8087, 1.2719, 0.0853, -0.6812, 0.0123, -0.1148]],\n",
" device='cuda:0', grad_fn=<ScatterReduceBackward0>)\n"
]
}
Expand Down Expand Up @@ -363,10 +363,10 @@
{
"data": {
"text/plain": [
"tensor([[ 0.1194, -0.0455, 0.2886, -0.2412, 0.5782, -0.0314, 0.0691, -0.0883,\n",
" -0.0427, 0.4910, 0.5271, 0.1340, 0.1813, -0.1867, -0.0742, 0.1329],\n",
" [ 0.3357, 0.2820, 0.3638, -0.1271, -0.1052, -0.3881, 0.3368, -0.5367,\n",
" -0.0495, 0.7888, 0.6429, -0.6687, 0.0486, -0.1593, -0.5402, 0.4580]],\n",
"tensor([[ 0.0927, -0.1268, -0.1730, -0.0280, 0.2832, 0.3245, -0.1873, -0.2191,\n",
" -0.0106, 0.2115, -0.0539, -0.0450, 0.0180, 0.0578, 0.0857, 0.1617],\n",
" [-0.0629, -0.4794, -0.3772, 2.1843, -0.5275, 0.6220, -0.2883, -0.6898,\n",
" 0.5452, -0.4738, 0.8087, 1.2719, 0.0853, -0.6812, 0.0123, -0.1148]],\n",
" device='cuda:0', grad_fn=<SliceBackward0>)"
]
},
Expand Down
14 changes: 7 additions & 7 deletions docs/source/molecules.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions kgcnn/layers/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@


def _pad_left(t):
# return ops.concatenate([ops.zeros_like(t[:1]), t], axis=0)
return ops.pad(t, [[1, 0]] + [[0, 0] for _ in range(len(ops.shape(t)) - 1)])


Expand Down
5 changes: 4 additions & 1 deletion kgcnn/literature/GNNExplain/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ def explain(self, graph_instance, output_to_explain=None, inspection=False, **kw

gnnx_optimizer.compile(**self.compile_options)

# Build gnnx_optimizer with example graph instance.
gnnx_optimizer.predict(graph_instance, steps=1)

gnnx_optimizer.fit(x=graph_instance, y=gnnx_optimizer.output_to_explain, **fit_options)

# Read out information from inspection_callback
Expand Down Expand Up @@ -375,7 +378,7 @@ def __init__(self, gnn_model, graph_instance,
)
self.output_to_explain.assign(output_to_explain)
else:
self.output_to_explain = output_to_explain
self.output_to_explain = ops.stop_gradient(ops.convert_to_tensor(output_to_explain))

# Configuration Parameters
self.edge_mask_loss_weight = edge_mask_loss_weight
Expand Down
5 changes: 3 additions & 2 deletions kgcnn/losses/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def __init__(self, reduction="sum_over_batch_size", name="mean_absolute_error",
super(MeanAbsoluteError, self).__init__(reduction=reduction, name=name, dtype=dtype)

def call(self, y_true, y_pred):
return mean_absolute_error(y_true, y_pred)
out = mean_absolute_error(y_true, y_pred)
return out

def get_config(self):
config = super(MeanAbsoluteError, self).get_config()
Expand All @@ -35,7 +36,7 @@ def call(self, y_true, y_pred):
check_nonzero = ops.cast(ops.logical_not(
ops.all(ops.isclose(y_true, ops.convert_to_tensor(0., dtype=y_true.dtype)), axis=2)), dtype="int32")
row_count = ops.sum(check_nonzero, axis=1)
row_count = ops.where(row_count < 1, 1, row_count)
row_count = ops.where(row_count < 1, 1, row_count) # Prevent divide by 0.
norm = 1/ops.cast(row_count, dtype=y_true.dtype)
else:
norm = 1/ops.shape(y_true)[1]
Expand Down
29 changes: 26 additions & 3 deletions kgcnn/models/force.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import tensorflow as tf
elif backend() == "torch":
import torch
elif backend() == "jax":
import jax.numpy as jnp
import jax
else:
raise NotImplementedError("Backend '%s' not supported for force model." % backend())

Expand Down Expand Up @@ -60,8 +63,7 @@ def __init__(self,
is_physical_force (bool): Whether to return the physical force, e.g. the negative gradient of the energy.
use_batch_jacobian: Deprecated.
name (str): Name of the model.
outputs: List of outputs as dictionary kwargs similar to inputs. Not used by the model but can be passed
for external use.
outputs: List of outputs as dictionary kwargs similar to inputs.
"""
super().__init__()
if model_energy is None:
Expand Down Expand Up @@ -101,6 +103,9 @@ def __init__(self,
else:
self.output_as_dict_use = False

energy_output_config = outputs[self.output_as_dict_names[0]] if self.output_as_dict_use else output[0]
self._expected_energy_states = energy_output_config["shape"][0]

# We can try to infer the model inputs from energy model, if not given explicit.
self._inputs_to_force_model = inputs
if self._inputs_to_force_model is None:
Expand All @@ -111,6 +116,8 @@ def __init__(self,
self._call_grad_backend = self._call_grad_tf
elif backend() == "torch":
self._call_grad_backend = self._call_grad_torch
elif backend() == "jax":
self._call_grad_backend = self._call_grad_jax
else:
raise NotImplementedError("Backend '%s' not supported for force model." % backend())

Expand Down Expand Up @@ -149,13 +156,29 @@ def _call_grad_torch(self, inputs, training=False, **kwargs):
eng = self.energy_model.call(inputs, training=training, **kwargs)
eng_sum = eng.sum(dim=0)
e_grad = torch.cat([
torch.unsqueeze(torch.autograd.grad(eng_sum[i], x, create_graph=True)[0], dim=-1) for i in
torch.unsqueeze(torch.autograd.grad(eng_sum[i], x, create_graph=True, allow_unused=True)[0], dim=-1) for i in
range(eng.shape[-1])], dim=-1)

if self.output_squeeze_states:
e_grad = torch.squeeze(e_grad, dim=-1)
return eng, e_grad

def _call_grad_jax(self, inputs, training=False, **kwargs):

def energy_reduce(*inputs, pos: int = 0):
eng = self.energy_model(inputs, training=training, **kwargs)
eng_sum = jnp.sum(eng, axis=0)[pos]
return eng_sum

grad_fn = jax.value_and_grad(energy_reduce, argnums=self.coordinate_input)
states = [grad_fn(*inputs, pos=i) for i in range(self._expected_energy_states)]
eng = jnp.concatenate([jnp.expand_dims(x[0], axis=-1) for x in states], axis=-1)
e_grad = jnp.concatenate([jnp.expand_dims(x[1], axis=-1) for x in states], axis=-1)

if self.output_squeeze_states:
e_grad = jnp.squeeze(e_grad, axis=-1)
return eng, e_grad

def call(self, inputs, training=False, **kwargs):

eng, e_grad = self._call_grad_backend(inputs, training=training, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion kgcnn/utils/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def check_device():
elif backend() == "jax":
import jax
jax_devices = jax.devices()
cuda_is_available = any([x.device_kind == "gpu" for x in jax_devices])
cuda_is_available = any([x.device_kind in ["gpu", "cuda"] for x in jax_devices])
physical_device_name = [x for x in jax_devices]
logical_device_list = [x.id for x in jax_devices]
memory_info = [x.memory_stats() for x in jax_devices]
Expand Down
22 changes: 11 additions & 11 deletions notebooks/example_transfer_learning.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 58e43b5

Please sign in to comment.