-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathAdelaiDet_CUDA_fix.patch
96 lines (84 loc) · 3.98 KB
/
AdelaiDet_CUDA_fix.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
diff --git a/adet/layers/csrc/ml_nms/ml_nms.cu b/adet/layers/csrc/ml_nms/ml_nms.cu
index f1c1a42..a380b75 100644
--- a/adet/layers/csrc/ml_nms/ml_nms.cu
+++ b/adet/layers/csrc/ml_nms/ml_nms.cu
@@ -1,12 +1,17 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#include <ATen/ATen.h>
-#include <ATen/cuda/CUDAContext.h>
-#include <THC/THC.h>
-#include <THC/THCDeviceUtils.cuh>
-
+//#include <ATen/cuda/CUDAContext.h>
+//#include <THC/THC.h>
+//#include <THC/THCDeviceUtils.cuh>
+#include "ATen/cuda/DeviceUtils.cuh"
+#include "ATen/ceil_div.h"
+
+#include <c10/cuda/CUDACachingAllocator.h>
+#include <ATen/core/Tensor.h>
#include <vector>
#include <iostream>
+
int const threadsPerBlock = sizeof(unsigned long long) * 8;
__device__ inline float devIoU(float const * const a, float const * const b) {
@@ -65,7 +70,8 @@ __global__ void ml_nms_kernel(const int n_boxes, const float nms_overlap_thresh,
t |= 1ULL << i;
}
}
- const int col_blocks = THCCeilDiv(n_boxes, threadsPerBlock);
+ //const int col_blocks = THCCeilDiv(n_boxes, threadsPerBlock);
+ const int col_blocks = at::ceil_div(n_boxes, threadsPerBlock);
dev_mask[cur_box_idx * col_blocks + col_start] = t;
}
}
@@ -82,20 +88,21 @@ at::Tensor ml_nms_cuda(const at::Tensor boxes, const float nms_overlap_thresh) {
int boxes_num = boxes.size(0);
- const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock);
+ const int col_blocks = at::ceil_div(boxes_num, threadsPerBlock);//THCCeilDiv(boxes_num, threadsPerBlock);
scalar_t* boxes_dev = boxes_sorted.data<scalar_t>();
- THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState
-
+ //THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState
+ //THCState *state = at::globalContext().lazyInitCUDA();
unsigned long long* mask_dev = NULL;
//THCudaCheck(THCudaMalloc(state, (void**) &mask_dev,
// boxes_num * col_blocks * sizeof(unsigned long long)));
- mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long));
+ //mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long));
+ mask_dev = (unsigned long long*) c10::cuda::CUDACachingAllocator::raw_alloc(boxes_num * col_blocks * sizeof(unsigned long long));
- dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock),
- THCCeilDiv(boxes_num, threadsPerBlock));
+ dim3 blocks(at::ceil_div(boxes_num, threadsPerBlock),
+ at::ceil_div(boxes_num, threadsPerBlock));
dim3 threads(threadsPerBlock);
ml_nms_kernel<<<blocks, threads>>>(boxes_num,
nms_overlap_thresh,
@@ -103,10 +110,10 @@ at::Tensor ml_nms_cuda(const at::Tensor boxes, const float nms_overlap_thresh) {
mask_dev);
std::vector<unsigned long long> mask_host(boxes_num * col_blocks);
- THCudaCheck(cudaMemcpy(&mask_host[0],
+ /*THCudaCheck(cudaMemcpy(&mask_host[0],
mask_dev,
sizeof(unsigned long long) * boxes_num * col_blocks,
- cudaMemcpyDeviceToHost));
+ cudaMemcpyDeviceToHost));*/
std::vector<unsigned long long> remv(col_blocks);
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
@@ -128,7 +135,10 @@ at::Tensor ml_nms_cuda(const at::Tensor boxes, const float nms_overlap_thresh) {
}
}
- THCudaFree(state, mask_dev);
+ c10::cuda::CUDACachingAllocator::raw_delete(mask_dev);
+ //THCudaFree(state, mask_dev);
+ //C10_CUDA_FREE(state, mask_dev);
+ // C10_CUDA_CHECK
// TODO improve this part
return std::get<0>(order_t.index({
keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(
@@ -136,4 +146,4 @@ at::Tensor ml_nms_cuda(const at::Tensor boxes, const float nms_overlap_thresh) {
}).sort(0, false));
}
-} // namespace adet
\ No newline at end of file
+} // namespace adet