From db121375e7c07d43f48cfbc2a57e39e5f16e7d3b Mon Sep 17 00:00:00 2001 From: Jongsoo Park Date: Wed, 30 Jan 2019 22:46:07 -0800 Subject: [PATCH] more careful use of inline/template function in perfkernels (#15388) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15388 This is another pass to make perfkernels code safer from illegal instruction error. Removed dependency to c10/util/Logging.h We're err on the safer side at the expense of some verbosity. Reviewed By: dskhudia Differential Revision: D13502902 fbshipit-source-id: 4f833115df885c5b4f8c1ca83b9badea1553f944 --- caffe2/perfkernels/adagrad.cc | 54 +- caffe2/perfkernels/adagrad.h | 51 +- caffe2/perfkernels/adagrad_avx.cc | 85 +-- caffe2/perfkernels/embedding_lookup.cc | 291 +++---- caffe2/perfkernels/embedding_lookup.h | 11 +- caffe2/perfkernels/embedding_lookup_avx2.cc | 833 ++++++++++----------- .../embedding_lookup_fused_8bit_rowwise_avx2.cc | 825 ++++++++++---------- .../fused_8bit_rowwise_embedding_lookup.cc | 228 +++--- .../fused_8bit_rowwise_embedding_lookup.h | 10 +- caffe2/perfkernels/hp_emblookup_codegen.py | 71 +- caffe2/perfkernels/math_cpu_avx2.cc | 1 - caffe2/perfkernels/typed_axpy_avx.cc | 6 +- caffe2/perfkernels/typed_axpy_avx2.cc | 6 +- .../lengths_reducer_fused_8bit_rowwise_ops_test.py | 45 +- 14 files changed, 1262 insertions(+), 1255 deletions(-) diff --git a/caffe2/perfkernels/adagrad.cc b/caffe2/perfkernels/adagrad.cc index 2c65616..d90b62d 100644 --- a/caffe2/perfkernels/adagrad.cc +++ b/caffe2/perfkernels/adagrad.cc @@ -71,6 +71,22 @@ void rowwise_adagrad_update__base( internal::rowwise_adagrad_update_inlined(N, w, w_n, g, h, h_n, epsilon, lr); } +// version without prefetching +decltype(adagrad_update__base) adagrad_update__avx_f16c; +void adagrad_update( + int N, + const float* w, + const float* g, + const float* h, + float* nw, + float* nh, + float epsilon, + float decay, + float lr) { + AVX_F16C_DO(adagrad_update, N, w, g, h, nw, nh, epsilon, decay, lr); + BASE_DO(adagrad_update, N, w, g, h, nw, nh, epsilon, decay, lr); +} + decltype(adagrad_update_prefetch__base) adagrad_update_prefetch__avx_f16c; void adagrad_update_prefetch( int N, @@ -184,27 +200,11 @@ void rowwise_adagrad_update( BASE_DO(rowwise_adagrad_update, N, w, w_n, g, h, h_n, epsilon, lr); } -// version without prefetching -decltype(adagrad_update__base) adagrad_update__avx_f16c; -void adagrad_update( - int N, - const float* w, - const float* g, - const float* h, - float* nw, - float* nh, - float epsilon, - float decay, - float lr) { - AVX_F16C_DO(adagrad_update, N, w, g, h, nw, nh, epsilon, decay, lr); - BASE_DO(adagrad_update, N, w, g, h, nw, nh, epsilon, decay, lr); -} - SPARSE_ADAGRAD_SPECIALIZATION(int32_t, base); decltype(sparse_adagrad_int32_t__base) sparse_adagrad_int32_t__avx_f16c; template <> -void sparse_adagrad( +int sparse_adagrad( int num_rows, int block_size, uint64_t param_size, @@ -215,8 +215,7 @@ void sparse_adagrad( float* nw, float* nh, float epsilon, - float lr, - const std::string& param_name) { + float lr) { AVX_F16C_DO( sparse_adagrad_int32_t, num_rows, @@ -229,8 +228,7 @@ void sparse_adagrad( nw, nh, epsilon, - lr, - param_name); + lr); BASE_DO( sparse_adagrad_int32_t, num_rows, @@ -243,15 +241,14 @@ void sparse_adagrad( nw, nh, epsilon, - lr, - param_name); + lr); } SPARSE_ADAGRAD_SPECIALIZATION(int64_t, base); decltype(sparse_adagrad_int64_t__base) sparse_adagrad_int64_t__avx_f16c; template <> -void sparse_adagrad( +int sparse_adagrad( int num_rows, int block_size, uint64_t param_size, @@ -262,8 +259,7 @@ void sparse_adagrad( float* nw, float* nh, float epsilon, - float lr, - const std::string& param_name) { + float lr) { AVX_F16C_DO( sparse_adagrad_int64_t, num_rows, @@ -276,8 +272,7 @@ void sparse_adagrad( nw, nh, epsilon, - lr, - param_name); + lr); BASE_DO( sparse_adagrad_int64_t, num_rows, @@ -290,8 +285,7 @@ void sparse_adagrad( nw, nh, epsilon, - lr, - param_name); + lr); } } // namespace caffe2 diff --git a/caffe2/perfkernels/adagrad.h b/caffe2/perfkernels/adagrad.h index a16794c..c75e27c 100644 --- a/caffe2/perfkernels/adagrad.h +++ b/caffe2/perfkernels/adagrad.h @@ -6,12 +6,14 @@ #include #endif #include -#include namespace caffe2 { namespace internal { +// The following functions inside internal namespace are inlined because they +// are performance critical. + template static inline void adagrad_update_base_inlined( int N, @@ -31,6 +33,23 @@ static inline void adagrad_update_base_inlined( } } +// version with prefetching +// TODO(msmelyan) +// Crux of the computation is computing a / (sqrt(b) + epsilon), +// where a and b are vectors and epislon is very small (eg., 10^-5) and does not +// change. Today it's computed using two vector sqrt and vector divide simd +// instructions. It is slow. We can take advantage of existing fast vector +// VRSQRTPS instruction that computes approximate reciprocals of square roots +// of the vector. It is 6x faster than vsrt and vdiv combinations. Since the +// addition of epislon is just done to avoid division by zero, we approximate a +// / (sqrt(b) + epsilon) by a / (sqrt(b + sqrt(epsilon)) If we do that, we can +// use VRSQRTPS instead now. VRSQRTPS is not very accurate. Specifically, for +// the test on random numbers between 0.1 and 1 the absolute error was about +// 10^-3 compared to using slower but more accurate combination of vsqrt and +// vdiv. Extend Marat's function with more NR iterations to get more accuracy +// for training +// TODO(msmelyan) +// explore streaming stores, but need to have unique indices (deduplication) inline void adagrad_update_prefetch_inlined( int N, const float* w, @@ -238,8 +257,12 @@ void adagrad_update( float decay, float lr); +/** + * @return num_rows if succeeds otherwise return the row idx where we pass + * the boundary of param_size + */ template -void sparse_adagrad( +int sparse_adagrad( int num_rows, // number of rows reading int block_size, // number of parameters per rows std::uint64_t param_size, // total number of parameters @@ -250,11 +273,10 @@ void sparse_adagrad( float* nw, // output parameters float* nh, // output momentums float epsilon, - float lr, - const std::string& param_name); // name of parameters (for error reporting) + float lr); #define SPARSE_ADAGRAD_SPECIALIZATION(SIndex, ISA) \ - void sparse_adagrad_##SIndex##__##ISA( \ + int sparse_adagrad_##SIndex##__##ISA( \ int num_rows, \ int block_size, \ std::uint64_t param_size, \ @@ -265,25 +287,15 @@ void sparse_adagrad( float* nw, \ float* nh, \ float epsilon, \ - float lr, \ - const std::string& param_name) { \ + float lr) { \ for (int i = 0; i < num_rows; ++i) { \ auto idx = indices[i]; \ auto offsetI = i * block_size; \ auto offsetIdx = idx * block_size; \ \ - CAFFE_ENFORCE_GE( \ - param_size, \ - block_size + offsetIdx, \ - param_name, \ - ", out of bound, idx:", \ - idx, \ - " for input i:", \ - i, \ - " and block size:", \ - block_size, \ - " max size:", \ - param_size); \ + if (block_size + offsetIdx > param_size) { \ + return i; \ + } \ \ if (block_size == 1) { \ float gi = g[i]; \ @@ -309,6 +321,7 @@ void sparse_adagrad( lr); \ } \ } \ + return num_rows; \ }; } // namespace caffe2 diff --git a/caffe2/perfkernels/adagrad_avx.cc b/caffe2/perfkernels/adagrad_avx.cc index 3c225e3..c1de220 100644 --- a/caffe2/perfkernels/adagrad_avx.cc +++ b/caffe2/perfkernels/adagrad_avx.cc @@ -6,23 +6,40 @@ namespace caffe2 { -// version with prefetching -// TODO(msmelyan) -// Crux of the computation is computing a / (sqrt(b) + epsilon), -// where a and b are vectors and epislon is very small (eg., 10^-5) and does not -// change. Today it's computed using two vector sqrt and vector divide simd -// instructions. It is slow. We can take advantage of existing fast vector -// VRSQRTPS instruction that computes approximate reciprocals of square roots -// of the vector. It is 6x faster than vsrt and vdiv combinations. Since the -// addition of epislon is just done to avoid division by zero, we approximate a -// / (sqrt(b) + epsilon) by a / (sqrt(b + sqrt(epsilon)) If we do that, we can -// use VRSQRTPS instead now. VRSQRTPS is not very accurate. Specifically, for -// the test on random numbers between 0.1 and 1 the absolute error was about -// 10^-3 compared to using slower but more accurate combination of vsqrt and -// vdiv. Extend Marat's function with more NR iterations to get more accuracy -// for training -// TODO(msmelyan) -// explore streaming stores, but need to have unique indices (deduplication) +// version without prefetching +void adagrad_update__avx_f16c( + int N, + const float* w, + const float* g, + const float* h, + float* nw, + float* nh, + float epsilon, + float decay, + float lr) { + constexpr size_t kSize = 8; + auto i = 0; + for (; i + kSize <= N; i += kSize) { + __m256 gi = _mm256_loadu_ps(g + i); + __m256 hi = _mm256_loadu_ps(h + i); + __m256 wi = _mm256_loadu_ps(w + i); + + __m256 nhi = _mm256_add_ps( + _mm256_mul_ps(_mm256_set1_ps(decay), hi), _mm256_mul_ps(gi, gi)); + _mm256_storeu_ps(nh + i, nhi); + __m256 vtmp = _mm256_div_ps( + gi, _mm256_add_ps(_mm256_sqrt_ps(nhi), _mm256_set1_ps(epsilon))); + _mm256_storeu_ps( + nw + i, _mm256_add_ps(wi, _mm256_mul_ps(_mm256_set1_ps(lr), vtmp))); + } + + for (; i < N; ++i) { + float gi = g[i]; + float hi = nh[i] = decay * h[i] + gi * gi; + nw[i] = w[i] + lr * gi / (std::sqrt(hi) + epsilon); + } +} + void adagrad_update_prefetch__avx_f16c( int N, const float* w, @@ -108,40 +125,6 @@ void rowwise_adagrad_update__avx_f16c( internal::rowwise_adagrad_update_inlined(N, w, w_n, g, h, h_n, epsilon, lr); } -// version without prefetching -void adagrad_update__avx_f16c( - int N, - const float* w, - const float* g, - const float* h, - float* nw, - float* nh, - float epsilon, - float decay, - float lr) { - constexpr int kSize = 8; - auto i = 0; - for (; i + kSize <= N; i += kSize) { - __m256 gi = _mm256_loadu_ps(g + i); - __m256 hi = _mm256_loadu_ps(h + i); - __m256 wi = _mm256_loadu_ps(w + i); - - __m256 nhi = _mm256_add_ps( - _mm256_mul_ps(_mm256_set1_ps(decay), hi), _mm256_mul_ps(gi, gi)); - _mm256_storeu_ps(nh + i, nhi); - __m256 vtmp = _mm256_div_ps( - gi, _mm256_add_ps(_mm256_sqrt_ps(nhi), _mm256_set1_ps(epsilon))); - _mm256_storeu_ps( - nw + i, _mm256_add_ps(wi, _mm256_mul_ps(_mm256_set1_ps(lr), vtmp))); - } - - for (; i < N; ++i) { - float gi = g[i]; - float hi = nh[i] = decay * h[i] + gi * gi; - nw[i] = w[i] + lr * gi / (std::sqrt(hi) + epsilon); - } -} - SPARSE_ADAGRAD_SPECIALIZATION(int32_t, avx_f16c); SPARSE_ADAGRAD_SPECIALIZATION(int64_t, avx_f16c); diff --git a/caffe2/perfkernels/embedding_lookup.cc b/caffe2/perfkernels/embedding_lookup.cc index e8c30a0..1a13a89 100644 --- a/caffe2/perfkernels/embedding_lookup.cc +++ b/caffe2/perfkernels/embedding_lookup.cc @@ -8,13 +8,16 @@ namespace caffe2 { -// Base implementation does runtime dispatch for each segment of reduction +/** + * Base implementation does runtime dispatch for each segment of reduction + * @return false if there is an out-of-bound error + */ template < typename IndexType, typename InType, typename OutType, bool IS_WEIGHT_POSITIONAL = false> -static void EmbeddingLookupGenericSlow( +static bool EmbeddingLookupGenericSlow( const int64_t block_size, const int64_t output_size, const int64_t index_size, @@ -30,18 +33,14 @@ static void EmbeddingLookupGenericSlow( for (int m = 0; m < output_size; ++m) { memset(out, 0, sizeof(OutType) * block_size); EigenVectorArrayMap out_vector(out, block_size); + if (current + lengths[m] > index_size) { + return false; + } for (int i = 0; i < lengths[m]; ++i) { - CAFFE_ENFORCE_LT(current, index_size); int64_t idx = indices[current]; - CAFFE_ENFORCE( - 0 <= idx && idx < data_size, - "Index ", - current, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); - CAFFE_ENFORCE_LT(idx, data_size); + if (idx < 0 || idx >= data_size) { + return false; + } #ifdef __GNUC__ if (current + 1 < index_size) { __builtin_prefetch(input + block_size * indices[current + 1], 0, 1); @@ -73,137 +72,155 @@ static void EmbeddingLookupGenericSlow( } out += block_size; } - CAFFE_ENFORCE_EQ( - current, - index_size, - "Your input seems to be incorrect: the sum of lengths values should be " - "the size of the indices tensor, but it appears not."); + return current == index_size; } // Proxy back to generic implementation -#define EMBEDDING_SPECIALIZATION( \ - IndexTypeName, \ - IndexType, \ - InTypeName, \ - InType, \ - OutTypeName, \ - OutType, \ - IS_WEIGHT_POSITIONAL) \ - void \ - EmbeddingLookup_##IndexTypeName##_##InTypeName##_##OutTypeName##_##IS_WEIGHT_POSITIONAL##__base( \ - const int64_t block_size, \ - const int64_t output_size, \ - const int64_t index_size, \ - const int64_t data_size, \ - const InType* input, \ - const IndexType* indices, \ - const int* lengths, \ - const float* weights, \ - const float* scale_bias, \ - bool normalize_by_lengths, \ - OutType* out) { \ - EmbeddingLookupGenericSlow< \ - IndexType, \ - InType, \ - OutType, \ - IS_WEIGHT_POSITIONAL>( \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - lengths, \ - weights, \ - scale_bias, \ - normalize_by_lengths, \ - out); \ - } \ - decltype( \ - EmbeddingLookup_##IndexTypeName##_##InTypeName##_##OutTypeName##_##IS_WEIGHT_POSITIONAL##__base) \ - EmbeddingLookup_##IndexTypeName##_##InTypeName##_##OutTypeName##_##IS_WEIGHT_POSITIONAL##__avx2_fma; \ - template <> \ - void EmbeddingLookup( \ - const int64_t block_size, \ - const int64_t output_size, \ - const int64_t index_size, \ - const int64_t data_size, \ - const InType* input, \ - const IndexType* indices, \ - const int* lengths, \ - const float* weights, \ - const float* scale_bias, \ - bool normalize_by_lengths, \ - OutType* out) { \ - AVX2_FMA_DO( \ - EmbeddingLookup_##IndexTypeName##_##InTypeName##_##OutTypeName##_##IS_WEIGHT_POSITIONAL, \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - lengths, \ - weights, \ - scale_bias, \ - normalize_by_lengths, \ - out); \ - BASE_DO( \ - EmbeddingLookup_##IndexTypeName##_##InTypeName##_##OutTypeName##_##IS_WEIGHT_POSITIONAL, \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - lengths, \ - weights, \ - scale_bias, \ - normalize_by_lengths, \ - out); \ +#define EMBEDDING_SPECIALIZATION( \ + IndexType, InTypeName, InType, OutType, IS_WEIGHT_POSITIONAL) \ + bool \ + EmbeddingLookup_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__base( \ + const int64_t block_size, \ + const int64_t output_size, \ + const int64_t index_size, \ + const int64_t data_size, \ + const InType* input, \ + const IndexType* indices, \ + const int* lengths, \ + const float* weights, \ + const float* scale_bias, \ + bool normalize_by_lengths, \ + OutType* out) { \ + return EmbeddingLookupGenericSlow< \ + IndexType, \ + InType, \ + OutType, \ + IS_WEIGHT_POSITIONAL>( \ + block_size, \ + output_size, \ + index_size, \ + data_size, \ + input, \ + indices, \ + lengths, \ + weights, \ + scale_bias, \ + normalize_by_lengths, \ + out); \ + } \ + decltype( \ + EmbeddingLookup_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__base) \ + EmbeddingLookup_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__avx2_fma; \ + bool \ + EmbeddingLookup_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL( \ + const int64_t block_size, \ + const int64_t output_size, \ + const int64_t index_size, \ + const int64_t data_size, \ + const InType* input, \ + const IndexType* indices, \ + const int* lengths, \ + const float* weights, \ + const float* scale_bias, \ + bool normalize_by_lengths, \ + OutType* out) { \ + if (std::is_same::value) { \ + CAFFE_ENFORCE(scale_bias != nullptr, "scale_bias must not be nullptr"); \ + } else { \ + CAFFE_ENFORCE(scale_bias == nullptr, "scale_bias must be nullptr"); \ + } \ + AVX2_FMA_DO( \ + EmbeddingLookup_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL, \ + block_size, \ + output_size, \ + index_size, \ + data_size, \ + input, \ + indices, \ + lengths, \ + weights, \ + scale_bias, \ + normalize_by_lengths, \ + out); \ + BASE_DO( \ + EmbeddingLookup_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL, \ + block_size, \ + output_size, \ + index_size, \ + data_size, \ + input, \ + indices, \ + lengths, \ + weights, \ + scale_bias, \ + normalize_by_lengths, \ + out); \ + } \ + template <> \ + void EmbeddingLookup( \ + const int64_t block_size, \ + const int64_t output_size, \ + const int64_t index_size, \ + const int64_t data_size, \ + const InType* input, \ + const IndexType* indices, \ + const int* lengths, \ + const float* weights, \ + const float* scale_bias, \ + bool normalize_by_lengths, \ + OutType* out) { \ + bool success = \ + EmbeddingLookup_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL( \ + block_size, \ + output_size, \ + index_size, \ + data_size, \ + input, \ + indices, \ + lengths, \ + weights, \ + scale_bias, \ + normalize_by_lengths, \ + out); \ + if (success) { \ + return; \ + } \ + int64_t current = 0; \ + for (int m = 0; m < output_size; ++m) { \ + for (int i = 0; i < lengths[m]; ++i) { \ + CAFFE_ENFORCE_LT(current, index_size); \ + IndexType idx = indices[current]; \ + CAFFE_ENFORCE( \ + 0 <= idx && idx < data_size, \ + "Index ", \ + current, \ + " is out of bounds: ", \ + idx, \ + ", range 0 to ", \ + data_size); \ + ++current; \ + } \ + } \ + CAFFE_ENFORCE_EQ( \ + current, \ + index_size, \ + "Your input seems to be incorrect: the sum of lengths values should be " \ + "the size of the indices tensor, but it appears not."); \ } -EMBEDDING_SPECIALIZATION(int32_t, int32_t, float, float, float, float, false); -EMBEDDING_SPECIALIZATION(int64_t, int64_t, float, float, float, float, false); -EMBEDDING_SPECIALIZATION(int32_t, int32_t, half, at::Half, float, float, false); -EMBEDDING_SPECIALIZATION(int64_t, int64_t, half, at::Half, float, float, false); -EMBEDDING_SPECIALIZATION( - int32_t, - int32_t, - uint8_t, - uint8_t, - float, - float, - false); -EMBEDDING_SPECIALIZATION( - int64_t, - int64_t, - uint8_t, - uint8_t, - float, - float, - false); +EMBEDDING_SPECIALIZATION(int32_t, float, float, float, false); +EMBEDDING_SPECIALIZATION(int64_t, float, float, float, false); +EMBEDDING_SPECIALIZATION(int32_t, half, at::Half, float, false); +EMBEDDING_SPECIALIZATION(int64_t, half, at::Half, float, false); +EMBEDDING_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, false); +EMBEDDING_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, false); -EMBEDDING_SPECIALIZATION(int32_t, int32_t, float, float, float, float, true); -EMBEDDING_SPECIALIZATION(int64_t, int64_t, float, float, float, float, true); -EMBEDDING_SPECIALIZATION(int32_t, int32_t, half, at::Half, float, float, true); -EMBEDDING_SPECIALIZATION(int64_t, int64_t, half, at::Half, float, float, true); -EMBEDDING_SPECIALIZATION( - int32_t, - int32_t, - uint8_t, - uint8_t, - float, - float, - true); -EMBEDDING_SPECIALIZATION( - int64_t, - int64_t, - uint8_t, - uint8_t, - float, - float, - true); +EMBEDDING_SPECIALIZATION(int32_t, float, float, float, true); +EMBEDDING_SPECIALIZATION(int64_t, float, float, float, true); +EMBEDDING_SPECIALIZATION(int32_t, half, at::Half, float, true); +EMBEDDING_SPECIALIZATION(int64_t, half, at::Half, float, true); +EMBEDDING_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, true); +EMBEDDING_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, true); #undef EMBEDDING_SPECIALIZATION diff --git a/caffe2/perfkernels/embedding_lookup.h b/caffe2/perfkernels/embedding_lookup.h index d147708..1d0cd2a 100644 --- a/caffe2/perfkernels/embedding_lookup.h +++ b/caffe2/perfkernels/embedding_lookup.h @@ -1,6 +1,6 @@ #pragma once -#include "caffe2/core/common.h" +#include namespace caffe2 { @@ -28,7 +28,6 @@ namespace caffe2 { * if (normalize_weights && lengths[i] > 0) * for (k = 0..block_size-1) * out[i*block_size + k] /= lengths[i] - * */ template < typename IndexType, @@ -36,10 +35,10 @@ template < typename OutType, bool IS_WEIGHT_POSITIONAL = false> void EmbeddingLookup( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, + const std::int64_t block_size, + const std::int64_t output_size, + const std::int64_t index_size, + const std::int64_t data_size, const InType* input, const IndexType* indices, const int* lengths, diff --git a/caffe2/perfkernels/embedding_lookup_avx2.cc b/caffe2/perfkernels/embedding_lookup_avx2.cc index d17600d..271c07a 100644 --- a/caffe2/perfkernels/embedding_lookup_avx2.cc +++ b/caffe2/perfkernels/embedding_lookup_avx2.cc @@ -6,32 +6,28 @@ //// -------------------------- #include -#include #include -#include - namespace caffe2 { template -static void EmbeddingLookup_int32_t_float_float__avx2_fma( +static bool EmbeddingLookup_int32_t_float_float__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const float* input, - const int32_t* indices, + const int* indices, const int* lengths, const float* weights, const float* scale_bias, bool normalize_by_lengths, float* out) { - const int32_t prefdist_T0 = 16; - const int32_t fused_block_size = block_size + 0; - CAFFE_ENFORCE(scale_bias == nullptr, "scale_bias must be nullptr"); + const int prefdist_T0 = 16; + const int fused_block_size = block_size + 0; + int dataInd = 0; if (block_size == 128) { // unrolling 16 times - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); @@ -49,28 +45,28 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma( __m256 vop104 = _mm256_setzero_ps(); __m256 vop112 = _mm256_setzero_ps(); __m256 vop120 = _mm256_setzero_ps(); - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const float* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); _mm_prefetch( @@ -152,8 +148,7 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma( } } else if (block_size == 64) { // unrolling 8 times - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); @@ -163,28 +158,28 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma( __m256 vop40 = _mm256_setzero_ps(); __m256 vop48 = _mm256_setzero_ps(); __m256 vop56 = _mm256_setzero_ps(); - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const float* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); _mm_prefetch( @@ -230,35 +225,34 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma( } } else if (block_size == 32) { // unrolling 4 times - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const float* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); _mm_prefetch( @@ -286,33 +280,32 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma( } } else if (block_size == 16) { // unrolling 2 times - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const float* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); _mm_prefetch( @@ -331,8 +324,7 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma( } } else { // generic code - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; int64_t j = 0; for (; j + 8 <= block_size; j += 8) { @@ -341,28 +333,28 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma( for (; j < block_size; j++) { op[j] = 0.0f; } - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const float* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; j = 0; for (; j + 8 <= block_size; j += 8) { @@ -391,20 +383,21 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma( } } } + return dataInd == index_size; } -void EmbeddingLookup_int32_t_float_float_false__avx2_fma( +bool EmbeddingLookup_int32_t_float_float_false__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const float* input, - const int32_t* indices, + const int* indices, const int* lengths, const float* weights, const float* scale_bias, bool normalize_by_lengths, float* out) { - EmbeddingLookup_int32_t_float_float__avx2_fma( + return EmbeddingLookup_int32_t_float_float__avx2_fma( block_size, output_size, index_size, @@ -417,19 +410,19 @@ void EmbeddingLookup_int32_t_float_float_false__avx2_fma( normalize_by_lengths, out); } -void EmbeddingLookup_int32_t_float_float_true__avx2_fma( +bool EmbeddingLookup_int32_t_float_float_true__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const float* input, - const int32_t* indices, + const int* indices, const int* lengths, const float* weights, const float* scale_bias, bool normalize_by_lengths, float* out) { - EmbeddingLookup_int32_t_float_float__avx2_fma( + return EmbeddingLookup_int32_t_float_float__avx2_fma( block_size, output_size, index_size, @@ -444,7 +437,7 @@ void EmbeddingLookup_int32_t_float_float_true__avx2_fma( } template -static void EmbeddingLookup_int64_t_float_float__avx2_fma( +static bool EmbeddingLookup_int64_t_float_float__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, @@ -458,10 +451,9 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma( float* out) { const int64_t prefdist_T0 = 16; const int64_t fused_block_size = block_size + 0; - CAFFE_ENFORCE(scale_bias == nullptr, "scale_bias must be nullptr"); + int64_t dataInd = 0; if (block_size == 128) { // unrolling 16 times - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); @@ -480,17 +472,15 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma( __m256 vop104 = _mm256_setzero_ps(); __m256 vop112 = _mm256_setzero_ps(); __m256 vop120 = _mm256_setzero_ps(); + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; @@ -501,7 +491,9 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); _mm_prefetch( @@ -583,7 +575,6 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma( } } else if (block_size == 64) { // unrolling 8 times - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); @@ -594,17 +585,15 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma( __m256 vop40 = _mm256_setzero_ps(); __m256 vop48 = _mm256_setzero_ps(); __m256 vop56 = _mm256_setzero_ps(); + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; @@ -615,7 +604,9 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); _mm_prefetch( @@ -661,24 +652,21 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma( } } else if (block_size == 32) { // unrolling 4 times - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; @@ -689,7 +677,9 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); _mm_prefetch( @@ -717,22 +707,19 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma( } } else if (block_size == 16) { // unrolling 2 times - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; @@ -743,7 +730,9 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); _mm_prefetch( @@ -762,7 +751,6 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma( } } else { // generic code - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; int64_t j = 0; @@ -772,17 +760,15 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma( for (; j < block_size; j++) { op[j] = 0.0f; } + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; @@ -793,7 +779,9 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; j = 0; for (; j + 8 <= block_size; j += 8) { @@ -822,8 +810,9 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma( } } } + return dataInd == index_size; } -void EmbeddingLookup_int64_t_float_float_false__avx2_fma( +bool EmbeddingLookup_int64_t_float_float_false__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, @@ -835,7 +824,7 @@ void EmbeddingLookup_int64_t_float_float_false__avx2_fma( const float* scale_bias, bool normalize_by_lengths, float* out) { - EmbeddingLookup_int64_t_float_float__avx2_fma( + return EmbeddingLookup_int64_t_float_float__avx2_fma( block_size, output_size, index_size, @@ -848,7 +837,7 @@ void EmbeddingLookup_int64_t_float_float_false__avx2_fma( normalize_by_lengths, out); } -void EmbeddingLookup_int64_t_float_float_true__avx2_fma( +bool EmbeddingLookup_int64_t_float_float_true__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, @@ -860,7 +849,7 @@ void EmbeddingLookup_int64_t_float_float_true__avx2_fma( const float* scale_bias, bool normalize_by_lengths, float* out) { - EmbeddingLookup_int64_t_float_float__avx2_fma( + return EmbeddingLookup_int64_t_float_float__avx2_fma( block_size, output_size, index_size, @@ -875,25 +864,24 @@ void EmbeddingLookup_int64_t_float_float_true__avx2_fma( } template -static void EmbeddingLookup_int32_t_half_float__avx2_fma( +static bool EmbeddingLookup_int32_t_half_float__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const at::Half* input, - const int32_t* indices, + const int* indices, const int* lengths, const float* weights, const float* scale_bias, bool normalize_by_lengths, float* out) { - const int32_t prefdist_T0 = 16; - const int32_t fused_block_size = block_size + 0; - CAFFE_ENFORCE(scale_bias == nullptr, "scale_bias must be nullptr"); + const int prefdist_T0 = 16; + const int fused_block_size = block_size + 0; + int dataInd = 0; if (block_size == 128) { // unrolling 16 times - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); @@ -911,28 +899,28 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma( __m256 vop104 = _mm256_setzero_ps(); __m256 vop112 = _mm256_setzero_ps(); __m256 vop120 = _mm256_setzero_ps(); - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const at::Half* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -1074,8 +1062,7 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma( } } else if (block_size == 64) { // unrolling 8 times - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); @@ -1085,28 +1072,28 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma( __m256 vop40 = _mm256_setzero_ps(); __m256 vop48 = _mm256_setzero_ps(); __m256 vop56 = _mm256_setzero_ps(); - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const at::Half* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -1182,35 +1169,34 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma( } } else if (block_size == 32) { // unrolling 4 times - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const at::Half* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -1253,33 +1239,32 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma( } } else if (block_size == 16) { // unrolling 2 times - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const at::Half* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -1306,8 +1291,7 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma( } } else { // generic code - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; int64_t j = 0; for (; j + 8 <= block_size; j += 8) { @@ -1316,28 +1300,28 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma( for (; j < block_size; j++) { op[j] = 0.0f; } - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const at::Half* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; j = 0; for (; j + 8 <= block_size; j += 8) { @@ -1372,20 +1356,21 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma( } } } + return dataInd == index_size; } -void EmbeddingLookup_int32_t_half_float_false__avx2_fma( +bool EmbeddingLookup_int32_t_half_float_false__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const at::Half* input, - const int32_t* indices, + const int* indices, const int* lengths, const float* weights, const float* scale_bias, bool normalize_by_lengths, float* out) { - EmbeddingLookup_int32_t_half_float__avx2_fma( + return EmbeddingLookup_int32_t_half_float__avx2_fma( block_size, output_size, index_size, @@ -1398,19 +1383,19 @@ void EmbeddingLookup_int32_t_half_float_false__avx2_fma( normalize_by_lengths, out); } -void EmbeddingLookup_int32_t_half_float_true__avx2_fma( +bool EmbeddingLookup_int32_t_half_float_true__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const at::Half* input, - const int32_t* indices, + const int* indices, const int* lengths, const float* weights, const float* scale_bias, bool normalize_by_lengths, float* out) { - EmbeddingLookup_int32_t_half_float__avx2_fma( + return EmbeddingLookup_int32_t_half_float__avx2_fma( block_size, output_size, index_size, @@ -1425,7 +1410,7 @@ void EmbeddingLookup_int32_t_half_float_true__avx2_fma( } template -static void EmbeddingLookup_int64_t_half_float__avx2_fma( +static bool EmbeddingLookup_int64_t_half_float__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, @@ -1439,10 +1424,9 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma( float* out) { const int64_t prefdist_T0 = 16; const int64_t fused_block_size = block_size + 0; - CAFFE_ENFORCE(scale_bias == nullptr, "scale_bias must be nullptr"); + int64_t dataInd = 0; if (block_size == 128) { // unrolling 16 times - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); @@ -1461,17 +1445,15 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma( __m256 vop104 = _mm256_setzero_ps(); __m256 vop112 = _mm256_setzero_ps(); __m256 vop120 = _mm256_setzero_ps(); + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; @@ -1482,7 +1464,9 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -1624,7 +1608,6 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma( } } else if (block_size == 64) { // unrolling 8 times - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); @@ -1635,17 +1618,15 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma( __m256 vop40 = _mm256_setzero_ps(); __m256 vop48 = _mm256_setzero_ps(); __m256 vop56 = _mm256_setzero_ps(); + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; @@ -1656,7 +1637,9 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -1732,24 +1715,21 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma( } } else if (block_size == 32) { // unrolling 4 times - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; @@ -1760,7 +1740,9 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -1803,22 +1785,19 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma( } } else if (block_size == 16) { // unrolling 2 times - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; @@ -1829,7 +1808,9 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -1856,7 +1837,6 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma( } } else { // generic code - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; int64_t j = 0; @@ -1866,17 +1846,15 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma( for (; j < block_size; j++) { op[j] = 0.0f; } + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; @@ -1887,7 +1865,9 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; j = 0; for (; j + 8 <= block_size; j += 8) { @@ -1922,8 +1902,9 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma( } } } + return dataInd == index_size; } -void EmbeddingLookup_int64_t_half_float_false__avx2_fma( +bool EmbeddingLookup_int64_t_half_float_false__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, @@ -1935,7 +1916,7 @@ void EmbeddingLookup_int64_t_half_float_false__avx2_fma( const float* scale_bias, bool normalize_by_lengths, float* out) { - EmbeddingLookup_int64_t_half_float__avx2_fma( + return EmbeddingLookup_int64_t_half_float__avx2_fma( block_size, output_size, index_size, @@ -1948,7 +1929,7 @@ void EmbeddingLookup_int64_t_half_float_false__avx2_fma( normalize_by_lengths, out); } -void EmbeddingLookup_int64_t_half_float_true__avx2_fma( +bool EmbeddingLookup_int64_t_half_float_true__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, @@ -1960,7 +1941,7 @@ void EmbeddingLookup_int64_t_half_float_true__avx2_fma( const float* scale_bias, bool normalize_by_lengths, float* out) { - EmbeddingLookup_int64_t_half_float__avx2_fma( + return EmbeddingLookup_int64_t_half_float__avx2_fma( block_size, output_size, index_size, @@ -1975,25 +1956,24 @@ void EmbeddingLookup_int64_t_half_float_true__avx2_fma( } template -static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( +static bool EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const uint8_t* input, - const int32_t* indices, + const int* indices, const int* lengths, const float* weights, const float* scale_bias, bool normalize_by_lengths, float* out) { - const int32_t prefdist_T0 = 16; - const int32_t fused_block_size = block_size + 0; - CAFFE_ENFORCE(scale_bias != nullptr, "scale_bias must not be nullptr"); + const int prefdist_T0 = 16; + const int fused_block_size = block_size + 0; + int dataInd = 0; if (block_size == 128) { // unrolling 16 times - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); @@ -2011,17 +1991,15 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( __m256 vop104 = _mm256_setzero_ps(); __m256 vop112 = _mm256_setzero_ps(); __m256 vop120 = _mm256_setzero_ps(); - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; float bio; if (weights) { @@ -2032,11 +2010,13 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); const uint8_t* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -2176,8 +2156,7 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( } } else if (block_size == 64) { // unrolling 8 times - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); @@ -2187,17 +2166,15 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( __m256 vop40 = _mm256_setzero_ps(); __m256 vop48 = _mm256_setzero_ps(); __m256 vop56 = _mm256_setzero_ps(); - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; float bio; if (weights) { @@ -2208,11 +2185,13 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); const uint8_t* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -2287,24 +2266,21 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( } } else if (block_size == 32) { // unrolling 4 times - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; float bio; if (weights) { @@ -2315,11 +2291,13 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); const uint8_t* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -2362,22 +2340,19 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( } } else if (block_size == 16) { // unrolling 2 times - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; float bio; if (weights) { @@ -2388,11 +2363,13 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); const uint8_t* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -2419,8 +2396,7 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( } } else { // generic code - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; int64_t j = 0; for (; j + 8 <= block_size; j += 8) { @@ -2429,33 +2405,32 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( for (; j < block_size; j++) { op[j] = 0.0f; } - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; float bio; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } - assert(scale_bias); bio = wgt * scale_bias[2 * idx + 1]; wgt = wgt * scale_bias[2 * idx]; __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); const uint8_t* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; j = 0; for (; j + 8 <= block_size; j += 8) { @@ -2487,20 +2462,21 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( } } } + return dataInd == index_size; } -void EmbeddingLookup_int32_t_uint8_t_float_false__avx2_fma( +bool EmbeddingLookup_int32_t_uint8_t_float_false__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const uint8_t* input, - const int32_t* indices, + const int* indices, const int* lengths, const float* weights, const float* scale_bias, bool normalize_by_lengths, float* out) { - EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( + return EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( block_size, output_size, index_size, @@ -2513,19 +2489,19 @@ void EmbeddingLookup_int32_t_uint8_t_float_false__avx2_fma( normalize_by_lengths, out); } -void EmbeddingLookup_int32_t_uint8_t_float_true__avx2_fma( +bool EmbeddingLookup_int32_t_uint8_t_float_true__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const uint8_t* input, - const int32_t* indices, + const int* indices, const int* lengths, const float* weights, const float* scale_bias, bool normalize_by_lengths, float* out) { - EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( + return EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( block_size, output_size, index_size, @@ -2540,7 +2516,7 @@ void EmbeddingLookup_int32_t_uint8_t_float_true__avx2_fma( } template -static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( +static bool EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, @@ -2554,10 +2530,9 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( float* out) { const int64_t prefdist_T0 = 16; const int64_t fused_block_size = block_size + 0; - CAFFE_ENFORCE(scale_bias != nullptr, "scale_bias must not be nullptr"); + int64_t dataInd = 0; if (block_size == 128) { // unrolling 16 times - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); @@ -2576,17 +2551,15 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( __m256 vop104 = _mm256_setzero_ps(); __m256 vop112 = _mm256_setzero_ps(); __m256 vop120 = _mm256_setzero_ps(); + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; float bio; if (weights) { @@ -2601,7 +2574,9 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -2741,7 +2716,6 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( } } else if (block_size == 64) { // unrolling 8 times - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); @@ -2752,17 +2726,15 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( __m256 vop40 = _mm256_setzero_ps(); __m256 vop48 = _mm256_setzero_ps(); __m256 vop56 = _mm256_setzero_ps(); + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; float bio; if (weights) { @@ -2777,7 +2749,9 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -2852,24 +2826,21 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( } } else if (block_size == 32) { // unrolling 4 times - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; float bio; if (weights) { @@ -2884,7 +2855,9 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -2927,22 +2900,19 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( } } else if (block_size == 16) { // unrolling 2 times - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; float bio; if (weights) { @@ -2957,7 +2927,9 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -2984,7 +2956,6 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( } } else { // generic code - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; int64_t j = 0; @@ -2994,23 +2965,20 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( for (; j < block_size; j++) { op[j] = 0.0f; } + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; float bio; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } - assert(scale_bias); bio = wgt * scale_bias[2 * idx + 1]; wgt = wgt * scale_bias[2 * idx]; __m256 vbio = _mm256_set1_ps(bio); @@ -3020,7 +2988,9 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; j = 0; for (; j + 8 <= block_size; j += 8) { @@ -3052,8 +3022,9 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( } } } + return dataInd == index_size; } -void EmbeddingLookup_int64_t_uint8_t_float_false__avx2_fma( +bool EmbeddingLookup_int64_t_uint8_t_float_false__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, @@ -3065,7 +3036,7 @@ void EmbeddingLookup_int64_t_uint8_t_float_false__avx2_fma( const float* scale_bias, bool normalize_by_lengths, float* out) { - EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( + return EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( block_size, output_size, index_size, @@ -3078,7 +3049,7 @@ void EmbeddingLookup_int64_t_uint8_t_float_false__avx2_fma( normalize_by_lengths, out); } -void EmbeddingLookup_int64_t_uint8_t_float_true__avx2_fma( +bool EmbeddingLookup_int64_t_uint8_t_float_true__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, @@ -3090,7 +3061,7 @@ void EmbeddingLookup_int64_t_uint8_t_float_true__avx2_fma( const float* scale_bias, bool normalize_by_lengths, float* out) { - EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( + return EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( block_size, output_size, index_size, diff --git a/caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_avx2.cc b/caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_avx2.cc index 0f3be37..12f790d 100644 --- a/caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_avx2.cc +++ b/caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_avx2.cc @@ -6,30 +6,27 @@ //// -------------------------- #include -#include #include -#include - namespace caffe2 { template -static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma( +static bool Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const float* input, - const int32_t* indices, + const int* indices, const int* lengths, const float* weights, bool normalize_by_lengths, float* out) { - const int32_t prefdist_T0 = 16; - const int32_t fused_block_size = block_size + 2; + const int prefdist_T0 = 16; + const int fused_block_size = block_size + 2; + int dataInd = 0; if (block_size == 128) { // unrolling 16 times - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); @@ -47,28 +44,28 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma( __m256 vop104 = _mm256_setzero_ps(); __m256 vop112 = _mm256_setzero_ps(); __m256 vop120 = _mm256_setzero_ps(); - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const float* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); _mm_prefetch( @@ -150,8 +147,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma( } } else if (block_size == 64) { // unrolling 8 times - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); @@ -161,28 +157,28 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma( __m256 vop40 = _mm256_setzero_ps(); __m256 vop48 = _mm256_setzero_ps(); __m256 vop56 = _mm256_setzero_ps(); - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const float* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); _mm_prefetch( @@ -228,35 +224,34 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma( } } else if (block_size == 32) { // unrolling 4 times - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const float* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); _mm_prefetch( @@ -284,33 +279,32 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma( } } else if (block_size == 16) { // unrolling 2 times - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const float* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); _mm_prefetch( @@ -329,8 +323,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma( } } else { // generic code - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; int64_t j = 0; for (; j + 8 <= block_size; j += 8) { @@ -339,28 +332,28 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma( for (; j < block_size; j++) { op[j] = 0.0f; } - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const float* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; j = 0; for (; j + 8 <= block_size; j += 8) { @@ -389,19 +382,20 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma( } } } + return dataInd == index_size; } -void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float_false__avx2_fma( +bool Fused8BitRowwiseEmbeddingLookup_int32_t_float_float_false__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const float* input, - const int32_t* indices, + const int* indices, const int* lengths, const float* weights, bool normalize_by_lengths, float* out) { - Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma( + return Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma( block_size, output_size, index_size, @@ -413,18 +407,18 @@ void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float_false__avx2_fma( normalize_by_lengths, out); } -void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float_true__avx2_fma( +bool Fused8BitRowwiseEmbeddingLookup_int32_t_float_float_true__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const float* input, - const int32_t* indices, + const int* indices, const int* lengths, const float* weights, bool normalize_by_lengths, float* out) { - Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma( + return Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma( block_size, output_size, index_size, @@ -438,7 +432,7 @@ void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float_true__avx2_fma( } template -static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( +static bool Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, @@ -451,9 +445,9 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( float* out) { const int64_t prefdist_T0 = 16; const int64_t fused_block_size = block_size + 2; + int64_t dataInd = 0; if (block_size == 128) { // unrolling 16 times - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); @@ -472,17 +466,15 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( __m256 vop104 = _mm256_setzero_ps(); __m256 vop112 = _mm256_setzero_ps(); __m256 vop120 = _mm256_setzero_ps(); + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; @@ -493,7 +485,9 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); _mm_prefetch( @@ -575,7 +569,6 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( } } else if (block_size == 64) { // unrolling 8 times - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); @@ -586,17 +579,15 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( __m256 vop40 = _mm256_setzero_ps(); __m256 vop48 = _mm256_setzero_ps(); __m256 vop56 = _mm256_setzero_ps(); + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; @@ -607,7 +598,9 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); _mm_prefetch( @@ -653,24 +646,21 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( } } else if (block_size == 32) { // unrolling 4 times - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; @@ -681,7 +671,9 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); _mm_prefetch( @@ -709,22 +701,19 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( } } else if (block_size == 16) { // unrolling 2 times - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; @@ -735,7 +724,9 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); _mm_prefetch( @@ -754,7 +745,6 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( } } else { // generic code - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; int64_t j = 0; @@ -764,17 +754,15 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( for (; j < block_size; j++) { op[j] = 0.0f; } + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; @@ -785,7 +773,9 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; j = 0; for (; j + 8 <= block_size; j += 8) { @@ -814,8 +804,9 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( } } } + return dataInd == index_size; } -void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float_false__avx2_fma( +bool Fused8BitRowwiseEmbeddingLookup_int64_t_float_float_false__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, @@ -826,7 +817,7 @@ void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float_false__avx2_fma( const float* weights, bool normalize_by_lengths, float* out) { - Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( + return Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( block_size, output_size, index_size, @@ -838,7 +829,7 @@ void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float_false__avx2_fma( normalize_by_lengths, out); } -void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float_true__avx2_fma( +bool Fused8BitRowwiseEmbeddingLookup_int64_t_float_float_true__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, @@ -849,7 +840,7 @@ void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float_true__avx2_fma( const float* weights, bool normalize_by_lengths, float* out) { - Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( + return Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( block_size, output_size, index_size, @@ -863,23 +854,23 @@ void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float_true__avx2_fma( } template -static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma( +static bool Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const at::Half* input, - const int32_t* indices, + const int* indices, const int* lengths, const float* weights, bool normalize_by_lengths, float* out) { - const int32_t prefdist_T0 = 16; - const int32_t fused_block_size = block_size + 4; + const int prefdist_T0 = 16; + const int fused_block_size = block_size + 4; + int dataInd = 0; if (block_size == 128) { // unrolling 16 times - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); @@ -897,28 +888,28 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma( __m256 vop104 = _mm256_setzero_ps(); __m256 vop112 = _mm256_setzero_ps(); __m256 vop120 = _mm256_setzero_ps(); - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const at::Half* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -1060,8 +1051,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma( } } else if (block_size == 64) { // unrolling 8 times - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); @@ -1071,28 +1061,28 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma( __m256 vop40 = _mm256_setzero_ps(); __m256 vop48 = _mm256_setzero_ps(); __m256 vop56 = _mm256_setzero_ps(); - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const at::Half* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -1168,35 +1158,34 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma( } } else if (block_size == 32) { // unrolling 4 times - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const at::Half* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -1239,33 +1228,32 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma( } } else if (block_size == 16) { // unrolling 2 times - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const at::Half* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -1292,8 +1280,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma( } } else { // generic code - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; int64_t j = 0; for (; j + 8 <= block_size; j += 8) { @@ -1302,28 +1289,28 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma( for (; j < block_size; j++) { op[j] = 0.0f; } - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const at::Half* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; j = 0; for (; j + 8 <= block_size; j += 8) { @@ -1358,19 +1345,20 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma( } } } + return dataInd == index_size; } -void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float_false__avx2_fma( +bool Fused8BitRowwiseEmbeddingLookup_int32_t_half_float_false__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const at::Half* input, - const int32_t* indices, + const int* indices, const int* lengths, const float* weights, bool normalize_by_lengths, float* out) { - Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma( + return Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma( block_size, output_size, index_size, @@ -1382,18 +1370,18 @@ void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float_false__avx2_fma( normalize_by_lengths, out); } -void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float_true__avx2_fma( +bool Fused8BitRowwiseEmbeddingLookup_int32_t_half_float_true__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const at::Half* input, - const int32_t* indices, + const int* indices, const int* lengths, const float* weights, bool normalize_by_lengths, float* out) { - Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma( + return Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma( block_size, output_size, index_size, @@ -1407,7 +1395,7 @@ void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float_true__avx2_fma( } template -static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( +static bool Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, @@ -1420,9 +1408,9 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( float* out) { const int64_t prefdist_T0 = 16; const int64_t fused_block_size = block_size + 4; + int64_t dataInd = 0; if (block_size == 128) { // unrolling 16 times - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); @@ -1441,17 +1429,15 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( __m256 vop104 = _mm256_setzero_ps(); __m256 vop112 = _mm256_setzero_ps(); __m256 vop120 = _mm256_setzero_ps(); + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; @@ -1462,7 +1448,9 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -1604,7 +1592,6 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( } } else if (block_size == 64) { // unrolling 8 times - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); @@ -1615,17 +1602,15 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( __m256 vop40 = _mm256_setzero_ps(); __m256 vop48 = _mm256_setzero_ps(); __m256 vop56 = _mm256_setzero_ps(); + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; @@ -1636,7 +1621,9 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -1712,24 +1699,21 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( } } else if (block_size == 32) { // unrolling 4 times - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; @@ -1740,7 +1724,9 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -1783,22 +1769,19 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( } } else if (block_size == 16) { // unrolling 2 times - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; @@ -1809,7 +1792,9 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -1836,7 +1821,6 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( } } else { // generic code - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; int64_t j = 0; @@ -1846,17 +1830,15 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( for (; j < block_size; j++) { op[j] = 0.0f; } + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; @@ -1867,7 +1849,9 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; j = 0; for (; j + 8 <= block_size; j += 8) { @@ -1902,8 +1886,9 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( } } } + return dataInd == index_size; } -void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float_false__avx2_fma( +bool Fused8BitRowwiseEmbeddingLookup_int64_t_half_float_false__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, @@ -1914,7 +1899,7 @@ void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float_false__avx2_fma( const float* weights, bool normalize_by_lengths, float* out) { - Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( + return Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( block_size, output_size, index_size, @@ -1926,7 +1911,7 @@ void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float_false__avx2_fma( normalize_by_lengths, out); } -void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float_true__avx2_fma( +bool Fused8BitRowwiseEmbeddingLookup_int64_t_half_float_true__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, @@ -1937,7 +1922,7 @@ void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float_true__avx2_fma( const float* weights, bool normalize_by_lengths, float* out) { - Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( + return Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( block_size, output_size, index_size, @@ -1951,23 +1936,23 @@ void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float_true__avx2_fma( } template -static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( +static bool Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const uint8_t* input, - const int32_t* indices, + const int* indices, const int* lengths, const float* weights, bool normalize_by_lengths, float* out) { - const int32_t prefdist_T0 = 16; - const int32_t fused_block_size = block_size + 8; + const int prefdist_T0 = 16; + const int fused_block_size = block_size + 8; + int dataInd = 0; if (block_size == 128) { // unrolling 16 times - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); @@ -1985,17 +1970,15 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( __m256 vop104 = _mm256_setzero_ps(); __m256 vop112 = _mm256_setzero_ps(); __m256 vop120 = _mm256_setzero_ps(); - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; float bio; if (weights) { @@ -2008,11 +1991,13 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); const uint8_t* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -2152,8 +2137,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( } } else if (block_size == 64) { // unrolling 8 times - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); @@ -2163,17 +2147,15 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( __m256 vop40 = _mm256_setzero_ps(); __m256 vop48 = _mm256_setzero_ps(); __m256 vop56 = _mm256_setzero_ps(); - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; float bio; if (weights) { @@ -2186,11 +2168,13 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); const uint8_t* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -2265,24 +2249,21 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( } } else if (block_size == 32) { // unrolling 4 times - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; float bio; if (weights) { @@ -2295,11 +2276,13 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); const uint8_t* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -2342,22 +2325,19 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( } } else if (block_size == 16) { // unrolling 2 times - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; float bio; if (weights) { @@ -2370,11 +2350,13 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); const uint8_t* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -2401,8 +2383,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( } } else { // generic code - int32_t dataInd = 0; - for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; int64_t j = 0; for (; j + 8 <= block_size; j += 8) { @@ -2411,17 +2392,15 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( for (; j < block_size; j++) { op[j] = 0.0f; } - for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } + for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { - const int32_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; float bio; if (weights) { @@ -2434,11 +2413,13 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); const uint8_t* ip = &input[idx * fused_block_size]; - const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; - const int32_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; j = 0; for (; j + 8 <= block_size; j += 8) { @@ -2470,19 +2451,20 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( } } } + return dataInd == index_size; } -void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float_false__avx2_fma( +bool Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float_false__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const uint8_t* input, - const int32_t* indices, + const int* indices, const int* lengths, const float* weights, bool normalize_by_lengths, float* out) { - Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( + return Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( block_size, output_size, index_size, @@ -2494,18 +2476,18 @@ void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float_false__avx2_fma( normalize_by_lengths, out); } -void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float_true__avx2_fma( +bool Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float_true__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const uint8_t* input, - const int32_t* indices, + const int* indices, const int* lengths, const float* weights, bool normalize_by_lengths, float* out) { - Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( + return Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( block_size, output_size, index_size, @@ -2519,7 +2501,7 @@ void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float_true__avx2_fma( } template -static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( +static bool Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, @@ -2532,9 +2514,9 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( float* out) { const int64_t prefdist_T0 = 16; const int64_t fused_block_size = block_size + 8; + int64_t dataInd = 0; if (block_size == 128) { // unrolling 16 times - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); @@ -2553,17 +2535,15 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( __m256 vop104 = _mm256_setzero_ps(); __m256 vop112 = _mm256_setzero_ps(); __m256 vop120 = _mm256_setzero_ps(); + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; float bio; if (weights) { @@ -2580,7 +2560,9 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -2720,7 +2702,6 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( } } else if (block_size == 64) { // unrolling 8 times - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); @@ -2731,17 +2712,15 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( __m256 vop40 = _mm256_setzero_ps(); __m256 vop48 = _mm256_setzero_ps(); __m256 vop56 = _mm256_setzero_ps(); + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; float bio; if (weights) { @@ -2758,7 +2737,9 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -2833,24 +2814,21 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( } } else if (block_size == 32) { // unrolling 4 times - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; float bio; if (weights) { @@ -2867,7 +2845,9 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -2910,22 +2890,19 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( } } else if (block_size == 16) { // unrolling 2 times - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; float bio; if (weights) { @@ -2942,7 +2919,9 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, @@ -2969,7 +2948,6 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( } } else { // generic code - int64_t dataInd = 0; for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; int64_t j = 0; @@ -2979,17 +2957,15 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( for (; j < block_size; j++) { op[j] = 0.0f; } + if (dataInd + lengths[rangeIndex] > index_size) { + return false; + } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; - CAFFE_ENFORCE( - idx >= 0 && idx < data_size, - "Index ", - dataInd, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); + if (idx < 0 || idx >= data_size) { + return false; + } float wgt = 1.f; float bio; if (weights) { @@ -3006,7 +2982,9 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; - CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; j = 0; for (; j + 8 <= block_size; j += 8) { @@ -3038,8 +3016,9 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( } } } + return dataInd == index_size; } -void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float_false__avx2_fma( +bool Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float_false__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, @@ -3050,7 +3029,7 @@ void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float_false__avx2_fma( const float* weights, bool normalize_by_lengths, float* out) { - Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( + return Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( block_size, output_size, index_size, @@ -3062,7 +3041,7 @@ void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float_false__avx2_fma( normalize_by_lengths, out); } -void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float_true__avx2_fma( +bool Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float_true__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, @@ -3073,7 +3052,7 @@ void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float_true__avx2_fma( const float* weights, bool normalize_by_lengths, float* out) { - Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( + return Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( block_size, output_size, index_size, diff --git a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc index d8f6a43..6b3eee2 100644 --- a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc +++ b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc @@ -9,13 +9,16 @@ namespace caffe2 { -// Base implementation does runtime dispatch for each segment of reduction +/** + * Base implementation does runtime dispatch for each segment of reduction + * @return false if there is an out-of-bound error + */ template < typename IndexType, typename InType, typename OutType, bool IS_WEIGHT_POSITIONAL = false> -static void Fused8BitRowwiseEmbeddingLookupGenericSlow( +static bool Fused8BitRowwiseEmbeddingLookupGenericSlow( const int64_t block_size, const int64_t output_size, const int64_t index_size, @@ -34,18 +37,14 @@ static void Fused8BitRowwiseEmbeddingLookupGenericSlow( for (int m = 0; m < output_size; ++m) { memset(out, 0, sizeof(OutType) * block_size); EigenVectorArrayMap out_vector(out, block_size); + if (current + lengths[m] > index_size) { + return false; + } for (int i = 0; i < lengths[m]; ++i) { - CAFFE_ENFORCE_LT(current, index_size); int64_t idx = indices[current]; - CAFFE_ENFORCE( - 0 <= idx && idx < data_size, - "Index ", - current, - " is out of bounds: ", - idx, - ", range 0 to ", - data_size); - CAFFE_ENFORCE_LT(idx, data_size); + if (idx < 0 || idx >= data_size) { + return false; + } #ifdef __GNUC__ if (current + 1 < index_size) { __builtin_prefetch( @@ -77,92 +76,135 @@ static void Fused8BitRowwiseEmbeddingLookupGenericSlow( } out += block_size; } - CAFFE_ENFORCE_EQ( - current, - index_size, - "Your input seems to be incorrect: the sum of lengths values should be " - "the size of the indices tensor, but it appears not."); + return current == index_size; } // Proxy back to generic implementation -#define FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION( \ - IndexType, InType, OutType) \ - void \ - Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##_false__base( \ - const int64_t block_size, \ - const int64_t output_size, \ - const int64_t index_size, \ - const int64_t data_size, \ - const InType* input, \ - const IndexType* indices, \ - const int* lengths, \ - const float* weights, \ - bool normalize_by_lengths, \ - OutType* out) { \ - Fused8BitRowwiseEmbeddingLookupGenericSlow< \ - IndexType, \ - InType, \ - OutType, \ - false>( \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - lengths, \ - weights, \ - normalize_by_lengths, \ - out); \ - } \ - decltype( \ - Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##_false__base) \ - Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##_false__avx2_fma; \ - template <> \ - void Fused8BitRowwiseEmbeddingLookup( \ - const int64_t block_size, \ - const int64_t output_size, \ - const int64_t index_size, \ - const int64_t data_size, \ - const InType* input, \ - const IndexType* indices, \ - const int* lengths, \ - const float* weights, \ - bool normalize_by_lengths, \ - OutType* out) { \ - const int32_t one = 1; \ - CAFFE_ENFORCE_EQ( \ - reinterpret_cast(&one)[0], \ - 1, \ - "Fused8BitRowwiseEmbeddingLookup is not supported on this platform"); \ - AVX2_FMA_DO( \ - Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##_false, \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - lengths, \ - weights, \ - normalize_by_lengths, \ - out); \ - BASE_DO( \ - Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##_false, \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - lengths, \ - weights, \ - normalize_by_lengths, \ - out); \ +#define FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(IndexType, OutType) \ + bool \ + Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false__base( \ + const int64_t block_size, \ + const int64_t output_size, \ + const int64_t index_size, \ + const int64_t data_size, \ + const uint8_t* input, \ + const IndexType* indices, \ + const int* lengths, \ + const float* weights, \ + bool normalize_by_lengths, \ + OutType* out) { \ + return Fused8BitRowwiseEmbeddingLookupGenericSlow< \ + IndexType, \ + uint8_t, \ + OutType, \ + false>( \ + block_size, \ + output_size, \ + index_size, \ + data_size, \ + input, \ + indices, \ + lengths, \ + weights, \ + normalize_by_lengths, \ + out); \ + } \ + decltype( \ + Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false__base) \ + Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false__avx2_fma; \ + bool Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType( \ + const int64_t block_size, \ + const int64_t output_size, \ + const int64_t index_size, \ + const int64_t data_size, \ + const uint8_t* input, \ + const IndexType* indices, \ + const int* lengths, \ + const float* weights, \ + bool normalize_by_lengths, \ + OutType* out) { \ + const int32_t one = 1; \ + CAFFE_ENFORCE_EQ( \ + reinterpret_cast(&one)[0], \ + 1, \ + "Fused8BitRowwiseEmbeddingLookup is not supported on this platform"); \ + AVX2_FMA_DO( \ + Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false, \ + block_size, \ + output_size, \ + index_size, \ + data_size, \ + input, \ + indices, \ + lengths, \ + weights, \ + normalize_by_lengths, \ + out); \ + BASE_DO( \ + Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false, \ + block_size, \ + output_size, \ + index_size, \ + data_size, \ + input, \ + indices, \ + lengths, \ + weights, \ + normalize_by_lengths, \ + out); \ + } \ + template <> \ + void Fused8BitRowwiseEmbeddingLookup( \ + const int64_t block_size, \ + const int64_t output_size, \ + const int64_t index_size, \ + const int64_t data_size, \ + const uint8_t* input, \ + const IndexType* indices, \ + const int* lengths, \ + const float* weights, \ + bool normalize_by_lengths, \ + OutType* out) { \ + bool success = \ + Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType( \ + block_size, \ + output_size, \ + index_size, \ + data_size, \ + input, \ + indices, \ + lengths, \ + weights, \ + normalize_by_lengths, \ + out); \ + if (success) { \ + return; \ + } \ + int64_t current = 0; \ + for (int m = 0; m < output_size; ++m) { \ + for (int i = 0; i < lengths[m]; ++i) { \ + CAFFE_ENFORCE_LT(current, index_size); \ + IndexType idx = indices[current]; \ + CAFFE_ENFORCE( \ + 0 <= idx && idx < data_size, \ + "Index ", \ + current, \ + " is out of bounds: ", \ + idx, \ + ", range 0 to ", \ + data_size); \ + ++current; \ + } \ + } \ + CAFFE_ENFORCE_EQ( \ + current, \ + index_size, \ + "Your input seems to be incorrect: the sum of lengths values should be " \ + "the size of the indices tensor, but it appears not."); \ } -FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(int32_t, uint8_t, float); -FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(int64_t, uint8_t, float); +FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(int32_t, float); +FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(int64_t, float); #undef FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION diff --git a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h index 85363c6..0d0b67e 100644 --- a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h +++ b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h @@ -1,6 +1,6 @@ #pragma once -#include "caffe2/core/common.h" +#include namespace caffe2 { @@ -42,10 +42,10 @@ template < typename OutType, bool IS_WEIGHT_POSITIONAL = false> void Fused8BitRowwiseEmbeddingLookup( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, + const std::int64_t block_size, + const std::int64_t output_size, + const std::int64_t index_size, + const std::int64_t data_size, const InType* input, const IndexType* indices, const int* lengths, diff --git a/caffe2/perfkernels/hp_emblookup_codegen.py b/caffe2/perfkernels/hp_emblookup_codegen.py index 72b58cc..0af4120 100644 --- a/caffe2/perfkernels/hp_emblookup_codegen.py +++ b/caffe2/perfkernels/hp_emblookup_codegen.py @@ -50,7 +50,6 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused): code = [] code.append(" // unrolling " + str(uf) + " times") - code.append(" " + IndexType + " dataInd = 0;") code.append( " for (" + IndexType @@ -63,14 +62,20 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused): # inner loop code.append( + " if (dataInd + lengths[rangeIndex] > index_size) {\n" + + " return false;\n" + + " }" + ) + code.append( " for (" + IndexType + " start = dataInd; dataInd < start + lengths[rangeIndex];\n ++dataInd) {" # noqa ) code.append(" const " + IndexType + " idx = indices[dataInd];") code.append( - ' CAFFE_ENFORCE(\n idx >= 0 && idx < data_size,\n "Index ",\n dataInd,\n' # noqa - ' " is out of bounds: ",\n idx,\n ", range 0 to ",\n data_size);' # noqa + " if (idx < 0 || idx >= data_size) {\n" + + " return false;\n" + + " }" ) if InType == "uint8_t": @@ -109,11 +114,15 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused): ) ) code.append(" const " + IndexType + " idx_pref_T0 = indices[next_T0];") - code.append(" CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);") + code.append( + " if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {\n" + + " return false;\n" + + " }" + ) code.append( - " const {}* ip_next_T0 = &input[idx_pref_T0" - " * fused_block_size];".format(InType) + " const {}* ip_next_T0 = " + "&input[idx_pref_T0 * fused_block_size];".format(InType) ) for i in range(0, uf): @@ -188,7 +197,6 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused): return code code = [] - code.append(" " + IndexType + " dataInd = 0;") code.append( " for (" + IndexType @@ -207,14 +215,20 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused): # inner loop code.append( + " if (dataInd + lengths[rangeIndex] > index_size) {\n" + + " return false;\n" + + " }" + ) + code.append( " for (" + IndexType + " start = dataInd; dataInd < start + lengths[rangeIndex];\n ++dataInd) {" # noqa ) code.append(" const " + IndexType + " idx = indices[dataInd];") code.append( - ' CAFFE_ENFORCE(\n idx >= 0 && idx < data_size,\n "Index ",\n dataInd,\n' # noqa - + ' " is out of bounds: ",\n idx,\n ", range 0 to ",\n data_size);' # noqa + " if (idx < 0 || idx >= data_size) {\n" + + " return false;\n" + + " }" ) if InType == "uint8_t": @@ -233,7 +247,6 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused): code.append(" bio = wgt * scale_bias[1];") code.append(" wgt = wgt * scale_bias[0];") else: - code.append(" assert(scale_bias);") code.append(" bio = wgt * scale_bias[2 * idx + 1];") code.append(" wgt = wgt * scale_bias[2 * idx];") code.append(" __m256 vbio = _mm256_set1_ps(bio);") @@ -254,11 +267,14 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused): ) ) code.append(" const " + IndexType + " idx_pref_T0 = indices[next_T0];") - code.append(" CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);") code.append( - " const {}* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];".format( - InType - ) + " if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {\n" + + " return false;\n" + + " }" + ) + code.append( + " const {}* ip_next_T0 = " + "&input[idx_pref_T0 * fused_block_size];".format(InType) ) # compute and store main loop @@ -318,11 +334,11 @@ else: filename = "embedding_lookup_avx2.cc" options = [ - ["int32_t", "int32_t", "float", "float", "float", "float"], + ["int32_t", "int", "float", "float", "float", "float"], ["int64_t", "int64_t", "float", "float", "float", "float"], - ["int32_t", "int32_t", "half", "at::Half", "float", "float"], + ["int32_t", "int", "half", "at::Half", "float", "float"], ["int64_t", "int64_t", "half", "at::Half", "float", "float"], - ["int32_t", "int32_t", "uint8_t", "uint8_t", "float", "float"], + ["int32_t", "int", "uint8_t", "uint8_t", "float", "float"], ["int64_t", "int64_t", "uint8_t", "uint8_t", "float", "float"], ] @@ -336,9 +352,7 @@ code.append("//// DO NOT MODIFY!!!") code.append("//// --------------------------\n") code.append("#include ") -code.append("#include ") code.append("#include ") -code.append("#include \n") code.append("namespace caffe2 {\n") for o in options: @@ -350,7 +364,7 @@ for o in options: prefix, IndexTypeName, InTypeName, OutTypeName ) suffix = "__avx2_fma" - fn = "static void " + fn_base + suffix + fn = "static bool " + fn_base + suffix code.append(fn + "(") args = [] @@ -375,19 +389,9 @@ for o in options: code.append( " const {} fused_block_size = block_size + {};".format(IndexType, offset) ) + code.append(" " + IndexType + " dataInd = 0;") # code.append("printf(\"calling " + fn + "\\n\");"); - if not opts.fused: - if InType != "uint8_t": - code.append( - " CAFFE_ENFORCE(scale_bias == nullptr," - ' "scale_bias must be nullptr");' - ) - else: - code.append( - " CAFFE_ENFORCE(scale_bias != nullptr," - ' "scale_bias must not be nullptr");' - ) code.append(" if (block_size == 128) {") code += unroll(16, IndexType, InType, OutType, True, "AVX2", opts.fused) @@ -401,13 +405,14 @@ for o in options: code.append(" // generic code") code += generic(IndexType, InType, OutType, True, "AVX2", opts.fused) code.append(" }") + code.append(" return dataInd == index_size;") code.append("}") for is_weight_positional in ["false", "true"]: - code.append("void " + fn_base + "_" + is_weight_positional + suffix + "(") + code.append("bool " + fn_base + "_" + is_weight_positional + suffix + "(") code += args - code.append(" " + fn_base + suffix + "<" + is_weight_positional + ">(") + code.append(" return " + fn_base + suffix + "<" + is_weight_positional + ">(") code.append(" block_size,") code.append(" output_size,") code.append(" index_size,") diff --git a/caffe2/perfkernels/math_cpu_avx2.cc b/caffe2/perfkernels/math_cpu_avx2.cc index cf99a67..9bc5d17 100644 --- a/caffe2/perfkernels/math_cpu_avx2.cc +++ b/caffe2/perfkernels/math_cpu_avx2.cc @@ -3,7 +3,6 @@ // computation library to different compiler options (-mno-avx2 or -mavx2). #include -#include #include #include diff --git a/caffe2/perfkernels/typed_axpy_avx.cc b/caffe2/perfkernels/typed_axpy_avx.cc index 780598d..f83ba86 100644 --- a/caffe2/perfkernels/typed_axpy_avx.cc +++ b/caffe2/perfkernels/typed_axpy_avx.cc @@ -1,8 +1,6 @@ -#include "caffe2/core/types.h" #include "caffe2/perfkernels/cvtsh_ss_bugfix.h" -#include "caffe2/perfkernels/typed_axpy.h" -#include "caffe2/utils/math.h" +#include #include #include @@ -15,7 +13,7 @@ void TypedAxpyHalffloat__avx_f16c( float* y) { // if x does not start at the 16 byte boundary, we will process the first few. // before we get to a real one. - while (N && (unsigned long)x % 16) { + while ((reinterpret_cast(x) % 16) && N) { *(y++) += _cvtsh_ss((*(x++)).x) * a; --N; } diff --git a/caffe2/perfkernels/typed_axpy_avx2.cc b/caffe2/perfkernels/typed_axpy_avx2.cc index 7187ddd..683a8a5 100644 --- a/caffe2/perfkernels/typed_axpy_avx2.cc +++ b/caffe2/perfkernels/typed_axpy_avx2.cc @@ -13,7 +13,7 @@ void TypedAxpyHalffloat__avx2_fma( float* y) { // if x does not start at the 16 byte boundary, we will process the first few. // before we get to a real one. - while (((unsigned long)x % 16) && N) { + while ((reinterpret_cast(x) % 16) && N) { *(y++) += _cvtsh_ss((*(x++)).x) * a; --N; } @@ -48,8 +48,8 @@ void TypedAxpy_uint8_float__avx2_fma( float* y) { // if x does not start at the 16 byte boundary, we will process the first few. // before we get to a real one. - while (((unsigned long)x % 16) && N) { - *(y++) += (float)(*(x++)) * a; + while ((reinterpret_cast(x) % 16) && N) { + *(y++) += static_cast(*(x++)) * a; --N; } diff --git a/caffe2/python/lengths_reducer_fused_8bit_rowwise_ops_test.py b/caffe2/python/lengths_reducer_fused_8bit_rowwise_ops_test.py index b57b8be..707f40c 100644 --- a/caffe2/python/lengths_reducer_fused_8bit_rowwise_ops_test.py +++ b/caffe2/python/lengths_reducer_fused_8bit_rowwise_ops_test.py @@ -24,19 +24,23 @@ class TestLengthsReducerOpsFused8BitRowwise(hu.HypothesisTestCase): input_data = np.random.rand(batchsize, blocksize).astype(np.float32) if empty_indices: - indices = np.empty(0, dtype=np.int32) + lengths = np.zeros(batchsize, dtype=np.int32) + num_indices = 0 else: - indices = np.random.randint( - low=0, - high=len(input_data), - size=[np.random.randint(len(input_data))], - dtype=np.int32, + num_indices = np.random.randint(len(input_data)) + num_lengths = np.clip(1, num_indices // 2, 10) + lengths = ( + np.ones([num_indices // num_lengths], dtype=np.int32) * num_lengths ) - weights = np.random.uniform(size=[len(indices)]).astype(np.float32) - lengths_split = np.clip(1, len(indices) // 2, 10) - lengths = ( - np.ones([len(indices) // lengths_split], dtype=np.int32) * lengths_split + # readjust num_indices when num_lengths doesn't divide num_indices + num_indices = num_indices // num_lengths * num_lengths + indices = np.random.randint( + low=0, + high=len(input_data), + size=[num_indices], + dtype=np.int32, ) + weights = np.random.uniform(size=[len(indices)]).astype(np.float32) quantized_data = net.FloatToFused8BitRowwiseQuantized( "input_data", "quantized_data" @@ -87,19 +91,22 @@ class TestLengthsReducerOpsFused8BitRowwise(hu.HypothesisTestCase): input_data = np.random.rand(batchsize, blocksize).astype(np.float32) if empty_indices: - indices = np.empty(0, dtype=np.int32) lengths = np.zeros(batchsize, dtype=np.int32) + num_indices = 0 else: - indices = np.random.randint( - low=0, - high=len(input_data), - size=[np.random.randint(len(input_data))], - dtype=np.int32, - ) - lengths_split = np.clip(1, len(indices) // 2, 10) + num_indices = np.random.randint(len(input_data)) + num_lengths = np.clip(1, num_indices // 2, 10) lengths = ( - np.ones([len(indices) // lengths_split], dtype=np.int32) * lengths_split + np.ones([num_indices // num_lengths], dtype=np.int32) * num_lengths ) + # readjust num_indices when num_lengths doesn't divide num_indices + num_indices = num_indices // num_lengths * num_lengths + indices = np.random.randint( + low=0, + high=len(input_data), + size=[num_indices], + dtype=np.int32, + ) print(indices, lengths) quantized_data = net.FloatToFused8BitRowwiseQuantized( -- 2.7.4