Resolve errors in perfkernel for Windows (#16031)
authorTongliang Liao <xkszltl@gmail.com>
Thu, 17 Jan 2019 05:38:13 +0000 (21:38 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 17 Jan 2019 05:51:00 +0000 (21:51 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16031

1. MSVC only has _mm_prefetch(const char*, int). Fixed in both python codegen and C++ files.
2. uint32_t in "cvtsh_ss_bugfix.h" requires "#include <cstdint>".
3. Some files use gflags headers. Add dependency via c10.
4. Isolate arch flags with interface library and private compile options.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15753

Reviewed By: dskhudia

Differential Revision: D13636233

Pulled By: jspark1105

fbshipit-source-id: cdcbd4240e07b749554a2a5676c11af88f23c31d

17 files changed:
caffe2/perfkernels/CMakeLists.txt
caffe2/perfkernels/adagrad.cc
caffe2/perfkernels/adagrad.h
caffe2/perfkernels/adagrad_avx.cc
caffe2/perfkernels/common.h
caffe2/perfkernels/cvtsh_ss_bugfix.h
caffe2/perfkernels/embedding_lookup.cc
caffe2/perfkernels/embedding_lookup_avx2.cc
caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_avx2.cc
caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc
caffe2/perfkernels/hp_emblookup_codegen.py
caffe2/perfkernels/math.h
caffe2/perfkernels/math_cpu_avx2.cc
caffe2/perfkernels/math_cpu_base.cc
caffe2/perfkernels/typed_axpy.cc
cmake/Dependencies.cmake
cmake/MiscCheck.cmake

index f2e2f86..3ca9ae5 100644 (file)
@@ -13,37 +13,69 @@ set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${common_srcs})
 
 # We will only build the perf kernel files if the compiler supports avx2
 # extensions.
-# Currently MSVC seems to have a symbol not found error while linking (related
-# to source file order?). As a result we will currently disable the perfkernel
-# in msvc.
-if (NOT MSVC AND CAFFE2_COMPILER_SUPPORTS_AVX2_EXTENSIONS)
-  add_library(Caffe2_perfkernels_avx OBJECT ${avx_srcs})
-  add_library(Caffe2_perfkernels_avx2 OBJECT ${avx2_srcs})
-  add_dependencies(Caffe2_perfkernels_avx Caffe2_PROTO c10)
-  add_dependencies(Caffe2_perfkernels_avx2 Caffe2_PROTO c10)
+if (CAFFE2_COMPILER_SUPPORTS_AVX2_EXTENSIONS)
+  add_library(Caffe2_perfkernels_avx STATIC ${avx_srcs})
+  add_library(Caffe2_perfkernels_avx2 STATIC ${avx2_srcs})
+  add_dependencies(Caffe2_perfkernels_avx Caffe2_PROTO)
+  add_dependencies(Caffe2_perfkernels_avx2 Caffe2_PROTO)
+  target_link_libraries(Caffe2_perfkernels_avx PRIVATE c10)
+  target_link_libraries(Caffe2_perfkernels_avx2 PRIVATE c10)
   if (MSVC)
-    set_target_properties(
-        Caffe2_perfkernels_avx PROPERTIES COMPILE_FLAGS "/arch:AVX")
-    set_target_properties(
-        Caffe2_perfkernels_avx2 PROPERTIES COMPILE_FLAGS "/arch:AVX2")
-    # Currently MSVC doesn't support AVX512
+    target_compile_options(Caffe2_perfkernels_avx
+        PRIVATE "/arch:AVX"
+        PRIVATE "/D__F16C__")
+    target_compile_options(Caffe2_perfkernels_avx2
+        PRIVATE "/arch:AVX2"
+        PRIVATE "/D__FMA__"
+        PRIVATE "/D__F16C__")
   else()
-    set_target_properties(
-        Caffe2_perfkernels_avx PROPERTIES COMPILE_FLAGS "-mavx -mf16c")
-    set_target_properties(
-        Caffe2_perfkernels_avx2 PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mavx -mf16c")
+    target_compile_options(Caffe2_perfkernels_avx
+        PRIVATE "-mavx"
+        PRIVATE "-mf16c")
+    target_compile_options(Caffe2_perfkernels_avx2
+        PRIVATE "-mavx2"
+        PRIVATE "-mfma"
+        PRIVATE "-mavx"
+        PRIVATE "-mf16c")
   endif()
-  set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS}
-      $<TARGET_OBJECTS:Caffe2_perfkernels_avx>
-      $<TARGET_OBJECTS:Caffe2_perfkernels_avx2>)
+  caffe2_interface_library(
+      Caffe2_perfkernels_avx Caffe2_perfkernels_avx_interface)
+  caffe2_interface_library(
+      Caffe2_perfkernels_avx2 Caffe2_perfkernels_avx2_interface)
+  list(APPEND
+       Caffe2_DEPENDENCY_WHOLE_LINK_LIBS
+       "Caffe2_perfkernels_avx_interface")
+  list(APPEND
+       Caffe2_DEPENDENCY_WHOLE_LINK_LIBS
+       "Caffe2_perfkernels_avx2_interface")
 
   if (CAFFE2_COMPILER_SUPPORTS_AVX512_EXTENSIONS)
-      add_library(Caffe2_perfkernels_avx512 OBJECT ${avx512_srcs})
-      add_dependencies(Caffe2_perfkernels_avx512 Caffe2_PROTO c10)
-      set_target_properties(
-          Caffe2_perfkernels_avx512 PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512dq -mavx512vl -mavx2 -mfma -mavx -mf16c")
-      set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS}
-          $<TARGET_OBJECTS:Caffe2_perfkernels_avx512>)
+    add_library(Caffe2_perfkernels_avx512 STATIC ${avx512_srcs})
+    add_dependencies(Caffe2_perfkernels_avx512 Caffe2_PROTO)
+    target_link_libraries(Caffe2_perfkernels_avx512 PRIVATE c10)
+    if (MSVC)
+      target_compile_options(Caffe2_perfkernels_avx512
+          PRIVATE "/D__AVX512F__"
+          PRIVATE "/D__AVX512DQ__"
+          PRIVATE "/D__AVX512VL__"
+          PRIVATE "/arch:AVX2"
+          PRIVATE "/D__FMA__"
+          PRIVATE "/D__F16C__")
+    else()
+      target_compile_options(Caffe2_perfkernels_avx512
+          PRIVATE "-mavx512f"
+          PRIVATE "-mavx512dq"
+          PRIVATE "-mavx512vl"
+          PRIVATE "-mavx2"
+          PRIVATE "-mfma"
+          PRIVATE "-mavx"
+          PRIVATE "-mf16c")
+    endif()
+    caffe2_interface_library(
+        Caffe2_perfkernels_avx512 Caffe2_perfkernels_avx512_interface)
+    list(APPEND
+         Caffe2_DEPENDENCY_WHOLE_LINK_LIBS
+         "Caffe2_perfkernels_avx512_interface")
   endif()
 endif()
 
@@ -54,3 +86,6 @@ endif()
 # more proper implementation.
 
 set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE)
+set(Caffe2_DEPENDENCY_WHOLE_LINK_LIBS
+    ${Caffe2_DEPENDENCY_WHOLE_LINK_LIBS}
+    PARENT_SCOPE)
index 0d6e25e..2c65616 100644 (file)
@@ -71,6 +71,7 @@ void rowwise_adagrad_update__base(
   internal::rowwise_adagrad_update_inlined(N, w, w_n, g, h, h_n, epsilon, lr);
 }
 
+decltype(adagrad_update_prefetch__base) adagrad_update_prefetch__avx_f16c;
 void adagrad_update_prefetch(
     int N,
     const float* w,
@@ -121,6 +122,8 @@ void adagrad_update_prefetch(
 
 // Version with prefetching for embeddings and
 // momentum using fp16
+decltype(
+    adagrad_fp16_update_prefetch__base) adagrad_fp16_update_prefetch__avx_f16c;
 void adagrad_fp16_update_prefetch(
     int N,
     const at::Half* w,
@@ -164,6 +167,7 @@ void adagrad_fp16_update_prefetch(
       lr);
 }
 
+decltype(rowwise_adagrad_update__base) rowwise_adagrad_update__avx_f16c;
 void rowwise_adagrad_update(
     int N,
     float* w,
@@ -181,6 +185,7 @@ void rowwise_adagrad_update(
 }
 
 // version without prefetching
+decltype(adagrad_update__base) adagrad_update__avx_f16c;
 void adagrad_update(
     int N,
     const float* w,
@@ -197,11 +202,12 @@ void adagrad_update(
 
 SPARSE_ADAGRAD_SPECIALIZATION(int32_t, base);
 
+decltype(sparse_adagrad_int32_t__base) sparse_adagrad_int32_t__avx_f16c;
 template <>
 void sparse_adagrad(
     int num_rows,
     int block_size,
-    size_t param_size,
+    uint64_t param_size,
     const float* w,
     const float* g,
     const float* h,
@@ -243,11 +249,12 @@ void sparse_adagrad(
 
 SPARSE_ADAGRAD_SPECIALIZATION(int64_t, base);
 
+decltype(sparse_adagrad_int64_t__base) sparse_adagrad_int64_t__avx_f16c;
 template <>
 void sparse_adagrad(
     int num_rows,
     int block_size,
-    size_t param_size,
+    uint64_t param_size,
     const float* w,
     const float* g,
     const float* h,
index 6ce1965..c39a1b3 100644 (file)
@@ -68,12 +68,12 @@ inline void adagrad_update_prefetch_inlined(
   auto i = 0;
 
 #ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
-  constexpr size_t kSize = 8;
+  constexpr int kSize = 8;
   for (; i + kSize <= N; i += kSize) {
-    _mm_prefetch(&w_n[i], _MM_HINT_T0);
-    _mm_prefetch(&h_n[i], _MM_HINT_T0);
-    _mm_prefetch(&nw_n[i], _MM_HINT_T0);
-    _mm_prefetch(&nh_n[i], _MM_HINT_T0);
+    _mm_prefetch(reinterpret_cast<const char*>(&w_n[i]), _MM_HINT_T0);
+    _mm_prefetch(reinterpret_cast<const char*>(&h_n[i]), _MM_HINT_T0);
+    _mm_prefetch(reinterpret_cast<const char*>(&nw_n[i]), _MM_HINT_T0);
+    _mm_prefetch(reinterpret_cast<const char*>(&nh_n[i]), _MM_HINT_T0);
 
     __m256 gi = _mm256_loadu_ps(g + i);
     __m256 hi = _mm256_loadu_ps(h + i);
@@ -115,8 +115,8 @@ inline void rowwise_adagrad_update_inlined(
   auto i = 0;
 
 #ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
-  constexpr size_t kSize = 8;
-  _mm_prefetch(h_n, _MM_HINT_T0);
+  constexpr int kSize = 8;
+  _mm_prefetch(reinterpret_cast<const char*>(h_n), _MM_HINT_T0);
   __m256 partial_sum = _mm256_setzero_ps();
   for (; i + kSize <= N; i += kSize) {
     __m256 gi = _mm256_loadu_ps(g + i);
@@ -144,7 +144,7 @@ inline void rowwise_adagrad_update_inlined(
   __m256 step = _mm256_set1_ps(float_step);
 
   for (i = 0; i + kSize <= N; i += kSize) {
-    _mm_prefetch(&w_n[i], _MM_HINT_T0);
+    _mm_prefetch(reinterpret_cast<const char*>(&w_n[i]), _MM_HINT_T0);
 
     __m256 gi = _mm256_loadu_ps(g + i);
     __m256 wi = _mm256_loadu_ps(w + i);
@@ -242,7 +242,7 @@ template <typename SIndex>
 void sparse_adagrad(
     int num_rows, // number of rows reading
     int block_size, // number of parameters per rows
-    std::size_t param_size, // total number of parameters
+    std::uint64_t param_size, // total number of parameters
     const float* w, // input parameters
     const float* g, // input gradients
     const float* h, // input momentums
@@ -257,7 +257,7 @@ void sparse_adagrad(
   void sparse_adagrad_##SIndex##__##ISA(                                 \
       int num_rows,                                                      \
       int block_size,                                                    \
-      std::size_t param_size,                                            \
+      std::uint64_t param_size,                                          \
       const float* w,                                                    \
       const float* g,                                                    \
       const float* h,                                                    \
index 36e355f..3c225e3 100644 (file)
@@ -59,13 +59,13 @@ void adagrad_fp16_update_prefetch__avx_f16c(
     at::Half* nh_n, // prefetch ptr
     float epsilon,
     float lr) {
-  constexpr size_t kSize = 8;
+  constexpr int kSize = 8;
   auto i = 0;
   for (; i + kSize <= N; i += kSize) {
-    _mm_prefetch(&w_n[i], _MM_HINT_T0);
-    _mm_prefetch(&h_n[i], _MM_HINT_T0);
-    _mm_prefetch(&nw_n[i], _MM_HINT_T0);
-    _mm_prefetch(&nh_n[i], _MM_HINT_T0);
+    _mm_prefetch(reinterpret_cast<const char*>(&w_n[i]), _MM_HINT_T0);
+    _mm_prefetch(reinterpret_cast<const char*>(&h_n[i]), _MM_HINT_T0);
+    _mm_prefetch(reinterpret_cast<const char*>(&nw_n[i]), _MM_HINT_T0);
+    _mm_prefetch(reinterpret_cast<const char*>(&nh_n[i]), _MM_HINT_T0);
 
     // only convert momentum and embedding, gradient is fp32
     __m256 gi = _mm256_loadu_ps(g + i);
@@ -119,7 +119,7 @@ void adagrad_update__avx_f16c(
     float epsilon,
     float decay,
     float lr) {
-  constexpr size_t kSize = 8;
+  constexpr int kSize = 8;
   auto i = 0;
   for (; i + kSize <= N; i += kSize) {
     __m256 gi = _mm256_loadu_ps(g + i);
index b128c76..1ceb9de 100644 (file)
@@ -33,6 +33,9 @@ In foo.cc, do:
    void foo__base(int a, float b) {
      [base, possibly slow implementation]
    }
+   decltype(foo__base) foo__avx512;
+   decltype(foo__base) foo__avx2;
+   decltype(foo__base) foo__avx;
    void foo(int a, float b) {
      // You should always order things by their preference, faster
      // implementations earlier in the function.
@@ -49,11 +52,11 @@ In foo.cc, do:
 // During build time:
 //    The build system should provide flags CAFFE2_PERF_WITH_AVX512,
 //    CAFFE2_PERF_WITH_AVX2, and CAFFE2_PERF_WITH_AVX that corresponds to the
-//    __AVX512F__, __AVX512DQ__, __AVX512VL__, __AVX__, and __AVX2__ flags the
+//    __AVX512F__, __AVX512DQ__, __AVX512VL__, __AVX2__, and __AVX__ flags the
 //    compiler provides. Note that we do not use the compiler flags but rely on
 //    the build system flags, because the common files (like foo.cc above) will
-//    always be built without __AVX512F__, __AVX512DQ__, __AVX512VL__, __AVX__
-//    and __AVX2__.
+//    always be built without __AVX512F__, __AVX512DQ__, __AVX512VL__, __AVX2__
+//    and __AVX__.
 // During run time:
 //    we use cpuid to identify cpu support and run the proper functions.
 
@@ -68,7 +71,6 @@ In foo.cc, do:
 
 #ifdef CAFFE2_PERF_WITH_AVX512
 #define AVX512_DO(funcname, ...)                       \
-  decltype(funcname##__base) funcname##__avx512;       \
   if (GetCpuId().avx512f() && GetCpuId().avx512dq() && \
       GetCpuId().avx512vl()) {                         \
     return funcname##__avx512(__VA_ARGS__);            \
@@ -78,15 +80,13 @@ In foo.cc, do:
 #endif // CAFFE2_PERF_WITH_AVX512
 
 #ifdef CAFFE2_PERF_WITH_AVX2
-#define AVX2_DO(funcname, ...)                 \
-  decltype(funcname##__base) funcname##__avx2; \
-  if (GetCpuId().avx2()) {                     \
-    return funcname##__avx2(__VA_ARGS__);      \
+#define AVX2_DO(funcname, ...)            \
+  if (GetCpuId().avx2()) {                \
+    return funcname##__avx2(__VA_ARGS__); \
   }
-#define AVX2_FMA_DO(funcname, ...)                 \
-  decltype(funcname##__base) funcname##__avx2_fma; \
-  if (GetCpuId().avx2() && GetCpuId().fma()) {     \
-    return funcname##__avx2_fma(__VA_ARGS__);      \
+#define AVX2_FMA_DO(funcname, ...)             \
+  if (GetCpuId().avx2() && GetCpuId().fma()) { \
+    return funcname##__avx2_fma(__VA_ARGS__);  \
   }
 #else // CAFFE2_PERF_WITH_AVX2
 #define AVX2_DO(funcname, ...)
@@ -94,15 +94,13 @@ In foo.cc, do:
 #endif // CAFFE2_PERF_WITH_AVX2
 
 #ifdef CAFFE2_PERF_WITH_AVX
-#define AVX_DO(funcname, ...)                 \
-  decltype(funcname##__base) funcname##__avx; \
-  if (GetCpuId().avx()) {                     \
-    return funcname##__avx(__VA_ARGS__);      \
+#define AVX_DO(funcname, ...)            \
+  if (GetCpuId().avx()) {                \
+    return funcname##__avx(__VA_ARGS__); \
   }
-#define AVX_F16C_DO(funcname, ...)                 \
-  decltype(funcname##__base) funcname##__avx_f16c; \
-  if (GetCpuId().avx() && GetCpuId().f16c()) {     \
-    return funcname##__avx_f16c(__VA_ARGS__);      \
+#define AVX_F16C_DO(funcname, ...)             \
+  if (GetCpuId().avx() && GetCpuId().f16c()) { \
+    return funcname##__avx_f16c(__VA_ARGS__);  \
   }
 #else // CAFFE2_PERF_WITH_AVX
 #define AVX_DO(funcname, ...)
index ee20ce6..825e266 100644 (file)
@@ -32,16 +32,17 @@ _cvtsh_ss(unsigned short a)
 
 #ifdef _MSC_VER
 
+#include <cstdint>
+
 // It seems that microsoft msvc does not have a _cvtsh_ss implementation so
 // we will add a dummy version to it.
 
-static inline float
-_cvtsh_ss(unsigned short x) {
+static inline float _cvtsh_ss(unsigned short x) {
   union {
-    uint32_t intval;
+    std::uint32_t intval;
     float floatval;
   } t1;
-  uint32_t t2, t3;
+  std::uint32_t t2, t3;
   t1.intval = x & 0x7fff; // Non-sign bits
   t2 = x & 0x8000; // Sign bit
   t3 = x & 0x7c00; // Exponent
index fa93ae7..e8c30a0 100644 (file)
@@ -81,83 +81,86 @@ static void EmbeddingLookupGenericSlow(
 }
 
 // Proxy back to generic implementation
-#define EMBEDDING_SPECIALIZATION(                                                                      \
-    IndexTypeName,                                                                                     \
-    IndexType,                                                                                         \
-    InTypeName,                                                                                        \
-    InType,                                                                                            \
-    OutTypeName,                                                                                       \
-    OutType,                                                                                           \
-    IS_WEIGHT_POSITIONAL)                                                                              \
-  void                                                                                                 \
-      EmbeddingLookup_##IndexTypeName##_##InTypeName##_##OutTypeName##_##IS_WEIGHT_POSITIONAL##__base( \
-          const int64_t block_size,                                                                    \
-          const int64_t output_size,                                                                   \
-          const int64_t index_size,                                                                    \
-          const int64_t data_size,                                                                     \
-          const InType* input,                                                                         \
-          const IndexType* indices,                                                                    \
-          const int* lengths,                                                                          \
-          const float* weights,                                                                        \
-          const float* scale_bias,                                                                     \
-          bool normalize_by_lengths,                                                                   \
-          OutType* out) {                                                                              \
-    EmbeddingLookupGenericSlow<                                                                        \
-        IndexType,                                                                                     \
-        InType,                                                                                        \
-        OutType,                                                                                       \
-        IS_WEIGHT_POSITIONAL>(                                                                         \
-        block_size,                                                                                    \
-        output_size,                                                                                   \
-        index_size,                                                                                    \
-        data_size,                                                                                     \
-        input,                                                                                         \
-        indices,                                                                                       \
-        lengths,                                                                                       \
-        weights,                                                                                       \
-        scale_bias,                                                                                    \
-        normalize_by_lengths,                                                                          \
-        out);                                                                                          \
-  }                                                                                                    \
-  template <>                                                                                          \
-  void EmbeddingLookup<IndexType, InType, OutType, IS_WEIGHT_POSITIONAL>(                              \
-      const int64_t block_size,                                                                        \
-      const int64_t output_size,                                                                       \
-      const int64_t index_size,                                                                        \
-      const int64_t data_size,                                                                         \
-      const InType* input,                                                                             \
-      const IndexType* indices,                                                                        \
-      const int* lengths,                                                                              \
-      const float* weights,                                                                            \
-      const float* scale_bias,                                                                         \
-      bool normalize_by_lengths,                                                                       \
-      OutType* out) {                                                                                  \
-    AVX2_FMA_DO(                                                                                       \
-        EmbeddingLookup_##IndexTypeName##_##InTypeName##_##OutTypeName##_##IS_WEIGHT_POSITIONAL,       \
-        block_size,                                                                                    \
-        output_size,                                                                                   \
-        index_size,                                                                                    \
-        data_size,                                                                                     \
-        input,                                                                                         \
-        indices,                                                                                       \
-        lengths,                                                                                       \
-        weights,                                                                                       \
-        scale_bias,                                                                                    \
-        normalize_by_lengths,                                                                          \
-        out);                                                                                          \
-    BASE_DO(                                                                                           \
-        EmbeddingLookup_##IndexTypeName##_##InTypeName##_##OutTypeName##_##IS_WEIGHT_POSITIONAL,       \
-        block_size,                                                                                    \
-        output_size,                                                                                   \
-        index_size,                                                                                    \
-        data_size,                                                                                     \
-        input,                                                                                         \
-        indices,                                                                                       \
-        lengths,                                                                                       \
-        weights,                                                                                       \
-        scale_bias,                                                                                    \
-        normalize_by_lengths,                                                                          \
-        out);                                                                                          \
+#define EMBEDDING_SPECIALIZATION(                                                                          \
+    IndexTypeName,                                                                                         \
+    IndexType,                                                                                             \
+    InTypeName,                                                                                            \
+    InType,                                                                                                \
+    OutTypeName,                                                                                           \
+    OutType,                                                                                               \
+    IS_WEIGHT_POSITIONAL)                                                                                  \
+  void                                                                                                     \
+      EmbeddingLookup_##IndexTypeName##_##InTypeName##_##OutTypeName##_##IS_WEIGHT_POSITIONAL##__base(     \
+          const int64_t block_size,                                                                        \
+          const int64_t output_size,                                                                       \
+          const int64_t index_size,                                                                        \
+          const int64_t data_size,                                                                         \
+          const InType* input,                                                                             \
+          const IndexType* indices,                                                                        \
+          const int* lengths,                                                                              \
+          const float* weights,                                                                            \
+          const float* scale_bias,                                                                         \
+          bool normalize_by_lengths,                                                                       \
+          OutType* out) {                                                                                  \
+    EmbeddingLookupGenericSlow<                                                                            \
+        IndexType,                                                                                         \
+        InType,                                                                                            \
+        OutType,                                                                                           \
+        IS_WEIGHT_POSITIONAL>(                                                                             \
+        block_size,                                                                                        \
+        output_size,                                                                                       \
+        index_size,                                                                                        \
+        data_size,                                                                                         \
+        input,                                                                                             \
+        indices,                                                                                           \
+        lengths,                                                                                           \
+        weights,                                                                                           \
+        scale_bias,                                                                                        \
+        normalize_by_lengths,                                                                              \
+        out);                                                                                              \
+  }                                                                                                        \
+  decltype(                                                                                                \
+      EmbeddingLookup_##IndexTypeName##_##InTypeName##_##OutTypeName##_##IS_WEIGHT_POSITIONAL##__base)     \
+      EmbeddingLookup_##IndexTypeName##_##InTypeName##_##OutTypeName##_##IS_WEIGHT_POSITIONAL##__avx2_fma; \
+  template <>                                                                                              \
+  void EmbeddingLookup<IndexType, InType, OutType, IS_WEIGHT_POSITIONAL>(                                  \
+      const int64_t block_size,                                                                            \
+      const int64_t output_size,                                                                           \
+      const int64_t index_size,                                                                            \
+      const int64_t data_size,                                                                             \
+      const InType* input,                                                                                 \
+      const IndexType* indices,                                                                            \
+      const int* lengths,                                                                                  \
+      const float* weights,                                                                                \
+      const float* scale_bias,                                                                             \
+      bool normalize_by_lengths,                                                                           \
+      OutType* out) {                                                                                      \
+    AVX2_FMA_DO(                                                                                           \
+        EmbeddingLookup_##IndexTypeName##_##InTypeName##_##OutTypeName##_##IS_WEIGHT_POSITIONAL,           \
+        block_size,                                                                                        \
+        output_size,                                                                                       \
+        index_size,                                                                                        \
+        data_size,                                                                                         \
+        input,                                                                                             \
+        indices,                                                                                           \
+        lengths,                                                                                           \
+        weights,                                                                                           \
+        scale_bias,                                                                                        \
+        normalize_by_lengths,                                                                              \
+        out);                                                                                              \
+    BASE_DO(                                                                                               \
+        EmbeddingLookup_##IndexTypeName##_##InTypeName##_##OutTypeName##_##IS_WEIGHT_POSITIONAL,           \
+        block_size,                                                                                        \
+        output_size,                                                                                       \
+        index_size,                                                                                        \
+        data_size,                                                                                         \
+        input,                                                                                             \
+        indices,                                                                                           \
+        lengths,                                                                                           \
+        weights,                                                                                           \
+        scale_bias,                                                                                        \
+        normalize_by_lengths,                                                                              \
+        out);                                                                                              \
   }
 
 EMBEDDING_SPECIALIZATION(int32_t, int32_t, float, float, float, float, false);
index 326818b..89fcc4b 100644 (file)
@@ -73,35 +73,43 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma(
         CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
         // skip unnecessary prefetch of (&ip_next_T0[8])
         vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
-        _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
         vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
         // skip unnecessary prefetch of (&ip_next_T0[24])
         vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
-        _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
         vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
         // skip unnecessary prefetch of (&ip_next_T0[40])
         vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
-        _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
         vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
         // skip unnecessary prefetch of (&ip_next_T0[56])
         vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
-        _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
         vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
         // skip unnecessary prefetch of (&ip_next_T0[72])
         vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
-        _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[80]), _MM_HINT_T0);
         vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
         // skip unnecessary prefetch of (&ip_next_T0[88])
         vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
-        _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
         vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
         // skip unnecessary prefetch of (&ip_next_T0[104])
         vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
-        _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[112]), _MM_HINT_T0);
         vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
         // skip unnecessary prefetch of (&ip_next_T0[120])
       }
@@ -179,19 +187,23 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma(
         CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
         // skip unnecessary prefetch of (&ip_next_T0[8])
         vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
-        _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
         vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
         // skip unnecessary prefetch of (&ip_next_T0[24])
         vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
-        _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
         vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
         // skip unnecessary prefetch of (&ip_next_T0[40])
         vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
-        _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
         vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
         // skip unnecessary prefetch of (&ip_next_T0[56])
       }
@@ -249,11 +261,13 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma(
         CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
         // skip unnecessary prefetch of (&ip_next_T0[8])
         vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
-        _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
         vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
         // skip unnecessary prefetch of (&ip_next_T0[24])
       }
@@ -301,7 +315,8 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma(
         CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
         // skip unnecessary prefetch of (&ip_next_T0[8])
       }
@@ -355,7 +370,8 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma(
               &op[j],
               _mm256_fmadd_ps(
                   vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
-          _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
+          _mm_prefetch(
+              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
         }
         for (; j < block_size; j++) {
           op[j] += wgt * ip[j];
@@ -488,35 +504,43 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma(
         CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
         // skip unnecessary prefetch of (&ip_next_T0[8])
         vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
-        _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
         vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
         // skip unnecessary prefetch of (&ip_next_T0[24])
         vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
-        _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
         vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
         // skip unnecessary prefetch of (&ip_next_T0[40])
         vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
-        _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
         vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
         // skip unnecessary prefetch of (&ip_next_T0[56])
         vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
-        _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
         vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
         // skip unnecessary prefetch of (&ip_next_T0[72])
         vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
-        _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[80]), _MM_HINT_T0);
         vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
         // skip unnecessary prefetch of (&ip_next_T0[88])
         vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
-        _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
         vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
         // skip unnecessary prefetch of (&ip_next_T0[104])
         vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
-        _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[112]), _MM_HINT_T0);
         vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
         // skip unnecessary prefetch of (&ip_next_T0[120])
       }
@@ -594,19 +618,23 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma(
         CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
         // skip unnecessary prefetch of (&ip_next_T0[8])
         vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
-        _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
         vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
         // skip unnecessary prefetch of (&ip_next_T0[24])
         vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
-        _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
         vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
         // skip unnecessary prefetch of (&ip_next_T0[40])
         vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
-        _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
         vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
         // skip unnecessary prefetch of (&ip_next_T0[56])
       }
@@ -664,11 +692,13 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma(
         CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
         // skip unnecessary prefetch of (&ip_next_T0[8])
         vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
-        _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
         vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
         // skip unnecessary prefetch of (&ip_next_T0[24])
       }
@@ -716,7 +746,8 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma(
         CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
         // skip unnecessary prefetch of (&ip_next_T0[8])
       }
@@ -770,7 +801,8 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma(
               &op[j],
               _mm256_fmadd_ps(
                   vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
-          _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
+          _mm_prefetch(
+              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
         }
         for (; j < block_size; j++) {
           op[j] += wgt * ip[j];
@@ -907,7 +939,8 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
             vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -931,7 +964,8 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
             vop32);
-        _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
         vop40 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -955,7 +989,8 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
             vop64);
-        _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
         vop72 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -979,7 +1014,8 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
             vop96);
-        _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
         vop104 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -1077,7 +1113,8 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
             vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -1101,7 +1138,8 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
             vop32);
-        _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
         vop40 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -1179,7 +1217,8 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
             vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -1247,7 +1286,8 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
             vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -1308,7 +1348,8 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma(
                   _mm256_cvtph_ps(_mm_loadu_si128(
                       reinterpret_cast<const __m128i*>(&ip[j]))),
                   _mm256_loadu_ps(&op[j])));
-          _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
+          _mm_prefetch(
+              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
         }
         alignas(64) at::Half vtmp1[8];
         for (; j < block_size; j++) {
@@ -1448,7 +1489,8 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
             vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -1472,7 +1514,8 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
             vop32);
-        _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
         vop40 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -1496,7 +1539,8 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
             vop64);
-        _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
         vop72 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -1520,7 +1564,8 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
             vop96);
-        _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
         vop104 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -1618,7 +1663,8 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
             vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -1642,7 +1688,8 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
             vop32);
-        _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
         vop40 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -1720,7 +1767,8 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
             vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -1788,7 +1836,8 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
             vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -1849,7 +1898,8 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma(
                   _mm256_cvtph_ps(_mm_loadu_si128(
                       reinterpret_cast<const __m128i*>(&ip[j]))),
                   _mm256_loadu_ps(&op[j])));
-          _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
+          _mm_prefetch(
+              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
         }
         alignas(64) at::Half vtmp1[8];
         for (; j < block_size; j++) {
@@ -1993,7 +2043,8 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
             _mm256_add_ps(vop0, vbio));
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
@@ -2041,7 +2092,8 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
             _mm256_add_ps(vop64, vbio));
-        _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
         vop72 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
@@ -2167,7 +2219,8 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
             _mm256_add_ps(vop0, vbio));
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
@@ -2273,7 +2326,8 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
             _mm256_add_ps(vop0, vbio));
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
@@ -2345,7 +2399,8 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
             _mm256_add_ps(vop0, vbio));
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
@@ -2411,7 +2466,8 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
                   _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
                       reinterpret_cast<const __m128i*>(&ip[j])))),
                   _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
-          _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
+          _mm_prefetch(
+              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
         }
         for (; j < block_size; j++) {
           op[j] += wgt * ((float)ip[j]) + bio;
@@ -2552,7 +2608,8 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
             _mm256_add_ps(vop0, vbio));
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
@@ -2600,7 +2657,8 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
             _mm256_add_ps(vop64, vbio));
-        _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
         vop72 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
@@ -2726,7 +2784,8 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
             _mm256_add_ps(vop0, vbio));
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
@@ -2832,7 +2891,8 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
             _mm256_add_ps(vop0, vbio));
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
@@ -2904,7 +2964,8 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
             _mm256_add_ps(vop0, vbio));
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
@@ -2970,7 +3031,8 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
                   _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
                       reinterpret_cast<const __m128i*>(&ip[j])))),
                   _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
-          _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
+          _mm_prefetch(
+              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
         }
         for (; j < block_size; j++) {
           op[j] += wgt * ((float)ip[j]) + bio;
index 1f4a831..0ae15c8 100644 (file)
@@ -71,35 +71,43 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma(
         CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
         // skip unnecessary prefetch of (&ip_next_T0[8])
         vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
-        _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
         vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
         // skip unnecessary prefetch of (&ip_next_T0[24])
         vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
-        _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
         vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
         // skip unnecessary prefetch of (&ip_next_T0[40])
         vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
-        _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
         vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
         // skip unnecessary prefetch of (&ip_next_T0[56])
         vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
-        _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
         vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
         // skip unnecessary prefetch of (&ip_next_T0[72])
         vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
-        _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[80]), _MM_HINT_T0);
         vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
         // skip unnecessary prefetch of (&ip_next_T0[88])
         vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
-        _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
         vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
         // skip unnecessary prefetch of (&ip_next_T0[104])
         vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
-        _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[112]), _MM_HINT_T0);
         vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
         // skip unnecessary prefetch of (&ip_next_T0[120])
       }
@@ -177,19 +185,23 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma(
         CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
         // skip unnecessary prefetch of (&ip_next_T0[8])
         vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
-        _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
         vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
         // skip unnecessary prefetch of (&ip_next_T0[24])
         vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
-        _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
         vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
         // skip unnecessary prefetch of (&ip_next_T0[40])
         vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
-        _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
         vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
         // skip unnecessary prefetch of (&ip_next_T0[56])
       }
@@ -247,11 +259,13 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma(
         CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
         // skip unnecessary prefetch of (&ip_next_T0[8])
         vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
-        _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
         vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
         // skip unnecessary prefetch of (&ip_next_T0[24])
       }
@@ -299,7 +313,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma(
         CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
         // skip unnecessary prefetch of (&ip_next_T0[8])
       }
@@ -353,7 +368,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma(
               &op[j],
               _mm256_fmadd_ps(
                   vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
-          _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
+          _mm_prefetch(
+              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
         }
         for (; j < block_size; j++) {
           op[j] += wgt * ip[j];
@@ -480,35 +496,43 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma(
         CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
         // skip unnecessary prefetch of (&ip_next_T0[8])
         vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
-        _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
         vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
         // skip unnecessary prefetch of (&ip_next_T0[24])
         vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
-        _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
         vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
         // skip unnecessary prefetch of (&ip_next_T0[40])
         vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
-        _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
         vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
         // skip unnecessary prefetch of (&ip_next_T0[56])
         vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
-        _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
         vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
         // skip unnecessary prefetch of (&ip_next_T0[72])
         vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
-        _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[80]), _MM_HINT_T0);
         vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
         // skip unnecessary prefetch of (&ip_next_T0[88])
         vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
-        _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
         vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
         // skip unnecessary prefetch of (&ip_next_T0[104])
         vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
-        _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[112]), _MM_HINT_T0);
         vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
         // skip unnecessary prefetch of (&ip_next_T0[120])
       }
@@ -586,19 +610,23 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma(
         CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
         // skip unnecessary prefetch of (&ip_next_T0[8])
         vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
-        _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
         vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
         // skip unnecessary prefetch of (&ip_next_T0[24])
         vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
-        _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
         vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
         // skip unnecessary prefetch of (&ip_next_T0[40])
         vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
-        _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
         vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
         // skip unnecessary prefetch of (&ip_next_T0[56])
       }
@@ -656,11 +684,13 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma(
         CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
         // skip unnecessary prefetch of (&ip_next_T0[8])
         vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
-        _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
         vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
         // skip unnecessary prefetch of (&ip_next_T0[24])
       }
@@ -708,7 +738,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma(
         CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
         const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
         vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
         // skip unnecessary prefetch of (&ip_next_T0[8])
       }
@@ -762,7 +793,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma(
               &op[j],
               _mm256_fmadd_ps(
                   vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
-          _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
+          _mm_prefetch(
+              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
         }
         for (; j < block_size; j++) {
           op[j] += wgt * ip[j];
@@ -893,7 +925,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
             vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -917,7 +950,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
             vop32);
-        _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
         vop40 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -941,7 +975,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
             vop64);
-        _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
         vop72 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -965,7 +1000,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
             vop96);
-        _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
         vop104 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -1063,7 +1099,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
             vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -1087,7 +1124,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
             vop32);
-        _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
         vop40 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -1165,7 +1203,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
             vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -1233,7 +1272,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
             vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -1294,7 +1334,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma(
                   _mm256_cvtph_ps(_mm_loadu_si128(
                       reinterpret_cast<const __m128i*>(&ip[j]))),
                   _mm256_loadu_ps(&op[j])));
-          _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
+          _mm_prefetch(
+              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
         }
         alignas(64) at::Half vtmp1[8];
         for (; j < block_size; j++) {
@@ -1428,7 +1469,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
             vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -1452,7 +1494,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
             vop32);
-        _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
         vop40 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -1476,7 +1519,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
             vop64);
-        _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
         vop72 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -1500,7 +1544,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
             vop96);
-        _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
         vop104 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -1598,7 +1643,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
             vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -1622,7 +1668,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
             vop32);
-        _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
         vop40 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -1700,7 +1747,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
             vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -1768,7 +1816,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma(
             _mm256_cvtph_ps(
                 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
             vop0);
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtph_ps(
@@ -1829,7 +1878,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma(
                   _mm256_cvtph_ps(_mm_loadu_si128(
                       reinterpret_cast<const __m128i*>(&ip[j]))),
                   _mm256_loadu_ps(&op[j])));
-          _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
+          _mm_prefetch(
+              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
         }
         alignas(64) at::Half vtmp1[8];
         for (; j < block_size; j++) {
@@ -1969,7 +2019,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
             _mm256_add_ps(vop0, vbio));
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
@@ -2017,7 +2068,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
             _mm256_add_ps(vop64, vbio));
-        _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
         vop72 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
@@ -2145,7 +2197,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
             _mm256_add_ps(vop0, vbio));
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
@@ -2253,7 +2306,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
             _mm256_add_ps(vop0, vbio));
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
@@ -2327,7 +2381,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
             _mm256_add_ps(vop0, vbio));
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
@@ -2394,7 +2449,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
                   _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
                       reinterpret_cast<const __m128i*>(&ip[j])))),
                   _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
-          _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
+          _mm_prefetch(
+              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
         }
         for (; j < block_size; j++) {
           op[j] += wgt * ((float)ip[j]) + bio;
@@ -2531,7 +2587,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
             _mm256_add_ps(vop0, vbio));
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
@@ -2579,7 +2636,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
             _mm256_add_ps(vop64, vbio));
-        _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
         vop72 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
@@ -2707,7 +2765,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
             _mm256_add_ps(vop0, vbio));
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
@@ -2815,7 +2874,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
             _mm256_add_ps(vop0, vbio));
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
@@ -2889,7 +2949,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
                 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
             _mm256_add_ps(vop0, vbio));
-        _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+        _mm_prefetch(
+            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
         vop8 = _mm256_fmadd_ps(
             vwgt,
             _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
@@ -2956,7 +3017,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
                   _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
                       reinterpret_cast<const __m128i*>(&ip[j])))),
                   _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
-          _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
+          _mm_prefetch(
+              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
         }
         for (; j < block_size; j++) {
           op[j] += wgt * ((float)ip[j]) + bio;
index 68c8c87..d8f6a43 100644 (file)
@@ -85,77 +85,80 @@ static void Fused8BitRowwiseEmbeddingLookupGenericSlow(
 }
 
 // Proxy back to generic implementation
-#define FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(                                    \
-    IndexType, InType, OutType)                                                         \
-  void                                                                                  \
-      Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##_false__base( \
-          const int64_t block_size,                                                      \
-          const int64_t output_size,                                                     \
-          const int64_t index_size,                                                      \
-          const int64_t data_size,                                                       \
-          const InType* input,                                                          \
-          const IndexType* indices,                                                     \
-          const int* lengths,                                                           \
-          const float* weights,                                                         \
-          bool normalize_by_lengths,                                                    \
-          OutType* out) {                                                               \
-    Fused8BitRowwiseEmbeddingLookupGenericSlow<                                         \
-        IndexType,                                                                      \
-        InType,                                                                         \
-        OutType,                                                                        \
-        false>(                                                                         \
-        block_size,                                                                     \
-        output_size,                                                                    \
-        index_size,                                                                     \
-        data_size,                                                                      \
-        input,                                                                          \
-        indices,                                                                        \
-        lengths,                                                                        \
-        weights,                                                                        \
-        normalize_by_lengths,                                                           \
-        out);                                                                           \
-  }                                                                                     \
-  template <>                                                                           \
-  void Fused8BitRowwiseEmbeddingLookup<IndexType, InType, OutType, false>(              \
-      const int64_t block_size,                                                          \
-      const int64_t output_size,                                                         \
-      const int64_t index_size,                                                          \
-      const int64_t data_size,                                                           \
-      const InType* input,                                                              \
-      const IndexType* indices,                                                         \
-      const int* lengths,                                                               \
-      const float* weights,                                                             \
-      bool normalize_by_lengths,                                                        \
-      OutType* out) {                                                                   \
-    const int32_t one = 1;                                                              \
-    CAFFE_ENFORCE_EQ(                                                                   \
-        reinterpret_cast<const uint8_t*>(&one)[0],                                      \
-        1,                                                                              \
-        "Fused8BitRowwiseEmbeddingLookup is not supported on this platform");           \
-    AVX2_FMA_DO(                                                                        \
-        Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##_false,     \
-        block_size,                                                                     \
-        output_size,                                                                    \
-        index_size,                                                                     \
-        data_size,                                                                      \
-        input,                                                                          \
-        indices,                                                                        \
-        lengths,                                                                        \
-        weights,                                                                        \
-        normalize_by_lengths,                                                           \
-        out);                                                                           \
-    BASE_DO(                                                                            \
-        Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##_false,     \
-        block_size,                                                                     \
-        output_size,                                                                    \
-        index_size,                                                                     \
-        data_size,                                                                      \
-        input,                                                                          \
-        indices,                                                                        \
-        lengths,                                                                        \
-        weights,                                                                        \
-        normalize_by_lengths,                                                           \
-        out);                                                                           \
+#define FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(                                        \
+    IndexType, InType, OutType)                                                             \
+  void                                                                                      \
+      Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##_false__base(     \
+          const int64_t block_size,                                                         \
+          const int64_t output_size,                                                        \
+          const int64_t index_size,                                                         \
+          const int64_t data_size,                                                          \
+          const InType* input,                                                              \
+          const IndexType* indices,                                                         \
+          const int* lengths,                                                               \
+          const float* weights,                                                             \
+          bool normalize_by_lengths,                                                        \
+          OutType* out) {                                                                   \
+    Fused8BitRowwiseEmbeddingLookupGenericSlow<                                             \
+        IndexType,                                                                          \
+        InType,                                                                             \
+        OutType,                                                                            \
+        false>(                                                                             \
+        block_size,                                                                         \
+        output_size,                                                                        \
+        index_size,                                                                         \
+        data_size,                                                                          \
+        input,                                                                              \
+        indices,                                                                            \
+        lengths,                                                                            \
+        weights,                                                                            \
+        normalize_by_lengths,                                                               \
+        out);                                                                               \
+  }                                                                                         \
+  decltype(                                                                                 \
+      Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##_false__base)     \
+      Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##_false__avx2_fma; \
+  template <>                                                                               \
+  void Fused8BitRowwiseEmbeddingLookup<IndexType, InType, OutType, false>(                  \
+      const int64_t block_size,                                                             \
+      const int64_t output_size,                                                            \
+      const int64_t index_size,                                                             \
+      const int64_t data_size,                                                              \
+      const InType* input,                                                                  \
+      const IndexType* indices,                                                             \
+      const int* lengths,                                                                   \
+      const float* weights,                                                                 \
+      bool normalize_by_lengths,                                                            \
+      OutType* out) {                                                                       \
+    const int32_t one = 1;                                                                  \
+    CAFFE_ENFORCE_EQ(                                                                       \
+        reinterpret_cast<const uint8_t*>(&one)[0],                                          \
+        1,                                                                                  \
+        "Fused8BitRowwiseEmbeddingLookup is not supported on this platform");               \
+    AVX2_FMA_DO(                                                                            \
+        Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##_false,         \
+        block_size,                                                                         \
+        output_size,                                                                        \
+        index_size,                                                                         \
+        data_size,                                                                          \
+        input,                                                                              \
+        indices,                                                                            \
+        lengths,                                                                            \
+        weights,                                                                            \
+        normalize_by_lengths,                                                               \
+        out);                                                                               \
+    BASE_DO(                                                                                \
+        Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##_false,         \
+        block_size,                                                                         \
+        output_size,                                                                        \
+        index_size,                                                                         \
+        data_size,                                                                          \
+        input,                                                                              \
+        indices,                                                                            \
+        lengths,                                                                            \
+        weights,                                                                            \
+        normalize_by_lengths,                                                               \
+        out);                                                                               \
   }
 
 FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(int32_t, uint8_t, float);
index 20f759c..748f5ce 100644 (file)
@@ -37,7 +37,9 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused):
 
         if prefetch:
             code.append(
-                "        _mm_prefetch((&ip_next_T0[%d]), _MM_HINT_T0);" % (regid)
+                "        _mm_prefetch(\n"
+                "            reinterpret_cast<const char*>(&ip_next_T0[%d]), _MM_HINT_T0);"
+                % (regid)
             )
         else:
             code.append(
@@ -178,7 +180,10 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused):
         else:
             assert False
 
-        code.append("          _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);")
+        code.append(
+            "          _mm_prefetch(\n"
+            "              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);"
+        )
 
         return code
 
index 14265f9..63380fc 100644 (file)
@@ -21,15 +21,15 @@ namespace math {
 void quantize_and_compress(
     const float* input_data,
     std::uint8_t* output_data,
-    std::size_t input_size,
-    std::size_t bitwidth,
+    std::uint64_t input_size,
+    std::uint64_t bitwidth,
     bool random,
     const float* random_buffer);
 
 void decompress_and_dequantize(
     const std::uint8_t* input_data,
     float* output_data,
-    std::size_t input_size);
+    std::uint64_t input_size);
 
 } // namespace math
 } // namespace caffe2
index 95292c3..cf99a67 100644 (file)
@@ -7,6 +7,9 @@
 #include <cmath>
 #include <cstdint>
 
+using std::uint64_t;
+using std::uint8_t;
+
 namespace caffe2 {
 
 namespace math {
@@ -16,8 +19,8 @@ static constexpr double QEPSILON = 1e-8;
 void quantize_and_compress__avx2(
     const float* input_data,
     uint8_t* output_data,
-    size_t input_size,
-    size_t bitwidth,
+    uint64_t input_size,
+    uint64_t bitwidth,
     bool random,
     const float* random_buffer) {
   __m256i shuffle_mask_v = _mm256_set_epi8(
@@ -56,10 +59,10 @@ void quantize_and_compress__avx2(
   __m256i permute_mask_v =
       _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00);
 
-  size_t data_per_byte = 8 / bitwidth;
-  size_t tail = input_size % data_per_byte;
+  uint64_t data_per_byte = 8 / bitwidth;
+  uint64_t tail = input_size % data_per_byte;
   tail = tail ? data_per_byte - tail : 0;
-  size_t segment_size = (input_size + data_per_byte - 1) / data_per_byte;
+  uint64_t segment_size = (input_size + data_per_byte - 1) / data_per_byte;
 
   // basic info
   float minimum_element = INFINITY, maximum_element = -INFINITY;
@@ -77,11 +80,11 @@ void quantize_and_compress__avx2(
   float gap = (maximum_element - minimum_element) / ((1 << bitwidth) - 1.0f);
   float gap_inverse = 1. / (gap + QEPSILON);
   uint8_t max_q = (1 << bitwidth) - 1;
-  size_t bit_start = 0;
+  uint64_t bit_start = 0;
   if (random) {
     for (int start = 0; start < input_size; start += segment_size) {
-      size_t stride = start + segment_size <= input_size ? segment_size
-                                                         : input_size - start;
+      uint64_t stride = start + segment_size <= input_size ? segment_size
+                                                           : input_size - start;
       int i = 0;
       constexpr int VLEN = 8;
       for (; i < stride / VLEN * VLEN; i += VLEN) {
@@ -122,8 +125,8 @@ void quantize_and_compress__avx2(
   } else {
     // !random
     for (int start = 0; start < input_size; start += segment_size) {
-      size_t stride = start + segment_size <= input_size ? segment_size
-                                                         : input_size - start;
+      uint64_t stride = start + segment_size <= input_size ? segment_size
+                                                           : input_size - start;
       int i = 0;
       constexpr int VLEN = 8;
       for (; i < stride / VLEN * VLEN; i += VLEN) {
@@ -165,26 +168,26 @@ void quantize_and_compress__avx2(
 void decompress_and_dequantize__avx2(
     const uint8_t* input_data,
     float* output_data,
-    size_t input_size) {
+    uint64_t input_size) {
   // basic info
   const float minimum_element =
       reinterpret_cast<const float*>(input_data + 2)[0];
   const float maximum_element =
       reinterpret_cast<const float*>(input_data + 2)[1];
-  const size_t bitwidth = input_data[0];
+  const uint64_t bitwidth = input_data[0];
   const float gap =
       (maximum_element - minimum_element) / ((1 << bitwidth) - 1.f) +
       QEPSILON; // for exact recovering
 
-  const size_t tail = input_data[1];
+  const uint64_t tail = input_data[1];
 
-  const size_t output_size = (input_size - 10) * (8 / bitwidth) - tail;
+  const uint64_t output_size = (input_size - 10) * (8 / bitwidth) - tail;
   // decoding
-  size_t bit_start = 0;
-  const size_t segment_size = input_size - 10;
+  uint64_t bit_start = 0;
+  const uint64_t segment_size = input_size - 10;
   for (int start = 0; start < output_size; start += segment_size) {
-    size_t stride = start + segment_size <= output_size ? segment_size
-                                                        : output_size - start;
+    uint64_t stride = start + segment_size <= output_size ? segment_size
+                                                          : output_size - start;
     uint8_t mask = (1 << bitwidth) - 1;
     int i = 0;
     // Can process 8 elements at a time because we need to expand uint8_t
index 1837641..6f2cc2b 100644 (file)
@@ -3,10 +3,15 @@
 // computation library to different compiler options (-mno-avx2 or -mavx2).
 
 #include <cfloat>
+#include <cmath>
+#include <cstdint>
 
 #include "common.h"
 #include "math.h"
 
+using std::uint64_t;
+using std::uint8_t;
+
 namespace caffe2 {
 
 namespace math {
@@ -16,14 +21,14 @@ static constexpr double QEPSILON = 1e-8;
 void quantize_and_compress__base(
     const float* input_data,
     uint8_t* output_data,
-    size_t input_size,
-    size_t bitwidth,
+    uint64_t input_size,
+    uint64_t bitwidth,
     bool random,
     const float* random_buffer) {
-  size_t data_per_byte = 8 / bitwidth;
-  size_t tail = input_size % data_per_byte;
+  uint64_t data_per_byte = 8 / bitwidth;
+  uint64_t tail = input_size % data_per_byte;
   tail = tail ? data_per_byte - tail : 0;
-  size_t segment_size = (input_size + data_per_byte - 1) / data_per_byte;
+  uint64_t segment_size = (input_size + data_per_byte - 1) / data_per_byte;
 
   // basic info
   float minimum_element = INFINITY, maximum_element = -INFINITY;
@@ -41,11 +46,11 @@ void quantize_and_compress__base(
   float gap = (maximum_element - minimum_element) / ((1 << bitwidth) - 1.0f);
   float gap_inverse = 1. / (gap + QEPSILON);
   uint8_t max_q = (1 << bitwidth) - 1;
-  size_t bit_start = 0;
+  uint64_t bit_start = 0;
   if (random) {
     for (int start = 0; start < input_size; start += segment_size) {
-      size_t stride = start + segment_size <= input_size ? segment_size
-                                                         : input_size - start;
+      uint64_t stride = start + segment_size <= input_size ? segment_size
+                                                           : input_size - start;
       int i = 0;
       for (; i < stride; ++i) {
         float fval = input_data[start + i];
@@ -64,8 +69,8 @@ void quantize_and_compress__base(
     }
   } else {
     for (int start = 0; start < input_size; start += segment_size) {
-      size_t stride = start + segment_size <= input_size ? segment_size
-                                                         : input_size - start;
+      uint64_t stride = start + segment_size <= input_size ? segment_size
+                                                           : input_size - start;
       int i = 0;
       for (; i < stride; ++i) {
         float fval = input_data[start + i];
@@ -84,11 +89,12 @@ void quantize_and_compress__base(
   }
 }
 
+decltype(quantize_and_compress__base) quantize_and_compress__avx2;
 void quantize_and_compress(
     const float* input_data,
     uint8_t* output_data,
-    size_t input_size,
-    size_t bitwidth,
+    uint64_t input_size,
+    uint64_t bitwidth,
     bool random,
     const float* random_buffer) {
   AVX2_DO(
@@ -112,26 +118,26 @@ void quantize_and_compress(
 void decompress_and_dequantize__base(
     const uint8_t* input_data,
     float* output_data,
-    size_t input_size) {
+    uint64_t input_size) {
   // basic info
   const float minimum_element =
       reinterpret_cast<const float*>(input_data + 2)[0];
   const float maximum_element =
       reinterpret_cast<const float*>(input_data + 2)[1];
-  const size_t bitwidth = input_data[0];
+  const uint64_t bitwidth = input_data[0];
   const float gap =
       (maximum_element - minimum_element) / ((1 << bitwidth) - 1.f) +
       QEPSILON; // for exact recovering
 
-  const size_t tail = input_data[1];
+  const uint64_t tail = input_data[1];
 
-  const size_t output_size = (input_size - 10) * (8 / bitwidth) - tail;
+  const uint64_t output_size = (input_size - 10) * (8 / bitwidth) - tail;
   // decoding
-  size_t bit_start = 0;
-  const size_t segment_size = input_size - 10;
+  uint64_t bit_start = 0;
+  const uint64_t segment_size = input_size - 10;
   for (int start = 0; start < output_size; start += segment_size) {
-    size_t stride = start + segment_size <= output_size ? segment_size
-                                                        : output_size - start;
+    uint64_t stride = start + segment_size <= output_size ? segment_size
+                                                          : output_size - start;
     uint8_t mask = (1 << bitwidth) - 1;
     int i = 0;
     for (; i < stride; ++i) {
@@ -142,10 +148,11 @@ void decompress_and_dequantize__base(
   }
 }
 
+decltype(decompress_and_dequantize__base) decompress_and_dequantize__avx2;
 void decompress_and_dequantize(
     const uint8_t* input_data,
     float* output_data,
-    size_t input_size) {
+    uint64_t input_size) {
   AVX2_DO(decompress_and_dequantize, input_data, output_data, input_size);
   BASE_DO(decompress_and_dequantize, input_data, output_data, input_size);
 }
index 8bcbc06..2ca219a 100644 (file)
@@ -36,6 +36,8 @@ void TypedAxpyHalffloat__base(
   }
 }
 
+decltype(TypedAxpyHalffloat__base) TypedAxpyHalffloat__avx2_fma;
+decltype(TypedAxpyHalffloat__base) TypedAxpyHalffloat__avx_f16c;
 template <>
 void TypedAxpy<at::Half, float>(
     int N,
@@ -57,6 +59,8 @@ void TypedAxpy_uint8_float__base(
   }
 }
 
+decltype(TypedAxpy_uint8_float__base) TypedAxpy_uint8_float__avx2_fma;
+decltype(TypedAxpy_uint8_float__base) TypedAxpy_uint8_float__avx_f16c;
 template <>
 void TypedAxpy<std::uint8_t, float>(
     int N,
index a49e597..2e2add2 100644 (file)
@@ -321,7 +321,7 @@ if(USE_FBGEMM)
   if(NOT DEFINED FBGEMM_SOURCE_DIR)
     set(FBGEMM_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/fbgemm" CACHE STRING "FBGEMM source directory")
   endif()
-  if(NOT CAFFE2_COMPILER_SUPPORTS_AVX512F_EXTENSIONS)
+  if(NOT CAFFE2_COMPILER_SUPPORTS_AVX512_EXTENSIONS)
     message(WARNING
       "A compiler with AVX512 support is required for FBGEMM. "
       "Not compiling with FBGEMM. "
index f35502d..0d2e61c 100644 (file)
@@ -171,41 +171,52 @@ CHECK_CXX_SOURCE_COMPILES(
      }" CAFFE2_COMPILER_SUPPORTS_AVX2_EXTENSIONS)
 if (CAFFE2_COMPILER_SUPPORTS_AVX2_EXTENSIONS)
   message(STATUS "Current compiler supports avx2 extension. Will build perfkernels.")
-  # Currently MSVC seems to have a symbol not found error while linking (related
-  # to source file order?). As a result we will currently disable the perfkernel
-  # in msvc.
   # Also see CMakeLists.txt under caffe2/perfkernels.
-  if (NOT MSVC)
-    set(CAFFE2_PERF_WITH_AVX 1)
-    set(CAFFE2_PERF_WITH_AVX2 1)
-  endif()
+  set(CAFFE2_PERF_WITH_AVX 1)
+  set(CAFFE2_PERF_WITH_AVX2 1)
 endif()
 cmake_pop_check_state()
 
-# ---[ Check if the compiler has AVX512F support.
+# ---[ Check if the compiler has AVX512 support.
 cmake_push_check_state(RESET)
 if (MSVC)
-  set(CMAKE_REQUIRED_FLAGS "/D__AVX512F__")
+  # We could've used MSVC's hidden option /arch:AVX512 that defines __AVX512F__,
+  # __AVX512DQ__, and __AVX512VL__, and /arch:AVX512F that defines __AVX512F__.
+  # But, we chose not to do that not to rely on hidden options.
+  set(CMAKE_REQUIRED_FLAGS "/D__AVX512F__ /D__AVX512DQ__ /D__AVX512VL__")
 else()
-  set(CMAKE_REQUIRED_FLAGS "-mavx512f")
+  # We only consider the case where all of avx512f, avx512dq, and avx512vl are
+  # supported.
+  # Platforms where avx512f is supported by not avx512dq and avx512vl as of
+  # Jan 15 2019 : linux_manywheel_2.7mu_cpu_build and
+  # linux_conda_3.7_cu100_build
+  set(CMAKE_REQUIRED_FLAGS "-mavx512f -mavx512dq -mavx512vl")
 endif()
 CHECK_CXX_SOURCE_COMPILES(
     "#if defined(_MSC_VER)
      #include <intrin.h>
      #else
-     #include <x86intrin.h>
+     #include <immintrin.h>
      #endif
+     // check avx512f
      __m512 addConstant(__m512 arg) {
        return _mm512_add_ps(arg, _mm512_set1_ps(1.f));
      }
+     // check avx512dq
+     __m512 andConstant(__m512 arg) {
+       return _mm512_and_ps(arg, _mm512_set1_ps(1.f));
+     }
      int main() {
        __m512i a = _mm512_set1_epi32(1);
        __m256i ymm = _mm512_extracti64x4_epi64(a, 0);
+       ymm = _mm256_abs_epi64(ymm); // check avx512vl
        __mmask16 m = _mm512_cmp_epi32_mask(a, a, _MM_CMPINT_EQ);
        __m512i r = _mm512_andnot_si512(a, a);
-       }" CAFFE2_COMPILER_SUPPORTS_AVX512F_EXTENSIONS)
-if (CAFFE2_COMPILER_SUPPORTS_AVX512F_EXTENSIONS)
+     }" CAFFE2_COMPILER_SUPPORTS_AVX512_EXTENSIONS)
+if (CAFFE2_COMPILER_SUPPORTS_AVX512_EXTENSIONS)
   message(STATUS "Current compiler supports avx512f extension. Will build fbgemm.")
+  # Also see CMakeLists.txt under caffe2/perfkernels.
+  set(CAFFE2_PERF_WITH_AVX512 1)
 endif()
 cmake_pop_check_state()