minimize header file includes from _avx2.cc (#14950)
authorJongsoo Park <jongsoo@fb.com>
Thu, 13 Dec 2018 08:15:51 +0000 (00:15 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 13 Dec 2018 08:18:11 +0000 (00:18 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14950

Minimize the number of headers included from _avx2.cc files to avoid accidental compilation of functions defined the header files reused by other translation units that can lead to illegal instruction errors.

Reviewed By: dskhudia

Differential Revision: D13394483

fbshipit-source-id: 67149a6fb51f7f047e745bfe395cb6dd4ae7c1ae

12 files changed:
caffe2/operators/fused_rowwise_random_quantization_ops.cc
caffe2/operators/fused_rowwise_random_quantization_ops.h
caffe2/perfkernels/adagrad.cc
caffe2/perfkernels/adagrad.h
caffe2/perfkernels/common.h
caffe2/perfkernels/embedding_lookup_avx2.cc
caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_avx2.cc
caffe2/perfkernels/hp_emblookup_codegen.py
caffe2/perfkernels/math.h
caffe2/perfkernels/math_cpu_avx2.cc
caffe2/perfkernels/math_cpu_base.cc
caffe2/perfkernels/typed_axpy_avx2.cc

index 7b5070a..e7cb974 100644 (file)
@@ -1,5 +1,5 @@
 #include "caffe2/operators/fused_rowwise_random_quantization_ops.h"
-#include "c10/util/Registry.h"
+#include <c10/util/Registry.h>
 #include "caffe2/utils/math.h"
 
 namespace caffe2 {
@@ -48,26 +48,36 @@ bool FloatToFusedRandRowwiseQuantizedOp<Context>::RunOnDevice() {
   memset(output_data, 0, output->numel());
 
   if (random_) {
-#ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
     random_buffer_.resize(input_columns);
-#endif
   }
 
   for (size_t row = 0; row < input_rows; ++row) {
+    if (random_) {
+#ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
+      int status = vsRngUniform(
+          VSL_RNG_METHOD_UNIFORM_STD,
+          vslStream_,
+          input_columns,
+          random_buffer_.data(),
+          0.0f,
+          1.0f);
+      if (status != VSL_ERROR_OK) {
+        LOG(WARNING) << "vsRngUniform returns " << status;
+      }
+#else
+      for (int i = 0; i < input_columns; ++i) {
+        random_buffer_[i] = (*dis_)(gen_);
+      }
+#endif
+    }
+
     math::quantize_and_compress(
         input_data + row * input_columns,
         output_data + row * output_columns,
         input_columns,
         bitwidth_,
         random_,
-#ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
-        vslStream_,
-        random_buffer_
-#else
-        dis_,
-        gen_
-#endif
-    );
+        random_buffer_.data());
   }
 
   return true;
index 1cbb69a..e1c5cb6 100644 (file)
 #include "caffe2/perfkernels/math.h"
 #include "caffe2/utils/math.h"
 
+#ifdef CAFFE2_USE_MKL
+#include <mkl.h>
+#define FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
+#endif
+
 namespace caffe2 {
 
 template <class Context>
@@ -61,9 +66,10 @@ class FloatToFusedRandRowwiseQuantizedOp : public Operator<Context> {
  protected:
   size_t bitwidth_{8};
   bool random_{true};
+  std::vector<float> random_buffer_;
+
 #ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
   VSLStreamStatePtr vslStream_;
-  std::vector<float> random_buffer_;
 #else
   std::unique_ptr<std::uniform_real_distribution<float>> dis_;
   std::minstd_rand gen_;
index c629cb6..0d6e25e 100644 (file)
@@ -195,21 +195,6 @@ void adagrad_update(
   BASE_DO(adagrad_update, N, w, g, h, nw, nh, epsilon, decay, lr);
 }
 
-template <typename SIndex>
-void sparse_adagrad(
-    int num_rows,
-    int block_size,
-    size_t param_size,
-    const float* w,
-    const float* g,
-    const float* h,
-    const SIndex* indices,
-    float* nw,
-    float* nh,
-    float epsilon,
-    float lr,
-    const std::string& param_name);
-
 SPARSE_ADAGRAD_SPECIALIZATION(int32_t, base);
 
 template <>
index efd4eff..6ce1965 100644 (file)
@@ -5,7 +5,8 @@
 #define CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
 #include <immintrin.h>
 #endif
-#include "caffe2/core/types.h"
+#include <ATen/core/Half.h>
+#include <c10/util/Logging.h>
 
 namespace caffe2 {
 
index 91e9e86..b128c76 100644 (file)
@@ -1,3 +1,9 @@
+// !!!! PLEASE READ !!!!
+// Minimize (transitively) included headers from _avx*.cc because some of the
+// functions defined in the headers compiled with platform dependent compiler
+// options can be reused by other translation units generating illegal
+// instruction run-time error.
+
 // Common utilities for writing performance kernels and easy dispatching of
 // different backends.
 /*
index cd5cb73..e470779 100644 (file)
@@ -5,9 +5,10 @@
 //// DO NOT MODIFY!!!
 //// --------------------------
 
-#include <caffe2/core/common.h>
-#include <caffe2/core/types.h>
+#include <ATen/core/Half.h>
+#include <c10/util/Logging.h>
 #include <immintrin.h>
+#include <cassert>
 
 namespace caffe2 {
 
@@ -1309,7 +1310,7 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma(
                   _mm256_loadu_ps(&op[j])));
           _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
         }
-        at::Half vtmp1[8] CAFFE2_ALIGNED(64);
+        alignas(64) at::Half vtmp1[8];
         for (; j < block_size; j++) {
           vtmp1[0] = ip[j];
           __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));
@@ -1850,7 +1851,7 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma(
                   _mm256_loadu_ps(&op[j])));
           _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
         }
-        at::Half vtmp1[8] CAFFE2_ALIGNED(64);
+        alignas(64) at::Half vtmp1[8];
         for (; j < block_size; j++) {
           vtmp1[0] = ip[j];
           __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));
index 5eeb4ef..650a2dc 100644 (file)
@@ -5,9 +5,10 @@
 //// DO NOT MODIFY!!!
 //// --------------------------
 
-#include <caffe2/core/common.h>
-#include <caffe2/core/types.h>
+#include <ATen/core/Half.h>
+#include <c10/util/Logging.h>
 #include <immintrin.h>
+#include <cassert>
 
 namespace caffe2 {
 
@@ -1295,7 +1296,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma(
                   _mm256_loadu_ps(&op[j])));
           _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
         }
-        at::Half vtmp1[8] CAFFE2_ALIGNED(64);
+        alignas(64) at::Half vtmp1[8];
         for (; j < block_size; j++) {
           vtmp1[0] = ip[j];
           __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));
@@ -1830,7 +1831,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma(
                   _mm256_loadu_ps(&op[j])));
           _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
         }
-        at::Half vtmp1[8] CAFFE2_ALIGNED(64);
+        alignas(64) at::Half vtmp1[8];
         for (; j < block_size; j++) {
           vtmp1[0] = ip[j];
           __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));
index 2db1cee..887a975 100644 (file)
@@ -1,11 +1,10 @@
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-from __future__ import unicode_literals
+from __future__ import absolute_import, division, print_function, unicode_literals
+
 import argparse
 import sys
 
-sizeof = {'float': 4, 'at::Half': 2, 'uint8_t': 1}
+
+sizeof = {"float": 4, "at::Half": 2, "uint8_t": 1}
 
 
 def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused):
@@ -14,90 +13,105 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused):
 
         if InType == "float":
             code.append(
-                "vop%d = _mm256_fmadd_ps(vwgt,  \
-                  _mm256_loadu_ps(ip + (%d)), vop%d);"
-                                                       % (regid, regid, regid)
+                "        vop%d = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (%d)), vop%d);"  # noqa
+                % (regid, regid, regid)
             )
         elif InType == "at::Half":
             code.append(
-                "vop%d = _mm256_fmadd_ps(vwgt,  \
-                   _mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (%d)))), \
-                   vop%d);"
-                            % (regid, regid, regid)
+                "        vop%d = _mm256_fmadd_ps(\n"
+                "            vwgt,\n"
+                "            _mm256_cvtph_ps(\n"
+                "                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (%d)))),\n"  # noqa
+                "            vop%d);" % (regid, regid, regid)
             )
         elif InType == "uint8_t":
             code.append(
-                "vop%d = _mm256_fmadd_ps(vwgt,  \
-                   _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (%d))))), \
-                   _mm256_add_ps(vop%d, vbio));"
-                                                 % (regid, regid, regid)
+                "        vop%d = _mm256_fmadd_ps(\n"
+                "            vwgt,\n"
+                "            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(\n"
+                "                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (%d))))),\n"  # noqa
+                "            _mm256_add_ps(vop%d, vbio));" % (regid, regid, regid)
             )
         else:
             assert False
 
         if prefetch:
-            code.append("_mm_prefetch((&ip_next_T0[%d]), _MM_HINT_T0);" % (regid))
+            code.append(
+                "        _mm_prefetch((&ip_next_T0[%d]), _MM_HINT_T0);" % (regid)
+            )
         else:
-            code.append("// skip unnecessary prefetch of (&ip_next_T0[%d])" % (regid))
+            code.append(
+                "        // skip unnecessary prefetch of (&ip_next_T0[%d])" % (regid)
+            )
 
         return code
 
     code = []
-    code.append("// unrolling " + str(uf) + " times")
-    code.append(IndexType + " dataInd = 0;")
-    code.append("for (" + IndexType +
-                " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {")
-    code.append(OutType + " *op = &out[rangeIndex * block_size];")
+    code.append("    // unrolling " + str(uf) + " times")
+    code.append("    " + IndexType + " dataInd = 0;")
+    code.append(
+        "    for ("
+        + IndexType
+        + " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {"
+    )
+    code.append("      " + OutType + "* op = &out[rangeIndex * block_size];")
     for i in range(0, uf):
         j = 8 * i
-        code.append("__m256 vop" + str(j) + " = _mm256_setzero_ps();")
+        code.append("      __m256 vop" + str(j) + " = _mm256_setzero_ps();")
 
     # inner loop
-    code.append("for (" + IndexType +
-                " start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) {")
-    code.append("const  " + IndexType + " idx = indices[dataInd];")
     code.append(
-        'CAFFE_ENFORCE(idx >=0 && idx < data_size, "Index ", dataInd, "'
-        ' is out of bounds: ", idx, ", range 0 to ", data_size);')
+        "      for ("
+        + IndexType
+        + " start = dataInd; dataInd < start + lengths[rangeIndex];\n           ++dataInd) {"  # noqa
+    )
+    code.append("        const " + IndexType + " idx = indices[dataInd];")
+    code.append(
+        '        CAFFE_ENFORCE(\n            idx >= 0 && idx < data_size,\n            "Index ",\n            dataInd,\n'  # noqa
+        '            " is out of bounds: ",\n            idx,\n            ", range 0 to ",\n            data_size);'  # noqa
+    )
 
     if InType == "uint8_t":
-        code.append(OutType + " wgt = 1.f;")
-        code.append(OutType + " bio;")
-        code.append("if (weights) {")
+        code.append("        " + OutType + " wgt = 1.f;")
+        code.append("        " + OutType + " bio;")
+        code.append("        if (weights) {")
         code.append(
-            "wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];")
-        code.append("}")
+            "          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];"  # noqa
+        )
+        code.append("        }")
         if fused:
             code.append(
-                'const float* scale_bias = reinterpret_cast<'
-                'const float*>(&input[idx * fused_block_size + block_size]);'
+                "        const float* scale_bias = reinterpret_cast<const float*>(\n"
+                "            &input[idx * fused_block_size + block_size]);"
             )
-            code.append("bio = wgt * scale_bias[1];")
-            code.append("wgt = wgt * scale_bias[0];")
+            code.append("        bio = wgt * scale_bias[1];")
+            code.append("        wgt = wgt * scale_bias[0];")
         else:
-            code.append("bio = wgt * scale_bias[2 * idx + 1];")
-            code.append("wgt = wgt * scale_bias[2 * idx];")
-        code.append("__m256 vbio = _mm256_set1_ps(bio);")
+            code.append("        bio = wgt * scale_bias[2 * idx + 1];")
+            code.append("        wgt = wgt * scale_bias[2 * idx];")
+        code.append("        __m256 vbio = _mm256_set1_ps(bio);")
     else:
-        code.append(OutType + " wgt = 1.f;")
-        code.append("if (weights) {")
+        code.append("        " + OutType + " wgt = 1.f;")
+        code.append("        if (weights) {")
         code.append(
-            "wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];")
-        code.append("}")
-    code.append("__m256 vwgt = _mm256_set1_ps(wgt);")
+            "          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];"  # noqa
+        )
+        code.append("        }")
+    code.append("        __m256 vwgt = _mm256_set1_ps(wgt);")
 
-    code.append("const {} *ip = &input[idx * fused_block_size];".format(InType))
+    code.append("        const {}* ip = &input[idx * fused_block_size];".format(InType))
     code.append(
-        'const {} next_T0 = (dataInd < index_size - prefdist_T0)'
-        ' ? (dataInd + prefdist_T0) : dataInd;'.format(IndexType)
+        "        const {} next_T0 = (dataInd < index_size - prefdist_T0)\n"
+        "            ? (dataInd + prefdist_T0)\n            : dataInd;".format(
+            IndexType
+        )
     )
-    code.append("const  " + IndexType + " idx_pref_T0 = indices[next_T0];")
-    code.append(
-        "CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);")
+    code.append("        const " + IndexType + " idx_pref_T0 = indices[next_T0];")
+    code.append("        CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);")
 
     code.append(
-        'const {} *ip_next_T0 = &input[idx_pref_T0'
-        ' * fused_block_size];'.format(InType)
+        "        const {}* ip_next_T0 = &input[idx_pref_T0"
+        " * fused_block_size];".format(InType)
     )
 
     for i in range(0, uf):
@@ -106,168 +120,190 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused):
         byteoffset = sizeof[InType] * j
         prefetch = (byteoffset % cachelinesize) == 0
         code.extend(compute(j, InType, use_weights, isa, prefetch))
-    code.append("}")
+    code.append("      }")
 
-    code.append("if (normalize_by_lengths == false) {")
+    code.append("      if (normalize_by_lengths == false) {")
     for i in range(0, uf):
         j = 8 * i
-        code.append(
-            "_mm256_storeu_ps(&op[" + str(j) + "], vop" + str(j) + ");")
-    code.append("} else if (lengths[rangeIndex]) {")
+        code.append("        _mm256_storeu_ps(&op[" + str(j) + "], vop" + str(j) + ");")
+    code.append("      } else if (lengths[rangeIndex]) {")
     # inv of length
-    code.append(
-        "__m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);")
+    code.append("        __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);")
     for i in range(0, uf):
         j = 8 * i
         code.append(
-            "_mm256_storeu_ps(&op[" + str(j) + "], _mm256_mul_ps(" + "vop" + str(j) + ", vlen_inv));")
-    code.append("}")
-
-    code.append("}")
+            "        _mm256_storeu_ps(&op["
+            + str(j)
+            + "], _mm256_mul_ps("
+            + "vop"
+            + str(j)
+            + ", vlen_inv));"
+        )
+    code.append("      }")
+
+    code.append("    }")
     return code
 
 
 def generic(IndexType, InType, OutType, use_weights, isa, fused):
-
     def compute(InType, use_weights, isa):
         code = []
         if InType == "float":
             code.append(
-                "_mm256_storeu_ps(&op[j], \
-                                 _mm256_fmadd_ps(vwgt,_mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])) \
-                                   );"
+                "          _mm256_storeu_ps(\n"
+                "              &op[j],\n"
+                "              _mm256_fmadd_ps(\n"
+                "                  vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));"  # noqa
             )
         elif InType == "at::Half":
             code.append(
-                "_mm256_storeu_ps(&op[j], \
-                   _mm256_fmadd_ps(vwgt, \
-                     _mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast<const __m128i*>(&ip[j]))), _mm256_loadu_ps(&op[j])) \
-                                   );"
+                "          _mm256_storeu_ps(\n"
+                "              &op[j],\n"
+                "              _mm256_fmadd_ps(\n"
+                "                  vwgt,\n"
+                "                  _mm256_cvtph_ps(_mm_loadu_si128(\n"
+                "                      reinterpret_cast<const __m128i*>(&ip[j]))),\n"
+                "                  _mm256_loadu_ps(&op[j])));"
             )
         elif InType == "uint8_t":
             code.append(
-                "_mm256_storeu_ps(&op[j], \
-                   _mm256_fmadd_ps(vwgt, \
-                     _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<const __m128i*>(&ip[j])))), \
-                     _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio) ) \
-                                   );"
+                "          _mm256_storeu_ps(\n"
+                "              &op[j],\n"
+                "              _mm256_fmadd_ps(\n"
+                "                  vwgt,\n"
+                "                  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(\n"  # noqa
+                "                      reinterpret_cast<const __m128i*>(&ip[j])))),\n"
+                "                  _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));"
             )
         else:
             assert False
 
-        code.append("_mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);")
+        code.append("          _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);")
 
         return code
 
     code = []
-    code.append(IndexType + " dataInd = 0;")
-    code.append("for (" + IndexType +
-                " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {")
-    code.append(OutType + " *op = &out[rangeIndex * block_size];")
+    code.append("    " + IndexType + " dataInd = 0;")
+    code.append(
+        "    for ("
+        + IndexType
+        + " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {"
+    )
+    code.append("      " + OutType + "* op = &out[rangeIndex * block_size];")
 
     # initialize to 0
-    code.append("int64_t j = 0;")
-    code.append("for(; j + 8 <= block_size; j += 8) {")
-    code.append("_mm256_storeu_ps(op + j, _mm256_setzero_ps());")
-    code.append("}")
-    code.append("for(; j < block_size; j++) {")
-    code.append("op[j] = 0.0f;")
-    code.append("}")
+    code.append("      int64_t j = 0;")
+    code.append("      for (; j + 8 <= block_size; j += 8) {")
+    code.append("        _mm256_storeu_ps(op + j, _mm256_setzero_ps());")
+    code.append("      }")
+    code.append("      for (; j < block_size; j++) {")
+    code.append("        op[j] = 0.0f;")
+    code.append("      }")
 
     # inner loop
-    code.append("for (" + IndexType +
-                " start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) {")
-    code.append("const  " + IndexType + " idx = indices[dataInd];")
     code.append(
-        'CAFFE_ENFORCE(idx >=0 && idx < data_size, "Index ", dataInd, "' +
-        ' is out of bounds: ", idx, ", range 0 to ", data_size);')
+        "      for ("
+        + IndexType
+        + " start = dataInd; dataInd < start + lengths[rangeIndex];\n           ++dataInd) {"  # noqa
+    )
+    code.append("        const " + IndexType + " idx = indices[dataInd];")
+    code.append(
+        '        CAFFE_ENFORCE(\n            idx >= 0 && idx < data_size,\n            "Index ",\n            dataInd,\n'  # noqa
+        + '            " is out of bounds: ",\n            idx,\n            ", range 0 to ",\n            data_size);'  # noqa
+    )
 
     if InType == "uint8_t":
-        code.append(OutType + " wgt = 1.f;")
-        code.append(OutType + " bio;")
-        code.append("if (weights) {")
+        code.append("        " + OutType + " wgt = 1.f;")
+        code.append("        " + OutType + " bio;")
+        code.append("        if (weights) {")
         code.append(
-            "wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];")
-        code.append("}")
+            "          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];"  # noqa
+        )
+        code.append("        }")
         if fused:
             code.append(
-                'const float* scale_bias = reinterpret_cast<'
-                'const float*>(&input[idx * fused_block_size + block_size]);'
+                "        const float* scale_bias = reinterpret_cast<const float*>(\n"
+                "            &input[idx * fused_block_size + block_size]);"
             )
-            code.append("bio = wgt * scale_bias[1];")
-            code.append("wgt = wgt * scale_bias[0];")
+            code.append("        bio = wgt * scale_bias[1];")
+            code.append("        wgt = wgt * scale_bias[0];")
         else:
-            code.append("assert (scale_bias);")
-            code.append("bio = wgt * scale_bias[2 * idx + 1];")
-            code.append("wgt = wgt * scale_bias[2 * idx];")
-        code.append("__m256 vbio = _mm256_set1_ps(bio);")
+            code.append("        assert(scale_bias);")
+            code.append("        bio = wgt * scale_bias[2 * idx + 1];")
+            code.append("        wgt = wgt * scale_bias[2 * idx];")
+        code.append("        __m256 vbio = _mm256_set1_ps(bio);")
     else:
-        code.append(OutType + " wgt = 1.f;")
-        code.append("if (weights) {")
+        code.append("        " + OutType + " wgt = 1.f;")
+        code.append("        if (weights) {")
         code.append(
-            "wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];")
-        code.append("}")
-    code.append("__m256 vwgt = _mm256_set1_ps(wgt);")
+            "          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];"  # noqa
+        )
+        code.append("        }")
+    code.append("        __m256 vwgt = _mm256_set1_ps(wgt);")
 
-    code.append("const {} *ip = &input[idx * fused_block_size];".format(InType))
+    code.append("        const {}* ip = &input[idx * fused_block_size];".format(InType))
     code.append(
-        'const {} next_T0 = (dataInd < index_size - prefdist_T0)'
-        ' ? (dataInd + prefdist_T0) : dataInd;'.format(IndexType)
+        "        const {} next_T0 = (dataInd < index_size - prefdist_T0)\n"
+        "            ? (dataInd + prefdist_T0)\n            : dataInd;".format(
+            IndexType
+        )
     )
-    code.append("const  " + IndexType + " idx_pref_T0 = indices[next_T0];")
-    code.append(
-        "CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);")
+    code.append("        const " + IndexType + " idx_pref_T0 = indices[next_T0];")
+    code.append("        CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);")
     code.append(
-        "const {} *ip_next_T0 = &input[idx_pref_T0 * fused_block_size];".
-        format(InType)
+        "        const {}* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];".format(
+            InType
+        )
     )
 
     # compute and store main loop
-    code.append("j = 0;")
-    code.append("for(; j + 8 <= block_size; j += 8) {")
+    code.append("        j = 0;")
+    code.append("        for (; j + 8 <= block_size; j += 8) {")
     code.extend(compute(InType, use_weights, isa))
-    code.append("}")
+    code.append("        }")
     # leftover
     if InType == "at::Half":
-        code.append("at::Half vtmp1[8] CAFFE2_ALIGNED(64);")
-    code.append("for(; j < block_size; j++) {")
+        code.append("        alignas(64) at::Half vtmp1[8];")
+    code.append("        for (; j < block_size; j++) {")
     if InType == "float":
-        code.append("op[j] += wgt * ip[j];")
+        code.append("          op[j] += wgt * ip[j];")
     elif InType == "at::Half":
-        code.append("vtmp1[0] = ip[j];")
-        code.append("__m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));")
-        code.append("op[j] += wgt * ((float*)(&vtmp2))[0];")
+        code.append("          vtmp1[0] = ip[j];")
+        code.append("          __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));")
+        code.append("          op[j] += wgt * ((float*)(&vtmp2))[0];")
     elif InType == "uint8_t":
-        code.append("op[j] += wgt * ((float)ip[j]) + bio;")
+        code.append("          op[j] += wgt * ((float)ip[j]) + bio;")
     else:
         assert False
 
-    code.append("}")
+    code.append("        }")
 
-    code.append("}")
+    code.append("      }")
 
-    code.append("if (normalize_by_lengths && lengths[rangeIndex]) {")
-    code.append("float len_inv = 1.0f / lengths[rangeIndex];")
-    code.append("__m256 vlen_inv = _mm256_set1_ps(len_inv);")
-    code.append("j = 0;")
-    code.append("for(; j + 8 <= block_size; j += 8) {")
+    code.append("      if (normalize_by_lengths && lengths[rangeIndex]) {")
+    code.append("        float len_inv = 1.0f / lengths[rangeIndex];")
+    code.append("        __m256 vlen_inv = _mm256_set1_ps(len_inv);")
+    code.append("        j = 0;")
+    code.append("        for (; j + 8 <= block_size; j += 8) {")
     code.append(
-        "_mm256_storeu_ps(&op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));")
-    code.append("}")
-    code.append("for(; j < block_size; j++) {")
-    code.append("op[j] = len_inv * op[j];")
-    code.append("}")
+        "          _mm256_storeu_ps(\n"
+        "              &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));"
+    )
+    code.append("        }")
+    code.append("        for (; j < block_size; j++) {")
+    code.append("          op[j] = len_inv * op[j];")
+    code.append("        }")
 
-    code.append("}")
+    code.append("      }")
 
-    code.append("}")
+    code.append("    }")
     return code
 
 
 # start main code
 parser = argparse.ArgumentParser()
-parser.add_argument('-f', '--filename', help="file name")
-parser.add_argument('--fused', action='store_true')
+parser.add_argument("-f", "--filename", help="file name")
+parser.add_argument("--fused", action="store_true")
 opts = parser.parse_args()
 if opts.filename:
     filename = opts.filename
@@ -275,14 +311,16 @@ elif opts.fused:
     filename = "embedding_lookup_fused_8bit_rowwise_avx2.cc"
 else:
     filename = "embedding_lookup_avx2.cc"
-fout = open(filename, 'w')
+fout = open(filename, "w")
 
-options = [["int32_t", "int32_t", "float", "float", "float", "float"],
-           ["int64_t", "int64_t", "float", "float", "float", "float"],
-           ["int32_t", "int32_t", "half", "at::Half", "float", "float"],
-           ["int64_t", "int64_t", "half", "at::Half", "float", "float"],
-           ["int32_t", "int32_t", "uint8_t", "uint8_t", "float"],
-           ["int64_t", "int64_t", "uint8_t", "uint8_t", "float"]]
+options = [
+    ["int32_t", "int32_t", "float", "float", "float", "float"],
+    ["int64_t", "int64_t", "float", "float", "float", "float"],
+    ["int32_t", "int32_t", "half", "at::Half", "float", "float"],
+    ["int64_t", "int64_t", "half", "at::Half", "float", "float"],
+    ["int32_t", "int32_t", "uint8_t", "uint8_t", "float", "float"],
+    ["int64_t", "int64_t", "uint8_t", "uint8_t", "float", "float"],
+]
 
 code = []
 # includes
@@ -291,105 +329,101 @@ code.append("//// ATTENTION:")
 code.append("//// THIS CODE IS AUTOGENERATED")
 code.append("//// BY {}".format(sys.argv[0]))
 code.append("//// DO NOT MODIFY!!!")
-code.append("//// --------------------------\n\n")
+code.append("//// --------------------------\n")
 
-code.append("#include <caffe2/core/types.h>")
-code.append("#include <caffe2/core/common.h>")
+code.append("#include <ATen/core/Half.h>")
+code.append("#include <c10/util/Logging.h>")
 code.append("#include <immintrin.h>")
-code.append("\n")
+code.append("#include <cassert>\n")
 
 code.append("namespace caffe2 {\n")
 for o in options:
     [IndexTypeName, IndexType, InTypeName, InType, OutTypeName, OutType] = o
 
-    prefix = 'Fused8BitRowwise' if opts.fused else ''
-    code.append('template <bool IS_WEIGHT_POSITIONAL>')
-    fn_base = '{}EmbeddingLookup_{}_{}_{}'.format(
+    prefix = "Fused8BitRowwise" if opts.fused else ""
+    code.append("template <bool IS_WEIGHT_POSITIONAL>")
+    fn_base = "{}EmbeddingLookup_{}_{}_{}".format(
         prefix, IndexTypeName, InTypeName, OutTypeName
     )
-    suffix = '__avx2_fma'
+    suffix = "__avx2_fma"
     fn = "static void " + fn_base + suffix
     code.append(fn + "(")
 
     args = []
-    args.append("const int64_t block_size,")
-    args.append("const int64_t output_size,")
-    args.append("const int64_t index_size,")
-    args.append("const int64_t data_size,")
-    args.append("const " + InType + "* input,")
-    args.append("const " + IndexType + "* indices,")
-    args.append("const int* lengths,")
-    args.append("const float* weights,")
+    args.append("    const int64_t block_size,")
+    args.append("    const int64_t output_size,")
+    args.append("    const int64_t index_size,")
+    args.append("    const int64_t data_size,")
+    args.append("    const " + InType + "* input,")
+    args.append("    const " + IndexType + "* indices,")
+    args.append("    const int* lengths,")
+    args.append("    const float* weights,")
     if not opts.fused:
-        args.append("const float* scale_bias,")
-    args.append("bool normalize_by_lengths,")
-    args.append(OutType + "* out)")
+        args.append("    const float* scale_bias,")
+    args.append("    bool normalize_by_lengths,")
+    args.append("    " + OutType + "* out) {")
     code += args
 
-    code.append("{")
-    code.append("const " + IndexType + " prefdist_T0 = 16;")
+    code.append("  const " + IndexType + " prefdist_T0 = 16;")
     # block_size is the number of elements and fused_block_size is the size of
     # an entire row, including scale and bias.
     offset = (8 // sizeof[InType]) if opts.fused else 0
     code.append(
-        "const {} fused_block_size = block_size + {};".
-        format(IndexType, offset)
+        "  const {} fused_block_size = block_size + {};".format(IndexType, offset)
     )
 
-    #code.append("printf(\"calling " + fn + "\\n\");");
+    # code.append("printf(\"calling " + fn + "\\n\");");
     if not opts.fused:
         if InType != "uint8_t":
             code.append(
-                'CAFFE_ENFORCE(scale_bias == nullptr,'
+                "  CAFFE_ENFORCE(scale_bias == nullptr,"
                 ' "scale_bias must be nullptr");'
             )
         else:
             code.append(
-                'CAFFE_ENFORCE(scale_bias != nullptr,'
+                "  CAFFE_ENFORCE(scale_bias != nullptr,"
                 ' "scale_bias must not be nullptr");'
             )
 
-    code.append("if (block_size == 128) {")
+    code.append("  if (block_size == 128) {")
     code += unroll(16, IndexType, InType, OutType, True, "AVX2", opts.fused)
-    code.append("} else if (block_size == 64) {")
+    code.append("  } else if (block_size == 64) {")
     code += unroll(8, IndexType, InType, OutType, True, "AVX2", opts.fused)
-    code.append("} else if (block_size == 32) {")
+    code.append("  } else if (block_size == 32) {")
     code += unroll(4, IndexType, InType, OutType, True, "AVX2", opts.fused)
-    code.append("} else if (block_size == 16) {")
+    code.append("  } else if (block_size == 16) {")
     code += unroll(2, IndexType, InType, OutType, True, "AVX2", opts.fused)
-    code.append("} else {")
-    code.append("// generic code")
+    code.append("  } else {")
+    code.append("    // generic code")
     code += generic(IndexType, InType, OutType, True, "AVX2", opts.fused)
-    code.append("}")
+    code.append("  }")
 
     code.append("}")
 
-    for is_weight_positional in ['false', 'true']:
-        code.append(
-            "void " + fn_base + "_" + is_weight_positional + suffix + "(")
+    for is_weight_positional in ["false", "true"]:
+        code.append("void " + fn_base + "_" + is_weight_positional + suffix + "(")
         code += args
-        code.append("{")
-        code.append(fn_base + suffix + "<" + is_weight_positional + ">(")
-        code.append("block_size,")
-        code.append("output_size,")
-        code.append("index_size,")
-        code.append("data_size,")
-        code.append("input,")
-        code.append("indices,")
-        code.append("lengths,")
-        code.append("weights,")
+        code.append("  " + fn_base + suffix + "<" + is_weight_positional + ">(")
+        code.append("      block_size,")
+        code.append("      output_size,")
+        code.append("      index_size,")
+        code.append("      data_size,")
+        code.append("      input,")
+        code.append("      indices,")
+        code.append("      lengths,")
+        code.append("      weights,")
         if not opts.fused:
-            code.append("scale_bias,")
-        code.append("normalize_by_lengths,")
-        code.append("out);")
+            code.append("      scale_bias,")
+        code.append("      normalize_by_lengths,")
+        code.append("      out);")
         code.append("}")
 
-    code.append("\n")
+    code.append("")
 
 code.append("} // namespace caffe2")
 
 for c in code:
-    #print(c, file = fout)
+    # print(c, file = fout)
     fout.write(c + "\n")
 fout.close()
 
index abbe6ca..14265f9 100644 (file)
@@ -1,10 +1,6 @@
 #pragma once
 
-#ifdef CAFFE2_USE_MKL
-#include <mkl.h>
-#endif // CAFFE2_USE_MKL
-
-#include <random>
+#include <cstdint>
 
 namespace caffe2 {
 
@@ -21,27 +17,19 @@ namespace math {
 // |    1B    |  1B  |  4B |  4B | ...output_data....|
 // In output_data: the b-th bucket of the i-th byte stores
 // the i-th data of the b-th segment of input row
-#ifdef CAFFE2_USE_MKL
-#define FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
-#endif
+
 void quantize_and_compress(
     const float* input_data,
-    uint8_t* output_data,
-    size_t input_size,
-    size_t bitwidth,
+    std::uint8_t* output_data,
+    std::size_t input_size,
+    std::size_t bitwidth,
     bool random,
-#ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
-    VSLStreamStatePtr& vslStream,
-    std::vector<float>& random_buffer
-#else
-    std::unique_ptr<std::uniform_real_distribution<float>>& dis,
-    std::minstd_rand& gen
-#endif
-);
+    const float* random_buffer);
+
 void decompress_and_dequantize(
-    const uint8_t* input_data,
+    const std::uint8_t* input_data,
     float* output_data,
-    size_t input_size);
+    std::size_t input_size);
 
 } // namespace math
 } // namespace caffe2
index c71c5ac..95292c3 100644 (file)
@@ -2,46 +2,16 @@
 // The implementation in this file allows us to route the underlying numerical
 // computation library to different compiler options (-mno-avx2 or -mavx2).
 
-#include <algorithm>
-#include <array>
-#include <atomic>
-#include <chrono>
+#include <immintrin.h>
+#include <cfloat>
 #include <cmath>
-#include <cstring>
-#include <functional>
-#include <limits>
-#include <numeric>
-#include <random>
-#include <tuple>
-#include <unordered_set>
-#include <vector>
-
-#include "caffe2/core/context.h"
-#include "caffe2/perfkernels/math.h"
-#include "caffe2/utils/cpu_neon.h"
-#include "caffe2/utils/cpuid.h"
-#include "caffe2/utils/eigen_utils.h"
-#include "caffe2/utils/math.h"
-
-#include "Eigen/Core"
-#include "Eigen/Dense"
-
-#ifdef CAFFE2_USE_MKL
-#include <mkl.h>
-#endif // CAFFE2_USE_MKL
-
-#ifdef CAFFE2_USE_HPTT
-#include <hptt.h>
-#endif // CAFFE2_USE_HPTT
-
-#if defined(_MSC_VER)
-#include <process.h>
-#endif
+#include <cstdint>
 
 namespace caffe2 {
 
 namespace math {
-#define QEPSILON 1e-8
+
+static constexpr double QEPSILON = 1e-8;
 
 void quantize_and_compress__avx2(
     const float* input_data,
@@ -49,19 +19,7 @@ void quantize_and_compress__avx2(
     size_t input_size,
     size_t bitwidth,
     bool random,
-#ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
-    VSLStreamStatePtr& vslStream,
-    std::vector<float>& random_buffer
-#else
-    std::unique_ptr<std::uniform_real_distribution<float>>& dis,
-    std::minstd_rand& gen
-#endif
-) {
-  CAFFE_ENFORCE(
-      bitwidth == 1 || bitwidth == 2 || bitwidth == 4 || bitwidth == 8,
-      "Unsupported bitwidth");
-
-#ifdef __AVX2__
+    const float* random_buffer) {
   __m256i shuffle_mask_v = _mm256_set_epi8(
       0xff,
       0xff,
@@ -97,14 +55,6 @@ void quantize_and_compress__avx2(
       0x00);
   __m256i permute_mask_v =
       _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00);
-#endif // __AVX2__
-
-  // memory pointers
-  ConstEigenVectorArrayMap<float> input_row(input_data, input_size);
-  uint8_t* output_row = output_data;
-  EigenVectorArrayMap<uint8_t> output_bitwidth_tail(output_row, 2);
-  EigenVectorArrayMap<float> output_row_min_max(
-      reinterpret_cast<float*>(output_row + 2), 2);
 
   size_t data_per_byte = 8 / bitwidth;
   size_t tail = input_size % data_per_byte;
@@ -112,46 +62,30 @@ void quantize_and_compress__avx2(
   size_t segment_size = (input_size + data_per_byte - 1) / data_per_byte;
 
   // basic info
-  const float minimum_element = input_row.minCoeff();
-  const float maximum_element = input_row.maxCoeff();
-  output_bitwidth_tail(0) = bitwidth;
-  output_bitwidth_tail(1) = tail;
-  output_row_min_max(0) = minimum_element;
-  output_row_min_max(1) = maximum_element;
+  float minimum_element = INFINITY, maximum_element = -INFINITY;
+  for (auto i = 0; i < input_size; ++i) {
+    minimum_element =
+        (input_data[i] < minimum_element) ? input_data[i] : minimum_element;
+    maximum_element =
+        (input_data[i] > maximum_element) ? input_data[i] : maximum_element;
+  }
+  output_data[0] = bitwidth;
+  output_data[1] = tail;
+  reinterpret_cast<float*>(output_data + 2)[0] = minimum_element;
+  reinterpret_cast<float*>(output_data + 2)[1] = maximum_element;
 
   float gap = (maximum_element - minimum_element) / ((1 << bitwidth) - 1.0f);
   float gap_inverse = 1. / (gap + QEPSILON);
   uint8_t max_q = (1 << bitwidth) - 1;
   size_t bit_start = 0;
   if (random) {
-#ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
-    int status = vsRngUniform(
-        VSL_RNG_METHOD_UNIFORM_STD,
-        vslStream,
-        input_size,
-        random_buffer.data(),
-        0.0f,
-        1.0f);
-    if (status != VSL_ERROR_OK) {
-      LOG(WARNING) << "vsRngUniform returns " << status;
-    }
-#endif
     for (int start = 0; start < input_size; start += segment_size) {
       size_t stride = start + segment_size <= input_size ? segment_size
                                                          : input_size - start;
       int i = 0;
-#ifdef __AVX2__
       constexpr int VLEN = 8;
       for (; i < stride / VLEN * VLEN; i += VLEN) {
-#ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
         __m256 r_v = _mm256_loadu_ps(&random_buffer[start + i]);
-#else
-        float random_buffer[VLEN];
-        for (int j = 0; j < VLEN; ++j) {
-          random_buffer[j] = (*dis)(gen);
-        }
-        __m256 r_v = _mm256_loadu_ps(random_buffer);
-#endif
         __m256 fval_v = _mm256_loadu_ps(input_data + start + i);
         __m256 thetimes_v = _mm256_mul_ps(
             _mm256_sub_ps(fval_v, _mm256_set1_ps(minimum_element)),
@@ -162,37 +96,35 @@ void quantize_and_compress__avx2(
             _mm256_min_ps(_mm256_set1_ps(max_q), rounded_v));
         __m256i qval_v = _mm256_cvtps_epi32(rounded_v);
         __m256i orval_v = _mm256_cvtepu8_epi32(_mm_lddqu_si128(
-            reinterpret_cast<const __m128i*>(output_row + 10 + i)));
+            reinterpret_cast<const __m128i*>(output_data + 10 + i)));
         orval_v =
             _mm256_or_si256(orval_v, _mm256_slli_epi32(qval_v, bit_start));
         orval_v = _mm256_shuffle_epi8(orval_v, shuffle_mask_v);
         orval_v = _mm256_permutevar8x32_epi32(orval_v, permute_mask_v);
-        *reinterpret_cast<int64_t*>(output_row + 10 + i) =
+        *reinterpret_cast<int64_t*>(output_data + 10 + i) =
             _mm256_extract_epi64(orval_v, 0);
       }
-#endif // __AVX2__
       for (; i < stride; ++i) {
         float fval = input_data[start + i];
         float thetimes = (fval - minimum_element) * gap_inverse;
-#ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
         float rounded = floor(thetimes + random_buffer[start + i]);
-#else
-        float rounded = floor(thetimes + (*dis)(gen));
-#endif
-        rounded = std::max(0.0f, std::min(static_cast<float>(max_q), rounded));
+        rounded = rounded < static_cast<float>(max_q)
+            ? rounded
+            : static_cast<float>(max_q);
+        rounded = rounded > 0.0f ? rounded : 0.0f;
         uint8_t qval = rounded;
 
-        uint8_t orval = output_row[10 + i];
-        output_row[10 + i] = orval | static_cast<uint8_t>(qval << bit_start);
+        uint8_t orval = output_data[10 + i];
+        output_data[10 + i] = orval | static_cast<uint8_t>(qval << bit_start);
       }
       bit_start += bitwidth;
     }
   } else {
+    // !random
     for (int start = 0; start < input_size; start += segment_size) {
       size_t stride = start + segment_size <= input_size ? segment_size
                                                          : input_size - start;
       int i = 0;
-#ifdef __AVX2__
       constexpr int VLEN = 8;
       for (; i < stride / VLEN * VLEN; i += VLEN) {
         __m256 fval_v = _mm256_loadu_ps(input_data + start + i);
@@ -205,52 +137,48 @@ void quantize_and_compress__avx2(
         __m256i qval_v = _mm256_cvtps_epi32(_mm256_round_ps(
             thetimes_v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
         __m256i orval_v = _mm256_cvtepu8_epi32(_mm_lddqu_si128(
-            reinterpret_cast<const __m128i*>(output_row + 10 + i)));
+            reinterpret_cast<const __m128i*>(output_data + 10 + i)));
         orval_v =
             _mm256_or_si256(orval_v, _mm256_slli_epi32(qval_v, bit_start));
         orval_v = _mm256_shuffle_epi8(orval_v, shuffle_mask_v);
         orval_v = _mm256_permutevar8x32_epi32(orval_v, permute_mask_v);
-        *reinterpret_cast<int64_t*>(output_row + 10 + i) =
+        *reinterpret_cast<int64_t*>(output_data + 10 + i) =
             _mm256_extract_epi64(orval_v, 0);
       }
-#endif // __AVX2__
       for (; i < stride; ++i) {
         float fval = input_data[start + i];
         float thetimes = (fval - minimum_element) * gap_inverse;
-        thetimes =
-            std::max(0.0f, std::min(static_cast<float>(max_q), thetimes));
+        thetimes = thetimes < static_cast<float>(max_q)
+            ? thetimes
+            : static_cast<float>(max_q);
+        thetimes = thetimes > 0.0f ? thetimes : 0.0f;
         uint8_t qval = nearbyint(thetimes);
 
-        uint8_t orval = output_row[10 + i];
-        output_row[10 + i] = orval | static_cast<uint8_t>(qval << bit_start);
+        uint8_t orval = output_data[10 + i];
+        output_data[10 + i] = orval | static_cast<uint8_t>(qval << bit_start);
       }
       bit_start += bitwidth;
     }
-  }
+  } // !random
 }
 
 void decompress_and_dequantize__avx2(
     const uint8_t* input_data,
     float* output_data,
     size_t input_size) {
-  ConstEigenVectorArrayMap<float> input_row_min_max(
-      reinterpret_cast<const float*>(input_data + 2), 2);
-
   // basic info
-  const float minimum_element = input_row_min_max(0);
-  const float maximum_element = input_row_min_max(1);
+  const float minimum_element =
+      reinterpret_cast<const float*>(input_data + 2)[0];
+  const float maximum_element =
+      reinterpret_cast<const float*>(input_data + 2)[1];
   const size_t bitwidth = input_data[0];
   const float gap =
       (maximum_element - minimum_element) / ((1 << bitwidth) - 1.f) +
       QEPSILON; // for exact recovering
 
-  CAFFE_ENFORCE(
-      bitwidth == 1 || bitwidth == 2 || bitwidth == 4 || bitwidth == 8,
-      "Unsupported bitwidth");
   const size_t tail = input_data[1];
 
   const size_t output_size = (input_size - 10) * (8 / bitwidth) - tail;
-  EigenVectorArrayMap<float> output_row(output_data, output_size);
   // decoding
   size_t bit_start = 0;
   const size_t segment_size = input_size - 10;
@@ -259,7 +187,6 @@ void decompress_and_dequantize__avx2(
                                                         : output_size - start;
     uint8_t mask = (1 << bitwidth) - 1;
     int i = 0;
-#ifdef __AVX2__
     // Can process 8 elements at a time because we need to expand uint8_t
     // to int32_t to use epi32 vector instructions.
     constexpr int VLEN = 8;
@@ -269,19 +196,19 @@ void decompress_and_dequantize__avx2(
       __m256i out_epi32_v = _mm256_and_si256(
           _mm256_srli_epi32(_mm256_cvtepu8_epi32(in_v), bit_start),
           _mm256_set1_epi32(mask));
-      _mm256_storeu_ps(
-          output_data + start + i, _mm256_cvtepi32_ps(out_epi32_v));
+      __m256 out_v = _mm256_fmadd_ps(
+          _mm256_cvtepi32_ps(out_epi32_v),
+          _mm256_set1_ps(gap),
+          _mm256_set1_ps(minimum_element));
+      _mm256_storeu_ps(output_data + start + i, out_v);
     }
-#endif
     for (; i < stride; ++i) {
-      output_data[start + i] = ((input_data[10 + i] >> bit_start) & mask);
+      output_data[start + i] =
+          ((input_data[10 + i] >> bit_start) & mask) * gap + minimum_element;
     }
     bit_start += bitwidth;
   }
-  // scaling and biasing
-  output_row = output_row * gap + minimum_element;
 }
 
-#undef QEPSILON
 } // namespace math
 } // namespace caffe2
index 843cfa1..1837641 100644 (file)
@@ -2,47 +2,16 @@
 // The implementation in this file allows us to route the underlying numerical
 // computation library to different compiler options (-mno-avx2 or -mavx2).
 
-#include <algorithm>
-#include <array>
-#include <atomic>
-#include <chrono>
-#include <cmath>
-#include <cstring>
-#include <functional>
-#include <limits>
-#include <numeric>
-#include <random>
-#include <tuple>
-#include <unordered_set>
-#include <vector>
-
-#include "caffe2/core/context.h"
-#include "caffe2/perfkernels/common.h"
-#include "caffe2/perfkernels/math.h"
-#include "caffe2/utils/cpu_neon.h"
-#include "caffe2/utils/cpuid.h"
-#include "caffe2/utils/eigen_utils.h"
-#include "caffe2/utils/math.h"
-
-#include "Eigen/Core"
-#include "Eigen/Dense"
-
-#ifdef CAFFE2_USE_MKL
-#include <mkl.h>
-#endif // CAFFE2_USE_MKL
-
-#ifdef CAFFE2_USE_HPTT
-#include <hptt.h>
-#endif // CAFFE2_USE_HPTT
-
-#if defined(_MSC_VER)
-#include <process.h>
-#endif
+#include <cfloat>
+
+#include "common.h"
+#include "math.h"
 
 namespace caffe2 {
 
 namespace math {
-#define QEPSILON 1e-8
+
+static constexpr double QEPSILON = 1e-8;
 
 void quantize_and_compress__base(
     const float* input_data,
@@ -50,55 +19,30 @@ void quantize_and_compress__base(
     size_t input_size,
     size_t bitwidth,
     bool random,
-#ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
-    VSLStreamStatePtr& vslStream,
-    std::vector<float>& random_buffer
-#else
-    std::unique_ptr<std::uniform_real_distribution<float>>& dis,
-    std::minstd_rand& gen
-#endif
-) {
-  CAFFE_ENFORCE(
-      bitwidth == 1 || bitwidth == 2 || bitwidth == 4 || bitwidth == 8,
-      "Unsupported bitwidth");
-
-  // memory pointers
-  ConstEigenVectorArrayMap<float> input_row(input_data, input_size);
-  uint8_t* output_row = output_data;
-  EigenVectorArrayMap<uint8_t> output_bitwidth_tail(output_row, 2);
-  EigenVectorArrayMap<float> output_row_min_max(
-      reinterpret_cast<float*>(output_row + 2), 2);
-
+    const float* random_buffer) {
   size_t data_per_byte = 8 / bitwidth;
   size_t tail = input_size % data_per_byte;
   tail = tail ? data_per_byte - tail : 0;
   size_t segment_size = (input_size + data_per_byte - 1) / data_per_byte;
 
   // basic info
-  const float minimum_element = input_row.minCoeff();
-  const float maximum_element = input_row.maxCoeff();
-  output_bitwidth_tail(0) = bitwidth;
-  output_bitwidth_tail(1) = tail;
-  output_row_min_max(0) = minimum_element;
-  output_row_min_max(1) = maximum_element;
+  float minimum_element = INFINITY, maximum_element = -INFINITY;
+  for (auto i = 0; i < input_size; ++i) {
+    minimum_element =
+        input_data[i] < minimum_element ? input_data[i] : minimum_element;
+    maximum_element =
+        input_data[i] > maximum_element ? input_data[i] : maximum_element;
+  }
+  output_data[0] = bitwidth;
+  output_data[1] = tail;
+  reinterpret_cast<float*>(output_data + 2)[0] = minimum_element;
+  reinterpret_cast<float*>(output_data + 2)[1] = maximum_element;
 
   float gap = (maximum_element - minimum_element) / ((1 << bitwidth) - 1.0f);
   float gap_inverse = 1. / (gap + QEPSILON);
   uint8_t max_q = (1 << bitwidth) - 1;
   size_t bit_start = 0;
   if (random) {
-#ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
-    int status = vsRngUniform(
-        VSL_RNG_METHOD_UNIFORM_STD,
-        vslStream,
-        input_size,
-        random_buffer.data(),
-        0.0f,
-        1.0f);
-    if (status != VSL_ERROR_OK) {
-      LOG(WARNING) << "vsRngUniform returns " << status;
-    }
-#endif
     for (int start = 0; start < input_size; start += segment_size) {
       size_t stride = start + segment_size <= input_size ? segment_size
                                                          : input_size - start;
@@ -106,16 +50,15 @@ void quantize_and_compress__base(
       for (; i < stride; ++i) {
         float fval = input_data[start + i];
         float thetimes = (fval - minimum_element) * gap_inverse;
-#ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
         float rounded = floor(thetimes + random_buffer[start + i]);
-#else
-        float rounded = floor(thetimes + (*dis)(gen));
-#endif
-        rounded = std::max(0.0f, std::min(static_cast<float>(max_q), rounded));
+        rounded = rounded < static_cast<float>(max_q)
+            ? rounded
+            : static_cast<float>(max_q);
+        rounded = rounded > 0.0f ? rounded : 0.0f;
         uint8_t qval = rounded;
 
-        uint8_t orval = output_row[10 + i];
-        output_row[10 + i] = orval | static_cast<uint8_t>(qval << bit_start);
+        uint8_t orval = output_data[10 + i];
+        output_data[10 + i] = orval | static_cast<uint8_t>(qval << bit_start);
       }
       bit_start += bitwidth;
     }
@@ -127,12 +70,14 @@ void quantize_and_compress__base(
       for (; i < stride; ++i) {
         float fval = input_data[start + i];
         float thetimes = (fval - minimum_element) * gap_inverse;
-        thetimes =
-            std::max(0.0f, std::min(static_cast<float>(max_q), thetimes));
+        thetimes = thetimes < static_cast<float>(max_q)
+            ? thetimes
+            : static_cast<float>(max_q);
+        thetimes = thetimes > 0.0f ? thetimes : 0.0f;
         uint8_t qval = nearbyint(thetimes);
 
-        uint8_t orval = output_row[10 + i];
-        output_row[10 + i] = orval | static_cast<uint8_t>(qval << bit_start);
+        uint8_t orval = output_data[10 + i];
+        output_data[10 + i] = orval | static_cast<uint8_t>(qval << bit_start);
       }
       bit_start += bitwidth;
     }
@@ -145,15 +90,7 @@ void quantize_and_compress(
     size_t input_size,
     size_t bitwidth,
     bool random,
-#ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
-    VSLStreamStatePtr& vslStream,
-    std::vector<float>& random_buffer
-#else
-    std::unique_ptr<std::uniform_real_distribution<float>>& dis,
-    std::minstd_rand& gen
-#endif
-) {
-#ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
+    const float* random_buffer) {
   AVX2_DO(
       quantize_and_compress,
       input_data,
@@ -161,7 +98,6 @@ void quantize_and_compress(
       input_size,
       bitwidth,
       random,
-      vslStream,
       random_buffer);
   BASE_DO(
       quantize_and_compress,
@@ -170,54 +106,26 @@ void quantize_and_compress(
       input_size,
       bitwidth,
       random,
-      vslStream,
       random_buffer);
-#else
-  AVX2_DO(
-      quantize_and_compress,
-      input_data,
-      output_data,
-      input_size,
-      bitwidth,
-      random,
-      dis,
-      gen);
-  BASE_DO(
-      quantize_and_compress,
-      input_data,
-      output_data,
-      input_size,
-      bitwidth,
-      random,
-      dis,
-      gen);
-#endif
 }
 
 void decompress_and_dequantize__base(
     const uint8_t* input_data,
     float* output_data,
     size_t input_size) {
-  // memory pointers ///
-  ConstEigenVectorArrayMap<uint8_t> input_bitwidth_tail(input_data, 2);
-  ConstEigenVectorArrayMap<float> input_row_min_max(
-      reinterpret_cast<const float*>(input_data + 2), 2);
-
   // basic info
-  const float minimum_element = input_row_min_max(0);
-  const float maximum_element = input_row_min_max(1);
+  const float minimum_element =
+      reinterpret_cast<const float*>(input_data + 2)[0];
+  const float maximum_element =
+      reinterpret_cast<const float*>(input_data + 2)[1];
   const size_t bitwidth = input_data[0];
   const float gap =
       (maximum_element - minimum_element) / ((1 << bitwidth) - 1.f) +
       QEPSILON; // for exact recovering
 
-  CAFFE_ENFORCE(
-      bitwidth == 1 || bitwidth == 2 || bitwidth == 4 || bitwidth == 8,
-      "Unsupported bitwidth");
   const size_t tail = input_data[1];
 
   const size_t output_size = (input_size - 10) * (8 / bitwidth) - tail;
-  EigenVectorArrayMap<float> output_row(output_data, output_size);
   // decoding
   size_t bit_start = 0;
   const size_t segment_size = input_size - 10;
@@ -227,12 +135,11 @@ void decompress_and_dequantize__base(
     uint8_t mask = (1 << bitwidth) - 1;
     int i = 0;
     for (; i < stride; ++i) {
-      output_data[start + i] = ((input_data[10 + i] >> bit_start) & mask);
+      output_data[start + i] =
+          ((input_data[10 + i] >> bit_start) & mask) * gap + minimum_element;
     }
     bit_start += bitwidth;
   }
-  // scaling and biasing
-  output_row = output_row * gap + minimum_element;
 }
 
 void decompress_and_dequantize(
@@ -243,6 +150,5 @@ void decompress_and_dequantize(
   BASE_DO(decompress_and_dequantize, input_data, output_data, input_size);
 }
 
-#undef QEPSILON
 } // namespace math
 } // namespace caffe2
index e4d4cbb..fa6c17e 100644 (file)
@@ -1,8 +1,6 @@
-#include "caffe2/core/types.h"
 #include "caffe2/perfkernels/cvtsh_ss_bugfix.h"
-#include "caffe2/perfkernels/typed_axpy.h"
-#include "caffe2/utils/math.h"
 
+#include <ATen/core/Half.h>
 #include <emmintrin.h>
 #include <immintrin.h>