namespace caffe2 {
+inline int CaffeGetBlocksSGD(const int N) {
+ return std::max(
+ (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS,
+ // Use at least 1 block, since CUDA does not allow empty block
+ 1);
+}
+template <bool nesterov>
__global__ void MomentumSGDKernel(
const int N,
const float* g,
float* nm,
const float* lr,
const float momentum,
- const bool nesterov,
+ float* param);
+
+template <>
+__global__ void MomentumSGDKernel<true>(
+ const int N,
+ const float* g,
+ const float* m,
+ float* ng,
+ float* nm,
+ const float* lr,
+ const float momentum,
float* param) {
const float LR = lr[0];
- if (!nesterov) {
- CUDA_1D_KERNEL_LOOP(i, N) {
- const float adjusted_gradient = LR * g[i] + momentum * m[i];
- nm[i] = adjusted_gradient;
- ng[i] = adjusted_gradient;
- if (param) {
- param[i] -= adjusted_gradient;
- }
+ CUDA_1D_KERNEL_LOOP(i, N) {
+ const float mi = m[i];
+ const float mi_new = momentum * mi + LR * g[i];
+ nm[i] = mi_new;
+ ng[i] = fmaf(momentum, mi_new - mi, mi_new);
+ if (param != nullptr) {
+ param[i] -= ng[i];
}
- } else {
- CUDA_1D_KERNEL_LOOP(i, N) {
- const float mi = m[i];
- const float mi_new = momentum * mi + LR * g[i];
- nm[i] = mi_new;
- ng[i] = (1 + momentum) * mi_new - momentum * mi;
- if (param) {
- param[i] -= ng[i];
- }
+ }
+}
+
+template <>
+__global__ void MomentumSGDKernel<false>(
+ const int N,
+ const float* g,
+ const float* m,
+ float* ng,
+ float* nm,
+ const float* lr,
+ const float momentum,
+ float* param) {
+ const float LR = lr[0];
+ CUDA_1D_KERNEL_LOOP(i, N) {
+ const float adjusted_gradient = LR * g[i] + momentum * m[i];
+ nm[i] = adjusted_gradient;
+ ng[i] = adjusted_gradient;
+ if (param != nullptr) {
+ param[i] -= adjusted_gradient;
}
}
}
const bool nesterov,
float* param,
CUDAContext* context) {
- MomentumSGDKernel<<<
- CAFFE_GET_BLOCKS(N),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context->cuda_stream()>>>(
- N, g, m, ng, nm, lr, momentum, nesterov, param);
+ if (nesterov) {
+ MomentumSGDKernel<true>
+ <<<CaffeGetBlocksSGD(N),
+ CAFFE_CUDA_NUM_THREADS,
+ 0,
+ context->cuda_stream()>>>(N, g, m, ng, nm, lr, momentum, param);
+ } else {
+ MomentumSGDKernel<false>
+ <<<CaffeGetBlocksSGD(N),
+ CAFFE_CUDA_NUM_THREADS,
+ 0,
+ context->cuda_stream()>>>(N, g, m, ng, nm, lr, momentum, param);
+ }
}