add avx512 option (but no avx512 kernel yet) (#14664)
authorJongsoo Park <jongsoo@fb.com>
Mon, 3 Dec 2018 20:14:47 +0000 (12:14 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 3 Dec 2018 20:18:19 +0000 (12:18 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14664

This diff just adds a framework to add avx512 kernels.
Please be really really careful about using avx512 kernels unless you're convinced using avx512 will bring good enough *overall* speedups because it can backfire because of cpu frequency going down.

Reviewed By: duc0

Differential Revision: D13281944

fbshipit-source-id: 04fce8619c63f814944b727a99fbd7d35538eac6

caffe2/core/macros.h.in
caffe2/perfkernels/CMakeLists.txt
caffe2/perfkernels/common.h
caffe2/perfkernels/common_avx512.cc [new file with mode: 0644]
caffe2/perfkernels/embedding_lookup.cc

index 71c9be5..f8a5b4e 100644 (file)
@@ -30,6 +30,7 @@ static_assert(
 #cmakedefine CAFFE2_HAS_MKL_SGEMM_PACK
 #cmakedefine CAFFE2_PERF_WITH_AVX
 #cmakedefine CAFFE2_PERF_WITH_AVX2
+#cmakedefine CAFFE2_PERF_WITH_AVX512
 #cmakedefine CAFFE2_THREADPOOL_MAIN_IMBALANCE
 #cmakedefine CAFFE2_THREADPOOL_STATS
 #cmakedefine CAFFE2_USE_EXCEPTION_PTR
@@ -71,6 +72,7 @@ static_assert(
   {"HAS_MKL_SGEMM_PACK", "${CAFFE2_HAS_MKL_SGEMM_PACK}"}, \
   {"PERF_WITH_AVX", "${CAFFE2_PERF_WITH_AVX}"}, \
   {"PERF_WITH_AVX2", "${CAFFE2_PERF_WITH_AVX2}"}, \
+  {"PERF_WITH_AVX512", "${CAFFE2_PERF_WITH_AVX512}"}, \
   {"USE_EXCEPTION_PTR", "${CAFFE2_USE_EXCEPTION_PTR}"}, \
   {"USE_ACCELERATE", "${CAFFE2_USE_ACCELERATE}"}, \
   {"USE_EIGEN_FOR_BLAS", "${CAFFE2_USE_EIGEN_FOR_BLAS}"}, \
index a5701da..18ae10d 100644 (file)
@@ -2,9 +2,11 @@
 file(GLOB common_srcs *.cc)
 file(GLOB avx_srcs *_avx.cc)
 file(GLOB avx2_srcs *_avx2.cc)
-# exclude avx and avx2 srcs from common_srcs
+file(GLOB avx512_srcs *_avx512.cc)
+# exclude avx, avx2, and avx512 srcs from common_srcs
 exclude(common_srcs "${common_srcs}" ${avx_srcs})
 exclude(common_srcs "${common_srcs}" ${avx2_srcs})
+exclude(common_srcs "${common_srcs}" ${avx512_srcs})
 
 # We will always build common srcs.
 set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${common_srcs})
@@ -24,6 +26,7 @@ if (NOT MSVC AND CAFFE2_COMPILER_SUPPORTS_AVX2_EXTENSIONS)
         Caffe2_perfkernels_avx PROPERTIES COMPILE_FLAGS "/arch:AVX")
     set_target_properties(
         Caffe2_perfkernels_avx2 PROPERTIES COMPILE_FLAGS "/arch:AVX2")
+    # Currently MSVC doesn't support AVX512
   else()
     set_target_properties(
         Caffe2_perfkernels_avx PROPERTIES COMPILE_FLAGS "-mavx -mf16c")
@@ -33,6 +36,15 @@ if (NOT MSVC AND CAFFE2_COMPILER_SUPPORTS_AVX2_EXTENSIONS)
   set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS}
       $<TARGET_OBJECTS:Caffe2_perfkernels_avx>
       $<TARGET_OBJECTS:Caffe2_perfkernels_avx2>)
+
+  if (CAFFE2_COMPILER_SUPPORTS_AVX512_EXTENSIONS)
+      add_library(Caffe2_perfkernels_avx512 OBJECT ${avx512_srcs})
+      add_dependencies(Caffe2_perfkernels_avx512 Caffe2_PROTO c10)
+      set_target_properties(
+          Caffe2_perfkernels_avx512 PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512dq -mavx2 -mfma -mavx -mf16c")
+      set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS}
+          $<TARGET_OBJECTS:Caffe2_perfkernels_avx512>)
+  endif()
 endif()
 
 # TODO(jiayq): currently, we only implement the very base files for the
index 70d5945..7fdf64f 100644 (file)
@@ -7,6 +7,11 @@ implement a functionality called void foo(int a, float b).
 In foo.h, do:
    void foo(int a, float b);
 
+In foo_avx512.cc, do:
+   void foo__avx512(int a, float b) {
+     [actual avx512 implementation]
+   }
+
 In foo_avx2.cc, do:
    void foo__avx2(int a, float b) {
      [actual avx2 implementation]
@@ -25,6 +30,7 @@ In foo.cc, do:
    void foo(int a, float b) {
      // You should always order things by their preference, faster
      // implementations earlier in the function.
+     AVX512_DO(foo, a, b);
      AVX2_DO(foo, a, b);
      AVX_DO(foo, a, b);
      BASE_DO(foo, a, b);
@@ -35,11 +41,12 @@ In foo.cc, do:
 // and run time architecture support.
 //
 // During build time:
-//    The build system should provide flags CAFFE2_PERF_WITH_AVX2 and
-//    CAFFE2_PERF_WITH_AVX that corresponds to the __AVX__ and __AVX2__ flags
-//    the compiler provides. Note that we do not use the compiler flags but
-//    rely on the build system flags, because the common files (like foo.cc
-//    above) will always be built without __AVX__ and __AVX2__.
+//    The build system should provide flags CAFFE2_PERF_WITH_AVX512,
+//    CAFFE2_PERF_WITH_AVX2, and CAFFE2_PERF_WITH_AVX that corresponds to the
+//    __AVX512F__, __AVX512DQ__, __AVX__, and __AVX2__ flags the compiler
+//    provides. Note that we do not use the compiler flags but rely on the build
+//    system flags, because the common files (like foo.cc above) will always be
+//    built without __AVX512F__, __AVX512DQ__, __AVX__ and __AVX2__.
 // During run time:
 //    we use cpuid to identify cpu support and run the proper functions.
 
@@ -52,6 +59,16 @@ In foo.cc, do:
 
 #define BASE_DO(funcname, ...) return funcname##__base(__VA_ARGS__);
 
+#ifdef CAFFE2_PERF_WITH_AVX512
+#define AVX512_DO(funcname, ...)                       \
+  decltype(funcname##__base) funcname##__avx512;       \
+  if (GetCpuId().avx512f() && GetCpuId().avx512dq()) { \
+    return funcname##__avx512(__VA_ARGS__);            \
+  }
+#else // CAFFE2_PERF_WITH_AVX512
+#define AVX512_DO(funcname, ...)
+#endif // CAFFE2_PERF_WITH_AVX512
+
 #ifdef CAFFE2_PERF_WITH_AVX2
 #define AVX2_DO(funcname, ...)                 \
   decltype(funcname##__base) funcname##__avx2; \
diff --git a/caffe2/perfkernels/common_avx512.cc b/caffe2/perfkernels/common_avx512.cc
new file mode 100644 (file)
index 0000000..055f95d
--- /dev/null
@@ -0,0 +1,23 @@
+// This file is here merely to check that the flags are not mixed up: for
+// example, if your compiler did not specify -mavx512f and -mavx512dq,
+// you should not provide the CAFFE2_PERF_WITH_AVX512 macro.
+
+#include "caffe2/core/common.h"
+
+#ifdef CAFFE2_PERF_WITH_AVX512
+#if !defined(__AVX512F__) || !defined(__AVX512DQ__)
+#error( \
+    "You found a build system error: CAFFE2_PERF_WITH_AVX512 is defined" \
+    "but __AVX512F__ or __AVX512DQ__ is not defined" \
+    "(via e.g. -mavx512f and -mavx512dq).");
+#endif
+#endif // CAFFE2_PERF_WITH_AVX512
+
+#if defined(__AVX512F__) && defined(__AVX512DQ__)
+#ifndef CAFFE2_PERF_WITH_AVX512
+#error( \
+    "You found a build system error: __AVX512F__ and __AVX512DQ__ is defined" \
+    "(via e.g. -mavx512f and -mavx512dq) " \
+    "but CAFFE2_PERF_WITH_AVX512 is not defined.");
+#endif // CAFFE2_PERF_WITH_AVX512
+#endif
index e98bc51..fa93ae7 100644 (file)
@@ -82,13 +82,19 @@ static void EmbeddingLookupGenericSlow(
 
 // Proxy back to generic implementation
 #define EMBEDDING_SPECIALIZATION(                                                                      \
-    IndexTypeName, IndexType, InTypeName, InType, OutTypeName, OutType, IS_WEIGHT_POSITIONAL)          \
+    IndexTypeName,                                                                                     \
+    IndexType,                                                                                         \
+    InTypeName,                                                                                        \
+    InType,                                                                                            \
+    OutTypeName,                                                                                       \
+    OutType,                                                                                           \
+    IS_WEIGHT_POSITIONAL)                                                                              \
   void                                                                                                 \
       EmbeddingLookup_##IndexTypeName##_##InTypeName##_##OutTypeName##_##IS_WEIGHT_POSITIONAL##__base( \
-          const int64_t block_size,                                                                     \
-          const int64_t output_size,                                                                    \
-          const int64_t index_size,                                                                     \
-          const int64_t data_size,                                                                      \
+          const int64_t block_size,                                                                    \
+          const int64_t output_size,                                                                   \
+          const int64_t index_size,                                                                    \
+          const int64_t data_size,                                                                     \
           const InType* input,                                                                         \
           const IndexType* indices,                                                                    \
           const int* lengths,                                                                          \
@@ -115,10 +121,10 @@ static void EmbeddingLookupGenericSlow(
   }                                                                                                    \
   template <>                                                                                          \
   void EmbeddingLookup<IndexType, InType, OutType, IS_WEIGHT_POSITIONAL>(                              \
-      const int64_t block_size,                                                                         \
-      const int64_t output_size,                                                                        \
-      const int64_t index_size,                                                                         \
-      const int64_t data_size,                                                                          \
+      const int64_t block_size,                                                                        \
+      const int64_t output_size,                                                                       \
+      const int64_t index_size,                                                                        \
+      const int64_t data_size,                                                                         \
       const InType* input,                                                                             \
       const IndexType* indices,                                                                        \
       const int* lengths,                                                                              \
@@ -158,15 +164,43 @@ EMBEDDING_SPECIALIZATION(int32_t, int32_t, float, float, float, float, false);
 EMBEDDING_SPECIALIZATION(int64_t, int64_t, float, float, float, float, false);
 EMBEDDING_SPECIALIZATION(int32_t, int32_t, half, at::Half, float, float, false);
 EMBEDDING_SPECIALIZATION(int64_t, int64_t, half, at::Half, float, float, false);
-EMBEDDING_SPECIALIZATION(int32_t, int32_t, uint8_t, uint8_t, float, float, false);
-EMBEDDING_SPECIALIZATION(int64_t, int64_t, uint8_t, uint8_t, float, float, false);
+EMBEDDING_SPECIALIZATION(
+    int32_t,
+    int32_t,
+    uint8_t,
+    uint8_t,
+    float,
+    float,
+    false);
+EMBEDDING_SPECIALIZATION(
+    int64_t,
+    int64_t,
+    uint8_t,
+    uint8_t,
+    float,
+    float,
+    false);
 
 EMBEDDING_SPECIALIZATION(int32_t, int32_t, float, float, float, float, true);
 EMBEDDING_SPECIALIZATION(int64_t, int64_t, float, float, float, float, true);
 EMBEDDING_SPECIALIZATION(int32_t, int32_t, half, at::Half, float, float, true);
 EMBEDDING_SPECIALIZATION(int64_t, int64_t, half, at::Half, float, float, true);
-EMBEDDING_SPECIALIZATION(int32_t, int32_t, uint8_t, uint8_t, float, float, true);
-EMBEDDING_SPECIALIZATION(int64_t, int64_t, uint8_t, uint8_t, float, float, true);
+EMBEDDING_SPECIALIZATION(
+    int32_t,
+    int32_t,
+    uint8_t,
+    uint8_t,
+    float,
+    float,
+    true);
+EMBEDDING_SPECIALIZATION(
+    int64_t,
+    int64_t,
+    uint8_t,
+    uint8_t,
+    float,
+    float,
+    true);
 
 #undef EMBEDDING_SPECIALIZATION