From 55511004d17bd3e0e36e88efa6abdc9a5a03dec1 Mon Sep 17 00:00:00 2001 From: Tongliang Liao Date: Wed, 16 Jan 2019 21:38:13 -0800 Subject: [PATCH] Resolve errors in perfkernel for Windows (#16031) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16031 1. MSVC only has _mm_prefetch(const char*, int). Fixed in both python codegen and C++ files. 2. uint32_t in "cvtsh_ss_bugfix.h" requires "#include ". 3. Some files use gflags headers. Add dependency via c10. 4. Isolate arch flags with interface library and private compile options. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15753 Reviewed By: dskhudia Differential Revision: D13636233 Pulled By: jspark1105 fbshipit-source-id: cdcbd4240e07b749554a2a5676c11af88f23c31d --- caffe2/perfkernels/CMakeLists.txt | 87 +++++++--- caffe2/perfkernels/adagrad.cc | 11 +- caffe2/perfkernels/adagrad.h | 20 +-- caffe2/perfkernels/adagrad_avx.cc | 12 +- caffe2/perfkernels/common.h | 38 ++--- caffe2/perfkernels/cvtsh_ss_bugfix.h | 9 +- caffe2/perfkernels/embedding_lookup.cc | 157 ++++++++--------- caffe2/perfkernels/embedding_lookup_avx2.cc | 186 ++++++++++++++------- .../embedding_lookup_fused_8bit_rowwise_avx2.cc | 186 ++++++++++++++------- .../fused_8bit_rowwise_embedding_lookup.cc | 145 ++++++++-------- caffe2/perfkernels/hp_emblookup_codegen.py | 9 +- caffe2/perfkernels/math.h | 6 +- caffe2/perfkernels/math_cpu_avx2.cc | 39 +++-- caffe2/perfkernels/math_cpu_base.cc | 49 +++--- caffe2/perfkernels/typed_axpy.cc | 4 + cmake/Dependencies.cmake | 2 +- cmake/MiscCheck.cmake | 37 ++-- 17 files changed, 599 insertions(+), 398 deletions(-) diff --git a/caffe2/perfkernels/CMakeLists.txt b/caffe2/perfkernels/CMakeLists.txt index f2e2f86..3ca9ae5 100644 --- a/caffe2/perfkernels/CMakeLists.txt +++ b/caffe2/perfkernels/CMakeLists.txt @@ -13,37 +13,69 @@ set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${common_srcs}) # We will only build the perf kernel files if the compiler supports avx2 # extensions. -# Currently MSVC seems to have a symbol not found error while linking (related -# to source file order?). As a result we will currently disable the perfkernel -# in msvc. -if (NOT MSVC AND CAFFE2_COMPILER_SUPPORTS_AVX2_EXTENSIONS) - add_library(Caffe2_perfkernels_avx OBJECT ${avx_srcs}) - add_library(Caffe2_perfkernels_avx2 OBJECT ${avx2_srcs}) - add_dependencies(Caffe2_perfkernels_avx Caffe2_PROTO c10) - add_dependencies(Caffe2_perfkernels_avx2 Caffe2_PROTO c10) +if (CAFFE2_COMPILER_SUPPORTS_AVX2_EXTENSIONS) + add_library(Caffe2_perfkernels_avx STATIC ${avx_srcs}) + add_library(Caffe2_perfkernels_avx2 STATIC ${avx2_srcs}) + add_dependencies(Caffe2_perfkernels_avx Caffe2_PROTO) + add_dependencies(Caffe2_perfkernels_avx2 Caffe2_PROTO) + target_link_libraries(Caffe2_perfkernels_avx PRIVATE c10) + target_link_libraries(Caffe2_perfkernels_avx2 PRIVATE c10) if (MSVC) - set_target_properties( - Caffe2_perfkernels_avx PROPERTIES COMPILE_FLAGS "/arch:AVX") - set_target_properties( - Caffe2_perfkernels_avx2 PROPERTIES COMPILE_FLAGS "/arch:AVX2") - # Currently MSVC doesn't support AVX512 + target_compile_options(Caffe2_perfkernels_avx + PRIVATE "/arch:AVX" + PRIVATE "/D__F16C__") + target_compile_options(Caffe2_perfkernels_avx2 + PRIVATE "/arch:AVX2" + PRIVATE "/D__FMA__" + PRIVATE "/D__F16C__") else() - set_target_properties( - Caffe2_perfkernels_avx PROPERTIES COMPILE_FLAGS "-mavx -mf16c") - set_target_properties( - Caffe2_perfkernels_avx2 PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mavx -mf16c") + target_compile_options(Caffe2_perfkernels_avx + PRIVATE "-mavx" + PRIVATE "-mf16c") + target_compile_options(Caffe2_perfkernels_avx2 + PRIVATE "-mavx2" + PRIVATE "-mfma" + PRIVATE "-mavx" + PRIVATE "-mf16c") endif() - set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} - $ - $) + caffe2_interface_library( + Caffe2_perfkernels_avx Caffe2_perfkernels_avx_interface) + caffe2_interface_library( + Caffe2_perfkernels_avx2 Caffe2_perfkernels_avx2_interface) + list(APPEND + Caffe2_DEPENDENCY_WHOLE_LINK_LIBS + "Caffe2_perfkernels_avx_interface") + list(APPEND + Caffe2_DEPENDENCY_WHOLE_LINK_LIBS + "Caffe2_perfkernels_avx2_interface") if (CAFFE2_COMPILER_SUPPORTS_AVX512_EXTENSIONS) - add_library(Caffe2_perfkernels_avx512 OBJECT ${avx512_srcs}) - add_dependencies(Caffe2_perfkernels_avx512 Caffe2_PROTO c10) - set_target_properties( - Caffe2_perfkernels_avx512 PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512dq -mavx512vl -mavx2 -mfma -mavx -mf16c") - set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} - $) + add_library(Caffe2_perfkernels_avx512 STATIC ${avx512_srcs}) + add_dependencies(Caffe2_perfkernels_avx512 Caffe2_PROTO) + target_link_libraries(Caffe2_perfkernels_avx512 PRIVATE c10) + if (MSVC) + target_compile_options(Caffe2_perfkernels_avx512 + PRIVATE "/D__AVX512F__" + PRIVATE "/D__AVX512DQ__" + PRIVATE "/D__AVX512VL__" + PRIVATE "/arch:AVX2" + PRIVATE "/D__FMA__" + PRIVATE "/D__F16C__") + else() + target_compile_options(Caffe2_perfkernels_avx512 + PRIVATE "-mavx512f" + PRIVATE "-mavx512dq" + PRIVATE "-mavx512vl" + PRIVATE "-mavx2" + PRIVATE "-mfma" + PRIVATE "-mavx" + PRIVATE "-mf16c") + endif() + caffe2_interface_library( + Caffe2_perfkernels_avx512 Caffe2_perfkernels_avx512_interface) + list(APPEND + Caffe2_DEPENDENCY_WHOLE_LINK_LIBS + "Caffe2_perfkernels_avx512_interface") endif() endif() @@ -54,3 +86,6 @@ endif() # more proper implementation. set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE) +set(Caffe2_DEPENDENCY_WHOLE_LINK_LIBS + ${Caffe2_DEPENDENCY_WHOLE_LINK_LIBS} + PARENT_SCOPE) diff --git a/caffe2/perfkernels/adagrad.cc b/caffe2/perfkernels/adagrad.cc index 0d6e25e..2c65616 100644 --- a/caffe2/perfkernels/adagrad.cc +++ b/caffe2/perfkernels/adagrad.cc @@ -71,6 +71,7 @@ void rowwise_adagrad_update__base( internal::rowwise_adagrad_update_inlined(N, w, w_n, g, h, h_n, epsilon, lr); } +decltype(adagrad_update_prefetch__base) adagrad_update_prefetch__avx_f16c; void adagrad_update_prefetch( int N, const float* w, @@ -121,6 +122,8 @@ void adagrad_update_prefetch( // Version with prefetching for embeddings and // momentum using fp16 +decltype( + adagrad_fp16_update_prefetch__base) adagrad_fp16_update_prefetch__avx_f16c; void adagrad_fp16_update_prefetch( int N, const at::Half* w, @@ -164,6 +167,7 @@ void adagrad_fp16_update_prefetch( lr); } +decltype(rowwise_adagrad_update__base) rowwise_adagrad_update__avx_f16c; void rowwise_adagrad_update( int N, float* w, @@ -181,6 +185,7 @@ void rowwise_adagrad_update( } // version without prefetching +decltype(adagrad_update__base) adagrad_update__avx_f16c; void adagrad_update( int N, const float* w, @@ -197,11 +202,12 @@ void adagrad_update( SPARSE_ADAGRAD_SPECIALIZATION(int32_t, base); +decltype(sparse_adagrad_int32_t__base) sparse_adagrad_int32_t__avx_f16c; template <> void sparse_adagrad( int num_rows, int block_size, - size_t param_size, + uint64_t param_size, const float* w, const float* g, const float* h, @@ -243,11 +249,12 @@ void sparse_adagrad( SPARSE_ADAGRAD_SPECIALIZATION(int64_t, base); +decltype(sparse_adagrad_int64_t__base) sparse_adagrad_int64_t__avx_f16c; template <> void sparse_adagrad( int num_rows, int block_size, - size_t param_size, + uint64_t param_size, const float* w, const float* g, const float* h, diff --git a/caffe2/perfkernels/adagrad.h b/caffe2/perfkernels/adagrad.h index 6ce1965..c39a1b3 100644 --- a/caffe2/perfkernels/adagrad.h +++ b/caffe2/perfkernels/adagrad.h @@ -68,12 +68,12 @@ inline void adagrad_update_prefetch_inlined( auto i = 0; #ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC - constexpr size_t kSize = 8; + constexpr int kSize = 8; for (; i + kSize <= N; i += kSize) { - _mm_prefetch(&w_n[i], _MM_HINT_T0); - _mm_prefetch(&h_n[i], _MM_HINT_T0); - _mm_prefetch(&nw_n[i], _MM_HINT_T0); - _mm_prefetch(&nh_n[i], _MM_HINT_T0); + _mm_prefetch(reinterpret_cast(&w_n[i]), _MM_HINT_T0); + _mm_prefetch(reinterpret_cast(&h_n[i]), _MM_HINT_T0); + _mm_prefetch(reinterpret_cast(&nw_n[i]), _MM_HINT_T0); + _mm_prefetch(reinterpret_cast(&nh_n[i]), _MM_HINT_T0); __m256 gi = _mm256_loadu_ps(g + i); __m256 hi = _mm256_loadu_ps(h + i); @@ -115,8 +115,8 @@ inline void rowwise_adagrad_update_inlined( auto i = 0; #ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC - constexpr size_t kSize = 8; - _mm_prefetch(h_n, _MM_HINT_T0); + constexpr int kSize = 8; + _mm_prefetch(reinterpret_cast(h_n), _MM_HINT_T0); __m256 partial_sum = _mm256_setzero_ps(); for (; i + kSize <= N; i += kSize) { __m256 gi = _mm256_loadu_ps(g + i); @@ -144,7 +144,7 @@ inline void rowwise_adagrad_update_inlined( __m256 step = _mm256_set1_ps(float_step); for (i = 0; i + kSize <= N; i += kSize) { - _mm_prefetch(&w_n[i], _MM_HINT_T0); + _mm_prefetch(reinterpret_cast(&w_n[i]), _MM_HINT_T0); __m256 gi = _mm256_loadu_ps(g + i); __m256 wi = _mm256_loadu_ps(w + i); @@ -242,7 +242,7 @@ template void sparse_adagrad( int num_rows, // number of rows reading int block_size, // number of parameters per rows - std::size_t param_size, // total number of parameters + std::uint64_t param_size, // total number of parameters const float* w, // input parameters const float* g, // input gradients const float* h, // input momentums @@ -257,7 +257,7 @@ void sparse_adagrad( void sparse_adagrad_##SIndex##__##ISA( \ int num_rows, \ int block_size, \ - std::size_t param_size, \ + std::uint64_t param_size, \ const float* w, \ const float* g, \ const float* h, \ diff --git a/caffe2/perfkernels/adagrad_avx.cc b/caffe2/perfkernels/adagrad_avx.cc index 36e355f..3c225e3 100644 --- a/caffe2/perfkernels/adagrad_avx.cc +++ b/caffe2/perfkernels/adagrad_avx.cc @@ -59,13 +59,13 @@ void adagrad_fp16_update_prefetch__avx_f16c( at::Half* nh_n, // prefetch ptr float epsilon, float lr) { - constexpr size_t kSize = 8; + constexpr int kSize = 8; auto i = 0; for (; i + kSize <= N; i += kSize) { - _mm_prefetch(&w_n[i], _MM_HINT_T0); - _mm_prefetch(&h_n[i], _MM_HINT_T0); - _mm_prefetch(&nw_n[i], _MM_HINT_T0); - _mm_prefetch(&nh_n[i], _MM_HINT_T0); + _mm_prefetch(reinterpret_cast(&w_n[i]), _MM_HINT_T0); + _mm_prefetch(reinterpret_cast(&h_n[i]), _MM_HINT_T0); + _mm_prefetch(reinterpret_cast(&nw_n[i]), _MM_HINT_T0); + _mm_prefetch(reinterpret_cast(&nh_n[i]), _MM_HINT_T0); // only convert momentum and embedding, gradient is fp32 __m256 gi = _mm256_loadu_ps(g + i); @@ -119,7 +119,7 @@ void adagrad_update__avx_f16c( float epsilon, float decay, float lr) { - constexpr size_t kSize = 8; + constexpr int kSize = 8; auto i = 0; for (; i + kSize <= N; i += kSize) { __m256 gi = _mm256_loadu_ps(g + i); diff --git a/caffe2/perfkernels/common.h b/caffe2/perfkernels/common.h index b128c76..1ceb9de 100644 --- a/caffe2/perfkernels/common.h +++ b/caffe2/perfkernels/common.h @@ -33,6 +33,9 @@ In foo.cc, do: void foo__base(int a, float b) { [base, possibly slow implementation] } + decltype(foo__base) foo__avx512; + decltype(foo__base) foo__avx2; + decltype(foo__base) foo__avx; void foo(int a, float b) { // You should always order things by their preference, faster // implementations earlier in the function. @@ -49,11 +52,11 @@ In foo.cc, do: // During build time: // The build system should provide flags CAFFE2_PERF_WITH_AVX512, // CAFFE2_PERF_WITH_AVX2, and CAFFE2_PERF_WITH_AVX that corresponds to the -// __AVX512F__, __AVX512DQ__, __AVX512VL__, __AVX__, and __AVX2__ flags the +// __AVX512F__, __AVX512DQ__, __AVX512VL__, __AVX2__, and __AVX__ flags the // compiler provides. Note that we do not use the compiler flags but rely on // the build system flags, because the common files (like foo.cc above) will -// always be built without __AVX512F__, __AVX512DQ__, __AVX512VL__, __AVX__ -// and __AVX2__. +// always be built without __AVX512F__, __AVX512DQ__, __AVX512VL__, __AVX2__ +// and __AVX__. // During run time: // we use cpuid to identify cpu support and run the proper functions. @@ -68,7 +71,6 @@ In foo.cc, do: #ifdef CAFFE2_PERF_WITH_AVX512 #define AVX512_DO(funcname, ...) \ - decltype(funcname##__base) funcname##__avx512; \ if (GetCpuId().avx512f() && GetCpuId().avx512dq() && \ GetCpuId().avx512vl()) { \ return funcname##__avx512(__VA_ARGS__); \ @@ -78,15 +80,13 @@ In foo.cc, do: #endif // CAFFE2_PERF_WITH_AVX512 #ifdef CAFFE2_PERF_WITH_AVX2 -#define AVX2_DO(funcname, ...) \ - decltype(funcname##__base) funcname##__avx2; \ - if (GetCpuId().avx2()) { \ - return funcname##__avx2(__VA_ARGS__); \ +#define AVX2_DO(funcname, ...) \ + if (GetCpuId().avx2()) { \ + return funcname##__avx2(__VA_ARGS__); \ } -#define AVX2_FMA_DO(funcname, ...) \ - decltype(funcname##__base) funcname##__avx2_fma; \ - if (GetCpuId().avx2() && GetCpuId().fma()) { \ - return funcname##__avx2_fma(__VA_ARGS__); \ +#define AVX2_FMA_DO(funcname, ...) \ + if (GetCpuId().avx2() && GetCpuId().fma()) { \ + return funcname##__avx2_fma(__VA_ARGS__); \ } #else // CAFFE2_PERF_WITH_AVX2 #define AVX2_DO(funcname, ...) @@ -94,15 +94,13 @@ In foo.cc, do: #endif // CAFFE2_PERF_WITH_AVX2 #ifdef CAFFE2_PERF_WITH_AVX -#define AVX_DO(funcname, ...) \ - decltype(funcname##__base) funcname##__avx; \ - if (GetCpuId().avx()) { \ - return funcname##__avx(__VA_ARGS__); \ +#define AVX_DO(funcname, ...) \ + if (GetCpuId().avx()) { \ + return funcname##__avx(__VA_ARGS__); \ } -#define AVX_F16C_DO(funcname, ...) \ - decltype(funcname##__base) funcname##__avx_f16c; \ - if (GetCpuId().avx() && GetCpuId().f16c()) { \ - return funcname##__avx_f16c(__VA_ARGS__); \ +#define AVX_F16C_DO(funcname, ...) \ + if (GetCpuId().avx() && GetCpuId().f16c()) { \ + return funcname##__avx_f16c(__VA_ARGS__); \ } #else // CAFFE2_PERF_WITH_AVX #define AVX_DO(funcname, ...) diff --git a/caffe2/perfkernels/cvtsh_ss_bugfix.h b/caffe2/perfkernels/cvtsh_ss_bugfix.h index ee20ce6..825e266 100644 --- a/caffe2/perfkernels/cvtsh_ss_bugfix.h +++ b/caffe2/perfkernels/cvtsh_ss_bugfix.h @@ -32,16 +32,17 @@ _cvtsh_ss(unsigned short a) #ifdef _MSC_VER +#include + // It seems that microsoft msvc does not have a _cvtsh_ss implementation so // we will add a dummy version to it. -static inline float -_cvtsh_ss(unsigned short x) { +static inline float _cvtsh_ss(unsigned short x) { union { - uint32_t intval; + std::uint32_t intval; float floatval; } t1; - uint32_t t2, t3; + std::uint32_t t2, t3; t1.intval = x & 0x7fff; // Non-sign bits t2 = x & 0x8000; // Sign bit t3 = x & 0x7c00; // Exponent diff --git a/caffe2/perfkernels/embedding_lookup.cc b/caffe2/perfkernels/embedding_lookup.cc index fa93ae7..e8c30a0 100644 --- a/caffe2/perfkernels/embedding_lookup.cc +++ b/caffe2/perfkernels/embedding_lookup.cc @@ -81,83 +81,86 @@ static void EmbeddingLookupGenericSlow( } // Proxy back to generic implementation -#define EMBEDDING_SPECIALIZATION( \ - IndexTypeName, \ - IndexType, \ - InTypeName, \ - InType, \ - OutTypeName, \ - OutType, \ - IS_WEIGHT_POSITIONAL) \ - void \ - EmbeddingLookup_##IndexTypeName##_##InTypeName##_##OutTypeName##_##IS_WEIGHT_POSITIONAL##__base( \ - const int64_t block_size, \ - const int64_t output_size, \ - const int64_t index_size, \ - const int64_t data_size, \ - const InType* input, \ - const IndexType* indices, \ - const int* lengths, \ - const float* weights, \ - const float* scale_bias, \ - bool normalize_by_lengths, \ - OutType* out) { \ - EmbeddingLookupGenericSlow< \ - IndexType, \ - InType, \ - OutType, \ - IS_WEIGHT_POSITIONAL>( \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - lengths, \ - weights, \ - scale_bias, \ - normalize_by_lengths, \ - out); \ - } \ - template <> \ - void EmbeddingLookup( \ - const int64_t block_size, \ - const int64_t output_size, \ - const int64_t index_size, \ - const int64_t data_size, \ - const InType* input, \ - const IndexType* indices, \ - const int* lengths, \ - const float* weights, \ - const float* scale_bias, \ - bool normalize_by_lengths, \ - OutType* out) { \ - AVX2_FMA_DO( \ - EmbeddingLookup_##IndexTypeName##_##InTypeName##_##OutTypeName##_##IS_WEIGHT_POSITIONAL, \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - lengths, \ - weights, \ - scale_bias, \ - normalize_by_lengths, \ - out); \ - BASE_DO( \ - EmbeddingLookup_##IndexTypeName##_##InTypeName##_##OutTypeName##_##IS_WEIGHT_POSITIONAL, \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - lengths, \ - weights, \ - scale_bias, \ - normalize_by_lengths, \ - out); \ +#define EMBEDDING_SPECIALIZATION( \ + IndexTypeName, \ + IndexType, \ + InTypeName, \ + InType, \ + OutTypeName, \ + OutType, \ + IS_WEIGHT_POSITIONAL) \ + void \ + EmbeddingLookup_##IndexTypeName##_##InTypeName##_##OutTypeName##_##IS_WEIGHT_POSITIONAL##__base( \ + const int64_t block_size, \ + const int64_t output_size, \ + const int64_t index_size, \ + const int64_t data_size, \ + const InType* input, \ + const IndexType* indices, \ + const int* lengths, \ + const float* weights, \ + const float* scale_bias, \ + bool normalize_by_lengths, \ + OutType* out) { \ + EmbeddingLookupGenericSlow< \ + IndexType, \ + InType, \ + OutType, \ + IS_WEIGHT_POSITIONAL>( \ + block_size, \ + output_size, \ + index_size, \ + data_size, \ + input, \ + indices, \ + lengths, \ + weights, \ + scale_bias, \ + normalize_by_lengths, \ + out); \ + } \ + decltype( \ + EmbeddingLookup_##IndexTypeName##_##InTypeName##_##OutTypeName##_##IS_WEIGHT_POSITIONAL##__base) \ + EmbeddingLookup_##IndexTypeName##_##InTypeName##_##OutTypeName##_##IS_WEIGHT_POSITIONAL##__avx2_fma; \ + template <> \ + void EmbeddingLookup( \ + const int64_t block_size, \ + const int64_t output_size, \ + const int64_t index_size, \ + const int64_t data_size, \ + const InType* input, \ + const IndexType* indices, \ + const int* lengths, \ + const float* weights, \ + const float* scale_bias, \ + bool normalize_by_lengths, \ + OutType* out) { \ + AVX2_FMA_DO( \ + EmbeddingLookup_##IndexTypeName##_##InTypeName##_##OutTypeName##_##IS_WEIGHT_POSITIONAL, \ + block_size, \ + output_size, \ + index_size, \ + data_size, \ + input, \ + indices, \ + lengths, \ + weights, \ + scale_bias, \ + normalize_by_lengths, \ + out); \ + BASE_DO( \ + EmbeddingLookup_##IndexTypeName##_##InTypeName##_##OutTypeName##_##IS_WEIGHT_POSITIONAL, \ + block_size, \ + output_size, \ + index_size, \ + data_size, \ + input, \ + indices, \ + lengths, \ + weights, \ + scale_bias, \ + normalize_by_lengths, \ + out); \ } EMBEDDING_SPECIALIZATION(int32_t, int32_t, float, float, float, float, false); diff --git a/caffe2/perfkernels/embedding_lookup_avx2.cc b/caffe2/perfkernels/embedding_lookup_avx2.cc index 326818b..89fcc4b 100644 --- a/caffe2/perfkernels/embedding_lookup_avx2.cc +++ b/caffe2/perfkernels/embedding_lookup_avx2.cc @@ -73,35 +73,43 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma( CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); - _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[16]), _MM_HINT_T0); vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); - _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[32]), _MM_HINT_T0); vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); // skip unnecessary prefetch of (&ip_next_T0[40]) vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); - _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[48]), _MM_HINT_T0); vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); // skip unnecessary prefetch of (&ip_next_T0[56]) vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64); - _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[64]), _MM_HINT_T0); vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72); // skip unnecessary prefetch of (&ip_next_T0[72]) vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80); - _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[80]), _MM_HINT_T0); vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88); // skip unnecessary prefetch of (&ip_next_T0[88]) vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96); - _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[96]), _MM_HINT_T0); vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104); // skip unnecessary prefetch of (&ip_next_T0[104]) vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112); - _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[112]), _MM_HINT_T0); vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120); // skip unnecessary prefetch of (&ip_next_T0[120]) } @@ -179,19 +187,23 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma( CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); - _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[16]), _MM_HINT_T0); vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); - _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[32]), _MM_HINT_T0); vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); // skip unnecessary prefetch of (&ip_next_T0[40]) vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); - _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[48]), _MM_HINT_T0); vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); // skip unnecessary prefetch of (&ip_next_T0[56]) } @@ -249,11 +261,13 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma( CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); - _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[16]), _MM_HINT_T0); vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) } @@ -301,7 +315,8 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma( CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) } @@ -355,7 +370,8 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma( &op[j], _mm256_fmadd_ps( vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j]))); - _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[j]), _MM_HINT_T0); } for (; j < block_size; j++) { op[j] += wgt * ip[j]; @@ -488,35 +504,43 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma( CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); - _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[16]), _MM_HINT_T0); vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); - _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[32]), _MM_HINT_T0); vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); // skip unnecessary prefetch of (&ip_next_T0[40]) vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); - _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[48]), _MM_HINT_T0); vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); // skip unnecessary prefetch of (&ip_next_T0[56]) vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64); - _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[64]), _MM_HINT_T0); vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72); // skip unnecessary prefetch of (&ip_next_T0[72]) vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80); - _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[80]), _MM_HINT_T0); vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88); // skip unnecessary prefetch of (&ip_next_T0[88]) vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96); - _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[96]), _MM_HINT_T0); vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104); // skip unnecessary prefetch of (&ip_next_T0[104]) vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112); - _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[112]), _MM_HINT_T0); vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120); // skip unnecessary prefetch of (&ip_next_T0[120]) } @@ -594,19 +618,23 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma( CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); - _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[16]), _MM_HINT_T0); vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); - _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[32]), _MM_HINT_T0); vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); // skip unnecessary prefetch of (&ip_next_T0[40]) vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); - _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[48]), _MM_HINT_T0); vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); // skip unnecessary prefetch of (&ip_next_T0[56]) } @@ -664,11 +692,13 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma( CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); - _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[16]), _MM_HINT_T0); vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) } @@ -716,7 +746,8 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma( CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) } @@ -770,7 +801,8 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma( &op[j], _mm256_fmadd_ps( vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j]))); - _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[j]), _MM_HINT_T0); } for (; j < block_size; j++) { op[j] += wgt * ip[j]; @@ -907,7 +939,8 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (0)))), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -931,7 +964,8 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (32)))), vop32); - _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[32]), _MM_HINT_T0); vop40 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -955,7 +989,8 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (64)))), vop64); - _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[64]), _MM_HINT_T0); vop72 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -979,7 +1014,8 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (96)))), vop96); - _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[96]), _MM_HINT_T0); vop104 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -1077,7 +1113,8 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (0)))), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -1101,7 +1138,8 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (32)))), vop32); - _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[32]), _MM_HINT_T0); vop40 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -1179,7 +1217,8 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (0)))), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -1247,7 +1286,8 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (0)))), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -1308,7 +1348,8 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma( _mm256_cvtph_ps(_mm_loadu_si128( reinterpret_cast(&ip[j]))), _mm256_loadu_ps(&op[j]))); - _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[j]), _MM_HINT_T0); } alignas(64) at::Half vtmp1[8]; for (; j < block_size; j++) { @@ -1448,7 +1489,8 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (0)))), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -1472,7 +1514,8 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (32)))), vop32); - _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[32]), _MM_HINT_T0); vop40 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -1496,7 +1539,8 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (64)))), vop64); - _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[64]), _MM_HINT_T0); vop72 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -1520,7 +1564,8 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (96)))), vop96); - _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[96]), _MM_HINT_T0); vop104 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -1618,7 +1663,8 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (0)))), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -1642,7 +1688,8 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (32)))), vop32); - _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[32]), _MM_HINT_T0); vop40 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -1720,7 +1767,8 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (0)))), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -1788,7 +1836,8 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (0)))), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -1849,7 +1898,8 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma( _mm256_cvtph_ps(_mm_loadu_si128( reinterpret_cast(&ip[j]))), _mm256_loadu_ps(&op[j]))); - _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[j]), _MM_HINT_T0); } alignas(64) at::Half vtmp1[8]; for (; j < block_size; j++) { @@ -1993,7 +2043,8 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (0))))), _mm256_add_ps(vop0, vbio)); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( @@ -2041,7 +2092,8 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (64))))), _mm256_add_ps(vop64, vbio)); - _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[64]), _MM_HINT_T0); vop72 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( @@ -2167,7 +2219,8 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (0))))), _mm256_add_ps(vop0, vbio)); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( @@ -2273,7 +2326,8 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (0))))), _mm256_add_ps(vop0, vbio)); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( @@ -2345,7 +2399,8 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (0))))), _mm256_add_ps(vop0, vbio)); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( @@ -2411,7 +2466,8 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64( reinterpret_cast(&ip[j])))), _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio))); - _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[j]), _MM_HINT_T0); } for (; j < block_size; j++) { op[j] += wgt * ((float)ip[j]) + bio; @@ -2552,7 +2608,8 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (0))))), _mm256_add_ps(vop0, vbio)); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( @@ -2600,7 +2657,8 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (64))))), _mm256_add_ps(vop64, vbio)); - _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[64]), _MM_HINT_T0); vop72 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( @@ -2726,7 +2784,8 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (0))))), _mm256_add_ps(vop0, vbio)); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( @@ -2832,7 +2891,8 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (0))))), _mm256_add_ps(vop0, vbio)); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( @@ -2904,7 +2964,8 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (0))))), _mm256_add_ps(vop0, vbio)); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( @@ -2970,7 +3031,8 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64( reinterpret_cast(&ip[j])))), _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio))); - _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[j]), _MM_HINT_T0); } for (; j < block_size; j++) { op[j] += wgt * ((float)ip[j]) + bio; diff --git a/caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_avx2.cc b/caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_avx2.cc index 1f4a831..0ae15c8 100644 --- a/caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_avx2.cc +++ b/caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_avx2.cc @@ -71,35 +71,43 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma( CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); - _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[16]), _MM_HINT_T0); vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); - _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[32]), _MM_HINT_T0); vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); // skip unnecessary prefetch of (&ip_next_T0[40]) vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); - _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[48]), _MM_HINT_T0); vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); // skip unnecessary prefetch of (&ip_next_T0[56]) vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64); - _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[64]), _MM_HINT_T0); vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72); // skip unnecessary prefetch of (&ip_next_T0[72]) vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80); - _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[80]), _MM_HINT_T0); vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88); // skip unnecessary prefetch of (&ip_next_T0[88]) vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96); - _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[96]), _MM_HINT_T0); vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104); // skip unnecessary prefetch of (&ip_next_T0[104]) vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112); - _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[112]), _MM_HINT_T0); vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120); // skip unnecessary prefetch of (&ip_next_T0[120]) } @@ -177,19 +185,23 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma( CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); - _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[16]), _MM_HINT_T0); vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); - _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[32]), _MM_HINT_T0); vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); // skip unnecessary prefetch of (&ip_next_T0[40]) vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); - _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[48]), _MM_HINT_T0); vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); // skip unnecessary prefetch of (&ip_next_T0[56]) } @@ -247,11 +259,13 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma( CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); - _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[16]), _MM_HINT_T0); vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) } @@ -299,7 +313,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma( CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) } @@ -353,7 +368,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma( &op[j], _mm256_fmadd_ps( vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j]))); - _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[j]), _MM_HINT_T0); } for (; j < block_size; j++) { op[j] += wgt * ip[j]; @@ -480,35 +496,43 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); - _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[16]), _MM_HINT_T0); vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); - _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[32]), _MM_HINT_T0); vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); // skip unnecessary prefetch of (&ip_next_T0[40]) vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); - _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[48]), _MM_HINT_T0); vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); // skip unnecessary prefetch of (&ip_next_T0[56]) vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64); - _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[64]), _MM_HINT_T0); vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72); // skip unnecessary prefetch of (&ip_next_T0[72]) vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80); - _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[80]), _MM_HINT_T0); vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88); // skip unnecessary prefetch of (&ip_next_T0[88]) vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96); - _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[96]), _MM_HINT_T0); vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104); // skip unnecessary prefetch of (&ip_next_T0[104]) vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112); - _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[112]), _MM_HINT_T0); vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120); // skip unnecessary prefetch of (&ip_next_T0[120]) } @@ -586,19 +610,23 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); - _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[16]), _MM_HINT_T0); vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); - _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[32]), _MM_HINT_T0); vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); // skip unnecessary prefetch of (&ip_next_T0[40]) vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); - _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[48]), _MM_HINT_T0); vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); // skip unnecessary prefetch of (&ip_next_T0[56]) } @@ -656,11 +684,13 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); - _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[16]), _MM_HINT_T0); vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) } @@ -708,7 +738,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) } @@ -762,7 +793,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( &op[j], _mm256_fmadd_ps( vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j]))); - _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[j]), _MM_HINT_T0); } for (; j < block_size; j++) { op[j] += wgt * ip[j]; @@ -893,7 +925,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (0)))), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -917,7 +950,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (32)))), vop32); - _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[32]), _MM_HINT_T0); vop40 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -941,7 +975,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (64)))), vop64); - _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[64]), _MM_HINT_T0); vop72 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -965,7 +1000,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (96)))), vop96); - _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[96]), _MM_HINT_T0); vop104 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -1063,7 +1099,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (0)))), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -1087,7 +1124,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (32)))), vop32); - _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[32]), _MM_HINT_T0); vop40 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -1165,7 +1203,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (0)))), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -1233,7 +1272,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (0)))), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -1294,7 +1334,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma( _mm256_cvtph_ps(_mm_loadu_si128( reinterpret_cast(&ip[j]))), _mm256_loadu_ps(&op[j]))); - _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[j]), _MM_HINT_T0); } alignas(64) at::Half vtmp1[8]; for (; j < block_size; j++) { @@ -1428,7 +1469,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (0)))), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -1452,7 +1494,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (32)))), vop32); - _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[32]), _MM_HINT_T0); vop40 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -1476,7 +1519,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (64)))), vop64); - _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[64]), _MM_HINT_T0); vop72 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -1500,7 +1544,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (96)))), vop96); - _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[96]), _MM_HINT_T0); vop104 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -1598,7 +1643,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (0)))), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -1622,7 +1668,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (32)))), vop32); - _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[32]), _MM_HINT_T0); vop40 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -1700,7 +1747,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (0)))), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -1768,7 +1816,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (0)))), vop0); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( @@ -1829,7 +1878,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( _mm256_cvtph_ps(_mm_loadu_si128( reinterpret_cast(&ip[j]))), _mm256_loadu_ps(&op[j]))); - _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[j]), _MM_HINT_T0); } alignas(64) at::Half vtmp1[8]; for (; j < block_size; j++) { @@ -1969,7 +2019,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (0))))), _mm256_add_ps(vop0, vbio)); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( @@ -2017,7 +2068,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (64))))), _mm256_add_ps(vop64, vbio)); - _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[64]), _MM_HINT_T0); vop72 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( @@ -2145,7 +2197,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (0))))), _mm256_add_ps(vop0, vbio)); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( @@ -2253,7 +2306,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (0))))), _mm256_add_ps(vop0, vbio)); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( @@ -2327,7 +2381,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (0))))), _mm256_add_ps(vop0, vbio)); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( @@ -2394,7 +2449,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64( reinterpret_cast(&ip[j])))), _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio))); - _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[j]), _MM_HINT_T0); } for (; j < block_size; j++) { op[j] += wgt * ((float)ip[j]) + bio; @@ -2531,7 +2587,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (0))))), _mm256_add_ps(vop0, vbio)); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( @@ -2579,7 +2636,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (64))))), _mm256_add_ps(vop64, vbio)); - _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[64]), _MM_HINT_T0); vop72 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( @@ -2707,7 +2765,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (0))))), _mm256_add_ps(vop0, vbio)); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( @@ -2815,7 +2874,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (0))))), _mm256_add_ps(vop0, vbio)); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( @@ -2889,7 +2949,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (0))))), _mm256_add_ps(vop0, vbio)); - _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( @@ -2956,7 +3017,8 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64( reinterpret_cast(&ip[j])))), _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio))); - _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[j]), _MM_HINT_T0); } for (; j < block_size; j++) { op[j] += wgt * ((float)ip[j]) + bio; diff --git a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc index 68c8c87..d8f6a43 100644 --- a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc +++ b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc @@ -85,77 +85,80 @@ static void Fused8BitRowwiseEmbeddingLookupGenericSlow( } // Proxy back to generic implementation -#define FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION( \ - IndexType, InType, OutType) \ - void \ - Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##_false__base( \ - const int64_t block_size, \ - const int64_t output_size, \ - const int64_t index_size, \ - const int64_t data_size, \ - const InType* input, \ - const IndexType* indices, \ - const int* lengths, \ - const float* weights, \ - bool normalize_by_lengths, \ - OutType* out) { \ - Fused8BitRowwiseEmbeddingLookupGenericSlow< \ - IndexType, \ - InType, \ - OutType, \ - false>( \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - lengths, \ - weights, \ - normalize_by_lengths, \ - out); \ - } \ - template <> \ - void Fused8BitRowwiseEmbeddingLookup( \ - const int64_t block_size, \ - const int64_t output_size, \ - const int64_t index_size, \ - const int64_t data_size, \ - const InType* input, \ - const IndexType* indices, \ - const int* lengths, \ - const float* weights, \ - bool normalize_by_lengths, \ - OutType* out) { \ - const int32_t one = 1; \ - CAFFE_ENFORCE_EQ( \ - reinterpret_cast(&one)[0], \ - 1, \ - "Fused8BitRowwiseEmbeddingLookup is not supported on this platform"); \ - AVX2_FMA_DO( \ - Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##_false, \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - lengths, \ - weights, \ - normalize_by_lengths, \ - out); \ - BASE_DO( \ - Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##_false, \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - lengths, \ - weights, \ - normalize_by_lengths, \ - out); \ +#define FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION( \ + IndexType, InType, OutType) \ + void \ + Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##_false__base( \ + const int64_t block_size, \ + const int64_t output_size, \ + const int64_t index_size, \ + const int64_t data_size, \ + const InType* input, \ + const IndexType* indices, \ + const int* lengths, \ + const float* weights, \ + bool normalize_by_lengths, \ + OutType* out) { \ + Fused8BitRowwiseEmbeddingLookupGenericSlow< \ + IndexType, \ + InType, \ + OutType, \ + false>( \ + block_size, \ + output_size, \ + index_size, \ + data_size, \ + input, \ + indices, \ + lengths, \ + weights, \ + normalize_by_lengths, \ + out); \ + } \ + decltype( \ + Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##_false__base) \ + Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##_false__avx2_fma; \ + template <> \ + void Fused8BitRowwiseEmbeddingLookup( \ + const int64_t block_size, \ + const int64_t output_size, \ + const int64_t index_size, \ + const int64_t data_size, \ + const InType* input, \ + const IndexType* indices, \ + const int* lengths, \ + const float* weights, \ + bool normalize_by_lengths, \ + OutType* out) { \ + const int32_t one = 1; \ + CAFFE_ENFORCE_EQ( \ + reinterpret_cast(&one)[0], \ + 1, \ + "Fused8BitRowwiseEmbeddingLookup is not supported on this platform"); \ + AVX2_FMA_DO( \ + Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##_false, \ + block_size, \ + output_size, \ + index_size, \ + data_size, \ + input, \ + indices, \ + lengths, \ + weights, \ + normalize_by_lengths, \ + out); \ + BASE_DO( \ + Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##_false, \ + block_size, \ + output_size, \ + index_size, \ + data_size, \ + input, \ + indices, \ + lengths, \ + weights, \ + normalize_by_lengths, \ + out); \ } FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(int32_t, uint8_t, float); diff --git a/caffe2/perfkernels/hp_emblookup_codegen.py b/caffe2/perfkernels/hp_emblookup_codegen.py index 20f759c..748f5ce 100644 --- a/caffe2/perfkernels/hp_emblookup_codegen.py +++ b/caffe2/perfkernels/hp_emblookup_codegen.py @@ -37,7 +37,9 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused): if prefetch: code.append( - " _mm_prefetch((&ip_next_T0[%d]), _MM_HINT_T0);" % (regid) + " _mm_prefetch(\n" + " reinterpret_cast(&ip_next_T0[%d]), _MM_HINT_T0);" + % (regid) ) else: code.append( @@ -178,7 +180,10 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused): else: assert False - code.append(" _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);") + code.append( + " _mm_prefetch(\n" + " reinterpret_cast(&ip_next_T0[j]), _MM_HINT_T0);" + ) return code diff --git a/caffe2/perfkernels/math.h b/caffe2/perfkernels/math.h index 14265f9..63380fc 100644 --- a/caffe2/perfkernels/math.h +++ b/caffe2/perfkernels/math.h @@ -21,15 +21,15 @@ namespace math { void quantize_and_compress( const float* input_data, std::uint8_t* output_data, - std::size_t input_size, - std::size_t bitwidth, + std::uint64_t input_size, + std::uint64_t bitwidth, bool random, const float* random_buffer); void decompress_and_dequantize( const std::uint8_t* input_data, float* output_data, - std::size_t input_size); + std::uint64_t input_size); } // namespace math } // namespace caffe2 diff --git a/caffe2/perfkernels/math_cpu_avx2.cc b/caffe2/perfkernels/math_cpu_avx2.cc index 95292c3..cf99a67 100644 --- a/caffe2/perfkernels/math_cpu_avx2.cc +++ b/caffe2/perfkernels/math_cpu_avx2.cc @@ -7,6 +7,9 @@ #include #include +using std::uint64_t; +using std::uint8_t; + namespace caffe2 { namespace math { @@ -16,8 +19,8 @@ static constexpr double QEPSILON = 1e-8; void quantize_and_compress__avx2( const float* input_data, uint8_t* output_data, - size_t input_size, - size_t bitwidth, + uint64_t input_size, + uint64_t bitwidth, bool random, const float* random_buffer) { __m256i shuffle_mask_v = _mm256_set_epi8( @@ -56,10 +59,10 @@ void quantize_and_compress__avx2( __m256i permute_mask_v = _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00); - size_t data_per_byte = 8 / bitwidth; - size_t tail = input_size % data_per_byte; + uint64_t data_per_byte = 8 / bitwidth; + uint64_t tail = input_size % data_per_byte; tail = tail ? data_per_byte - tail : 0; - size_t segment_size = (input_size + data_per_byte - 1) / data_per_byte; + uint64_t segment_size = (input_size + data_per_byte - 1) / data_per_byte; // basic info float minimum_element = INFINITY, maximum_element = -INFINITY; @@ -77,11 +80,11 @@ void quantize_and_compress__avx2( float gap = (maximum_element - minimum_element) / ((1 << bitwidth) - 1.0f); float gap_inverse = 1. / (gap + QEPSILON); uint8_t max_q = (1 << bitwidth) - 1; - size_t bit_start = 0; + uint64_t bit_start = 0; if (random) { for (int start = 0; start < input_size; start += segment_size) { - size_t stride = start + segment_size <= input_size ? segment_size - : input_size - start; + uint64_t stride = start + segment_size <= input_size ? segment_size + : input_size - start; int i = 0; constexpr int VLEN = 8; for (; i < stride / VLEN * VLEN; i += VLEN) { @@ -122,8 +125,8 @@ void quantize_and_compress__avx2( } else { // !random for (int start = 0; start < input_size; start += segment_size) { - size_t stride = start + segment_size <= input_size ? segment_size - : input_size - start; + uint64_t stride = start + segment_size <= input_size ? segment_size + : input_size - start; int i = 0; constexpr int VLEN = 8; for (; i < stride / VLEN * VLEN; i += VLEN) { @@ -165,26 +168,26 @@ void quantize_and_compress__avx2( void decompress_and_dequantize__avx2( const uint8_t* input_data, float* output_data, - size_t input_size) { + uint64_t input_size) { // basic info const float minimum_element = reinterpret_cast(input_data + 2)[0]; const float maximum_element = reinterpret_cast(input_data + 2)[1]; - const size_t bitwidth = input_data[0]; + const uint64_t bitwidth = input_data[0]; const float gap = (maximum_element - minimum_element) / ((1 << bitwidth) - 1.f) + QEPSILON; // for exact recovering - const size_t tail = input_data[1]; + const uint64_t tail = input_data[1]; - const size_t output_size = (input_size - 10) * (8 / bitwidth) - tail; + const uint64_t output_size = (input_size - 10) * (8 / bitwidth) - tail; // decoding - size_t bit_start = 0; - const size_t segment_size = input_size - 10; + uint64_t bit_start = 0; + const uint64_t segment_size = input_size - 10; for (int start = 0; start < output_size; start += segment_size) { - size_t stride = start + segment_size <= output_size ? segment_size - : output_size - start; + uint64_t stride = start + segment_size <= output_size ? segment_size + : output_size - start; uint8_t mask = (1 << bitwidth) - 1; int i = 0; // Can process 8 elements at a time because we need to expand uint8_t diff --git a/caffe2/perfkernels/math_cpu_base.cc b/caffe2/perfkernels/math_cpu_base.cc index 1837641..6f2cc2b 100644 --- a/caffe2/perfkernels/math_cpu_base.cc +++ b/caffe2/perfkernels/math_cpu_base.cc @@ -3,10 +3,15 @@ // computation library to different compiler options (-mno-avx2 or -mavx2). #include +#include +#include #include "common.h" #include "math.h" +using std::uint64_t; +using std::uint8_t; + namespace caffe2 { namespace math { @@ -16,14 +21,14 @@ static constexpr double QEPSILON = 1e-8; void quantize_and_compress__base( const float* input_data, uint8_t* output_data, - size_t input_size, - size_t bitwidth, + uint64_t input_size, + uint64_t bitwidth, bool random, const float* random_buffer) { - size_t data_per_byte = 8 / bitwidth; - size_t tail = input_size % data_per_byte; + uint64_t data_per_byte = 8 / bitwidth; + uint64_t tail = input_size % data_per_byte; tail = tail ? data_per_byte - tail : 0; - size_t segment_size = (input_size + data_per_byte - 1) / data_per_byte; + uint64_t segment_size = (input_size + data_per_byte - 1) / data_per_byte; // basic info float minimum_element = INFINITY, maximum_element = -INFINITY; @@ -41,11 +46,11 @@ void quantize_and_compress__base( float gap = (maximum_element - minimum_element) / ((1 << bitwidth) - 1.0f); float gap_inverse = 1. / (gap + QEPSILON); uint8_t max_q = (1 << bitwidth) - 1; - size_t bit_start = 0; + uint64_t bit_start = 0; if (random) { for (int start = 0; start < input_size; start += segment_size) { - size_t stride = start + segment_size <= input_size ? segment_size - : input_size - start; + uint64_t stride = start + segment_size <= input_size ? segment_size + : input_size - start; int i = 0; for (; i < stride; ++i) { float fval = input_data[start + i]; @@ -64,8 +69,8 @@ void quantize_and_compress__base( } } else { for (int start = 0; start < input_size; start += segment_size) { - size_t stride = start + segment_size <= input_size ? segment_size - : input_size - start; + uint64_t stride = start + segment_size <= input_size ? segment_size + : input_size - start; int i = 0; for (; i < stride; ++i) { float fval = input_data[start + i]; @@ -84,11 +89,12 @@ void quantize_and_compress__base( } } +decltype(quantize_and_compress__base) quantize_and_compress__avx2; void quantize_and_compress( const float* input_data, uint8_t* output_data, - size_t input_size, - size_t bitwidth, + uint64_t input_size, + uint64_t bitwidth, bool random, const float* random_buffer) { AVX2_DO( @@ -112,26 +118,26 @@ void quantize_and_compress( void decompress_and_dequantize__base( const uint8_t* input_data, float* output_data, - size_t input_size) { + uint64_t input_size) { // basic info const float minimum_element = reinterpret_cast(input_data + 2)[0]; const float maximum_element = reinterpret_cast(input_data + 2)[1]; - const size_t bitwidth = input_data[0]; + const uint64_t bitwidth = input_data[0]; const float gap = (maximum_element - minimum_element) / ((1 << bitwidth) - 1.f) + QEPSILON; // for exact recovering - const size_t tail = input_data[1]; + const uint64_t tail = input_data[1]; - const size_t output_size = (input_size - 10) * (8 / bitwidth) - tail; + const uint64_t output_size = (input_size - 10) * (8 / bitwidth) - tail; // decoding - size_t bit_start = 0; - const size_t segment_size = input_size - 10; + uint64_t bit_start = 0; + const uint64_t segment_size = input_size - 10; for (int start = 0; start < output_size; start += segment_size) { - size_t stride = start + segment_size <= output_size ? segment_size - : output_size - start; + uint64_t stride = start + segment_size <= output_size ? segment_size + : output_size - start; uint8_t mask = (1 << bitwidth) - 1; int i = 0; for (; i < stride; ++i) { @@ -142,10 +148,11 @@ void decompress_and_dequantize__base( } } +decltype(decompress_and_dequantize__base) decompress_and_dequantize__avx2; void decompress_and_dequantize( const uint8_t* input_data, float* output_data, - size_t input_size) { + uint64_t input_size) { AVX2_DO(decompress_and_dequantize, input_data, output_data, input_size); BASE_DO(decompress_and_dequantize, input_data, output_data, input_size); } diff --git a/caffe2/perfkernels/typed_axpy.cc b/caffe2/perfkernels/typed_axpy.cc index 8bcbc06..2ca219a 100644 --- a/caffe2/perfkernels/typed_axpy.cc +++ b/caffe2/perfkernels/typed_axpy.cc @@ -36,6 +36,8 @@ void TypedAxpyHalffloat__base( } } +decltype(TypedAxpyHalffloat__base) TypedAxpyHalffloat__avx2_fma; +decltype(TypedAxpyHalffloat__base) TypedAxpyHalffloat__avx_f16c; template <> void TypedAxpy( int N, @@ -57,6 +59,8 @@ void TypedAxpy_uint8_float__base( } } +decltype(TypedAxpy_uint8_float__base) TypedAxpy_uint8_float__avx2_fma; +decltype(TypedAxpy_uint8_float__base) TypedAxpy_uint8_float__avx_f16c; template <> void TypedAxpy( int N, diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index a49e597..2e2add2 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -321,7 +321,7 @@ if(USE_FBGEMM) if(NOT DEFINED FBGEMM_SOURCE_DIR) set(FBGEMM_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/fbgemm" CACHE STRING "FBGEMM source directory") endif() - if(NOT CAFFE2_COMPILER_SUPPORTS_AVX512F_EXTENSIONS) + if(NOT CAFFE2_COMPILER_SUPPORTS_AVX512_EXTENSIONS) message(WARNING "A compiler with AVX512 support is required for FBGEMM. " "Not compiling with FBGEMM. " diff --git a/cmake/MiscCheck.cmake b/cmake/MiscCheck.cmake index f35502d..0d2e61c 100644 --- a/cmake/MiscCheck.cmake +++ b/cmake/MiscCheck.cmake @@ -171,41 +171,52 @@ CHECK_CXX_SOURCE_COMPILES( }" CAFFE2_COMPILER_SUPPORTS_AVX2_EXTENSIONS) if (CAFFE2_COMPILER_SUPPORTS_AVX2_EXTENSIONS) message(STATUS "Current compiler supports avx2 extension. Will build perfkernels.") - # Currently MSVC seems to have a symbol not found error while linking (related - # to source file order?). As a result we will currently disable the perfkernel - # in msvc. # Also see CMakeLists.txt under caffe2/perfkernels. - if (NOT MSVC) - set(CAFFE2_PERF_WITH_AVX 1) - set(CAFFE2_PERF_WITH_AVX2 1) - endif() + set(CAFFE2_PERF_WITH_AVX 1) + set(CAFFE2_PERF_WITH_AVX2 1) endif() cmake_pop_check_state() -# ---[ Check if the compiler has AVX512F support. +# ---[ Check if the compiler has AVX512 support. cmake_push_check_state(RESET) if (MSVC) - set(CMAKE_REQUIRED_FLAGS "/D__AVX512F__") + # We could've used MSVC's hidden option /arch:AVX512 that defines __AVX512F__, + # __AVX512DQ__, and __AVX512VL__, and /arch:AVX512F that defines __AVX512F__. + # But, we chose not to do that not to rely on hidden options. + set(CMAKE_REQUIRED_FLAGS "/D__AVX512F__ /D__AVX512DQ__ /D__AVX512VL__") else() - set(CMAKE_REQUIRED_FLAGS "-mavx512f") + # We only consider the case where all of avx512f, avx512dq, and avx512vl are + # supported. + # Platforms where avx512f is supported by not avx512dq and avx512vl as of + # Jan 15 2019 : linux_manywheel_2.7mu_cpu_build and + # linux_conda_3.7_cu100_build + set(CMAKE_REQUIRED_FLAGS "-mavx512f -mavx512dq -mavx512vl") endif() CHECK_CXX_SOURCE_COMPILES( "#if defined(_MSC_VER) #include #else - #include + #include #endif + // check avx512f __m512 addConstant(__m512 arg) { return _mm512_add_ps(arg, _mm512_set1_ps(1.f)); } + // check avx512dq + __m512 andConstant(__m512 arg) { + return _mm512_and_ps(arg, _mm512_set1_ps(1.f)); + } int main() { __m512i a = _mm512_set1_epi32(1); __m256i ymm = _mm512_extracti64x4_epi64(a, 0); + ymm = _mm256_abs_epi64(ymm); // check avx512vl __mmask16 m = _mm512_cmp_epi32_mask(a, a, _MM_CMPINT_EQ); __m512i r = _mm512_andnot_si512(a, a); - }" CAFFE2_COMPILER_SUPPORTS_AVX512F_EXTENSIONS) -if (CAFFE2_COMPILER_SUPPORTS_AVX512F_EXTENSIONS) + }" CAFFE2_COMPILER_SUPPORTS_AVX512_EXTENSIONS) +if (CAFFE2_COMPILER_SUPPORTS_AVX512_EXTENSIONS) message(STATUS "Current compiler supports avx512f extension. Will build fbgemm.") + # Also see CMakeLists.txt under caffe2/perfkernels. + set(CAFFE2_PERF_WITH_AVX512 1) endif() cmake_pop_check_state() -- 2.7.4