inline adagrad functions (#14194)
authorJongsoo Park <jongsoo@fb.com>
Mon, 3 Dec 2018 04:20:24 +0000 (20:20 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 3 Dec 2018 04:23:02 +0000 (20:23 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14194

Inline some of perfkernels/adagrad.h functions for better performance

Reviewed By: hyuen

Differential Revision: D13096351

fbshipit-source-id: b4da8053278d585eabc5389b8a8dcae0f253b413

caffe2/perfkernels/adagrad.cc
caffe2/perfkernels/adagrad.h
caffe2/perfkernels/adagrad_avx.cc

index 752202e..c629cb6 100644 (file)
@@ -16,11 +16,7 @@ void adagrad_update__base(
     float epsilon,
     float decay,
     const float lr) {
-  for (auto i = 0; i < N; ++i) {
-    float gi = g[i];
-    float hi = nh[i] = decay * h[i] + gi * gi;
-    nw[i] = w[i] + lr * gi / (std::sqrt(hi) + epsilon);
-  }
+  internal::adagrad_update_base_inlined(N, w, g, h, nw, nh, decay, epsilon, lr);
 }
 
 void adagrad_update_prefetch__base(
@@ -57,39 +53,22 @@ void adagrad_fp16_update_prefetch__base(
     at::Half* /* nh_n */, // prefetch ptr
     float epsilon,
     float lr) {
-  for (auto i = 0; i < N; ++i) {
-    float gi = g[i];
-    float hi = h[i] + gi * gi;
-    nh[i] = hi;
-    nw[i] = w[i] + lr * gi / (std::sqrt(hi) + epsilon);
-  }
+  internal::adagrad_update_base_inlined(N, w, g, h, nw, nh, 1.0f, epsilon, lr);
 }
 
 void rowwise_adagrad_update__base(
     int N,
     float* w,
-    float* /* w_n */, // prefetch ptr
+    float* w_n, // prefetch ptr
 
     const float* g,
 
     float* h,
-    float* /* h_n */, // prefetch ptr
+    float* h_n, // prefetch ptr
 
     float epsilon,
     float lr) {
-  float sum = 0.0f;
-  for (auto i = 0; i < N; ++i) {
-    sum += g[i] * g[i];
-  }
-  sum /= N;
-
-  float hi = *h = *h + sum;
-  float float_step = lr / (std::sqrt(hi) + epsilon);
-
-  for (auto i = 0; i < N; ++i) {
-    float gi = g[i];
-    w[i] = w[i] + gi * float_step;
-  }
+  internal::rowwise_adagrad_update_inlined(N, w, w_n, g, h, h_n, epsilon, lr);
 }
 
 void adagrad_update_prefetch(
index cb36742..efd4eff 100644 (file)
@@ -1,9 +1,165 @@
 #pragma once
 
+#if defined(__AVX__) && !defined(__NVCC__) && \
+    (defined(__x86_64__) || defined(_M_X64) || defined(__i386__))
+#define CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
+#include <immintrin.h>
+#endif
 #include "caffe2/core/types.h"
 
 namespace caffe2 {
 
+namespace internal {
+
+template <typename T>
+static inline void adagrad_update_base_inlined(
+    int N,
+    const T* w,
+    const float* g,
+    const T* h,
+    T* nw,
+    T* nh,
+    float decay,
+    float epsilon,
+    float lr) {
+  for (auto i = 0; i < N; ++i) {
+    float gi = g[i];
+    float hi = decay * h[i] + gi * gi;
+    nh[i] = hi;
+    nw[i] = w[i] + lr * gi / (std::sqrt(hi) + epsilon);
+  }
+}
+
+inline void adagrad_update_prefetch_inlined(
+    int N,
+    const float* w,
+#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
+    const float* w_n, // prefetch ptr
+#else
+    const float* /* unused */,
+#endif
+
+    const float* g,
+
+    const float* h,
+#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
+    const float* h_n, // prefetch ptr
+#else
+    const float* /* unused */,
+#endif
+
+    float* nw,
+#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
+    float* nw_n, // prefetch ptr
+#else
+    float* /* unused */,
+#endif
+
+    float* nh,
+#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
+    float* nh_n, // prefetch ptr
+#else
+    float* /* unused */,
+#endif
+
+    float epsilon,
+    float lr) {
+  auto i = 0;
+
+#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
+  constexpr size_t kSize = 8;
+  for (; i + kSize <= N; i += kSize) {
+    _mm_prefetch(&w_n[i], _MM_HINT_T0);
+    _mm_prefetch(&h_n[i], _MM_HINT_T0);
+    _mm_prefetch(&nw_n[i], _MM_HINT_T0);
+    _mm_prefetch(&nh_n[i], _MM_HINT_T0);
+
+    __m256 gi = _mm256_loadu_ps(g + i);
+    __m256 hi = _mm256_loadu_ps(h + i);
+    __m256 wi = _mm256_loadu_ps(w + i);
+
+    __m256 nhi = _mm256_add_ps(hi, _mm256_mul_ps(gi, gi));
+    _mm256_storeu_ps(nh + i, nhi);
+    __m256 vtmp = _mm256_div_ps(
+        gi, _mm256_add_ps(_mm256_sqrt_ps(nhi), _mm256_set1_ps(epsilon)));
+    _mm256_storeu_ps(
+        nw + i, _mm256_add_ps(wi, _mm256_mul_ps(_mm256_set1_ps(lr), vtmp)));
+  }
+#endif
+
+  adagrad_update_base_inlined(
+      N - i, w + i, g + i, h + i, nw + i, nh + i, 1.0f, epsilon, lr);
+}
+
+inline void rowwise_adagrad_update_inlined(
+    int N,
+    float* w,
+#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
+    float* w_n, // prefetch ptr
+#else
+    float* /* unused */,
+#endif
+
+    const float* g,
+
+    float* h,
+#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
+    float* h_n, // prefetch ptr
+#else
+    float* /* unused */,
+#endif
+
+    float epsilon,
+    float lr) {
+  auto i = 0;
+
+#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
+  constexpr size_t kSize = 8;
+  _mm_prefetch(h_n, _MM_HINT_T0);
+  __m256 partial_sum = _mm256_setzero_ps();
+  for (; i + kSize <= N; i += kSize) {
+    __m256 gi = _mm256_loadu_ps(g + i);
+    partial_sum = _mm256_add_ps(partial_sum, _mm256_mul_ps(gi, gi));
+  }
+  // Reduce sum to 1 value
+  __m256 partial_sum_2 = _mm256_hadd_ps(partial_sum, partial_sum);
+  __m256 partial_sum_3 = _mm256_hadd_ps(partial_sum_2, partial_sum_2);
+  float final_sum = _mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3)) +
+      _mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1));
+#else
+  float final_sum = 0.0f;
+#endif
+
+  for (; i < N; ++i) {
+    final_sum += g[i] * g[i];
+  }
+  final_sum /= N;
+
+  float hi = *h = *h + final_sum;
+  float float_step = lr / (std::sqrt(hi) + epsilon);
+
+  i = 0;
+#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
+  __m256 step = _mm256_set1_ps(float_step);
+
+  for (i = 0; i + kSize <= N; i += kSize) {
+    _mm_prefetch(&w_n[i], _MM_HINT_T0);
+
+    __m256 gi = _mm256_loadu_ps(g + i);
+    __m256 wi = _mm256_loadu_ps(w + i);
+
+    _mm256_storeu_ps(w + i, _mm256_add_ps(wi, _mm256_mul_ps(gi, step)));
+  }
+#endif
+
+  for (; i < N; ++i) {
+    float gi = g[i];
+    w[i] = w[i] + gi * float_step;
+  }
+}
+
+} // namespace internal
+
 // version with prefetching
 // TODO(msmelyan)
 // Crux of the computation is computing a  / (sqrt(b) + epsilon),
@@ -155,3 +311,7 @@ void sparse_adagrad(
   };
 
 } // namespace caffe2
+
+#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
+#undef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
+#endif
index b4b575b..36e355f 100644 (file)
@@ -1,6 +1,3 @@
-#include "caffe2/core/common.h"
-#include "caffe2/core/context.h"
-#include "caffe2/core/types.h"
 #include "caffe2/perfkernels/adagrad.h"
 #include "caffe2/perfkernels/cvtsh_ss_bugfix.h"
 
@@ -44,31 +41,8 @@ void adagrad_update_prefetch__avx_f16c(
 
     float epsilon,
     float lr) {
-  constexpr size_t kSize = 8;
-  auto i = 0;
-  for (; i + kSize <= N; i += kSize) {
-    _mm_prefetch(&w_n[i], _MM_HINT_T0);
-    _mm_prefetch(&h_n[i], _MM_HINT_T0);
-    _mm_prefetch(&nw_n[i], _MM_HINT_T0);
-    _mm_prefetch(&nh_n[i], _MM_HINT_T0);
-
-    __m256 gi = _mm256_loadu_ps(g + i);
-    __m256 hi = _mm256_loadu_ps(h + i);
-    __m256 wi = _mm256_loadu_ps(w + i);
-
-    __m256 nhi = _mm256_add_ps(hi, _mm256_mul_ps(gi, gi));
-    _mm256_storeu_ps(nh + i, nhi);
-    __m256 vtmp = _mm256_div_ps(
-        gi, _mm256_add_ps(_mm256_sqrt_ps(nhi), _mm256_set1_ps(epsilon)));
-    _mm256_storeu_ps(
-        nw + i, _mm256_add_ps(wi, _mm256_mul_ps(_mm256_set1_ps(lr), vtmp)));
-  }
-
-  for (; i < N; ++i) {
-    float gi = g[i];
-    float hi = nh[i] = h[i] + gi * gi;
-    nw[i] = w[i] + lr * gi / (std::sqrt(hi) + epsilon);
-  }
+  internal::adagrad_update_prefetch_inlined(
+      N, w, w_n, g, h, h_n, nw, nw_n, nh, nh_n, epsilon, lr);
 }
 
 // Compute adagrad sparse, assumes embedding and momentum are at::Half
@@ -131,45 +105,7 @@ void rowwise_adagrad_update__avx_f16c(
 
     float epsilon,
     float lr) {
-  _mm_prefetch(h_n, _MM_HINT_T0);
-
-  constexpr size_t kSize = 8;
-  auto i = 0;
-
-  __m256 partial_sum = _mm256_setzero_ps();
-  for (; i + kSize <= N; i += kSize) {
-    __m256 gi = _mm256_loadu_ps(g + i);
-    partial_sum = _mm256_add_ps(partial_sum, _mm256_mul_ps(gi, gi));
-  }
-  // Reduce sum to 1 value
-  __m256 partial_sum_2 = _mm256_hadd_ps(partial_sum, partial_sum);
-  __m256 partial_sum_3 = _mm256_hadd_ps(partial_sum_2, partial_sum_2);
-  float final_sum = _mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3)) +
-      _mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1));
-
-  for (; i < N; ++i) {
-    final_sum += g[i] * g[i];
-  }
-  final_sum /= N;
-
-  float hi = *h = *h + final_sum;
-  float float_step = lr / (std::sqrt(hi) + epsilon);
-
-  __m256 step = _mm256_set1_ps(float_step);
-
-  for (i = 0; i + kSize <= N; i += kSize) {
-    _mm_prefetch(&w_n[i], _MM_HINT_T0);
-
-    __m256 gi = _mm256_loadu_ps(g + i);
-    __m256 wi = _mm256_loadu_ps(w + i);
-
-    _mm256_storeu_ps(w + i, _mm256_add_ps(wi, _mm256_mul_ps(gi, step)));
-  }
-
-  for (; i < N; ++i) {
-    float gi = g[i];
-    w[i] = w[i] + gi * float_step;
-  }
+  internal::rowwise_adagrad_update_inlined(N, w, w_n, g, h, h_n, epsilon, lr);
 }
 
 // version without prefetching