Skip to content

Commit

Permalink
Fix Nan gradients in Force model with padded_disjoint representaiton
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Feb 21, 2024
1 parent 58e43b5 commit 177f133
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 6 deletions.
4 changes: 4 additions & 0 deletions kgcnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@

# Behaviour for backend functions.
__safe_scatter_max_min_to_zero__ = True

# Geometry
__geom_euclidean_norm_add_eps__ = False
__geom_euclidean_norm_no_nan__ = True # Only used for inverse norm.
1 change: 0 additions & 1 deletion kgcnn/layers/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@


def _pad_left(t):
# return ops.concatenate([ops.zeros_like(t[:1]), t], axis=0)
return ops.pad(t, [[1, 0]] + [[0, 0] for _ in range(len(ops.shape(t)) - 1)])


Expand Down
20 changes: 15 additions & 5 deletions kgcnn/layers/geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from kgcnn.layers.polynom import SphericalBesselJnExplicit, SphericalHarmonicsYl
from kgcnn.ops.axis import get_positive_axis
from kgcnn.ops.core import cross as kgcnn_cross
from kgcnn import __geom_euclidean_norm_add_eps__ as global_geom_euclidean_norm_add_eps
from kgcnn import __geom_euclidean_norm_no_nan__ as global_geom_euclidean_norm_no_nan


class NodePosition(Layer):
Expand Down Expand Up @@ -142,8 +144,11 @@ class EuclideanNorm(Layer):
with :obj:`invert_norm` layer arguments.
"""

def __init__(self, axis: int = -1, keepdims: bool = False, invert_norm: bool = False, add_eps: bool = False,
no_nan: bool = True, square_norm: bool = False, **kwargs):
def __init__(self, axis: int = -1, keepdims: bool = False,
invert_norm: bool = False,
add_eps: bool = global_geom_euclidean_norm_add_eps,
no_nan: bool = global_geom_euclidean_norm_no_nan,
square_norm: bool = False, **kwargs):
"""Initialize layer.
Args:
Expand Down Expand Up @@ -177,7 +182,7 @@ def compute_output_shape(self, input_shape):

@staticmethod
def _compute_euclidean_norm(inputs, axis: int = -1, keepdims: bool = False, invert_norm: bool = False,
add_eps: bool = False, no_nan: bool = True, square_norm: bool = False):
add_eps: bool = False, no_nan: bool = False, square_norm: bool = False):
"""Function to compute euclidean norm for inputs.
Args:
Expand Down Expand Up @@ -306,7 +311,10 @@ class NodeDistanceEuclidean(Layer):
the output of :obj:`NodePosition`.
"""

def __init__(self, add_eps: bool = False, no_nan: bool = True, **kwargs):
def __init__(self,
add_eps: bool = global_geom_euclidean_norm_add_eps,
no_nan: bool = global_geom_euclidean_norm_no_nan,
**kwargs):
r"""Initialize layer instance of :obj:`NodeDistanceEuclidean`. """
super(NodeDistanceEuclidean, self).__init__(**kwargs)
self.layer_subtract = Subtract()
Expand Down Expand Up @@ -354,7 +362,9 @@ class EdgeDirectionNormalized(Layer):
As the first index defines the incoming edge.
"""

def __init__(self, add_eps: bool = False, no_nan: bool = True, **kwargs):
def __init__(self, add_eps: bool = global_geom_euclidean_norm_add_eps,
no_nan: bool = global_geom_euclidean_norm_no_nan,
**kwargs):
"""Initialize layer."""
super(EdgeDirectionNormalized, self).__init__(**kwargs)
self.layer_subtract = Subtract()
Expand Down
4 changes: 4 additions & 0 deletions training/train_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import argparse
import keras as ks
from datetime import timedelta
import kgcnn
import kgcnn.training.schedule
import kgcnn.training.scheduler
from kgcnn.data.utils import save_pickle_file
Expand All @@ -18,6 +19,9 @@
from kgcnn.metrics.metrics import ScaledMeanAbsoluteError, ScaledForceMeanAbsoluteError
from kgcnn.data.transform.scaler.force import EnergyForceExtensiveLabelScaler

# For force gradients
kgcnn.__geom_euclidean_norm_add_eps__ = True

# Input arguments from command line.
parser = argparse.ArgumentParser(description='Train a GNN on an Energy-Force Dataset.')
parser.add_argument("--hyper", required=False, help="Filepath to hyper-parameter config file (.py or .json).",
Expand Down

0 comments on commit 177f133

Please sign in to comment.