Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions faenet/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""
FAENet: Frame Averaging Equivariant graph neural Network
FAENet: Frame Averaging Equivariant graph neural Network
Simple, scalable and expressive model for property prediction on 3D atomic systems.
"""
from typing import Dict, Optional, Union

from typing import Optional, Union

import torch
from torch import nn
Expand All @@ -14,7 +15,7 @@
from faenet.base_model import BaseModel
from faenet.embedding import PhysEmbedding
from faenet.force_decoder import ForceDecoder
from faenet.utils import GaussianSmearing, swish, pbc_preprocess, base_preprocess
from faenet.utils import GaussianSmearing, swish


class EmbeddingBlock(nn.Module):
Expand Down Expand Up @@ -232,7 +233,7 @@ def reset_parameters(self):
nn.init.xavier_uniform_(self.lin_h.weight)
self.lin_h.bias.data.fill_(0)

def forward(self, h, edge_index, e):
def forward(self, h, edge_index, e, batch=None):
"""Forward pass of the Interaction block.
Called in FAENet forward pass to update atom representations.

Expand Down Expand Up @@ -261,7 +262,7 @@ def forward(self, h, edge_index, e):
h = self.act(self.lin_down(h)) # downscale node rep.
h = self.propagate(edge_index, x=h, W=e) # propagate
if self.graph_norm:
h = self.act(self.graph_norm(h))
h = self.act(self.graph_norm(h, batch=batch))
h = self.act(self.lin_up(h)) # upscale node rep.

elif self.mp_type == "updown_local_env":
Expand All @@ -270,14 +271,14 @@ def forward(self, h, edge_index, e):
e = self.lin_geom(e)
h = self.propagate(edge_index, x=h, W=e) # propagate
if self.graph_norm:
h = self.act(self.graph_norm(h))
h = self.act(self.graph_norm(h, batch=batch))
h = torch.cat((h, chi), dim=1)
h = self.lin_up(h)

elif self.mp_type in {"base", "simple"}:
h = self.propagate(edge_index, x=h, W=e) # propagate
if self.graph_norm:
h = self.act(self.graph_norm(h))
h = self.act(self.graph_norm(h, batch=batch))
h = self.act(self.lin_h(h))

else:
Expand Down Expand Up @@ -616,7 +617,7 @@ def energy_forward(self, data, preproc=True):
energy_skip_co.append(
self.output_block(h, edge_index, edge_weight, batch, alpha)
)
h = h + interaction(h, edge_index, e)
h = h + interaction(h, edge_index, e, batch)

# Atom skip-co
if self.skip_co == "concat_atom":
Expand Down