Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3d structure #895

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions lib/model/nms/_ext/nms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +0,0 @@

from torch.utils.ffi import _wrap_function
from ._nms import lib as _lib, ffi as _ffi

__all__ = []
def _import_symbols(locals):
for symbol in dir(_lib):
fn = getattr(_lib, symbol)
if callable(fn):
locals[symbol] = _wrap_function(fn, _ffi)
else:
locals[symbol] = fn
__all__.append(symbol)

_import_symbols(locals())
19 changes: 0 additions & 19 deletions lib/model/nms/src/nms_cuda.c
Original file line number Diff line number Diff line change
@@ -1,19 +0,0 @@
#include <THC/THC.h>
#include <stdio.h>
#include "nms_cuda_kernel.h"

// this symbol will be resolved automatically from PyTorch libs
extern THCState *state;

int nms_cuda(THCudaIntTensor *keep_out, THCudaTensor *boxes_host,
THCudaIntTensor *num_out, float nms_overlap_thresh) {

nms_cuda_compute(THCudaIntTensor_data(state, keep_out),
THCudaIntTensor_data(state, num_out),
THCudaTensor_data(state, boxes_host),
THCudaTensor_size(state, boxes_host, 0),
THCudaTensor_size(state, boxes_host, 1),
nms_overlap_thresh);

return 1;
}
5 changes: 0 additions & 5 deletions lib/model/nms/src/nms_cuda.h
Original file line number Diff line number Diff line change
@@ -1,5 +0,0 @@
// int nms_cuda(THCudaTensor *keep_out, THCudaTensor *num_out,
// THCudaTensor *boxes_host, THCudaTensor *nms_overlap_thresh);

int nms_cuda(THCudaIntTensor *keep_out, THCudaTensor *boxes_host,
THCudaIntTensor *num_out, float nms_overlap_thresh);
161 changes: 0 additions & 161 deletions lib/model/nms/src/nms_cuda_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,161 +0,0 @@
// ------------------------------------------------------------------
// Faster R-CNN
// Copyright (c) 2015 Microsoft
// Licensed under The MIT License [see fast-rcnn/LICENSE for details]
// Written by Shaoqing Ren
// ------------------------------------------------------------------

#include <stdbool.h>
#include <stdio.h>
#include <vector>
#include <iostream>
#include "nms_cuda_kernel.h"

#define CUDA_WARN(XXX) \
do { if (XXX != cudaSuccess) std::cout << "CUDA Error: " << \
cudaGetErrorString(XXX) << ", at line " << __LINE__ \
<< std::endl; cudaDeviceSynchronize(); } while (0)

#define CUDA_CHECK(condition) \
/* Code block avoids redefinition of cudaError_t error */ \
do { \
cudaError_t error = condition; \
if (error != cudaSuccess) { \
std::cout << cudaGetErrorString(error) << std::endl; \
} \
} while (0)

#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
int const threadsPerBlock = sizeof(unsigned long long) * 8;

__device__ inline float devIoU(float const * const a, float const * const b) {
float left = max(a[0], b[0]), right = min(a[2], b[2]);
float top = max(a[1], b[1]), bottom = min(a[3], b[3]);
float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f);
float interS = width * height;
float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1);
float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1);
return interS / (Sa + Sb - interS);
}

__global__ void nms_kernel(int n_boxes, float nms_overlap_thresh,
float *dev_boxes, unsigned long long *dev_mask) {
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;

// if (row_start > col_start) return;

const int row_size =
min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
const int col_size =
min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);

__shared__ float block_boxes[threadsPerBlock * 5];
if (threadIdx.x < col_size) {
block_boxes[threadIdx.x * 5 + 0] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0];
block_boxes[threadIdx.x * 5 + 1] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1];
block_boxes[threadIdx.x * 5 + 2] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2];
block_boxes[threadIdx.x * 5 + 3] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3];
block_boxes[threadIdx.x * 5 + 4] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4];
}
__syncthreads();

if (threadIdx.x < row_size) {
const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
const float *cur_box = dev_boxes + cur_box_idx * 5;
int i = 0;
unsigned long long t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {
t |= 1ULL << i;
}
}
const int col_blocks = DIVUP(n_boxes, threadsPerBlock);
dev_mask[cur_box_idx * col_blocks + col_start] = t;
}
}

void nms_cuda_compute(int* keep_out, int *num_out, float* boxes_host, int boxes_num,
int boxes_dim, float nms_overlap_thresh) {

float* boxes_dev = NULL;
unsigned long long* mask_dev = NULL;

const int col_blocks = DIVUP(boxes_num, threadsPerBlock);

CUDA_CHECK(cudaMalloc(&boxes_dev,
boxes_num * boxes_dim * sizeof(float)));
CUDA_CHECK(cudaMemcpy(boxes_dev,
boxes_host,
boxes_num * boxes_dim * sizeof(float),
cudaMemcpyHostToDevice));

CUDA_CHECK(cudaMalloc(&mask_dev,
boxes_num * col_blocks * sizeof(unsigned long long)));

dim3 blocks(DIVUP(boxes_num, threadsPerBlock),
DIVUP(boxes_num, threadsPerBlock));
dim3 threads(threadsPerBlock);

// printf("i am at line %d\n", boxes_num);
// printf("i am at line %d\n", boxes_dim);

nms_kernel<<<blocks, threads>>>(boxes_num,
nms_overlap_thresh,
boxes_dev,
mask_dev);

std::vector<unsigned long long> mask_host(boxes_num * col_blocks);
CUDA_CHECK(cudaMemcpy(&mask_host[0],
mask_dev,
sizeof(unsigned long long) * boxes_num * col_blocks,
cudaMemcpyDeviceToHost));

std::vector<unsigned long long> remv(col_blocks);
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);

// we need to create a memory for keep_out on cpu
// otherwise, the following code cannot run

int* keep_out_cpu = new int[boxes_num];

int num_to_keep = 0;
for (int i = 0; i < boxes_num; i++) {
int nblock = i / threadsPerBlock;
int inblock = i % threadsPerBlock;

if (!(remv[nblock] & (1ULL << inblock))) {
// orignal: keep_out[num_to_keep++] = i;
keep_out_cpu[num_to_keep++] = i;
unsigned long long *p = &mask_host[0] + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv[j] |= p[j];
}
}
}

// copy keep_out_cpu to keep_out on gpu
CUDA_WARN(cudaMemcpy(keep_out, keep_out_cpu, boxes_num * sizeof(int),cudaMemcpyHostToDevice));

// *num_out = num_to_keep;

// original: *num_out = num_to_keep;
// copy num_to_keep to num_out on gpu

CUDA_WARN(cudaMemcpy(num_out, &num_to_keep, 1 * sizeof(int),cudaMemcpyHostToDevice));

// release cuda memory
CUDA_CHECK(cudaFree(boxes_dev));
CUDA_CHECK(cudaFree(mask_dev));
// release cpu memory
delete []keep_out_cpu;
}
10 changes: 0 additions & 10 deletions lib/model/nms/src/nms_cuda_kernel.h
Original file line number Diff line number Diff line change
@@ -1,10 +0,0 @@
#ifdef __cplusplus
extern "C" {
#endif

void nms_cuda_compute(int* keep_out, int *num_out, float* boxes_host, int boxes_num,
int boxes_dim, float nms_overlap_thresh);

#ifdef __cplusplus
}
#endif
5 changes: 5 additions & 0 deletions lib/model/roi_align/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .functions.roi_align import roi_align
from .modules.roi_align import RoIAlign
from .modules.roi_align_3d import RoIAlign3D

__all__ = ['roi_align', 'RoIAlign', 'RoIAlign3D']
Empty file.
15 changes: 0 additions & 15 deletions lib/model/roi_align/_ext/roi_align/__init__.py

This file was deleted.

38 changes: 0 additions & 38 deletions lib/model/roi_align/build.py

This file was deleted.

Empty file modified lib/model/roi_align/functions/__init__.py
100644 → 100755
Empty file.
76 changes: 43 additions & 33 deletions lib/model/roi_align/functions/roi_align.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,51 +1,61 @@
import torch
from torch.autograd import Function
from .._ext import roi_align

from .. import roi_align_cuda


# TODO use save_for_backward instead
class RoIAlignFunction(Function):
def __init__(self, aligned_height, aligned_width, spatial_scale):
self.aligned_width = int(aligned_width)
self.aligned_height = int(aligned_height)
self.spatial_scale = float(spatial_scale)
self.rois = None
self.feature_size = None

def forward(self, features, rois):
self.rois = rois
self.feature_size = features.size()
@staticmethod
def forward(ctx, features, rois, out_size, spatial_scale, sample_num=0):
if isinstance(out_size, int):
out_h = out_size
out_w = out_size
elif isinstance(out_size, tuple):
assert len(out_size) == 2
assert isinstance(out_size[0], int)
assert isinstance(out_size[1], int)
out_h, out_w = out_size
else:
raise TypeError(
'"out_size" must be an integer or tuple of integers')
ctx.spatial_scale = spatial_scale
ctx.sample_num = sample_num
ctx.save_for_backward(rois)
ctx.feature_size = features.size()

batch_size, num_channels, data_height, data_width = features.size()
num_rois = rois.size(0)

output = features.new(num_rois, num_channels, self.aligned_height, self.aligned_width).zero_()
output = features.new_zeros(num_rois, num_channels, out_h, out_w)
if features.is_cuda:
roi_align.roi_align_forward_cuda(self.aligned_height,
self.aligned_width,
self.spatial_scale, features,
rois, output)
roi_align_cuda.forward(features, rois, out_h, out_w, spatial_scale,
sample_num, output)
else:
roi_align.roi_align_forward(self.aligned_height,
self.aligned_width,
self.spatial_scale, features,
rois, output)
# raise NotImplementedError
raise NotImplementedError

return output

def backward(self, grad_output):
assert(self.feature_size is not None and grad_output.is_cuda)
@staticmethod
def backward(ctx, grad_output):
feature_size = ctx.feature_size
spatial_scale = ctx.spatial_scale
sample_num = ctx.sample_num
rois = ctx.saved_tensors[0]
assert (feature_size is not None and grad_output.is_cuda)

batch_size, num_channels, data_height, data_width = feature_size
out_w = grad_output.size(3)
out_h = grad_output.size(2)

batch_size, num_channels, data_height, data_width = self.feature_size
grad_input = grad_rois = None
if ctx.needs_input_grad[0]:
grad_input = rois.new_zeros(batch_size, num_channels, data_height,
data_width)
roi_align_cuda.backward(grad_output.contiguous(), rois, out_h,
out_w, spatial_scale, sample_num,
grad_input)

grad_input = self.rois.new(batch_size, num_channels, data_height,
data_width).zero_()
roi_align.roi_align_backward_cuda(self.aligned_height,
self.aligned_width,
self.spatial_scale, grad_output,
self.rois, grad_input)
return grad_input, grad_rois, None, None, None

# print grad_input

return grad_input, None
roi_align = RoIAlignFunction.apply
10 changes: 0 additions & 10 deletions lib/model/roi_align/make.sh

This file was deleted.

Empty file modified lib/model/roi_align/modules/__init__.py
100644 → 100755
Empty file.
Loading