Optimize MomentumSGDUpdate maximum block size and make it templated
authorBilge Acun <acun@fb.com>
Fri, 22 Mar 2019 16:51:27 +0000 (09:51 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 22 Mar 2019 16:54:25 +0000 (09:54 -0700)
Summary: Removing the maximum number of blocks limit from the operator and making the nesterov parameter templated to remove branching.

Reviewed By: BIT-silence

Differential Revision: D14567003

fbshipit-source-id: 394c2039ee214adc6ccd2e562e4e9563d307131f

caffe2/sgd/momentum_sgd_op_gpu.cu

index 74d84f1..ebf0aba 100644 (file)
@@ -4,6 +4,13 @@
 
 namespace caffe2 {
 
+inline int CaffeGetBlocksSGD(const int N) {
+  return std::max(
+      (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS,
+      // Use at least 1 block, since CUDA does not allow empty block
+      1);
+}
+template <bool nesterov>
 __global__ void MomentumSGDKernel(
     const int N,
     const float* g,
@@ -12,27 +19,47 @@ __global__ void MomentumSGDKernel(
     float* nm,
     const float* lr,
     const float momentum,
-    const bool nesterov,
+    float* param);
+
+template <>
+__global__ void MomentumSGDKernel<true>(
+    const int N,
+    const float* g,
+    const float* m,
+    float* ng,
+    float* nm,
+    const float* lr,
+    const float momentum,
     float* param) {
   const float LR = lr[0];
-  if (!nesterov) {
-    CUDA_1D_KERNEL_LOOP(i, N) {
-      const float adjusted_gradient =  LR * g[i] + momentum * m[i];
-      nm[i] = adjusted_gradient;
-      ng[i] = adjusted_gradient;
-      if (param) {
-        param[i] -= adjusted_gradient;
-      }
+  CUDA_1D_KERNEL_LOOP(i, N) {
+    const float mi = m[i];
+    const float mi_new = momentum * mi + LR * g[i];
+    nm[i] = mi_new;
+    ng[i] = fmaf(momentum, mi_new - mi, mi_new);
+    if (param != nullptr) {
+      param[i] -= ng[i];
     }
-  } else {
-    CUDA_1D_KERNEL_LOOP(i, N) {
-      const float mi = m[i];
-      const float mi_new = momentum * mi + LR * g[i];
-      nm[i] = mi_new;
-      ng[i] = (1 + momentum) * mi_new - momentum * mi;
-      if (param) {
-        param[i] -= ng[i];
-      }
+  }
+}
+
+template <>
+__global__ void MomentumSGDKernel<false>(
+    const int N,
+    const float* g,
+    const float* m,
+    float* ng,
+    float* nm,
+    const float* lr,
+    const float momentum,
+    float* param) {
+  const float LR = lr[0];
+  CUDA_1D_KERNEL_LOOP(i, N) {
+    const float adjusted_gradient = LR * g[i] + momentum * m[i];
+    nm[i] = adjusted_gradient;
+    ng[i] = adjusted_gradient;
+    if (param != nullptr) {
+      param[i] -= adjusted_gradient;
     }
   }
 }
@@ -49,12 +76,19 @@ void momentum_sgd_update<CUDAContext>(
     const bool nesterov,
     float* param,
     CUDAContext* context) {
-  MomentumSGDKernel<<<
-      CAFFE_GET_BLOCKS(N),
-      CAFFE_CUDA_NUM_THREADS,
-      0,
-      context->cuda_stream()>>>(
-      N, g, m, ng, nm, lr, momentum, nesterov, param);
+  if (nesterov) {
+    MomentumSGDKernel<true>
+        <<<CaffeGetBlocksSGD(N),
+           CAFFE_CUDA_NUM_THREADS,
+           0,
+           context->cuda_stream()>>>(N, g, m, ng, nm, lr, momentum, param);
+  } else {
+    MomentumSGDKernel<false>
+        <<<CaffeGetBlocksSGD(N),
+           CAFFE_CUDA_NUM_THREADS,
+           0,
+           context->cuda_stream()>>>(N, g, m, ng, nm, lr, momentum, param);
+  }
 }