float epsilon,
float decay,
const float lr) {
- for (auto i = 0; i < N; ++i) {
- float gi = g[i];
- float hi = nh[i] = decay * h[i] + gi * gi;
- nw[i] = w[i] + lr * gi / (std::sqrt(hi) + epsilon);
- }
+ internal::adagrad_update_base_inlined(N, w, g, h, nw, nh, decay, epsilon, lr);
}
void adagrad_update_prefetch__base(
at::Half* /* nh_n */, // prefetch ptr
float epsilon,
float lr) {
- for (auto i = 0; i < N; ++i) {
- float gi = g[i];
- float hi = h[i] + gi * gi;
- nh[i] = hi;
- nw[i] = w[i] + lr * gi / (std::sqrt(hi) + epsilon);
- }
+ internal::adagrad_update_base_inlined(N, w, g, h, nw, nh, 1.0f, epsilon, lr);
}
void rowwise_adagrad_update__base(
int N,
float* w,
- float* /* w_n */, // prefetch ptr
+ float* w_n, // prefetch ptr
const float* g,
float* h,
- float* /* h_n */, // prefetch ptr
+ float* h_n, // prefetch ptr
float epsilon,
float lr) {
- float sum = 0.0f;
- for (auto i = 0; i < N; ++i) {
- sum += g[i] * g[i];
- }
- sum /= N;
-
- float hi = *h = *h + sum;
- float float_step = lr / (std::sqrt(hi) + epsilon);
-
- for (auto i = 0; i < N; ++i) {
- float gi = g[i];
- w[i] = w[i] + gi * float_step;
- }
+ internal::rowwise_adagrad_update_inlined(N, w, w_n, g, h, h_n, epsilon, lr);
}
void adagrad_update_prefetch(
#pragma once
+#if defined(__AVX__) && !defined(__NVCC__) && \
+ (defined(__x86_64__) || defined(_M_X64) || defined(__i386__))
+#define CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
+#include <immintrin.h>
+#endif
#include "caffe2/core/types.h"
namespace caffe2 {
+namespace internal {
+
+template <typename T>
+static inline void adagrad_update_base_inlined(
+ int N,
+ const T* w,
+ const float* g,
+ const T* h,
+ T* nw,
+ T* nh,
+ float decay,
+ float epsilon,
+ float lr) {
+ for (auto i = 0; i < N; ++i) {
+ float gi = g[i];
+ float hi = decay * h[i] + gi * gi;
+ nh[i] = hi;
+ nw[i] = w[i] + lr * gi / (std::sqrt(hi) + epsilon);
+ }
+}
+
+inline void adagrad_update_prefetch_inlined(
+ int N,
+ const float* w,
+#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
+ const float* w_n, // prefetch ptr
+#else
+ const float* /* unused */,
+#endif
+
+ const float* g,
+
+ const float* h,
+#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
+ const float* h_n, // prefetch ptr
+#else
+ const float* /* unused */,
+#endif
+
+ float* nw,
+#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
+ float* nw_n, // prefetch ptr
+#else
+ float* /* unused */,
+#endif
+
+ float* nh,
+#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
+ float* nh_n, // prefetch ptr
+#else
+ float* /* unused */,
+#endif
+
+ float epsilon,
+ float lr) {
+ auto i = 0;
+
+#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
+ constexpr size_t kSize = 8;
+ for (; i + kSize <= N; i += kSize) {
+ _mm_prefetch(&w_n[i], _MM_HINT_T0);
+ _mm_prefetch(&h_n[i], _MM_HINT_T0);
+ _mm_prefetch(&nw_n[i], _MM_HINT_T0);
+ _mm_prefetch(&nh_n[i], _MM_HINT_T0);
+
+ __m256 gi = _mm256_loadu_ps(g + i);
+ __m256 hi = _mm256_loadu_ps(h + i);
+ __m256 wi = _mm256_loadu_ps(w + i);
+
+ __m256 nhi = _mm256_add_ps(hi, _mm256_mul_ps(gi, gi));
+ _mm256_storeu_ps(nh + i, nhi);
+ __m256 vtmp = _mm256_div_ps(
+ gi, _mm256_add_ps(_mm256_sqrt_ps(nhi), _mm256_set1_ps(epsilon)));
+ _mm256_storeu_ps(
+ nw + i, _mm256_add_ps(wi, _mm256_mul_ps(_mm256_set1_ps(lr), vtmp)));
+ }
+#endif
+
+ adagrad_update_base_inlined(
+ N - i, w + i, g + i, h + i, nw + i, nh + i, 1.0f, epsilon, lr);
+}
+
+inline void rowwise_adagrad_update_inlined(
+ int N,
+ float* w,
+#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
+ float* w_n, // prefetch ptr
+#else
+ float* /* unused */,
+#endif
+
+ const float* g,
+
+ float* h,
+#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
+ float* h_n, // prefetch ptr
+#else
+ float* /* unused */,
+#endif
+
+ float epsilon,
+ float lr) {
+ auto i = 0;
+
+#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
+ constexpr size_t kSize = 8;
+ _mm_prefetch(h_n, _MM_HINT_T0);
+ __m256 partial_sum = _mm256_setzero_ps();
+ for (; i + kSize <= N; i += kSize) {
+ __m256 gi = _mm256_loadu_ps(g + i);
+ partial_sum = _mm256_add_ps(partial_sum, _mm256_mul_ps(gi, gi));
+ }
+ // Reduce sum to 1 value
+ __m256 partial_sum_2 = _mm256_hadd_ps(partial_sum, partial_sum);
+ __m256 partial_sum_3 = _mm256_hadd_ps(partial_sum_2, partial_sum_2);
+ float final_sum = _mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3)) +
+ _mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1));
+#else
+ float final_sum = 0.0f;
+#endif
+
+ for (; i < N; ++i) {
+ final_sum += g[i] * g[i];
+ }
+ final_sum /= N;
+
+ float hi = *h = *h + final_sum;
+ float float_step = lr / (std::sqrt(hi) + epsilon);
+
+ i = 0;
+#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
+ __m256 step = _mm256_set1_ps(float_step);
+
+ for (i = 0; i + kSize <= N; i += kSize) {
+ _mm_prefetch(&w_n[i], _MM_HINT_T0);
+
+ __m256 gi = _mm256_loadu_ps(g + i);
+ __m256 wi = _mm256_loadu_ps(w + i);
+
+ _mm256_storeu_ps(w + i, _mm256_add_ps(wi, _mm256_mul_ps(gi, step)));
+ }
+#endif
+
+ for (; i < N; ++i) {
+ float gi = g[i];
+ w[i] = w[i] + gi * float_step;
+ }
+}
+
+} // namespace internal
+
// version with prefetching
// TODO(msmelyan)
// Crux of the computation is computing a / (sqrt(b) + epsilon),
};
} // namespace caffe2
+
+#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
+#undef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
+#endif
-#include "caffe2/core/common.h"
-#include "caffe2/core/context.h"
-#include "caffe2/core/types.h"
#include "caffe2/perfkernels/adagrad.h"
#include "caffe2/perfkernels/cvtsh_ss_bugfix.h"
float epsilon,
float lr) {
- constexpr size_t kSize = 8;
- auto i = 0;
- for (; i + kSize <= N; i += kSize) {
- _mm_prefetch(&w_n[i], _MM_HINT_T0);
- _mm_prefetch(&h_n[i], _MM_HINT_T0);
- _mm_prefetch(&nw_n[i], _MM_HINT_T0);
- _mm_prefetch(&nh_n[i], _MM_HINT_T0);
-
- __m256 gi = _mm256_loadu_ps(g + i);
- __m256 hi = _mm256_loadu_ps(h + i);
- __m256 wi = _mm256_loadu_ps(w + i);
-
- __m256 nhi = _mm256_add_ps(hi, _mm256_mul_ps(gi, gi));
- _mm256_storeu_ps(nh + i, nhi);
- __m256 vtmp = _mm256_div_ps(
- gi, _mm256_add_ps(_mm256_sqrt_ps(nhi), _mm256_set1_ps(epsilon)));
- _mm256_storeu_ps(
- nw + i, _mm256_add_ps(wi, _mm256_mul_ps(_mm256_set1_ps(lr), vtmp)));
- }
-
- for (; i < N; ++i) {
- float gi = g[i];
- float hi = nh[i] = h[i] + gi * gi;
- nw[i] = w[i] + lr * gi / (std::sqrt(hi) + epsilon);
- }
+ internal::adagrad_update_prefetch_inlined(
+ N, w, w_n, g, h, h_n, nw, nw_n, nh, nh_n, epsilon, lr);
}
// Compute adagrad sparse, assumes embedding and momentum are at::Half
float epsilon,
float lr) {
- _mm_prefetch(h_n, _MM_HINT_T0);
-
- constexpr size_t kSize = 8;
- auto i = 0;
-
- __m256 partial_sum = _mm256_setzero_ps();
- for (; i + kSize <= N; i += kSize) {
- __m256 gi = _mm256_loadu_ps(g + i);
- partial_sum = _mm256_add_ps(partial_sum, _mm256_mul_ps(gi, gi));
- }
- // Reduce sum to 1 value
- __m256 partial_sum_2 = _mm256_hadd_ps(partial_sum, partial_sum);
- __m256 partial_sum_3 = _mm256_hadd_ps(partial_sum_2, partial_sum_2);
- float final_sum = _mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3)) +
- _mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1));
-
- for (; i < N; ++i) {
- final_sum += g[i] * g[i];
- }
- final_sum /= N;
-
- float hi = *h = *h + final_sum;
- float float_step = lr / (std::sqrt(hi) + epsilon);
-
- __m256 step = _mm256_set1_ps(float_step);
-
- for (i = 0; i + kSize <= N; i += kSize) {
- _mm_prefetch(&w_n[i], _MM_HINT_T0);
-
- __m256 gi = _mm256_loadu_ps(g + i);
- __m256 wi = _mm256_loadu_ps(w + i);
-
- _mm256_storeu_ps(w + i, _mm256_add_ps(wi, _mm256_mul_ps(gi, step)));
- }
-
- for (; i < N; ++i) {
- float gi = g[i];
- w[i] = w[i] + gi * float_step;
- }
+ internal::rowwise_adagrad_update_inlined(N, w, w_n, g, h, h_n, epsilon, lr);
}
// version without prefetching