-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathnms_cupy.py
126 lines (108 loc) · 4.05 KB
/
nms_cupy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from __future__ import division
import numpy as np
import cupy as cp
import torch as t
from _nms_gpu_post import _nms_gpu_post
@cp.util.memoize(for_each_device=True)
def _load_kernel(kernel_name, code, options=()):
cp.cuda.runtime.free(0)
assert isinstance(options, tuple)
kernel_code = cp.cuda.compile_with_cache(code, options=options)
return kernel_code.get_function(kernel_name)
def _non_maximum_suppression_gpu(bbox, thresh, score=None, limit=None):
if len(bbox) == 0:
return cp.zeros((0,), dtype=np.int32)
n_bbox = bbox.shape[0]
if score is not None:
order = score.argsort()[::-1].astype(np.int32)
else:
order = cp.arange(n_bbox, dtype=np.int32)
sorted_bbox = bbox[order, :]
selec, n_selec = _call_nms_kernel(
sorted_bbox, thresh)
selec = selec[:n_selec]
selec = order[selec]
if limit is not None:
selec = selec[:limit]
return cp.asnumpy(selec)
_nms_gpu_code = '''
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
int const threadsPerBlock = sizeof(unsigned long long) * 8;
__device__
inline float devIoU(float const *const bbox_a, float const *const bbox_b) {
float top = max(bbox_a[0], bbox_b[0]);
float bottom = min(bbox_a[2], bbox_b[2]);
float left = max(bbox_a[1], bbox_b[1]);
float right = min(bbox_a[3], bbox_b[3]);
float height = max(bottom - top, 0.f);
float width = max(right - left, 0.f);
float area_i = height * width;
float area_a = (bbox_a[2] - bbox_a[0]) * (bbox_a[3] - bbox_a[1]);
float area_b = (bbox_b[2] - bbox_b[0]) * (bbox_b[3] - bbox_b[1]);
return area_i / (area_a + area_b - area_i);
}
extern "C"
__global__
void nms_kernel(const int n_bbox, const float thresh,
const float *dev_bbox,
unsigned long long *dev_mask) {
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
const int row_size =
min(n_bbox - row_start * threadsPerBlock, threadsPerBlock);
const int col_size =
min(n_bbox - col_start * threadsPerBlock, threadsPerBlock);
__shared__ float block_bbox[threadsPerBlock * 4];
if (threadIdx.x < col_size) {
block_bbox[threadIdx.x * 4 + 0] =
dev_bbox[(threadsPerBlock * col_start + threadIdx.x) * 4 + 0];
block_bbox[threadIdx.x * 4 + 1] =
dev_bbox[(threadsPerBlock * col_start + threadIdx.x) * 4 + 1];
block_bbox[threadIdx.x * 4 + 2] =
dev_bbox[(threadsPerBlock * col_start + threadIdx.x) * 4 + 2];
block_bbox[threadIdx.x * 4 + 3] =
dev_bbox[(threadsPerBlock * col_start + threadIdx.x) * 4 + 3];
}
__syncthreads();
if (threadIdx.x < row_size) {
const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
const float *cur_box = dev_bbox + cur_box_idx * 4;
int i = 0;
unsigned long long t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
if (devIoU(cur_box, block_bbox + i * 4) >= thresh) {
t |= 1ULL << i;
}
}
const int col_blocks = DIVUP(n_bbox, threadsPerBlock);
dev_mask[cur_box_idx * col_blocks + col_start] = t;
}
}
'''
def cupy_call_nms_kernel(bbox, thresh):
# PyTorch does not support unsigned long Tensor.
# Doesn't matter,since it returns ndarray finally.
# So I'll keep it unmodified.
n_bbox = bbox.shape[0]
threads_per_block = 64
col_blocks = np.ceil(n_bbox / threads_per_block).astype(np.int32)
blocks = (col_blocks, col_blocks, 1)
threads = (threads_per_block, 1, 1)
mask_dev = cp.zeros((n_bbox * col_blocks,), dtype=np.uint64)
bbox = cp.ascontiguousarray(bbox, dtype=np.float32)
kern = _load_kernel('nms_kernel', _nms_gpu_code)
kern(blocks, threads, args=(cp.int32(n_bbox), cp.float32(thresh),
bbox, mask_dev))
mask_host = mask_dev.get()
selection, n_selec = _nms_gpu_post(
mask_host, n_bbox, threads_per_block, col_blocks)
return selection, n_selec
if __name__ == "__main__":
bbox=np.load("bbox.npy")
bbox=cp.asarray(bbox)
mask_dev= cupy_call_nms_kernel(bbox,0.7)
print(mask_dev)