Skip to content

Commit dbc1569

Browse files
committed
fix scatter
1 parent a30c48f commit dbc1569

1 file changed

Lines changed: 5 additions & 5 deletions

File tree

cosmic/cosmo.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
from torch import nn
33

4-
from .utilities import scatter_add, scatter_mean, scatter_softmax
4+
from .utilities import scatter_mean, scatter_softmax, scatter_sum
55

66
"""
77
Cosmo can be implemented with various filter functions. The underlying principle is always to compute the filter under transformation of a local reference frame (hood_coords) which is derived from neighboring input points. The forward signature of the layer is always the same and inputs can be obtained from a Lift2D or Lift3D module.
@@ -41,7 +41,7 @@ def forward(
4141
w = self.w[:, nn_idx] # use closest kernel point
4242
f = features[source]
4343
out_channels = torch.einsum("ni,oni->no", f, w) # m x out
44-
features = scatter_add(out_channels, target, dim=0, dim_size=m)
44+
features = scatter_sum(out_channels, target, m)
4545
return features # Updated features of shape (m, out_channels)
4646

4747

@@ -94,7 +94,7 @@ def forward(
9494
)
9595
f = features[source]
9696
out_channels = torch.einsum("ni,noi->no", f, w) # m x out
97-
features = scatter_mean(out_channels, target, dim_size=m, dim=0)
97+
features = scatter_mean(out_channels, target, m)
9898
return features # Updated features of shape (m, out_channels)
9999

100100

@@ -136,6 +136,6 @@ def forward(
136136
w1 = self.w1(features)
137137
w2 = self.w2(features)
138138
w3 = self.w3(features)
139-
a = scatter_softmax(w1[target] - w2[source] + d, target, dim=0, dim_size=m)
140-
features = scatter_add(a * (w3[source] + d), target, dim=0, dim_size=m)
139+
a = scatter_softmax(w1[target] - w2[source] + d, target, m)
140+
features = scatter_sum(a * (w3[source] + d), target, m)
141141
return features # Updated features of shape (m, out_channels)

0 commit comments

Comments
 (0)