From f6df6aed89c00f3baa270998417ce2b8ca5756c9 Mon Sep 17 00:00:00 2001 From: Bilge Acun Date: Fri, 22 Mar 2019 09:51:27 -0700 Subject: [PATCH] Optimize MomentumSGDUpdate maximum block size and make it templated Summary: Removing the maximum number of blocks limit from the operator and making the nesterov parameter templated to remove branching. Reviewed By: BIT-silence Differential Revision: D14567003 fbshipit-source-id: 394c2039ee214adc6ccd2e562e4e9563d307131f --- caffe2/sgd/momentum_sgd_op_gpu.cu | 82 +++++++++++++++++++++++++++------------ 1 file changed, 58 insertions(+), 24 deletions(-) diff --git a/caffe2/sgd/momentum_sgd_op_gpu.cu b/caffe2/sgd/momentum_sgd_op_gpu.cu index 74d84f1..ebf0aba 100644 --- a/caffe2/sgd/momentum_sgd_op_gpu.cu +++ b/caffe2/sgd/momentum_sgd_op_gpu.cu @@ -4,6 +4,13 @@ namespace caffe2 { +inline int CaffeGetBlocksSGD(const int N) { + return std::max( + (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS, + // Use at least 1 block, since CUDA does not allow empty block + 1); +} +template __global__ void MomentumSGDKernel( const int N, const float* g, @@ -12,27 +19,47 @@ __global__ void MomentumSGDKernel( float* nm, const float* lr, const float momentum, - const bool nesterov, + float* param); + +template <> +__global__ void MomentumSGDKernel( + const int N, + const float* g, + const float* m, + float* ng, + float* nm, + const float* lr, + const float momentum, float* param) { const float LR = lr[0]; - if (!nesterov) { - CUDA_1D_KERNEL_LOOP(i, N) { - const float adjusted_gradient = LR * g[i] + momentum * m[i]; - nm[i] = adjusted_gradient; - ng[i] = adjusted_gradient; - if (param) { - param[i] -= adjusted_gradient; - } + CUDA_1D_KERNEL_LOOP(i, N) { + const float mi = m[i]; + const float mi_new = momentum * mi + LR * g[i]; + nm[i] = mi_new; + ng[i] = fmaf(momentum, mi_new - mi, mi_new); + if (param != nullptr) { + param[i] -= ng[i]; } - } else { - CUDA_1D_KERNEL_LOOP(i, N) { - const float mi = m[i]; - const float mi_new = momentum * mi + LR * g[i]; - nm[i] = mi_new; - ng[i] = (1 + momentum) * mi_new - momentum * mi; - if (param) { - param[i] -= ng[i]; - } + } +} + +template <> +__global__ void MomentumSGDKernel( + const int N, + const float* g, + const float* m, + float* ng, + float* nm, + const float* lr, + const float momentum, + float* param) { + const float LR = lr[0]; + CUDA_1D_KERNEL_LOOP(i, N) { + const float adjusted_gradient = LR * g[i] + momentum * m[i]; + nm[i] = adjusted_gradient; + ng[i] = adjusted_gradient; + if (param != nullptr) { + param[i] -= adjusted_gradient; } } } @@ -49,12 +76,19 @@ void momentum_sgd_update( const bool nesterov, float* param, CUDAContext* context) { - MomentumSGDKernel<<< - CAFFE_GET_BLOCKS(N), - CAFFE_CUDA_NUM_THREADS, - 0, - context->cuda_stream()>>>( - N, g, m, ng, nm, lr, momentum, nesterov, param); + if (nesterov) { + MomentumSGDKernel + <<cuda_stream()>>>(N, g, m, ng, nm, lr, momentum, param); + } else { + MomentumSGDKernel + <<cuda_stream()>>>(N, g, m, ng, nm, lr, momentum, param); + } } -- 2.7.4