#cmakedefine CAFFE2_HAS_MKL_SGEMM_PACK
#cmakedefine CAFFE2_PERF_WITH_AVX
#cmakedefine CAFFE2_PERF_WITH_AVX2
+#cmakedefine CAFFE2_PERF_WITH_AVX512
#cmakedefine CAFFE2_THREADPOOL_MAIN_IMBALANCE
#cmakedefine CAFFE2_THREADPOOL_STATS
#cmakedefine CAFFE2_USE_EXCEPTION_PTR
{"HAS_MKL_SGEMM_PACK", "${CAFFE2_HAS_MKL_SGEMM_PACK}"}, \
{"PERF_WITH_AVX", "${CAFFE2_PERF_WITH_AVX}"}, \
{"PERF_WITH_AVX2", "${CAFFE2_PERF_WITH_AVX2}"}, \
+ {"PERF_WITH_AVX512", "${CAFFE2_PERF_WITH_AVX512}"}, \
{"USE_EXCEPTION_PTR", "${CAFFE2_USE_EXCEPTION_PTR}"}, \
{"USE_ACCELERATE", "${CAFFE2_USE_ACCELERATE}"}, \
{"USE_EIGEN_FOR_BLAS", "${CAFFE2_USE_EIGEN_FOR_BLAS}"}, \
file(GLOB common_srcs *.cc)
file(GLOB avx_srcs *_avx.cc)
file(GLOB avx2_srcs *_avx2.cc)
-# exclude avx and avx2 srcs from common_srcs
+file(GLOB avx512_srcs *_avx512.cc)
+# exclude avx, avx2, and avx512 srcs from common_srcs
exclude(common_srcs "${common_srcs}" ${avx_srcs})
exclude(common_srcs "${common_srcs}" ${avx2_srcs})
+exclude(common_srcs "${common_srcs}" ${avx512_srcs})
# We will always build common srcs.
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${common_srcs})
Caffe2_perfkernels_avx PROPERTIES COMPILE_FLAGS "/arch:AVX")
set_target_properties(
Caffe2_perfkernels_avx2 PROPERTIES COMPILE_FLAGS "/arch:AVX2")
+ # Currently MSVC doesn't support AVX512
else()
set_target_properties(
Caffe2_perfkernels_avx PROPERTIES COMPILE_FLAGS "-mavx -mf16c")
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS}
$<TARGET_OBJECTS:Caffe2_perfkernels_avx>
$<TARGET_OBJECTS:Caffe2_perfkernels_avx2>)
+
+ if (CAFFE2_COMPILER_SUPPORTS_AVX512_EXTENSIONS)
+ add_library(Caffe2_perfkernels_avx512 OBJECT ${avx512_srcs})
+ add_dependencies(Caffe2_perfkernels_avx512 Caffe2_PROTO c10)
+ set_target_properties(
+ Caffe2_perfkernels_avx512 PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512dq -mavx2 -mfma -mavx -mf16c")
+ set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS}
+ $<TARGET_OBJECTS:Caffe2_perfkernels_avx512>)
+ endif()
endif()
# TODO(jiayq): currently, we only implement the very base files for the
In foo.h, do:
void foo(int a, float b);
+In foo_avx512.cc, do:
+ void foo__avx512(int a, float b) {
+ [actual avx512 implementation]
+ }
+
In foo_avx2.cc, do:
void foo__avx2(int a, float b) {
[actual avx2 implementation]
void foo(int a, float b) {
// You should always order things by their preference, faster
// implementations earlier in the function.
+ AVX512_DO(foo, a, b);
AVX2_DO(foo, a, b);
AVX_DO(foo, a, b);
BASE_DO(foo, a, b);
// and run time architecture support.
//
// During build time:
-// The build system should provide flags CAFFE2_PERF_WITH_AVX2 and
-// CAFFE2_PERF_WITH_AVX that corresponds to the __AVX__ and __AVX2__ flags
-// the compiler provides. Note that we do not use the compiler flags but
-// rely on the build system flags, because the common files (like foo.cc
-// above) will always be built without __AVX__ and __AVX2__.
+// The build system should provide flags CAFFE2_PERF_WITH_AVX512,
+// CAFFE2_PERF_WITH_AVX2, and CAFFE2_PERF_WITH_AVX that corresponds to the
+// __AVX512F__, __AVX512DQ__, __AVX__, and __AVX2__ flags the compiler
+// provides. Note that we do not use the compiler flags but rely on the build
+// system flags, because the common files (like foo.cc above) will always be
+// built without __AVX512F__, __AVX512DQ__, __AVX__ and __AVX2__.
// During run time:
// we use cpuid to identify cpu support and run the proper functions.
#define BASE_DO(funcname, ...) return funcname##__base(__VA_ARGS__);
+#ifdef CAFFE2_PERF_WITH_AVX512
+#define AVX512_DO(funcname, ...) \
+ decltype(funcname##__base) funcname##__avx512; \
+ if (GetCpuId().avx512f() && GetCpuId().avx512dq()) { \
+ return funcname##__avx512(__VA_ARGS__); \
+ }
+#else // CAFFE2_PERF_WITH_AVX512
+#define AVX512_DO(funcname, ...)
+#endif // CAFFE2_PERF_WITH_AVX512
+
#ifdef CAFFE2_PERF_WITH_AVX2
#define AVX2_DO(funcname, ...) \
decltype(funcname##__base) funcname##__avx2; \
--- /dev/null
+// This file is here merely to check that the flags are not mixed up: for
+// example, if your compiler did not specify -mavx512f and -mavx512dq,
+// you should not provide the CAFFE2_PERF_WITH_AVX512 macro.
+
+#include "caffe2/core/common.h"
+
+#ifdef CAFFE2_PERF_WITH_AVX512
+#if !defined(__AVX512F__) || !defined(__AVX512DQ__)
+#error( \
+ "You found a build system error: CAFFE2_PERF_WITH_AVX512 is defined" \
+ "but __AVX512F__ or __AVX512DQ__ is not defined" \
+ "(via e.g. -mavx512f and -mavx512dq).");
+#endif
+#endif // CAFFE2_PERF_WITH_AVX512
+
+#if defined(__AVX512F__) && defined(__AVX512DQ__)
+#ifndef CAFFE2_PERF_WITH_AVX512
+#error( \
+ "You found a build system error: __AVX512F__ and __AVX512DQ__ is defined" \
+ "(via e.g. -mavx512f and -mavx512dq) " \
+ "but CAFFE2_PERF_WITH_AVX512 is not defined.");
+#endif // CAFFE2_PERF_WITH_AVX512
+#endif
// Proxy back to generic implementation
#define EMBEDDING_SPECIALIZATION( \
- IndexTypeName, IndexType, InTypeName, InType, OutTypeName, OutType, IS_WEIGHT_POSITIONAL) \
+ IndexTypeName, \
+ IndexType, \
+ InTypeName, \
+ InType, \
+ OutTypeName, \
+ OutType, \
+ IS_WEIGHT_POSITIONAL) \
void \
EmbeddingLookup_##IndexTypeName##_##InTypeName##_##OutTypeName##_##IS_WEIGHT_POSITIONAL##__base( \
- const int64_t block_size, \
- const int64_t output_size, \
- const int64_t index_size, \
- const int64_t data_size, \
+ const int64_t block_size, \
+ const int64_t output_size, \
+ const int64_t index_size, \
+ const int64_t data_size, \
const InType* input, \
const IndexType* indices, \
const int* lengths, \
} \
template <> \
void EmbeddingLookup<IndexType, InType, OutType, IS_WEIGHT_POSITIONAL>( \
- const int64_t block_size, \
- const int64_t output_size, \
- const int64_t index_size, \
- const int64_t data_size, \
+ const int64_t block_size, \
+ const int64_t output_size, \
+ const int64_t index_size, \
+ const int64_t data_size, \
const InType* input, \
const IndexType* indices, \
const int* lengths, \
EMBEDDING_SPECIALIZATION(int64_t, int64_t, float, float, float, float, false);
EMBEDDING_SPECIALIZATION(int32_t, int32_t, half, at::Half, float, float, false);
EMBEDDING_SPECIALIZATION(int64_t, int64_t, half, at::Half, float, float, false);
-EMBEDDING_SPECIALIZATION(int32_t, int32_t, uint8_t, uint8_t, float, float, false);
-EMBEDDING_SPECIALIZATION(int64_t, int64_t, uint8_t, uint8_t, float, float, false);
+EMBEDDING_SPECIALIZATION(
+ int32_t,
+ int32_t,
+ uint8_t,
+ uint8_t,
+ float,
+ float,
+ false);
+EMBEDDING_SPECIALIZATION(
+ int64_t,
+ int64_t,
+ uint8_t,
+ uint8_t,
+ float,
+ float,
+ false);
EMBEDDING_SPECIALIZATION(int32_t, int32_t, float, float, float, float, true);
EMBEDDING_SPECIALIZATION(int64_t, int64_t, float, float, float, float, true);
EMBEDDING_SPECIALIZATION(int32_t, int32_t, half, at::Half, float, float, true);
EMBEDDING_SPECIALIZATION(int64_t, int64_t, half, at::Half, float, float, true);
-EMBEDDING_SPECIALIZATION(int32_t, int32_t, uint8_t, uint8_t, float, float, true);
-EMBEDDING_SPECIALIZATION(int64_t, int64_t, uint8_t, uint8_t, float, float, true);
+EMBEDDING_SPECIALIZATION(
+ int32_t,
+ int32_t,
+ uint8_t,
+ uint8_t,
+ float,
+ float,
+ true);
+EMBEDDING_SPECIALIZATION(
+ int64_t,
+ int64_t,
+ uint8_t,
+ uint8_t,
+ float,
+ float,
+ true);
#undef EMBEDDING_SPECIALIZATION