|
1 | 1 | import torch |
2 | 2 | from torch import nn |
3 | 3 |
|
4 | | -from .utilities import scatter_add, scatter_mean, scatter_softmax |
| 4 | +from .utilities import scatter_mean, scatter_softmax, scatter_sum |
5 | 5 |
|
6 | 6 | """ |
7 | 7 | Cosmo can be implemented with various filter functions. The underlying principle is always to compute the filter under transformation of a local reference frame (hood_coords) which is derived from neighboring input points. The forward signature of the layer is always the same and inputs can be obtained from a Lift2D or Lift3D module. |
@@ -41,7 +41,7 @@ def forward( |
41 | 41 | w = self.w[:, nn_idx] # use closest kernel point |
42 | 42 | f = features[source] |
43 | 43 | out_channels = torch.einsum("ni,oni->no", f, w) # m x out |
44 | | - features = scatter_add(out_channels, target, dim=0, dim_size=m) |
| 44 | + features = scatter_sum(out_channels, target, m) |
45 | 45 | return features # Updated features of shape (m, out_channels) |
46 | 46 |
|
47 | 47 |
|
@@ -94,7 +94,7 @@ def forward( |
94 | 94 | ) |
95 | 95 | f = features[source] |
96 | 96 | out_channels = torch.einsum("ni,noi->no", f, w) # m x out |
97 | | - features = scatter_mean(out_channels, target, dim_size=m, dim=0) |
| 97 | + features = scatter_mean(out_channels, target, m) |
98 | 98 | return features # Updated features of shape (m, out_channels) |
99 | 99 |
|
100 | 100 |
|
@@ -136,6 +136,6 @@ def forward( |
136 | 136 | w1 = self.w1(features) |
137 | 137 | w2 = self.w2(features) |
138 | 138 | w3 = self.w3(features) |
139 | | - a = scatter_softmax(w1[target] - w2[source] + d, target, dim=0, dim_size=m) |
140 | | - features = scatter_add(a * (w3[source] + d), target, dim=0, dim_size=m) |
| 139 | + a = scatter_softmax(w1[target] - w2[source] + d, target, m) |
| 140 | + features = scatter_sum(a * (w3[source] + d), target, m) |
141 | 141 | return features # Updated features of shape (m, out_channels) |
0 commit comments