diff --git a/model.py b/model.py index 87bf937..d1ab4ae 100644 --- a/model.py +++ b/model.py @@ -88,7 +88,7 @@ def forward(self, coords, features, knn_output): expanded_coords = coords.transpose(-2,-1).unsqueeze(-1).expand(B, 3, N, K) neighbor_coords = torch.gather(expanded_coords, 2, expanded_idx) # shape (B, 3, N, K) - expanded_idx = idx.unsqueeze(1).expand(B, features.size(1), N, K) + expanded_idx = idx.unsqueeze(1).expand(B, features.size(1), N, K).to(self.device) expanded_features = features.expand(B, -1, N, K) neighbor_features = torch.gather(expanded_features, 2, expanded_idx) # if USE_CUDA: