From e012b183dd7a9e2f505c821e23238741e262fe23 Mon Sep 17 00:00:00 2001 From: Jongsoo Park Date: Fri, 21 Dec 2018 22:17:35 -0800 Subject: [PATCH] handle empty inputs to SparseLengthsMean correctly (#15389) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15389 SparseLengthsMean was generating uninitialized data for empty inputs (lengths == 0). We should return zeros. The unit tests were also not covering this special case which is fixed by this diff. Reviewed By: salexspb Differential Revision: D13515970 fbshipit-source-id: 3c35265638f64f13f0262cee930c94f8628005da --- caffe2/perfkernels/embedding_lookup_avx2.cc | 96 +++---- .../embedding_lookup_fused_8bit_rowwise_avx2.cc | 96 +++---- caffe2/perfkernels/hp_emblookup_codegen.py | 4 +- .../lengths_reducer_fused_8bit_rowwise_ops_test.py | 132 +++++----- .../operator_test/specialized_segment_ops_test.py | 277 ++++++++++----------- 5 files changed, 301 insertions(+), 304 deletions(-) diff --git a/caffe2/perfkernels/embedding_lookup_avx2.cc b/caffe2/perfkernels/embedding_lookup_avx2.cc index e470779..326818b 100644 --- a/caffe2/perfkernels/embedding_lookup_avx2.cc +++ b/caffe2/perfkernels/embedding_lookup_avx2.cc @@ -105,7 +105,7 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma( vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120); // skip unnecessary prefetch of (&ip_next_T0[120]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); @@ -122,7 +122,7 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma( _mm256_storeu_ps(&op[104], vop104); _mm256_storeu_ps(&op[112], vop112); _mm256_storeu_ps(&op[120], vop120); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -195,7 +195,7 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma( vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); // skip unnecessary prefetch of (&ip_next_T0[56]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); @@ -204,7 +204,7 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma( _mm256_storeu_ps(&op[40], vop40); _mm256_storeu_ps(&op[48], vop48); _mm256_storeu_ps(&op[56], vop56); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -257,12 +257,12 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma( vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -305,10 +305,10 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma( vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -520,7 +520,7 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma( vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120); // skip unnecessary prefetch of (&ip_next_T0[120]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); @@ -537,7 +537,7 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma( _mm256_storeu_ps(&op[104], vop104); _mm256_storeu_ps(&op[112], vop112); _mm256_storeu_ps(&op[120], vop120); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -610,7 +610,7 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma( vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); // skip unnecessary prefetch of (&ip_next_T0[56]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); @@ -619,7 +619,7 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma( _mm256_storeu_ps(&op[40], vop40); _mm256_storeu_ps(&op[48], vop48); _mm256_storeu_ps(&op[56], vop56); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -672,12 +672,12 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma( vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -720,10 +720,10 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma( vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -999,7 +999,7 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma( vop120); // skip unnecessary prefetch of (&ip_next_T0[120]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); @@ -1016,7 +1016,7 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma( _mm256_storeu_ps(&op[104], vop104); _mm256_storeu_ps(&op[112], vop112); _mm256_storeu_ps(&op[120], vop120); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -1121,7 +1121,7 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma( vop56); // skip unnecessary prefetch of (&ip_next_T0[56]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); @@ -1130,7 +1130,7 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma( _mm256_storeu_ps(&op[40], vop40); _mm256_storeu_ps(&op[48], vop48); _mm256_storeu_ps(&op[56], vop56); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -1199,12 +1199,12 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma( vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -1255,10 +1255,10 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma( vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -1540,7 +1540,7 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma( vop120); // skip unnecessary prefetch of (&ip_next_T0[120]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); @@ -1557,7 +1557,7 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma( _mm256_storeu_ps(&op[104], vop104); _mm256_storeu_ps(&op[112], vop112); _mm256_storeu_ps(&op[120], vop120); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -1662,7 +1662,7 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma( vop56); // skip unnecessary prefetch of (&ip_next_T0[56]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); @@ -1671,7 +1671,7 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma( _mm256_storeu_ps(&op[40], vop40); _mm256_storeu_ps(&op[48], vop48); _mm256_storeu_ps(&op[56], vop56); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -1740,12 +1740,12 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma( vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -1796,10 +1796,10 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma( vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -2085,7 +2085,7 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( _mm256_add_ps(vop120, vbio)); // skip unnecessary prefetch of (&ip_next_T0[120]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); @@ -2102,7 +2102,7 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( _mm256_storeu_ps(&op[104], vop104); _mm256_storeu_ps(&op[112], vop112); _mm256_storeu_ps(&op[120], vop120); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -2211,7 +2211,7 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( _mm256_add_ps(vop56, vbio)); // skip unnecessary prefetch of (&ip_next_T0[56]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); @@ -2220,7 +2220,7 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( _mm256_storeu_ps(&op[40], vop40); _mm256_storeu_ps(&op[48], vop48); _mm256_storeu_ps(&op[56], vop56); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -2293,12 +2293,12 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( _mm256_add_ps(vop24, vbio)); // skip unnecessary prefetch of (&ip_next_T0[24]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -2353,10 +2353,10 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( _mm256_add_ps(vop8, vbio)); // skip unnecessary prefetch of (&ip_next_T0[8]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -2644,7 +2644,7 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( _mm256_add_ps(vop120, vbio)); // skip unnecessary prefetch of (&ip_next_T0[120]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); @@ -2661,7 +2661,7 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( _mm256_storeu_ps(&op[104], vop104); _mm256_storeu_ps(&op[112], vop112); _mm256_storeu_ps(&op[120], vop120); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -2770,7 +2770,7 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( _mm256_add_ps(vop56, vbio)); // skip unnecessary prefetch of (&ip_next_T0[56]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); @@ -2779,7 +2779,7 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( _mm256_storeu_ps(&op[40], vop40); _mm256_storeu_ps(&op[48], vop48); _mm256_storeu_ps(&op[56], vop56); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -2852,12 +2852,12 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( _mm256_add_ps(vop24, vbio)); // skip unnecessary prefetch of (&ip_next_T0[24]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -2912,10 +2912,10 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( _mm256_add_ps(vop8, vbio)); // skip unnecessary prefetch of (&ip_next_T0[8]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); diff --git a/caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_avx2.cc b/caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_avx2.cc index 650a2dc..1f4a831 100644 --- a/caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_avx2.cc +++ b/caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_avx2.cc @@ -103,7 +103,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma( vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120); // skip unnecessary prefetch of (&ip_next_T0[120]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); @@ -120,7 +120,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma( _mm256_storeu_ps(&op[104], vop104); _mm256_storeu_ps(&op[112], vop112); _mm256_storeu_ps(&op[120], vop120); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -193,7 +193,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma( vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); // skip unnecessary prefetch of (&ip_next_T0[56]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); @@ -202,7 +202,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma( _mm256_storeu_ps(&op[40], vop40); _mm256_storeu_ps(&op[48], vop48); _mm256_storeu_ps(&op[56], vop56); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -255,12 +255,12 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma( vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -303,10 +303,10 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma( vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -512,7 +512,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120); // skip unnecessary prefetch of (&ip_next_T0[120]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); @@ -529,7 +529,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( _mm256_storeu_ps(&op[104], vop104); _mm256_storeu_ps(&op[112], vop112); _mm256_storeu_ps(&op[120], vop120); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -602,7 +602,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); // skip unnecessary prefetch of (&ip_next_T0[56]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); @@ -611,7 +611,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( _mm256_storeu_ps(&op[40], vop40); _mm256_storeu_ps(&op[48], vop48); _mm256_storeu_ps(&op[56], vop56); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -664,12 +664,12 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -712,10 +712,10 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -985,7 +985,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma( vop120); // skip unnecessary prefetch of (&ip_next_T0[120]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); @@ -1002,7 +1002,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma( _mm256_storeu_ps(&op[104], vop104); _mm256_storeu_ps(&op[112], vop112); _mm256_storeu_ps(&op[120], vop120); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -1107,7 +1107,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma( vop56); // skip unnecessary prefetch of (&ip_next_T0[56]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); @@ -1116,7 +1116,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma( _mm256_storeu_ps(&op[40], vop40); _mm256_storeu_ps(&op[48], vop48); _mm256_storeu_ps(&op[56], vop56); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -1185,12 +1185,12 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma( vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -1241,10 +1241,10 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma( vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -1520,7 +1520,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( vop120); // skip unnecessary prefetch of (&ip_next_T0[120]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); @@ -1537,7 +1537,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( _mm256_storeu_ps(&op[104], vop104); _mm256_storeu_ps(&op[112], vop112); _mm256_storeu_ps(&op[120], vop120); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -1642,7 +1642,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( vop56); // skip unnecessary prefetch of (&ip_next_T0[56]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); @@ -1651,7 +1651,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( _mm256_storeu_ps(&op[40], vop40); _mm256_storeu_ps(&op[48], vop48); _mm256_storeu_ps(&op[56], vop56); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -1720,12 +1720,12 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -1776,10 +1776,10 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma( vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -2061,7 +2061,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( _mm256_add_ps(vop120, vbio)); // skip unnecessary prefetch of (&ip_next_T0[120]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); @@ -2078,7 +2078,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( _mm256_storeu_ps(&op[104], vop104); _mm256_storeu_ps(&op[112], vop112); _mm256_storeu_ps(&op[120], vop120); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -2189,7 +2189,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( _mm256_add_ps(vop56, vbio)); // skip unnecessary prefetch of (&ip_next_T0[56]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); @@ -2198,7 +2198,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( _mm256_storeu_ps(&op[40], vop40); _mm256_storeu_ps(&op[48], vop48); _mm256_storeu_ps(&op[56], vop56); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -2273,12 +2273,12 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( _mm256_add_ps(vop24, vbio)); // skip unnecessary prefetch of (&ip_next_T0[24]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -2335,10 +2335,10 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( _mm256_add_ps(vop8, vbio)); // skip unnecessary prefetch of (&ip_next_T0[8]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -2623,7 +2623,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( _mm256_add_ps(vop120, vbio)); // skip unnecessary prefetch of (&ip_next_T0[120]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); @@ -2640,7 +2640,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( _mm256_storeu_ps(&op[104], vop104); _mm256_storeu_ps(&op[112], vop112); _mm256_storeu_ps(&op[120], vop120); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -2751,7 +2751,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( _mm256_add_ps(vop56, vbio)); // skip unnecessary prefetch of (&ip_next_T0[56]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); @@ -2760,7 +2760,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( _mm256_storeu_ps(&op[40], vop40); _mm256_storeu_ps(&op[48], vop48); _mm256_storeu_ps(&op[56], vop56); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -2835,12 +2835,12 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( _mm256_add_ps(vop24, vbio)); // skip unnecessary prefetch of (&ip_next_T0[24]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); @@ -2897,10 +2897,10 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( _mm256_add_ps(vop8, vbio)); // skip unnecessary prefetch of (&ip_next_T0[8]) } - if (normalize_by_lengths == false) { + if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); - } else if (lengths[rangeIndex]) { + } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); diff --git a/caffe2/perfkernels/hp_emblookup_codegen.py b/caffe2/perfkernels/hp_emblookup_codegen.py index 887a975..c1cbd4f 100644 --- a/caffe2/perfkernels/hp_emblookup_codegen.py +++ b/caffe2/perfkernels/hp_emblookup_codegen.py @@ -122,11 +122,11 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused): code.extend(compute(j, InType, use_weights, isa, prefetch)) code.append(" }") - code.append(" if (normalize_by_lengths == false) {") + code.append(" if (!normalize_by_lengths || lengths[rangeIndex] == 0) {") for i in range(0, uf): j = 8 * i code.append(" _mm256_storeu_ps(&op[" + str(j) + "], vop" + str(j) + ");") - code.append(" } else if (lengths[rangeIndex]) {") + code.append(" } else {") # inv of length code.append(" __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);") for i in range(0, uf): diff --git a/caffe2/python/lengths_reducer_fused_8bit_rowwise_ops_test.py b/caffe2/python/lengths_reducer_fused_8bit_rowwise_ops_test.py index 18d1e15..b57b8be 100644 --- a/caffe2/python/lengths_reducer_fused_8bit_rowwise_ops_test.py +++ b/caffe2/python/lengths_reducer_fused_8bit_rowwise_ops_test.py @@ -1,127 +1,129 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals +from __future__ import absolute_import, division, print_function, unicode_literals -from caffe2.python import core, workspace import caffe2.python.hypothesis_test_util as hu - +import hypothesis.strategies as st import numpy as np +from caffe2.python import core, workspace from hypothesis import given -import hypothesis.strategies as st class TestLengthsReducerOpsFused8BitRowwise(hu.HypothesisTestCase): @given( - input_data=hu.tensor(min_dim=2, max_dim=2), + batchsize=st.integers(1, 20), + blocksize=st.sampled_from([8, 16, 32, 64, 85, 96, 128, 163]), weighted=st.booleans(), - seed=st.integers(0, 2**32 - 1), + seed=st.integers(0, 2 ** 32 - 1), + empty_indices=st.booleans(), ) - def test_sparse_lengths_sum(self, input_data, weighted, seed): + def test_sparse_lengths_sum( + self, batchsize, blocksize, weighted, seed, empty_indices + ): net = core.Net("bench") np.random.seed(seed) - input_data = input_data.astype(np.float32) - indices = np.random.randint( - low=0, - high=len(input_data), - size=[np.random.randint(len(input_data))], - dtype=np.int32 - ) + input_data = np.random.rand(batchsize, blocksize).astype(np.float32) + if empty_indices: + indices = np.empty(0, dtype=np.int32) + else: + indices = np.random.randint( + low=0, + high=len(input_data), + size=[np.random.randint(len(input_data))], + dtype=np.int32, + ) weights = np.random.uniform(size=[len(indices)]).astype(np.float32) lengths_split = np.clip(1, len(indices) // 2, 10) - lengths = np.ones( - [len(indices) // lengths_split], dtype=np.int32 - ) * lengths_split - print(indices, weights, lengths) + lengths = ( + np.ones([len(indices) // lengths_split], dtype=np.int32) * lengths_split + ) quantized_data = net.FloatToFused8BitRowwiseQuantized( - 'input_data', 'quantized_data' + "input_data", "quantized_data" ) dequantized_data = net.Fused8BitRowwiseQuantizedToFloat( - quantized_data, 'dequantized_data' + quantized_data, "dequantized_data" ) if weighted: net.SparseLengthsWeightedSum( - [dequantized_data, 'weights', 'indices', 'lengths'], - 'sum_reference', - engine='fp16' + [dequantized_data, "weights", "indices", "lengths"], + "sum_reference", ) net.SparseLengthsWeightedSumFused8BitRowwise( - [quantized_data, 'weights', 'indices', 'lengths'], - 'sum_quantized' + [quantized_data, "weights", "indices", "lengths"], "sum_quantized" ) else: net.SparseLengthsSum( - [dequantized_data, 'indices', 'lengths'], - 'sum_reference', - engine='fp16' + [dequantized_data, "indices", "lengths"], "sum_reference", ) net.SparseLengthsSumFused8BitRowwise( - [quantized_data, 'indices', 'lengths'], 'sum_quantized' + [quantized_data, "indices", "lengths"], "sum_quantized" ) - workspace.FeedBlob('input_data', input_data) - workspace.FeedBlob('weights', weights) - workspace.FeedBlob('indices', indices) - workspace.FeedBlob('lengths', lengths) + workspace.FeedBlob("input_data", input_data) + workspace.FeedBlob("weights", weights) + workspace.FeedBlob("indices", indices) + workspace.FeedBlob("lengths", lengths) - workspace.GlobalInit(['caffe2', '--caffe2_log_level=0']) + workspace.GlobalInit(["caffe2", "--caffe2_log_level=0"]) workspace.CreateNet(net) workspace.RunNetOnce(net) - sum_reference = workspace.FetchBlob('sum_reference') - sum_quantized = workspace.FetchBlob('sum_quantized') + sum_reference = workspace.FetchBlob("sum_reference") + sum_quantized = workspace.FetchBlob("sum_quantized") np.testing.assert_array_almost_equal(sum_reference, sum_quantized) @given( - input_data=hu.tensor(min_dim=2, max_dim=2), - seed=st.integers(0, 2**32 - 1), + batchsize=st.integers(1, 20), + blocksize=st.sampled_from([8, 16, 32, 64, 85, 96, 128, 163]), + seed=st.integers(0, 2 ** 32 - 1), + empty_indices=st.booleans(), ) - def test_sparse_lengths_mean(self, input_data, seed): + def test_sparse_lengths_mean(self, batchsize, blocksize, seed, empty_indices): net = core.Net("bench") np.random.seed(seed) - input_data = input_data.astype(np.float32) - indices = np.random.randint( - low=0, - high=len(input_data), - size=[np.random.randint(len(input_data))], - dtype=np.int32 - ) - lengths_split = np.clip(1, len(indices) // 2, 10) - lengths = np.ones( - [len(indices) // lengths_split], dtype=np.int32 - ) * lengths_split + input_data = np.random.rand(batchsize, blocksize).astype(np.float32) + if empty_indices: + indices = np.empty(0, dtype=np.int32) + lengths = np.zeros(batchsize, dtype=np.int32) + else: + indices = np.random.randint( + low=0, + high=len(input_data), + size=[np.random.randint(len(input_data))], + dtype=np.int32, + ) + lengths_split = np.clip(1, len(indices) // 2, 10) + lengths = ( + np.ones([len(indices) // lengths_split], dtype=np.int32) * lengths_split + ) print(indices, lengths) quantized_data = net.FloatToFused8BitRowwiseQuantized( - 'input_data', 'quantized_data' + "input_data", "quantized_data" ) dequantized_data = net.Fused8BitRowwiseQuantizedToFloat( - quantized_data, 'dequantized_data' + quantized_data, "dequantized_data" ) net.SparseLengthsMean( - [dequantized_data, 'indices', 'lengths'], - 'mean_reference', - engine='fp16' + [dequantized_data, "indices", "lengths"], "mean_reference" ) net.SparseLengthsMeanFused8BitRowwise( - [quantized_data, 'indices', 'lengths'], 'mean_quantized' + [quantized_data, "indices", "lengths"], "mean_quantized" ) - workspace.FeedBlob('input_data', input_data) - workspace.FeedBlob('indices', indices) - workspace.FeedBlob('lengths', lengths) + workspace.FeedBlob("input_data", input_data) + workspace.FeedBlob("indices", indices) + workspace.FeedBlob("lengths", lengths) - workspace.GlobalInit(['caffe2', '--caffe2_log_level=0']) + workspace.GlobalInit(["caffe2", "--caffe2_log_level=0"]) workspace.CreateNet(net) workspace.RunNetOnce(net) - mean_reference = workspace.FetchBlob('mean_reference') - mean_quantized = workspace.FetchBlob('mean_quantized') + mean_reference = workspace.FetchBlob("mean_reference") + mean_quantized = workspace.FetchBlob("mean_quantized") np.testing.assert_array_almost_equal(mean_reference, mean_quantized) diff --git a/caffe2/python/operator_test/specialized_segment_ops_test.py b/caffe2/python/operator_test/specialized_segment_ops_test.py index cdd82ac..54a840b 100644 --- a/caffe2/python/operator_test/specialized_segment_ops_test.py +++ b/caffe2/python/operator_test/specialized_segment_ops_test.py @@ -2,11 +2,13 @@ from __future__ import absolute_import, division, print_function, unicode_litera import unittest +from caffe2.proto import caffe2_pb2 +from caffe2.python import core +import caffe2.python.hip_test_util as hiputl import caffe2.python.hypothesis_test_util as hu import hypothesis.strategies as st import numpy as np -from caffe2.python import core -from hypothesis import given +from hypothesis import given, assume class TestSpecializedSegmentOps(hu.HypothesisTestCase): @@ -14,18 +16,31 @@ class TestSpecializedSegmentOps(hu.HypothesisTestCase): batchsize=st.integers(1, 20), fptype=st.sampled_from([np.float16, np.float32]), fp16asint=st.booleans(), - blocksize=st.sampled_from([8, 17, 32, 64, 85, 96, 128, 163]), + blocksize=st.sampled_from([8, 16, 32, 64, 85, 96, 128, 163]), normalize_by_lengths=st.booleans(), + empty_indices=st.booleans(), **hu.gcs ) def test_sparse_lengths_sum_cpu( - self, batchsize, fptype, fp16asint, blocksize, normalize_by_lengths, gc, dc + self, + batchsize, + fptype, + fp16asint, + blocksize, + normalize_by_lengths, + empty_indices, + gc, + dc, ): + if fptype != np.float32: + assume(gc.device_type == caffe2_pb2.CPU) + assume(not hiputl.run_in_hip(gc, dc)) + assume(caffe2_pb2.CUDA not in {d.device_type for d in dc}) - if normalize_by_lengths == False: - print("") - else: + if normalize_by_lengths: print("") + else: + print("") tblsize = 300 if fptype == np.float32: @@ -44,44 +59,42 @@ class TestSpecializedSegmentOps(hu.HypothesisTestCase): atol = 1e-1 # array of each row length - Lengths = np.random.randint(1, 30, size=batchsize).astype(np.int32) - # flat indices - Indices = np.random.randint(0, tblsize, size=sum(Lengths)).astype(np.int64) - - if normalize_by_lengths == False: - op = core.CreateOperator( - "SparseLengthsSum", ["Tbl", "Indices", "Lengths"], "out" - ) + if empty_indices: + Lengths = np.zeros(batchsize, dtype=np.int32) else: - op = core.CreateOperator( - "SparseLengthsMean", ["Tbl", "Indices", "Lengths"], "out" - ) + Lengths = np.random.randint(1, 30, size=batchsize, dtype=np.int32) + # flat indices + Indices = np.random.randint(0, tblsize, size=sum(Lengths), dtype=np.int64) - self.ws.create_blob("Tbl").feed(Tbl) - self.ws.create_blob("Indices").feed(Indices) - self.ws.create_blob("Lengths").feed(Lengths) - self.ws.run(op) + op = core.CreateOperator( + "SparseLengths" + ("Mean" if normalize_by_lengths else "Sum"), + ["Tbl", "Indices", "Lengths"], + "out", + ) def sparse_lengths_sum_ref(Tbl, Indices, Lengths): rptr = np.cumsum(np.insert(Lengths, [0], [0])) out = np.zeros((len(Lengths), blocksize)) - if normalize_by_lengths == False: + if normalize_by_lengths: for i in range(0, len(rptr[0:-1])): - out[i] = Tbl[Indices[rptr[i] : rptr[i + 1]]].sum(axis=0) + if Lengths[i] != 0: + out[i] = ( + Tbl[Indices[rptr[i] : rptr[i + 1]]].sum(axis=0) + * 1.0 + / float(Lengths[i]) + ) else: for i in range(0, len(rptr[0:-1])): - out[i] = ( - Tbl[Indices[rptr[i] : rptr[i + 1]]].sum(axis=0) - * 1.0 - / float(Lengths[i]) - ) - - return out - - np.testing.assert_allclose( - self.ws.blobs[("out")].fetch(), - sparse_lengths_sum_ref(Tbl, Indices, Lengths), - rtol=1e-3, + out[i] = Tbl[Indices[rptr[i] : rptr[i + 1]]].sum(axis=0) + + return [out.astype(np.float32)] + + self.assertReferenceChecks( + gc, + op, + [Tbl, Indices, Lengths], + sparse_lengths_sum_ref, + threshold=1e-3, atol=atol, ) @@ -89,12 +102,17 @@ class TestSpecializedSegmentOps(hu.HypothesisTestCase): batchsize=st.integers(1, 20), fptype=st.sampled_from([np.float16, np.float32]), fp16asint=st.booleans(), - blocksize=st.sampled_from([8, 17, 32, 64, 85, 96, 128, 163]), + blocksize=st.sampled_from([8, 16, 32, 64, 85, 96, 128, 163]), + empty_indices=st.booleans(), **hu.gcs ) def test_sparse_lengths_weightedsum_cpu( - self, batchsize, fptype, fp16asint, blocksize, gc, dc + self, batchsize, fptype, fp16asint, blocksize, empty_indices, gc, dc ): + if fptype != np.float32: + assume(gc.device_type == caffe2_pb2.CPU) + assume(not hiputl.run_in_hip(gc, dc)) + assume(caffe2_pb2.CUDA not in {d.device_type for d in dc}) print("") @@ -115,21 +133,18 @@ class TestSpecializedSegmentOps(hu.HypothesisTestCase): atol = 1e-1 # array of each row length - Lengths = np.random.randint(1, 30, size=batchsize).astype(np.int32) + if empty_indices: + Lengths = np.zeros(batchsize, dtype=np.int32) + else: + Lengths = np.random.randint(1, 30, size=batchsize, dtype=np.int32) # flat indices - Indices = np.random.randint(0, tblsize, size=sum(Lengths)).astype(np.int64) + Indices = np.random.randint(0, tblsize, size=sum(Lengths), dtype=np.int64) Weights = np.random.rand(sum(Lengths)).astype(np.float32) op = core.CreateOperator( "SparseLengthsWeightedSum", ["Tbl", "Weights", "Indices", "Lengths"], "out" ) - self.ws.create_blob("Tbl").feed(Tbl) - self.ws.create_blob("Indices").feed(Indices) - self.ws.create_blob("Lengths").feed(Lengths) - self.ws.create_blob("Weights").feed(Weights) - self.ws.run(op) - def sparse_lengths_weightedsum_ref(Tbl, Weights, Indices, Lengths): rptr = np.cumsum(np.insert(Lengths, [0], [0])) out = np.zeros((len(Lengths), blocksize)) @@ -138,12 +153,14 @@ class TestSpecializedSegmentOps(hu.HypothesisTestCase): out[i] = (Tbl[Indices[rptr[i] : rptr[i + 1]]] * w[:, np.newaxis]).sum( axis=0 ) - return out - - np.testing.assert_allclose( - self.ws.blobs[("out")].fetch(), - sparse_lengths_weightedsum_ref(Tbl, Weights, Indices, Lengths), - rtol=1e-3, + return [out.astype(np.float32)] + + self.assertReferenceChecks( + gc, + op, + [Tbl, Weights, Indices, Lengths], + sparse_lengths_weightedsum_ref, + threshold=1e-3, atol=atol, ) @@ -151,51 +168,42 @@ class TestSpecializedSegmentOps(hu.HypothesisTestCase): batchsize=st.integers(1, 20), blocksize=st.sampled_from([8, 16, 17, 26, 32, 64, 85, 96, 128, 148, 163]), normalize_by_lengths=st.booleans(), - **hu.gcs + empty_indices=st.booleans(), + **hu.gcs_cpu_only ) def test_sparse_lengths_weightedsum_8BitsRowwiseOp_cpu( - self, batchsize, blocksize, normalize_by_lengths, gc, dc + self, batchsize, blocksize, normalize_by_lengths, empty_indices, gc, dc ): - - if normalize_by_lengths == False: + if normalize_by_lengths: print( - "" + "" ) else: print( - "" + "" ) tblsize = 300 - Tbl = np.random.randint(7, size=(tblsize, blocksize)).astype(np.uint8) + Tbl = np.random.randint(7, size=(tblsize, blocksize), dtype=np.uint8) atol = 1e-5 # array of each row length - Lengths = np.random.randint(1, 30, size=batchsize).astype(np.int32) + if empty_indices: + Lengths = np.zeros(batchsize, dtype=np.int32) + else: + Lengths = np.random.randint(1, 30, size=batchsize, dtype=np.int32) # flat indices - Indices = np.random.randint(0, tblsize, size=sum(Lengths)).astype(np.int64) + Indices = np.random.randint(0, tblsize, size=sum(Lengths), dtype=np.int64) Weights = np.random.rand(sum(Lengths)).astype(np.float32) Scale_Bias = np.random.rand(tblsize, 2).astype(np.float32) - if normalize_by_lengths == False: - op = core.CreateOperator( - "SparseLengthsWeightedSum8BitsRowwise", - ["Tbl", "Weights", "Indices", "Lengths", "Scale_Bias"], - "out", - ) - else: - op = core.CreateOperator( - "SparseLengthsWeightedMean8BitsRowwise", - ["Tbl", "Weights", "Indices", "Lengths", "Scale_Bias"], - "out", - ) - - self.ws.create_blob("Tbl").feed(Tbl) - self.ws.create_blob("Weights").feed(Weights) - self.ws.create_blob("Indices").feed(Indices) - self.ws.create_blob("Lengths").feed(Lengths) - self.ws.create_blob("Scale_Bias").feed(Scale_Bias) - self.ws.run(op) + op = core.CreateOperator( + "SparseLengthsWeighted" + + ("Mean" if normalize_by_lengths else "Sum") + + "8BitsRowwise", + ["Tbl", "Weights", "Indices", "Lengths", "Scale_Bias"], + "out", + ) def sparse_lengths_weightedsum_8BitsRowwiseOp_cpu_ref( Tbl, Weights, Indices, Lengths, Scale_Bias @@ -207,19 +215,19 @@ class TestSpecializedSegmentOps(hu.HypothesisTestCase): s = Scale_Bias[Indices[rptr[i] : rptr[i + 1]], 0][:, np.newaxis] b = Scale_Bias[Indices[rptr[i] : rptr[i + 1]], 1][:, np.newaxis] f = 1.0 - if normalize_by_lengths == True: + if normalize_by_lengths and Lengths[i] != 0: f = 1.0 / float(Lengths[i]) out[i] = ( w[:, np.newaxis] * (s * Tbl[Indices[rptr[i] : rptr[i + 1]]] + b) ).sum(axis=0) * f - return out - - np.testing.assert_allclose( - self.ws.blobs[("out")].fetch(), - sparse_lengths_weightedsum_8BitsRowwiseOp_cpu_ref( - Tbl, Weights, Indices, Lengths, Scale_Bias - ), - rtol=1e-3, + return [out.astype(np.float32)] + + self.assertReferenceChecks( + gc, + op, + [Tbl, Weights, Indices, Lengths, Scale_Bias], + sparse_lengths_weightedsum_8BitsRowwiseOp_cpu_ref, + threshold=1e-3, atol=atol, ) @@ -227,45 +235,37 @@ class TestSpecializedSegmentOps(hu.HypothesisTestCase): batchsize=st.integers(1, 20), blocksize=st.sampled_from([8, 16, 17, 26, 32, 64, 85, 96, 128, 148, 163]), normalize_by_lengths=st.booleans(), - **hu.gcs + empty_indices=st.booleans(), + **hu.gcs_cpu_only ) def test_sparse_lengths_sum_8BitsRowwiseOp_cpu( - self, batchsize, blocksize, normalize_by_lengths, gc, dc + self, batchsize, blocksize, normalize_by_lengths, empty_indices, gc, dc ): - - if normalize_by_lengths == False: - print("") - else: + if normalize_by_lengths: print("") + else: + print("") tblsize = 300 - Tbl = np.random.randint(7, size=(tblsize, blocksize)).astype(np.uint8) + Tbl = np.random.randint(7, size=(tblsize, blocksize), dtype=np.uint8) atol = 1e-5 # array of each row length - Lengths = np.random.randint(1, 30, size=batchsize).astype(np.int32) + if empty_indices: + Lengths = np.zeros(batchsize, dtype=np.int32) + else: + Lengths = np.random.randint(1, 30, size=batchsize, dtype=np.int32) # flat indices - Indices = np.random.randint(0, tblsize, size=sum(Lengths)).astype(np.int64) + Indices = np.random.randint(0, tblsize, size=sum(Lengths), dtype=np.int64) Scale_Bias = np.random.rand(tblsize, 2).astype(np.float32) - if normalize_by_lengths == False: - op = core.CreateOperator( - "SparseLengthsSum8BitsRowwise", - ["Tbl", "Indices", "Lengths", "Scale_Bias"], - "out", - ) - else: - op = core.CreateOperator( - "SparseLengthsMean8BitsRowwise", - ["Tbl", "Indices", "Lengths", "Scale_Bias"], - "out", - ) - - self.ws.create_blob("Tbl").feed(Tbl) - self.ws.create_blob("Indices").feed(Indices) - self.ws.create_blob("Lengths").feed(Lengths) - self.ws.create_blob("Scale_Bias").feed(Scale_Bias) - self.ws.run(op) + op = core.CreateOperator( + "SparseLengths" + + ("Mean" if normalize_by_lengths else "Sum") + + "8BitsRowwise", + ["Tbl", "Indices", "Lengths", "Scale_Bias"], + "out", + ) def sparse_lengths_sum_8BitsRowwiseOp_cpu_reg( Tbl, Indices, Lengths, Scale_Bias @@ -276,17 +276,17 @@ class TestSpecializedSegmentOps(hu.HypothesisTestCase): s = Scale_Bias[Indices[rptr[i] : rptr[i + 1]], 0][:, np.newaxis] b = Scale_Bias[Indices[rptr[i] : rptr[i + 1]], 1][:, np.newaxis] f = 1.0 - if normalize_by_lengths == True: + if normalize_by_lengths and Lengths[i] != 0: f = 1.0 / float(Lengths[i]) out[i] = (s * Tbl[Indices[rptr[i] : rptr[i + 1]]] + b).sum(axis=0) * f - return out - - np.testing.assert_allclose( - self.ws.blobs[("out")].fetch(), - sparse_lengths_sum_8BitsRowwiseOp_cpu_reg( - Tbl, Indices, Lengths, Scale_Bias - ), - rtol=1e-3, + return [out.astype(np.float32)] + + self.assertReferenceChecks( + gc, + op, + [Tbl, Indices, Lengths, Scale_Bias], + sparse_lengths_sum_8BitsRowwiseOp_cpu_reg, + threshold=1e-3, atol=atol, ) @@ -294,34 +294,29 @@ class TestSpecializedSegmentOps(hu.HypothesisTestCase): batchsize=st.integers(1, 20), blocksize=st.sampled_from([8, 16, 17, 26, 32, 64, 85, 96, 128, 148, 163]), normalize_by_lengths=st.booleans(), - **hu.gcs + **hu.gcs_cpu_only ) def test_sparse_lengths_sum_8BitsRowwiseOp_cpu_invalid_index( self, batchsize, blocksize, normalize_by_lengths, gc, dc ): tblsize = 300 - Tbl = np.random.randint(7, size=(tblsize, blocksize)).astype(np.uint8) + Tbl = np.random.randint(7, size=(tblsize, blocksize), dtype=np.uint8) # array of each row length - Lengths = np.random.randint(1, 30, size=batchsize).astype(np.int32) + Lengths = np.random.randint(1, 30, size=batchsize, dtype=np.int32) # flat indices - Indices = np.random.randint(0, tblsize, size=sum(Lengths)).astype(np.int64) + Indices = np.random.randint(0, tblsize, size=sum(Lengths), dtype=np.int64) Indices[0] += 1000 Scale_Bias = np.random.rand(tblsize, 2).astype(np.float32) - if normalize_by_lengths == False: - op = core.CreateOperator( - "SparseLengthsSum8BitsRowwise", - ["Tbl", "Indices", "Lengths", "Scale_Bias"], - "out", - ) - else: - op = core.CreateOperator( - "SparseLengthsMean8BitsRowwise", - ["Tbl", "Indices", "Lengths", "Scale_Bias"], - "out", - ) + op = core.CreateOperator( + "SparseLengths" + + ("Mean" if normalize_by_lengths else "Sum") + + "8BitsRowwise", + ["Tbl", "Indices", "Lengths", "Scale_Bias"], + "out", + ) self.ws.create_blob("Tbl").feed(Tbl) self.ws.create_blob("Indices").feed(Indices) -- 2.7.4