From cd9af56b886a40745998044ed079ac06c575aba6 Mon Sep 17 00:00:00 2001 From: Ali Ramlaoui Date: Wed, 12 Feb 2025 18:25:58 +0100 Subject: [PATCH] fix: Send batch info to graph_norm --- faenet/model.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/faenet/model.py b/faenet/model.py index 18f7b5d..c7cca89 100644 --- a/faenet/model.py +++ b/faenet/model.py @@ -1,8 +1,9 @@ """ -FAENet: Frame Averaging Equivariant graph neural Network +FAENet: Frame Averaging Equivariant graph neural Network Simple, scalable and expressive model for property prediction on 3D atomic systems. """ -from typing import Dict, Optional, Union + +from typing import Optional, Union import torch from torch import nn @@ -14,7 +15,7 @@ from faenet.base_model import BaseModel from faenet.embedding import PhysEmbedding from faenet.force_decoder import ForceDecoder -from faenet.utils import GaussianSmearing, swish, pbc_preprocess, base_preprocess +from faenet.utils import GaussianSmearing, swish class EmbeddingBlock(nn.Module): @@ -232,7 +233,7 @@ def reset_parameters(self): nn.init.xavier_uniform_(self.lin_h.weight) self.lin_h.bias.data.fill_(0) - def forward(self, h, edge_index, e): + def forward(self, h, edge_index, e, batch=None): """Forward pass of the Interaction block. Called in FAENet forward pass to update atom representations. @@ -261,7 +262,7 @@ def forward(self, h, edge_index, e): h = self.act(self.lin_down(h)) # downscale node rep. h = self.propagate(edge_index, x=h, W=e) # propagate if self.graph_norm: - h = self.act(self.graph_norm(h)) + h = self.act(self.graph_norm(h, batch=batch)) h = self.act(self.lin_up(h)) # upscale node rep. elif self.mp_type == "updown_local_env": @@ -270,14 +271,14 @@ def forward(self, h, edge_index, e): e = self.lin_geom(e) h = self.propagate(edge_index, x=h, W=e) # propagate if self.graph_norm: - h = self.act(self.graph_norm(h)) + h = self.act(self.graph_norm(h, batch=batch)) h = torch.cat((h, chi), dim=1) h = self.lin_up(h) elif self.mp_type in {"base", "simple"}: h = self.propagate(edge_index, x=h, W=e) # propagate if self.graph_norm: - h = self.act(self.graph_norm(h)) + h = self.act(self.graph_norm(h, batch=batch)) h = self.act(self.lin_h(h)) else: @@ -616,7 +617,7 @@ def energy_forward(self, data, preproc=True): energy_skip_co.append( self.output_block(h, edge_index, edge_weight, batch, alpha) ) - h = h + interaction(h, edge_index, e) + h = h + interaction(h, edge_index, e, batch) # Atom skip-co if self.skip_co == "concat_atom":