Skip to content
This repository has been archived by the owner on Feb 7, 2023. It is now read-only.

Commit

Permalink
Implement CUDA version of GRU operator
Browse files Browse the repository at this point in the history
Summary: Add CUDA version of GRU operator

Reviewed By: jamesr66a

Differential Revision: D5571043

fbshipit-source-id: 332aa64fc8a9116cc33382f2b2907080e58c13b3
  • Loading branch information
Jianlong Zhong authored and facebook-github-bot committed Aug 8, 2017
1 parent 3196d8b commit 3383b68
Show file tree
Hide file tree
Showing 4 changed files with 261 additions and 134 deletions.
102 changes: 0 additions & 102 deletions caffe2/operators/gru_unit_op.cc
Original file line number Diff line number Diff line change
@@ -1,107 +1,6 @@
#include "gru_unit_op.h"

namespace caffe2 {
namespace detail {

template <typename T>
inline T sigmoid(T x) {
return 1.0f / (1.0f + exp(-x));
}

template <typename T>
inline T host_tanh(T x) {
return 2.0f * sigmoid(2.0f * x) - 1.0f;
}

template <typename T, typename Context>
void GRUUnit(
int N,
int D,
int t,
const T* H_prev,
const T* X,
const int32_t* seqLengths,
bool drop_states,
T* H) {
for (int n = 0; n < N; ++n) {
const bool valid = t < seqLengths[n];

for (int d = 0; d < D; ++d) {
if (valid == false) {
if (drop_states) {
H[d] = 0;
} else {
H[d] = H_prev[d];
}
} else {
const T update = X[1 * D + d];
const T output = X[2 * D + d];
H[d] = H_prev[d] * sigmoid(update) +
host_tanh(output) * (1.0f - sigmoid(update));
}
}

H_prev += D;
X += 3 * D;
H += D;
}
}

template <typename T, typename Context>
void GRUUnitGradient(
int N,
int D,
int t,
const T* H_prev,
const T* X,
const int32_t* seqLengths,
const T* H,
const T* H_diff,
bool drop_states,
T* H_prev_diff,
T* X_diff) {
for (int n = 0; n < N; ++n) {
const bool valid = t < seqLengths[n];

for (int d = 0; d < D; ++d) {
T* h_prev_diff = H_prev_diff + d;
T* reset_diff = X_diff + 0 * D + d;
T* update_diff = X_diff + 1 * D + d;
T* output_diff = X_diff + 2 * D + d;

if (!valid) {
if (drop_states) {
*h_prev_diff = 0;
} else {
*h_prev_diff = H_diff[d];
}
*reset_diff = 0;
*update_diff = 0;
*output_diff = 0;
} else {
// Calculate Gate Outputs
const T u = sigmoid(X[1 * D + d]);
const T o = host_tanh(X[2 * D + d]);

*h_prev_diff = H_diff[d] * u;
*reset_diff = 0; // 0 contribution to gradient from this operation
*update_diff = (H_diff[d] * H_prev[d] - H_diff[d] * o) * u * (1.0f - u);
*output_diff = H_diff[d] * (1.0f - u) * (1.0f - o * o);
}
}

H_prev += D;
X += 3 * D;
H += D;
H_diff += D;
X_diff += 3 * D;
H_prev_diff += D;
}
}

} // namespace detail

namespace {
REGISTER_CPU_OPERATOR(GRUUnit, GRUUnitOp<float, CPUContext>);
OPERATOR_SCHEMA(GRUUnit)
.NumInputs(4)
Expand Down Expand Up @@ -147,5 +46,4 @@ class GetGRUUnitGradient : public GradientMakerBase {
}
};
REGISTER_GRADIENT(GRUUnit, GetGRUUnitGradient);
}
} // namespace caffe2
85 changes: 80 additions & 5 deletions caffe2/operators/gru_unit_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@
namespace caffe2 {
namespace detail {

template <typename T>
inline T sigmoid(T x) {
return 1.0f / (1.0f + exp(-x));
}

template <typename T>
inline T host_tanh(T x) {
return 2.0f * sigmoid(2.0f * x) - 1.0f;
}

template <typename T, typename Context>
void GRUUnit(
int N,
Expand All @@ -17,7 +27,32 @@ void GRUUnit(
const T* X,
const int32_t* seqLengths,
bool drop_states,
T* H);
T* H,
Context* /*context*/) {
for (int n = 0; n < N; ++n) {
const bool valid = t < seqLengths[n];

for (int d = 0; d < D; ++d) {
if (!valid) {
if (drop_states) {
H[d] = 0;
} else {
H[d] = H_prev[d];
}
} else {
const T update = X[1 * D + d];
const T output = X[2 * D + d];
T sigmoid_update = sigmoid(update);
H[d] = H_prev[d] * sigmoid_update +
host_tanh(output) * (1.0f - sigmoid_update);
}
}

H_prev += D;
X += 3 * D;
H += D;
}
}

template <typename T, typename Context>
void GRUUnitGradient(
Expand All @@ -31,9 +66,48 @@ void GRUUnitGradient(
const T* H_diff,
bool drop_states,
T* H_prev_diff,
T* X_diff);
T* X_diff,
Context* /*context*/) {
for (int n = 0; n < N; ++n) {
const bool valid = t < seqLengths[n];

for (int d = 0; d < D; ++d) {
T* h_prev_diff = H_prev_diff + d;
T* reset_diff = X_diff + 0 * D + d;
T* update_diff = X_diff + 1 * D + d;
T* output_diff = X_diff + 2 * D + d;

if (!valid) {
if (drop_states) {
*h_prev_diff = 0;
} else {
*h_prev_diff = H_diff[d];
}
*reset_diff = 0;
*update_diff = 0;
*output_diff = 0;
} else {
// Calculate Gate Outputs
const T u = sigmoid(X[1 * D + d]);
const T o = host_tanh(X[2 * D + d]);

*h_prev_diff = H_diff[d] * u;
*reset_diff = 0; // 0 contribution to gradient from this operation
*update_diff = (H_diff[d] * H_prev[d] - H_diff[d] * o) * u * (1.0f - u);
*output_diff = H_diff[d] * (1.0f - u) * (1.0f - o * o);
}
}

H_prev += D;
X += 3 * D;
H += D;
H_diff += D;
X_diff += 3 * D;
H_prev_diff += D;
}
}

}; // namespace detail
} // namespace detail

template <typename T, typename Context>
class GRUUnitOp : public Operator<Context> {
Expand Down Expand Up @@ -64,7 +138,7 @@ class GRUUnitOp : public Operator<Context> {
auto* H = Output(HIDDEN_T)->template mutable_data<T>();

detail::GRUUnit<T, Context>(
N, D, t, H_prev, X, seqLengths, drop_states_, H);
N, D, t, H_prev, X, seqLengths, drop_states_, H, &context_);
return true;
}

Expand Down Expand Up @@ -118,7 +192,8 @@ class GRUUnitGradientOp : public Operator<Context> {
H_diff,
drop_states_,
H_prev_diff,
X_diff);
X_diff,
&context_);
return true;
}

Expand Down
140 changes: 140 additions & 0 deletions caffe2/operators/gru_unit_op_gpu.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
#include <algorithm>
#include <cmath>
#include <vector>
#include "caffe2/core/context_gpu.h"
#include "gru_unit_op.h"

namespace caffe2 {

namespace detail {

template <typename Dtype>
__device__ Dtype cuda_sigmoid(const Dtype x) {
return Dtype(1) / (Dtype(1) + exp(-x));
}

template <typename T>
__global__ void GRUUnitKernel(
const int ND,
const int dim,
const int t,
const T* H_prev,
const T* X,
const int32_t* seqLengths,
bool drop_states,
T* H) {
// index is virtual thread ID in range [0, ND)
CUDA_1D_KERNEL_LOOP(index, ND) {
const int n = index / dim;
const int d = index % dim;
const bool valid = t < seqLengths[n];
if (!valid) {
H[index] = H_prev[index] * !drop_states;
} else {
const T* X_offset = X + 3 * dim * n;
const T update = X_offset[1 * dim + d];
const T output = X_offset[2 * dim + d];
T sigmoid_update = cuda_sigmoid(update);
H[index] = H_prev[index] * sigmoid_update +
tanh(output) * (1.0f - sigmoid_update);
}
}
}

template <typename T>
__global__ void GRUUnitGradientKernel(
const int ND,
const int dim,
const int t,
const T* H_prev,
const T* X,
const int32_t* seqLengths,
const T* H,
const T* H_diff,
bool drop_states,
T* H_prev_diff,
T* X_diff) {
CUDA_1D_KERNEL_LOOP(index, ND) {
const int n = index / dim;
const bool valid = t < seqLengths[n];
const int d = index % dim;
const T* X_offset = X + 3 * dim * n;
T* h_prev_diff = H_prev_diff + index;
T* X_diff_offset = X_diff + 3 * dim * n;
T* reset_diff = X_diff_offset + 0 * dim + d;
T* update_diff = X_diff_offset + 1 * dim + d;
T* output_diff = X_diff_offset + 2 * dim + d;

if (!valid) {
*h_prev_diff = H_diff[index] * !drop_states;
*reset_diff = 0;
*update_diff = 0;
*output_diff = 0;
} else {
const T u = cuda_sigmoid(X_offset[1 * dim + d]);
const T o = tanh(X_offset[2 * dim + d]);

*h_prev_diff = H_diff[index] * u;
*reset_diff = 0; // 0 contribution to gradient from this operation
*update_diff =
(H_diff[index] * H_prev[index] - H_diff[index] * o) * u * (1.0f - u);
*output_diff = H_diff[index] * (1.0f - u) * (1.0f - o * o);
}
}
}

template <>
void GRUUnit<float, CUDAContext>(
int N,
int D,
int t,
const float* H_prev,
const float* X,
const int32_t* seqLengths,
bool drop_states,
float* H,
CUDAContext* context) {
GRUUnitKernel<float>
<<<CAFFE_GET_BLOCKS(N * D),
CAFFE_CUDA_NUM_THREADS,
0,
context->cuda_stream()>>>(
N * D, D, t, H_prev, X, seqLengths, drop_states, H);
}
template <>
void GRUUnitGradient<float, CUDAContext>(
int N,
int D,
int t,
const float* H_prev,
const float* X,
const int32_t* seqLengths,
const float* H,
const float* H_diff,
bool drop_states,
float* H_prev_diff,
float* X_diff,
CUDAContext* context) {
GRUUnitGradientKernel<float>
<<<CAFFE_GET_BLOCKS(N * D),
CAFFE_CUDA_NUM_THREADS,
0,
context->cuda_stream()>>>(
N * D,
D,
t,
H_prev,
X,
seqLengths,
H,
H_diff,
drop_states,
H_prev_diff,
X_diff);
}
}
REGISTER_CUDA_OPERATOR(GRUUnit, GRUUnitOp<float, CUDAContext>);
REGISTER_CUDA_OPERATOR(GRUUnitGradient, GRUUnitGradientOp<float, CUDAContext>);
}
Loading

0 comments on commit 3383b68

Please sign in to comment.