Optimize relu op on GPU (#18506)
authorXiaomeng Yang <yangxm@fb.com>
Fri, 29 Mar 2019 07:20:25 +0000 (00:20 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 29 Mar 2019 07:23:24 +0000 (00:23 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18506

Optimize relu op on GPU

Reviewed By: houseroad

Differential Revision: D14633171

fbshipit-source-id: bd3afa9a0bae1325d32ad4153736a0c7ecb0ec64

caffe2/operators/relu_op.cu
caffe2/operators/relu_op_cudnn.cc [deleted file]

index 6ee8b15..613ed39 100644 (file)
@@ -4,29 +4,35 @@
 #include <functional>
 
 #include "caffe2/core/context_gpu.h"
+#include "caffe2/utils/math.h"
 
 namespace caffe2 {
 
 namespace {
 
 #ifdef __HIPCC__
-typedef __half2 half2;
-#endif
+using half2 = __half2;
+#endif // __HIPCC__
 
 template <typename T>
-__global__ void ReluCUDAKernel(const int N, const T* X, T* Y) {
-  CUDA_1D_KERNEL_LOOP(i, N) {
-#if __CUDA_ARCH__ >= 350
-    Y[i] = __ldg(X + i) > 0 ? __ldg(X + i) : T(0);
-#else
-    Y[i] = X[i] > 0 ? X[i] : T(0);
-#endif
+__global__ void ReluCUDAKernel(const int N, const T* X, T* Y);
+
+#define DELEGATE_RELU_CUDA_KERNEL(T, MaxFunc)                        \
+  template <>                                                        \
+  __global__ void ReluCUDAKernel<T>(const int N, const T* X, T* Y) { \
+    const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; \
+    if (i < N) {                                                     \
+      Y[i] = MaxFunc(X[i], T(0));                                    \
+    }                                                                \
   }
-}
+DELEGATE_RELU_CUDA_KERNEL(float, fmaxf)
+#undef DELEGATE_RELU_CUDA_KERNEL
 
-__global__ void ReluHalfCUDAKernel(const int N, const half* X, half* Y) {
-  const half kZero = __float2half(0.0f);
-  CUDA_1D_KERNEL_LOOP(i, N) {
+template <>
+__global__ void ReluCUDAKernel<half>(const int N, const half* X, half* Y) {
+  const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+  if (i < N) {
+    const half kZero = __float2half(0.0f);
 #if __CUDA_ARCH__ >= 530
     Y[i] = __hgt(__ldg(X + i), kZero) ? __ldg(X + i) : kZero;
 #else
@@ -35,14 +41,17 @@ __global__ void ReluHalfCUDAKernel(const int N, const half* X, half* Y) {
   }
 }
 
-__global__ void ReluHalf2CUDAKernel(const int N, const half2* X, half2* Y) {
-  const half2 kZero = __float2half2_rn(0.0f);
-  CUDA_1D_KERNEL_LOOP(i, N) {
+template <>
+__global__ void ReluCUDAKernel<half2>(const int N, const half2* X, half2* Y) {
+  const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+  if (i < N) {
+    const half2 kZero = __float2half2_rn(0.0f);
 #if __CUDA_ARCH__ >= 530
     Y[i] = __hmul2(__hgt2(__ldg(X + i), kZero), __ldg(X + i));
 #else
     const float2 xx = __half22float2(X[i]);
-    Y[i] = __floats2half2_rn(xx.x > 0 ? xx.x : 0.f, xx.y > 0 ? xx.y : 0.f);
+    Y[i] =
+        __floats2half2_rn(xx.x > 0.0f ? xx.x : 0.0f, xx.y > 0.0f ? xx.y : 0.0f);
 #endif
   }
 }
@@ -50,22 +59,25 @@ __global__ void ReluHalf2CUDAKernel(const int N, const half2* X, half2* Y) {
 template <typename T>
 __global__ void
 ReluGradientCUDAKernel(const int N, const T* dY, const T* Y, T* dX) {
-  CUDA_1D_KERNEL_LOOP(i, N) {
+  const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+  if (i < N) {
 #if __CUDA_ARCH__ >= 350
-    dX[i] = __ldg(Y + i) > 0 ? __ldg(dY + i) : 0;
+    dX[i] = __ldg(Y + i) > T(0) ? __ldg(dY + i) : T(0);
 #else
-    dX[i] = Y[i] > 0 ? dY[i] : 0;
+    dX[i] = Y[i] > T(0) ? dY[i] : T(0);
 #endif
   }
 }
 
-__global__ void ReluGradientHalfCUDAKernel(
+template <>
+__global__ void ReluGradientCUDAKernel<half>(
     const int N,
     const half* dY,
     const half* Y,
     half* dX) {
-  const half kZero = __float2half(0.0f);
-  CUDA_1D_KERNEL_LOOP(i, N) {
+  const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+  if (i < N) {
+    const half kZero = __float2half(0.0f);
 #if __CUDA_ARCH__ >= 530
     dX[i] = __hgt(__ldg(Y + i), kZero) ? __ldg(dY + i) : kZero;
 #else
@@ -74,19 +86,22 @@ __global__ void ReluGradientHalfCUDAKernel(
   }
 }
 
-__global__ void ReluGradientHalf2CUDAKernel(
+template <>
+__global__ void ReluGradientCUDAKernel<half2>(
     const int N,
     const half2* dY,
     const half2* Y,
     half2* dX) {
-  const half2 kZero = __float2half2_rn(0.0f);
-  CUDA_1D_KERNEL_LOOP(i, N) {
+  const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+  if (i < N) {
+    const half2 kZero = __float2half2_rn(0.0f);
 #if __CUDA_ARCH__ >= 530
     dX[i] = __hmul2(__hgt2(__ldg(Y + i), kZero), __ldg(dY + i));
 #else
     const float2 dy = __half22float2(dY[i]);
     const float2 yy = __half22float2(Y[i]);
-    dX[i] = __floats2half2_rn(yy.x > 0 ? dy.x : 0.f, yy.y > 0 ? dy.y : 0.f);
+    dX[i] =
+        __floats2half2_rn(yy.x > 0.0f ? dy.x : 0.0f, yy.y > 0.0f ? dy.y : 0.0f);
 #endif
   }
 }
@@ -97,11 +112,9 @@ template <>
 template <typename T>
 bool ReluFunctor<CUDAContext>::
 operator()(const int N, const T* X, T* Y, CUDAContext* context) const {
+  const int M = math::DivUp(N, CAFFE_CUDA_NUM_THREADS);
   ReluCUDAKernel<T>
-      <<<CAFFE_GET_BLOCKS(N),
-         CAFFE_CUDA_NUM_THREADS,
-         0,
-         context->cuda_stream()>>>(N, X, Y);
+      <<<M, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(N, X, Y);
   return true;
 }
 
@@ -112,22 +125,18 @@ bool ReluFunctor<CUDAContext>::operator()<at::Half>(
     const at::Half* X,
     at::Half* Y,
     CUDAContext* context) const {
-  if ((N & 1) == 0) {
-    ReluHalf2CUDAKernel<<<
-        CAFFE_GET_BLOCKS((N >> 1)),
-        CAFFE_CUDA_NUM_THREADS,
-        0,
-        context->cuda_stream()>>>(
-        (N >> 1),
-        reinterpret_cast<const half2*>(X),
-        reinterpret_cast<half2*>(Y));
+  if (N % 2 == 0) {
+    const int M = math::DivUp(N / 2, CAFFE_CUDA_NUM_THREADS);
+    ReluCUDAKernel<half2>
+        <<<M, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+            N / 2,
+            reinterpret_cast<const half2*>(X),
+            reinterpret_cast<half2*>(Y));
   } else {
-    ReluHalfCUDAKernel<<<
-        CAFFE_GET_BLOCKS(N),
-        CAFFE_CUDA_NUM_THREADS,
-        0,
-        context->cuda_stream()>>>(
-        N, reinterpret_cast<const half*>(X), reinterpret_cast<half*>(Y));
+    const int M = math::DivUp(N, CAFFE_CUDA_NUM_THREADS);
+    ReluCUDAKernel<half>
+        <<<M, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+            N, reinterpret_cast<const half*>(X), reinterpret_cast<half*>(Y));
   }
   return true;
 }
@@ -141,13 +150,11 @@ bool ReluGradientFunctor<CUDAContext>::Forward(
     const T* dY,
     T* dX,
     CUDAContext* context) const {
-  const int size = std::accumulate(
+  const int N = std::accumulate(
       Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
+  const int M = math::DivUp(N, CAFFE_CUDA_NUM_THREADS);
   ReluGradientCUDAKernel<T>
-      <<<CAFFE_GET_BLOCKS(size),
-         CAFFE_CUDA_NUM_THREADS,
-         0,
-         context->cuda_stream()>>>(size, dY, Y, dX);
+      <<<M, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(N, dY, Y, dX);
   return true;
 }
 
@@ -160,28 +167,24 @@ bool ReluGradientFunctor<CUDAContext>::Forward<at::Half>(
     const at::Half* dY,
     at::Half* dX,
     CUDAContext* context) const {
-  const int size = std::accumulate(
+  const int N = std::accumulate(
       Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
-  if ((size & 1) == 0) {
-    ReluGradientHalf2CUDAKernel<<<
-        CAFFE_GET_BLOCKS((size >> 1)),
-        CAFFE_CUDA_NUM_THREADS,
-        0,
-        context->cuda_stream()>>>(
-        (size >> 1),
-        reinterpret_cast<const half2*>(dY),
-        reinterpret_cast<const half2*>(Y),
-        reinterpret_cast<half2*>(dX));
+  if (N % 2 == 0) {
+    const int M = math::DivUp(N / 2, CAFFE_CUDA_NUM_THREADS);
+    ReluGradientCUDAKernel<half2>
+        <<<M, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+            N / 2,
+            reinterpret_cast<const half2*>(dY),
+            reinterpret_cast<const half2*>(Y),
+            reinterpret_cast<half2*>(dX));
   } else {
-    ReluGradientHalfCUDAKernel<<<
-        CAFFE_GET_BLOCKS(size),
-        CAFFE_CUDA_NUM_THREADS,
-        0,
-        context->cuda_stream()>>>(
-        size,
-        reinterpret_cast<const half*>(dY),
-        reinterpret_cast<const half*>(Y),
-        reinterpret_cast<half*>(dX));
+    const int M = math::DivUp(N, CAFFE_CUDA_NUM_THREADS);
+    ReluGradientCUDAKernel<half>
+        <<<M, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+            N,
+            reinterpret_cast<const half*>(dY),
+            reinterpret_cast<const half*>(Y),
+            reinterpret_cast<half*>(dX));
   }
   return true;
 }
diff --git a/caffe2/operators/relu_op_cudnn.cc b/caffe2/operators/relu_op_cudnn.cc
deleted file mode 100644 (file)
index 75dc6cd..0000000
+++ /dev/null
@@ -1,12 +0,0 @@
-#include "caffe2/operators/relu_op.h"
-
-#include "caffe2/operators/activation_ops_cudnn.h"
-
-namespace caffe2 {
-
-REGISTER_CUDNN_OPERATOR(Relu, CuDNNActivationOp<CUDNN_ACTIVATION_RELU>);
-REGISTER_CUDNN_OPERATOR(
-    ReluGradient,
-    CuDNNActivationGradientOp<CUDNN_ACTIVATION_RELU>);
-
-} // namespace caffe2