From c38b447878a9f6bbc30de1b12d64a3ff4169b450 Mon Sep 17 00:00:00 2001 From: zergzzlun Date: Fri, 27 Aug 2021 02:31:03 +0800 Subject: [PATCH] fixed bug: features and indexes on diffent devices when using cuda. --- model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model.py b/model.py index 87bf937..d1ab4ae 100644 --- a/model.py +++ b/model.py @@ -88,7 +88,7 @@ def forward(self, coords, features, knn_output): expanded_coords = coords.transpose(-2,-1).unsqueeze(-1).expand(B, 3, N, K) neighbor_coords = torch.gather(expanded_coords, 2, expanded_idx) # shape (B, 3, N, K) - expanded_idx = idx.unsqueeze(1).expand(B, features.size(1), N, K) + expanded_idx = idx.unsqueeze(1).expand(B, features.size(1), N, K).to(self.device) expanded_features = features.expand(B, -1, N, K) neighbor_features = torch.gather(expanded_features, 2, expanded_idx) # if USE_CUDA: