optimize elementwise sum (#17456)
authorJongsoo Park <jongsoo@fb.com>
Wed, 27 Feb 2019 18:09:53 +0000 (10:09 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 27 Feb 2019 18:12:41 +0000 (10:12 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17456

Using an instruction sequence similar to function in fbgemm/src/QuantUtilAvx2.cc
elementwise_sum_benchmark added

Reviewed By: protonu

Differential Revision: D14205695

fbshipit-source-id: 84939c9d3551f123deec3baf7086c8d31fbc873e

caffe2/quantization/server/elementwise_sum_benchmark.cc [new file with mode: 0644]
caffe2/quantization/server/elementwise_sum_dnnlowp_op_avx2.cc

diff --git a/caffe2/quantization/server/elementwise_sum_benchmark.cc b/caffe2/quantization/server/elementwise_sum_benchmark.cc
new file mode 100644 (file)
index 0000000..795075d
--- /dev/null
@@ -0,0 +1,36 @@
+#include <chrono>
+#include <cstdint>
+#include <iostream>
+#include <vector>
+
+#include "utility_dnnlowp_ops.h"
+
+using namespace std;
+
+int main(int argc, const char* argv[]) {
+  int LEN = argc > 1 ? atoi(argv[1]) : 65536;
+
+  vector<uint8_t> a(LEN), b(LEN), c_avx2(LEN), c_avx512(LEN);
+  for (int i = 0; i < LEN; ++i) {
+    a[i] = i % 256;
+    b[i] = (i * 2) % 256;
+  }
+
+  chrono::time_point<chrono::system_clock> t = chrono::system_clock::now();
+  caffe2::internal::ElementWiseSumAVX2<uint8_t, false>(
+      a.data(),
+      b.data(),
+      c_avx2.data(),
+      a.size(),
+      1.0f,
+      11,
+      2.0f,
+      22,
+      3.0f,
+      33);
+  double dt = chrono::duration<double>(chrono::system_clock::now() - t).count();
+  double bytes = 3. * LEN * sizeof(a[0]);
+  cout << bytes / dt / 1e9 << " GB/s" << endl;
+
+  return 0;
+}
index 69346c2..156fff5 100644 (file)
@@ -1,7 +1,6 @@
 #include <algorithm>
 #include <cmath>
 #include <cstdint>
-#include <limits>
 
 #include <immintrin.h>
 
@@ -25,82 +24,122 @@ void ElementWiseSumAVX2(
     int32_t b_zero_point,
     float c_scale,
     int32_t c_zero_point) {
-  // TODO: this intrinsic code is replicated in dnnlowp.cc,
-  // fbgemm_i8i8_acc32.cc, conv_dnnlowp_op.cc, and here.
-  // We need to somehow refactor this.
-  __m256 min_v = _mm256_set1_ps(numeric_limits<uint8_t>::min());
-  __m256 max_v = _mm256_set1_ps(numeric_limits<uint8_t>::max());
-
-  __m256i shuffle_mask_v = _mm256_set_epi8(
-      0xff,
-      0xff,
-      0xff,
-      0xff,
-      0xff,
-      0xff,
-      0xff,
-      0xff,
-      0xff,
-      0xff,
-      0xff,
-      0xff,
-      0x0c,
-      0x08,
-      0x04,
-      0x00,
-      0xff,
-      0xff,
-      0xff,
-      0xff,
-      0xff,
-      0xff,
-      0xff,
-      0xff,
-      0xff,
-      0xff,
-      0xff,
-      0xff,
-      0x0c,
-      0x08,
-      0x04,
-      0x00);
   __m256i permute_mask_v =
-      _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00);
+      _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
 
-  int len_aligned = len / VLEN * VLEN;
+  int len_aligned = len / (VLEN * 4) * (VLEN * 4);
   int j = 0;
-  for (; j < len_aligned; j += VLEN) {
+  for (; j < len_aligned; j += VLEN * 4) {
     // Input is uint8_t but cvtepi8_epi32 assumes the input is int8_t,
     // so we subtract 0x80, cvtepi8_epi32, and then add 0x80
-    __m256 in_v = _mm256_cvtepi32_ps(_mm256_add_epi32(
+    // x
+    __m256 in_v0 = _mm256_cvtepi32_ps(_mm256_add_epi32(
         _mm256_cvtepi8_epi32(_mm_sub_epi8(
             _mm_loadl_epi64(reinterpret_cast<const __m128i*>(input0 + j)),
             _mm_set1_epi8(0x80))),
         _mm256_set1_epi32(0x80)));
-    in_v = _mm256_fmadd_ps(
-        in_v, _mm256_set1_ps(a_scale), _mm256_set1_ps(-a_zero_point * a_scale));
-    __m256 acc_v = in_v;
+    in_v0 = _mm256_fmadd_ps(
+        in_v0,
+        _mm256_set1_ps(a_scale),
+        _mm256_set1_ps(-a_zero_point * a_scale - b_zero_point * b_scale));
 
-    in_v = _mm256_cvtepi32_ps(_mm256_add_epi32(
+    __m256 in_v1 = _mm256_cvtepi32_ps(_mm256_add_epi32(
         _mm256_cvtepi8_epi32(_mm_sub_epi8(
             _mm_loadl_epi64(reinterpret_cast<const __m128i*>(input1 + j)),
             _mm_set1_epi8(0x80))),
         _mm256_set1_epi32(0x80)));
-    in_v = _mm256_fmadd_ps(
-        in_v, _mm256_set1_ps(b_scale), _mm256_set1_ps(-b_zero_point * b_scale));
-    acc_v = _mm256_add_ps(acc_v, in_v);
+    __m256 acc_v = _mm256_fmadd_ps(in_v1, _mm256_set1_ps(b_scale), in_v0);
 
-    __m256 transformed_v = _mm256_fmadd_ps(
+    __m256 x_transformed_v = _mm256_fmadd_ps(
         acc_v, _mm256_set1_ps(1.0 / c_scale), _mm256_set1_ps(c_zero_point));
-    __m256 clipped_v = _mm256_min_ps(
-        _mm256_max_ps(
-            transformed_v, ReluFused ? _mm256_set1_ps(c_zero_point) : min_v),
-        max_v);
-    __m256i rounded_v = _mm256_cvtps_epi32(clipped_v);
-    rounded_v = _mm256_shuffle_epi8(rounded_v, shuffle_mask_v);
-    rounded_v = _mm256_permutevar8x32_epi32(rounded_v, permute_mask_v);
-    *reinterpret_cast<int64_t*>(output + j) =
-        _mm256_extract_epi64(rounded_v, 0);
+
+    // y
+    in_v0 = _mm256_cvtepi32_ps(_mm256_add_epi32(
+        _mm256_cvtepi8_epi32(_mm_sub_epi8(
+            _mm_loadl_epi64(
+                reinterpret_cast<const __m128i*>(input0 + j + VLEN)),
+            _mm_set1_epi8(0x80))),
+        _mm256_set1_epi32(0x80)));
+    in_v0 = _mm256_fmadd_ps(
+        in_v0,
+        _mm256_set1_ps(a_scale),
+        _mm256_set1_ps(-a_zero_point * a_scale - b_zero_point * b_scale));
+
+    in_v1 = _mm256_cvtepi32_ps(_mm256_add_epi32(
+        _mm256_cvtepi8_epi32(_mm_sub_epi8(
+            _mm_loadl_epi64(
+                reinterpret_cast<const __m128i*>(input1 + j + VLEN)),
+            _mm_set1_epi8(0x80))),
+        _mm256_set1_epi32(0x80)));
+    acc_v = _mm256_fmadd_ps(in_v1, _mm256_set1_ps(b_scale), in_v0);
+
+    __m256 y_transformed_v = _mm256_fmadd_ps(
+        acc_v, _mm256_set1_ps(1.0 / c_scale), _mm256_set1_ps(c_zero_point));
+
+    // z
+    in_v0 = _mm256_cvtepi32_ps(_mm256_add_epi32(
+        _mm256_cvtepi8_epi32(_mm_sub_epi8(
+            _mm_loadl_epi64(
+                reinterpret_cast<const __m128i*>(input0 + j + 2 * VLEN)),
+            _mm_set1_epi8(0x80))),
+        _mm256_set1_epi32(0x80)));
+    in_v0 = _mm256_fmadd_ps(
+        in_v0,
+        _mm256_set1_ps(a_scale),
+        _mm256_set1_ps(-a_zero_point * a_scale - b_zero_point * b_scale));
+
+    in_v1 = _mm256_cvtepi32_ps(_mm256_add_epi32(
+        _mm256_cvtepi8_epi32(_mm_sub_epi8(
+            _mm_loadl_epi64(
+                reinterpret_cast<const __m128i*>(input1 + j + 2 * VLEN)),
+            _mm_set1_epi8(0x80))),
+        _mm256_set1_epi32(0x80)));
+    acc_v = _mm256_fmadd_ps(in_v1, _mm256_set1_ps(b_scale), in_v0);
+
+    __m256 z_transformed_v = _mm256_fmadd_ps(
+        acc_v, _mm256_set1_ps(1.0 / c_scale), _mm256_set1_ps(c_zero_point));
+
+    // w
+    in_v0 = _mm256_cvtepi32_ps(_mm256_add_epi32(
+        _mm256_cvtepi8_epi32(_mm_sub_epi8(
+            _mm_loadl_epi64(
+                reinterpret_cast<const __m128i*>(input0 + j + 3 * VLEN)),
+            _mm_set1_epi8(0x80))),
+        _mm256_set1_epi32(0x80)));
+    in_v0 = _mm256_fmadd_ps(
+        in_v0,
+        _mm256_set1_ps(a_scale),
+        _mm256_set1_ps(-a_zero_point * a_scale - b_zero_point * b_scale));
+
+    in_v1 = _mm256_cvtepi32_ps(_mm256_add_epi32(
+        _mm256_cvtepi8_epi32(_mm_sub_epi8(
+            _mm_loadl_epi64(
+                reinterpret_cast<const __m128i*>(input1 + j + 3 * VLEN)),
+            _mm_set1_epi8(0x80))),
+        _mm256_set1_epi32(0x80)));
+    acc_v = _mm256_fmadd_ps(in_v1, _mm256_set1_ps(b_scale), in_v0);
+
+    __m256 w_transformed_v = _mm256_fmadd_ps(
+        acc_v, _mm256_set1_ps(1.0 / c_scale), _mm256_set1_ps(c_zero_point));
+
+    // See fbgemm/src/QuantUtilsAvx2.cc requantizeOutputProcessingAvx2 function
+    // for more details on this instruction sequence
+    __m256i x_rounded_v = _mm256_cvtps_epi32(x_transformed_v);
+    __m256i y_rounded_v = _mm256_cvtps_epi32(y_transformed_v);
+    __m256i z_rounded_v = _mm256_cvtps_epi32(z_transformed_v);
+    __m256i w_rounded_v = _mm256_cvtps_epi32(w_transformed_v);
+
+    __m256i xy_packed_v = _mm256_packs_epi32(x_rounded_v, y_rounded_v);
+    __m256i zw_packed_v = _mm256_packs_epi32(z_rounded_v, w_rounded_v);
+    __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v);
+    __m256i xyzw_clamped_v = _mm256_max_epu8(
+        ReluFused ? _mm256_set1_epi8(c_zero_point) : _mm256_setzero_si256(),
+        _mm256_min_epu8(
+            xyzw_packed_v, _mm256_set1_epi8(static_cast<uint8_t>(255))));
+
+    xyzw_clamped_v =
+        _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
+    _mm256_storeu_si256(reinterpret_cast<__m256i*>(output + j), xyzw_clamped_v);
   }
   for (; j < len; ++j) {
     float acc = 0;