handle empty inputs to SparseLengthsMean correctly (#15389)
authorJongsoo Park <jongsoo@fb.com>
Sat, 22 Dec 2018 06:17:35 +0000 (22:17 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 22 Dec 2018 06:20:14 +0000 (22:20 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15389

SparseLengthsMean was generating uninitialized data for empty inputs (lengths == 0). We should return zeros.
The unit tests were also not covering this special case which is fixed by this diff.

Reviewed By: salexspb

Differential Revision: D13515970

fbshipit-source-id: 3c35265638f64f13f0262cee930c94f8628005da

caffe2/perfkernels/embedding_lookup_avx2.cc
caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_avx2.cc
caffe2/perfkernels/hp_emblookup_codegen.py
caffe2/python/lengths_reducer_fused_8bit_rowwise_ops_test.py
caffe2/python/operator_test/specialized_segment_ops_test.py

index e470779..326818b 100644 (file)
@@ -105,7 +105,7 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma(
         vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
         // skip unnecessary prefetch of (&ip_next_T0[120])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
@@ -122,7 +122,7 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma(
         _mm256_storeu_ps(&op[104], vop104);
         _mm256_storeu_ps(&op[112], vop112);
         _mm256_storeu_ps(&op[120], vop120);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -195,7 +195,7 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma(
         vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
         // skip unnecessary prefetch of (&ip_next_T0[56])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
@@ -204,7 +204,7 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma(
         _mm256_storeu_ps(&op[40], vop40);
         _mm256_storeu_ps(&op[48], vop48);
         _mm256_storeu_ps(&op[56], vop56);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -257,12 +257,12 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma(
         vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
         // skip unnecessary prefetch of (&ip_next_T0[24])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
         _mm256_storeu_ps(&op[24], vop24);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -305,10 +305,10 @@ static void EmbeddingLookup_int32_t_float_float__avx2_fma(
         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
         // skip unnecessary prefetch of (&ip_next_T0[8])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -520,7 +520,7 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma(
         vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
         // skip unnecessary prefetch of (&ip_next_T0[120])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
@@ -537,7 +537,7 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma(
         _mm256_storeu_ps(&op[104], vop104);
         _mm256_storeu_ps(&op[112], vop112);
         _mm256_storeu_ps(&op[120], vop120);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -610,7 +610,7 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma(
         vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
         // skip unnecessary prefetch of (&ip_next_T0[56])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
@@ -619,7 +619,7 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma(
         _mm256_storeu_ps(&op[40], vop40);
         _mm256_storeu_ps(&op[48], vop48);
         _mm256_storeu_ps(&op[56], vop56);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -672,12 +672,12 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma(
         vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
         // skip unnecessary prefetch of (&ip_next_T0[24])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
         _mm256_storeu_ps(&op[24], vop24);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -720,10 +720,10 @@ static void EmbeddingLookup_int64_t_float_float__avx2_fma(
         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
         // skip unnecessary prefetch of (&ip_next_T0[8])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -999,7 +999,7 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma(
             vop120);
         // skip unnecessary prefetch of (&ip_next_T0[120])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
@@ -1016,7 +1016,7 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma(
         _mm256_storeu_ps(&op[104], vop104);
         _mm256_storeu_ps(&op[112], vop112);
         _mm256_storeu_ps(&op[120], vop120);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -1121,7 +1121,7 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma(
             vop56);
         // skip unnecessary prefetch of (&ip_next_T0[56])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
@@ -1130,7 +1130,7 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma(
         _mm256_storeu_ps(&op[40], vop40);
         _mm256_storeu_ps(&op[48], vop48);
         _mm256_storeu_ps(&op[56], vop56);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -1199,12 +1199,12 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma(
             vop24);
         // skip unnecessary prefetch of (&ip_next_T0[24])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
         _mm256_storeu_ps(&op[24], vop24);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -1255,10 +1255,10 @@ static void EmbeddingLookup_int32_t_half_float__avx2_fma(
             vop8);
         // skip unnecessary prefetch of (&ip_next_T0[8])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -1540,7 +1540,7 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma(
             vop120);
         // skip unnecessary prefetch of (&ip_next_T0[120])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
@@ -1557,7 +1557,7 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma(
         _mm256_storeu_ps(&op[104], vop104);
         _mm256_storeu_ps(&op[112], vop112);
         _mm256_storeu_ps(&op[120], vop120);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -1662,7 +1662,7 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma(
             vop56);
         // skip unnecessary prefetch of (&ip_next_T0[56])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
@@ -1671,7 +1671,7 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma(
         _mm256_storeu_ps(&op[40], vop40);
         _mm256_storeu_ps(&op[48], vop48);
         _mm256_storeu_ps(&op[56], vop56);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -1740,12 +1740,12 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma(
             vop24);
         // skip unnecessary prefetch of (&ip_next_T0[24])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
         _mm256_storeu_ps(&op[24], vop24);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -1796,10 +1796,10 @@ static void EmbeddingLookup_int64_t_half_float__avx2_fma(
             vop8);
         // skip unnecessary prefetch of (&ip_next_T0[8])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -2085,7 +2085,7 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
             _mm256_add_ps(vop120, vbio));
         // skip unnecessary prefetch of (&ip_next_T0[120])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
@@ -2102,7 +2102,7 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
         _mm256_storeu_ps(&op[104], vop104);
         _mm256_storeu_ps(&op[112], vop112);
         _mm256_storeu_ps(&op[120], vop120);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -2211,7 +2211,7 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
             _mm256_add_ps(vop56, vbio));
         // skip unnecessary prefetch of (&ip_next_T0[56])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
@@ -2220,7 +2220,7 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
         _mm256_storeu_ps(&op[40], vop40);
         _mm256_storeu_ps(&op[48], vop48);
         _mm256_storeu_ps(&op[56], vop56);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -2293,12 +2293,12 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
             _mm256_add_ps(vop24, vbio));
         // skip unnecessary prefetch of (&ip_next_T0[24])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
         _mm256_storeu_ps(&op[24], vop24);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -2353,10 +2353,10 @@ static void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
             _mm256_add_ps(vop8, vbio));
         // skip unnecessary prefetch of (&ip_next_T0[8])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -2644,7 +2644,7 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
             _mm256_add_ps(vop120, vbio));
         // skip unnecessary prefetch of (&ip_next_T0[120])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
@@ -2661,7 +2661,7 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
         _mm256_storeu_ps(&op[104], vop104);
         _mm256_storeu_ps(&op[112], vop112);
         _mm256_storeu_ps(&op[120], vop120);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -2770,7 +2770,7 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
             _mm256_add_ps(vop56, vbio));
         // skip unnecessary prefetch of (&ip_next_T0[56])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
@@ -2779,7 +2779,7 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
         _mm256_storeu_ps(&op[40], vop40);
         _mm256_storeu_ps(&op[48], vop48);
         _mm256_storeu_ps(&op[56], vop56);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -2852,12 +2852,12 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
             _mm256_add_ps(vop24, vbio));
         // skip unnecessary prefetch of (&ip_next_T0[24])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
         _mm256_storeu_ps(&op[24], vop24);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -2912,10 +2912,10 @@ static void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
             _mm256_add_ps(vop8, vbio));
         // skip unnecessary prefetch of (&ip_next_T0[8])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
index 650a2dc..1f4a831 100644 (file)
@@ -103,7 +103,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma(
         vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
         // skip unnecessary prefetch of (&ip_next_T0[120])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
@@ -120,7 +120,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma(
         _mm256_storeu_ps(&op[104], vop104);
         _mm256_storeu_ps(&op[112], vop112);
         _mm256_storeu_ps(&op[120], vop120);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -193,7 +193,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma(
         vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
         // skip unnecessary prefetch of (&ip_next_T0[56])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
@@ -202,7 +202,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma(
         _mm256_storeu_ps(&op[40], vop40);
         _mm256_storeu_ps(&op[48], vop48);
         _mm256_storeu_ps(&op[56], vop56);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -255,12 +255,12 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma(
         vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
         // skip unnecessary prefetch of (&ip_next_T0[24])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
         _mm256_storeu_ps(&op[24], vop24);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -303,10 +303,10 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma(
         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
         // skip unnecessary prefetch of (&ip_next_T0[8])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -512,7 +512,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma(
         vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
         // skip unnecessary prefetch of (&ip_next_T0[120])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
@@ -529,7 +529,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma(
         _mm256_storeu_ps(&op[104], vop104);
         _mm256_storeu_ps(&op[112], vop112);
         _mm256_storeu_ps(&op[120], vop120);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -602,7 +602,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma(
         vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
         // skip unnecessary prefetch of (&ip_next_T0[56])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
@@ -611,7 +611,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma(
         _mm256_storeu_ps(&op[40], vop40);
         _mm256_storeu_ps(&op[48], vop48);
         _mm256_storeu_ps(&op[56], vop56);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -664,12 +664,12 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma(
         vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
         // skip unnecessary prefetch of (&ip_next_T0[24])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
         _mm256_storeu_ps(&op[24], vop24);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -712,10 +712,10 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma(
         vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
         // skip unnecessary prefetch of (&ip_next_T0[8])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -985,7 +985,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma(
             vop120);
         // skip unnecessary prefetch of (&ip_next_T0[120])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
@@ -1002,7 +1002,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma(
         _mm256_storeu_ps(&op[104], vop104);
         _mm256_storeu_ps(&op[112], vop112);
         _mm256_storeu_ps(&op[120], vop120);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -1107,7 +1107,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma(
             vop56);
         // skip unnecessary prefetch of (&ip_next_T0[56])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
@@ -1116,7 +1116,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma(
         _mm256_storeu_ps(&op[40], vop40);
         _mm256_storeu_ps(&op[48], vop48);
         _mm256_storeu_ps(&op[56], vop56);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -1185,12 +1185,12 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma(
             vop24);
         // skip unnecessary prefetch of (&ip_next_T0[24])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
         _mm256_storeu_ps(&op[24], vop24);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -1241,10 +1241,10 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma(
             vop8);
         // skip unnecessary prefetch of (&ip_next_T0[8])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -1520,7 +1520,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma(
             vop120);
         // skip unnecessary prefetch of (&ip_next_T0[120])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
@@ -1537,7 +1537,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma(
         _mm256_storeu_ps(&op[104], vop104);
         _mm256_storeu_ps(&op[112], vop112);
         _mm256_storeu_ps(&op[120], vop120);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -1642,7 +1642,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma(
             vop56);
         // skip unnecessary prefetch of (&ip_next_T0[56])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
@@ -1651,7 +1651,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma(
         _mm256_storeu_ps(&op[40], vop40);
         _mm256_storeu_ps(&op[48], vop48);
         _mm256_storeu_ps(&op[56], vop56);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -1720,12 +1720,12 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma(
             vop24);
         // skip unnecessary prefetch of (&ip_next_T0[24])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
         _mm256_storeu_ps(&op[24], vop24);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -1776,10 +1776,10 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma(
             vop8);
         // skip unnecessary prefetch of (&ip_next_T0[8])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -2061,7 +2061,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
             _mm256_add_ps(vop120, vbio));
         // skip unnecessary prefetch of (&ip_next_T0[120])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
@@ -2078,7 +2078,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
         _mm256_storeu_ps(&op[104], vop104);
         _mm256_storeu_ps(&op[112], vop112);
         _mm256_storeu_ps(&op[120], vop120);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -2189,7 +2189,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
             _mm256_add_ps(vop56, vbio));
         // skip unnecessary prefetch of (&ip_next_T0[56])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
@@ -2198,7 +2198,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
         _mm256_storeu_ps(&op[40], vop40);
         _mm256_storeu_ps(&op[48], vop48);
         _mm256_storeu_ps(&op[56], vop56);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -2273,12 +2273,12 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
             _mm256_add_ps(vop24, vbio));
         // skip unnecessary prefetch of (&ip_next_T0[24])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
         _mm256_storeu_ps(&op[24], vop24);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -2335,10 +2335,10 @@ static void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
             _mm256_add_ps(vop8, vbio));
         // skip unnecessary prefetch of (&ip_next_T0[8])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -2623,7 +2623,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
             _mm256_add_ps(vop120, vbio));
         // skip unnecessary prefetch of (&ip_next_T0[120])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
@@ -2640,7 +2640,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
         _mm256_storeu_ps(&op[104], vop104);
         _mm256_storeu_ps(&op[112], vop112);
         _mm256_storeu_ps(&op[120], vop120);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -2751,7 +2751,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
             _mm256_add_ps(vop56, vbio));
         // skip unnecessary prefetch of (&ip_next_T0[56])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
@@ -2760,7 +2760,7 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
         _mm256_storeu_ps(&op[40], vop40);
         _mm256_storeu_ps(&op[48], vop48);
         _mm256_storeu_ps(&op[56], vop56);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -2835,12 +2835,12 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
             _mm256_add_ps(vop24, vbio));
         // skip unnecessary prefetch of (&ip_next_T0[24])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
         _mm256_storeu_ps(&op[16], vop16);
         _mm256_storeu_ps(&op[24], vop24);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
@@ -2897,10 +2897,10 @@ static void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
             _mm256_add_ps(vop8, vbio));
         // skip unnecessary prefetch of (&ip_next_T0[8])
       }
-      if (normalize_by_lengths == false) {
+      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
         _mm256_storeu_ps(&op[0], vop0);
         _mm256_storeu_ps(&op[8], vop8);
-      } else if (lengths[rangeIndex]) {
+      } else {
         __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
         _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
         _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
index 887a975..c1cbd4f 100644 (file)
@@ -122,11 +122,11 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused):
         code.extend(compute(j, InType, use_weights, isa, prefetch))
     code.append("      }")
 
-    code.append("      if (normalize_by_lengths == false) {")
+    code.append("      if (!normalize_by_lengths || lengths[rangeIndex] == 0) {")
     for i in range(0, uf):
         j = 8 * i
         code.append("        _mm256_storeu_ps(&op[" + str(j) + "], vop" + str(j) + ");")
-    code.append("      } else if (lengths[rangeIndex]) {")
+    code.append("      } else {")
     # inv of length
     code.append("        __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);")
     for i in range(0, uf):
index 18d1e15..b57b8be 100644 (file)
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-from __future__ import unicode_literals
+from __future__ import absolute_import, division, print_function, unicode_literals
 
-from caffe2.python import core, workspace
 import caffe2.python.hypothesis_test_util as hu
-
+import hypothesis.strategies as st
 import numpy as np
+from caffe2.python import core, workspace
 from hypothesis import given
-import hypothesis.strategies as st
 
 
 class TestLengthsReducerOpsFused8BitRowwise(hu.HypothesisTestCase):
     @given(
-        input_data=hu.tensor(min_dim=2, max_dim=2),
+        batchsize=st.integers(1, 20),
+        blocksize=st.sampled_from([8, 16, 32, 64, 85, 96, 128, 163]),
         weighted=st.booleans(),
-        seed=st.integers(0, 2**32 - 1),
+        seed=st.integers(0, 2 ** 32 - 1),
+        empty_indices=st.booleans(),
     )
-    def test_sparse_lengths_sum(self, input_data, weighted, seed):
+    def test_sparse_lengths_sum(
+        self, batchsize, blocksize, weighted, seed, empty_indices
+    ):
         net = core.Net("bench")
 
         np.random.seed(seed)
 
-        input_data = input_data.astype(np.float32)
-        indices = np.random.randint(
-            low=0,
-            high=len(input_data),
-            size=[np.random.randint(len(input_data))],
-            dtype=np.int32
-        )
+        input_data = np.random.rand(batchsize, blocksize).astype(np.float32)
+        if empty_indices:
+            indices = np.empty(0, dtype=np.int32)
+        else:
+            indices = np.random.randint(
+                low=0,
+                high=len(input_data),
+                size=[np.random.randint(len(input_data))],
+                dtype=np.int32,
+            )
         weights = np.random.uniform(size=[len(indices)]).astype(np.float32)
         lengths_split = np.clip(1, len(indices) // 2, 10)
-        lengths = np.ones(
-            [len(indices) // lengths_split], dtype=np.int32
-        ) * lengths_split
-        print(indices, weights, lengths)
+        lengths = (
+            np.ones([len(indices) // lengths_split], dtype=np.int32) * lengths_split
+        )
 
         quantized_data = net.FloatToFused8BitRowwiseQuantized(
-            'input_data', 'quantized_data'
+            "input_data", "quantized_data"
         )
         dequantized_data = net.Fused8BitRowwiseQuantizedToFloat(
-            quantized_data, 'dequantized_data'
+            quantized_data, "dequantized_data"
         )
 
         if weighted:
             net.SparseLengthsWeightedSum(
-                [dequantized_data, 'weights', 'indices', 'lengths'],
-                'sum_reference',
-                engine='fp16'
+                [dequantized_data, "weights", "indices", "lengths"],
+                "sum_reference",
             )
             net.SparseLengthsWeightedSumFused8BitRowwise(
-                [quantized_data, 'weights', 'indices', 'lengths'],
-                'sum_quantized'
+                [quantized_data, "weights", "indices", "lengths"], "sum_quantized"
             )
         else:
             net.SparseLengthsSum(
-                [dequantized_data, 'indices', 'lengths'],
-                'sum_reference',
-                engine='fp16'
+                [dequantized_data, "indices", "lengths"], "sum_reference",
             )
             net.SparseLengthsSumFused8BitRowwise(
-                [quantized_data, 'indices', 'lengths'], 'sum_quantized'
+                [quantized_data, "indices", "lengths"], "sum_quantized"
             )
 
-        workspace.FeedBlob('input_data', input_data)
-        workspace.FeedBlob('weights', weights)
-        workspace.FeedBlob('indices', indices)
-        workspace.FeedBlob('lengths', lengths)
+        workspace.FeedBlob("input_data", input_data)
+        workspace.FeedBlob("weights", weights)
+        workspace.FeedBlob("indices", indices)
+        workspace.FeedBlob("lengths", lengths)
 
-        workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
+        workspace.GlobalInit(["caffe2", "--caffe2_log_level=0"])
         workspace.CreateNet(net)
         workspace.RunNetOnce(net)
 
-        sum_reference = workspace.FetchBlob('sum_reference')
-        sum_quantized = workspace.FetchBlob('sum_quantized')
+        sum_reference = workspace.FetchBlob("sum_reference")
+        sum_quantized = workspace.FetchBlob("sum_quantized")
         np.testing.assert_array_almost_equal(sum_reference, sum_quantized)
 
     @given(
-        input_data=hu.tensor(min_dim=2, max_dim=2),
-        seed=st.integers(0, 2**32 - 1),
+        batchsize=st.integers(1, 20),
+        blocksize=st.sampled_from([8, 16, 32, 64, 85, 96, 128, 163]),
+        seed=st.integers(0, 2 ** 32 - 1),
+        empty_indices=st.booleans(),
     )
-    def test_sparse_lengths_mean(self, input_data, seed):
+    def test_sparse_lengths_mean(self, batchsize, blocksize, seed, empty_indices):
         net = core.Net("bench")
 
         np.random.seed(seed)
 
-        input_data = input_data.astype(np.float32)
-        indices = np.random.randint(
-            low=0,
-            high=len(input_data),
-            size=[np.random.randint(len(input_data))],
-            dtype=np.int32
-        )
-        lengths_split = np.clip(1, len(indices) // 2, 10)
-        lengths = np.ones(
-            [len(indices) // lengths_split], dtype=np.int32
-        ) * lengths_split
+        input_data = np.random.rand(batchsize, blocksize).astype(np.float32)
+        if empty_indices:
+            indices = np.empty(0, dtype=np.int32)
+            lengths = np.zeros(batchsize, dtype=np.int32)
+        else:
+            indices = np.random.randint(
+                low=0,
+                high=len(input_data),
+                size=[np.random.randint(len(input_data))],
+                dtype=np.int32,
+            )
+            lengths_split = np.clip(1, len(indices) // 2, 10)
+            lengths = (
+                np.ones([len(indices) // lengths_split], dtype=np.int32) * lengths_split
+            )
         print(indices, lengths)
 
         quantized_data = net.FloatToFused8BitRowwiseQuantized(
-            'input_data', 'quantized_data'
+            "input_data", "quantized_data"
         )
         dequantized_data = net.Fused8BitRowwiseQuantizedToFloat(
-            quantized_data, 'dequantized_data'
+            quantized_data, "dequantized_data"
         )
 
         net.SparseLengthsMean(
-            [dequantized_data, 'indices', 'lengths'],
-            'mean_reference',
-            engine='fp16'
+            [dequantized_data, "indices", "lengths"], "mean_reference"
         )
         net.SparseLengthsMeanFused8BitRowwise(
-            [quantized_data, 'indices', 'lengths'], 'mean_quantized'
+            [quantized_data, "indices", "lengths"], "mean_quantized"
         )
 
-        workspace.FeedBlob('input_data', input_data)
-        workspace.FeedBlob('indices', indices)
-        workspace.FeedBlob('lengths', lengths)
+        workspace.FeedBlob("input_data", input_data)
+        workspace.FeedBlob("indices", indices)
+        workspace.FeedBlob("lengths", lengths)
 
-        workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
+        workspace.GlobalInit(["caffe2", "--caffe2_log_level=0"])
         workspace.CreateNet(net)
         workspace.RunNetOnce(net)
 
-        mean_reference = workspace.FetchBlob('mean_reference')
-        mean_quantized = workspace.FetchBlob('mean_quantized')
+        mean_reference = workspace.FetchBlob("mean_reference")
+        mean_quantized = workspace.FetchBlob("mean_quantized")
         np.testing.assert_array_almost_equal(mean_reference, mean_quantized)
index cdd82ac..54a840b 100644 (file)
@@ -2,11 +2,13 @@ from __future__ import absolute_import, division, print_function, unicode_litera
 
 import unittest
 
+from caffe2.proto import caffe2_pb2
+from caffe2.python import core
+import caffe2.python.hip_test_util as hiputl
 import caffe2.python.hypothesis_test_util as hu
 import hypothesis.strategies as st
 import numpy as np
-from caffe2.python import core
-from hypothesis import given
+from hypothesis import given, assume
 
 
 class TestSpecializedSegmentOps(hu.HypothesisTestCase):
@@ -14,18 +16,31 @@ class TestSpecializedSegmentOps(hu.HypothesisTestCase):
         batchsize=st.integers(1, 20),
         fptype=st.sampled_from([np.float16, np.float32]),
         fp16asint=st.booleans(),
-        blocksize=st.sampled_from([8, 17, 32, 64, 85, 96, 128, 163]),
+        blocksize=st.sampled_from([8, 16, 32, 64, 85, 96, 128, 163]),
         normalize_by_lengths=st.booleans(),
+        empty_indices=st.booleans(),
         **hu.gcs
     )
     def test_sparse_lengths_sum_cpu(
-        self, batchsize, fptype, fp16asint, blocksize, normalize_by_lengths, gc, dc
+        self,
+        batchsize,
+        fptype,
+        fp16asint,
+        blocksize,
+        normalize_by_lengths,
+        empty_indices,
+        gc,
+        dc,
     ):
+        if fptype != np.float32:
+            assume(gc.device_type == caffe2_pb2.CPU)
+            assume(not hiputl.run_in_hip(gc, dc))
+            assume(caffe2_pb2.CUDA not in {d.device_type for d in dc})
 
-        if normalize_by_lengths == False:
-            print("<test_sparse_lengths_sum_cpu>")
-        else:
+        if normalize_by_lengths:
             print("<test_sparse_lengths_sum_mean_cpu>")
+        else:
+            print("<test_sparse_lengths_sum_cpu>")
 
         tblsize = 300
         if fptype == np.float32:
@@ -44,44 +59,42 @@ class TestSpecializedSegmentOps(hu.HypothesisTestCase):
                 atol = 1e-1
 
         # array of each row length
-        Lengths = np.random.randint(1, 30, size=batchsize).astype(np.int32)
-        # flat indices
-        Indices = np.random.randint(0, tblsize, size=sum(Lengths)).astype(np.int64)
-
-        if normalize_by_lengths == False:
-            op = core.CreateOperator(
-                "SparseLengthsSum", ["Tbl", "Indices", "Lengths"], "out"
-            )
+        if empty_indices:
+            Lengths = np.zeros(batchsize, dtype=np.int32)
         else:
-            op = core.CreateOperator(
-                "SparseLengthsMean", ["Tbl", "Indices", "Lengths"], "out"
-            )
+            Lengths = np.random.randint(1, 30, size=batchsize, dtype=np.int32)
+        # flat indices
+        Indices = np.random.randint(0, tblsize, size=sum(Lengths), dtype=np.int64)
 
-        self.ws.create_blob("Tbl").feed(Tbl)
-        self.ws.create_blob("Indices").feed(Indices)
-        self.ws.create_blob("Lengths").feed(Lengths)
-        self.ws.run(op)
+        op = core.CreateOperator(
+            "SparseLengths" + ("Mean" if normalize_by_lengths else "Sum"),
+            ["Tbl", "Indices", "Lengths"],
+            "out",
+        )
 
         def sparse_lengths_sum_ref(Tbl, Indices, Lengths):
             rptr = np.cumsum(np.insert(Lengths, [0], [0]))
             out = np.zeros((len(Lengths), blocksize))
-            if normalize_by_lengths == False:
+            if normalize_by_lengths:
                 for i in range(0, len(rptr[0:-1])):
-                    out[i] = Tbl[Indices[rptr[i] : rptr[i + 1]]].sum(axis=0)
+                    if Lengths[i] != 0:
+                        out[i] = (
+                            Tbl[Indices[rptr[i] : rptr[i + 1]]].sum(axis=0)
+                            * 1.0
+                            / float(Lengths[i])
+                        )
             else:
                 for i in range(0, len(rptr[0:-1])):
-                    out[i] = (
-                        Tbl[Indices[rptr[i] : rptr[i + 1]]].sum(axis=0)
-                        * 1.0
-                        / float(Lengths[i])
-                    )
-
-            return out
-
-        np.testing.assert_allclose(
-            self.ws.blobs[("out")].fetch(),
-            sparse_lengths_sum_ref(Tbl, Indices, Lengths),
-            rtol=1e-3,
+                    out[i] = Tbl[Indices[rptr[i] : rptr[i + 1]]].sum(axis=0)
+
+            return [out.astype(np.float32)]
+
+        self.assertReferenceChecks(
+            gc,
+            op,
+            [Tbl, Indices, Lengths],
+            sparse_lengths_sum_ref,
+            threshold=1e-3,
             atol=atol,
         )
 
@@ -89,12 +102,17 @@ class TestSpecializedSegmentOps(hu.HypothesisTestCase):
         batchsize=st.integers(1, 20),
         fptype=st.sampled_from([np.float16, np.float32]),
         fp16asint=st.booleans(),
-        blocksize=st.sampled_from([8, 17, 32, 64, 85, 96, 128, 163]),
+        blocksize=st.sampled_from([8, 16, 32, 64, 85, 96, 128, 163]),
+        empty_indices=st.booleans(),
         **hu.gcs
     )
     def test_sparse_lengths_weightedsum_cpu(
-        self, batchsize, fptype, fp16asint, blocksize, gc, dc
+        self, batchsize, fptype, fp16asint, blocksize, empty_indices, gc, dc
     ):
+        if fptype != np.float32:
+            assume(gc.device_type == caffe2_pb2.CPU)
+            assume(not hiputl.run_in_hip(gc, dc))
+            assume(caffe2_pb2.CUDA not in {d.device_type for d in dc})
 
         print("<test_sparse_lengths_weightedsum_cpu>")
 
@@ -115,21 +133,18 @@ class TestSpecializedSegmentOps(hu.HypothesisTestCase):
                 atol = 1e-1
 
         # array of each row length
-        Lengths = np.random.randint(1, 30, size=batchsize).astype(np.int32)
+        if empty_indices:
+            Lengths = np.zeros(batchsize, dtype=np.int32)
+        else:
+            Lengths = np.random.randint(1, 30, size=batchsize, dtype=np.int32)
         # flat indices
-        Indices = np.random.randint(0, tblsize, size=sum(Lengths)).astype(np.int64)
+        Indices = np.random.randint(0, tblsize, size=sum(Lengths), dtype=np.int64)
         Weights = np.random.rand(sum(Lengths)).astype(np.float32)
 
         op = core.CreateOperator(
             "SparseLengthsWeightedSum", ["Tbl", "Weights", "Indices", "Lengths"], "out"
         )
 
-        self.ws.create_blob("Tbl").feed(Tbl)
-        self.ws.create_blob("Indices").feed(Indices)
-        self.ws.create_blob("Lengths").feed(Lengths)
-        self.ws.create_blob("Weights").feed(Weights)
-        self.ws.run(op)
-
         def sparse_lengths_weightedsum_ref(Tbl, Weights, Indices, Lengths):
             rptr = np.cumsum(np.insert(Lengths, [0], [0]))
             out = np.zeros((len(Lengths), blocksize))
@@ -138,12 +153,14 @@ class TestSpecializedSegmentOps(hu.HypothesisTestCase):
                 out[i] = (Tbl[Indices[rptr[i] : rptr[i + 1]]] * w[:, np.newaxis]).sum(
                     axis=0
                 )
-            return out
-
-        np.testing.assert_allclose(
-            self.ws.blobs[("out")].fetch(),
-            sparse_lengths_weightedsum_ref(Tbl, Weights, Indices, Lengths),
-            rtol=1e-3,
+            return [out.astype(np.float32)]
+
+        self.assertReferenceChecks(
+            gc,
+            op,
+            [Tbl, Weights, Indices, Lengths],
+            sparse_lengths_weightedsum_ref,
+            threshold=1e-3,
             atol=atol,
         )
 
@@ -151,51 +168,42 @@ class TestSpecializedSegmentOps(hu.HypothesisTestCase):
         batchsize=st.integers(1, 20),
         blocksize=st.sampled_from([8, 16, 17, 26, 32, 64, 85, 96, 128, 148, 163]),
         normalize_by_lengths=st.booleans(),
-        **hu.gcs
+        empty_indices=st.booleans(),
+        **hu.gcs_cpu_only
     )
     def test_sparse_lengths_weightedsum_8BitsRowwiseOp_cpu(
-        self, batchsize, blocksize, normalize_by_lengths, gc, dc
+        self, batchsize, blocksize, normalize_by_lengths, empty_indices, gc, dc
     ):
-
-        if normalize_by_lengths == False:
+        if normalize_by_lengths:
             print(
-                "<test_sparse_lengths_weightedsum_SparseLengthsWeightedSum8BitsRowwise_cpu>"
+                "<test_sparse_lengths_weightedsum_SparseLengthsWeightedMean8BitsRowwise_cpu>"
             )
         else:
             print(
-                "<test_sparse_lengths_weightedsum_SparseLengthsWeightedMean8BitsRowwise_cpu>"
+                "<test_sparse_lengths_weightedsum_SparseLengthsWeightedSum8BitsRowwise_cpu>"
             )
 
         tblsize = 300
-        Tbl = np.random.randint(7, size=(tblsize, blocksize)).astype(np.uint8)
+        Tbl = np.random.randint(7, size=(tblsize, blocksize), dtype=np.uint8)
         atol = 1e-5
 
         # array of each row length
-        Lengths = np.random.randint(1, 30, size=batchsize).astype(np.int32)
+        if empty_indices:
+            Lengths = np.zeros(batchsize, dtype=np.int32)
+        else:
+            Lengths = np.random.randint(1, 30, size=batchsize, dtype=np.int32)
         # flat indices
-        Indices = np.random.randint(0, tblsize, size=sum(Lengths)).astype(np.int64)
+        Indices = np.random.randint(0, tblsize, size=sum(Lengths), dtype=np.int64)
         Weights = np.random.rand(sum(Lengths)).astype(np.float32)
         Scale_Bias = np.random.rand(tblsize, 2).astype(np.float32)
 
-        if normalize_by_lengths == False:
-            op = core.CreateOperator(
-                "SparseLengthsWeightedSum8BitsRowwise",
-                ["Tbl", "Weights", "Indices", "Lengths", "Scale_Bias"],
-                "out",
-            )
-        else:
-            op = core.CreateOperator(
-                "SparseLengthsWeightedMean8BitsRowwise",
-                ["Tbl", "Weights", "Indices", "Lengths", "Scale_Bias"],
-                "out",
-            )
-
-        self.ws.create_blob("Tbl").feed(Tbl)
-        self.ws.create_blob("Weights").feed(Weights)
-        self.ws.create_blob("Indices").feed(Indices)
-        self.ws.create_blob("Lengths").feed(Lengths)
-        self.ws.create_blob("Scale_Bias").feed(Scale_Bias)
-        self.ws.run(op)
+        op = core.CreateOperator(
+            "SparseLengthsWeighted"
+            + ("Mean" if normalize_by_lengths else "Sum")
+            + "8BitsRowwise",
+            ["Tbl", "Weights", "Indices", "Lengths", "Scale_Bias"],
+            "out",
+        )
 
         def sparse_lengths_weightedsum_8BitsRowwiseOp_cpu_ref(
             Tbl, Weights, Indices, Lengths, Scale_Bias
@@ -207,19 +215,19 @@ class TestSpecializedSegmentOps(hu.HypothesisTestCase):
                 s = Scale_Bias[Indices[rptr[i] : rptr[i + 1]], 0][:, np.newaxis]
                 b = Scale_Bias[Indices[rptr[i] : rptr[i + 1]], 1][:, np.newaxis]
                 f = 1.0
-                if normalize_by_lengths == True:
+                if normalize_by_lengths and Lengths[i] != 0:
                     f = 1.0 / float(Lengths[i])
                 out[i] = (
                     w[:, np.newaxis] * (s * Tbl[Indices[rptr[i] : rptr[i + 1]]] + b)
                 ).sum(axis=0) * f
-            return out
-
-        np.testing.assert_allclose(
-            self.ws.blobs[("out")].fetch(),
-            sparse_lengths_weightedsum_8BitsRowwiseOp_cpu_ref(
-                Tbl, Weights, Indices, Lengths, Scale_Bias
-            ),
-            rtol=1e-3,
+            return [out.astype(np.float32)]
+
+        self.assertReferenceChecks(
+            gc,
+            op,
+            [Tbl, Weights, Indices, Lengths, Scale_Bias],
+            sparse_lengths_weightedsum_8BitsRowwiseOp_cpu_ref,
+            threshold=1e-3,
             atol=atol,
         )
 
@@ -227,45 +235,37 @@ class TestSpecializedSegmentOps(hu.HypothesisTestCase):
         batchsize=st.integers(1, 20),
         blocksize=st.sampled_from([8, 16, 17, 26, 32, 64, 85, 96, 128, 148, 163]),
         normalize_by_lengths=st.booleans(),
-        **hu.gcs
+        empty_indices=st.booleans(),
+        **hu.gcs_cpu_only
     )
     def test_sparse_lengths_sum_8BitsRowwiseOp_cpu(
-        self, batchsize, blocksize, normalize_by_lengths, gc, dc
+        self, batchsize, blocksize, normalize_by_lengths, empty_indices, gc, dc
     ):
-
-        if normalize_by_lengths == False:
-            print("<test_sparse_lengths_sum_SparseLengthsSum8BitsRowwise_cpu>")
-        else:
+        if normalize_by_lengths:
             print("<test_sparse_lengths_sum_SparseLengthsMean8BitsRowwise_cpu>")
+        else:
+            print("<test_sparse_lengths_sum_SparseLengthsSum8BitsRowwise_cpu>")
 
         tblsize = 300
-        Tbl = np.random.randint(7, size=(tblsize, blocksize)).astype(np.uint8)
+        Tbl = np.random.randint(7, size=(tblsize, blocksize), dtype=np.uint8)
         atol = 1e-5
 
         # array of each row length
-        Lengths = np.random.randint(1, 30, size=batchsize).astype(np.int32)
+        if empty_indices:
+            Lengths = np.zeros(batchsize, dtype=np.int32)
+        else:
+            Lengths = np.random.randint(1, 30, size=batchsize, dtype=np.int32)
         # flat indices
-        Indices = np.random.randint(0, tblsize, size=sum(Lengths)).astype(np.int64)
+        Indices = np.random.randint(0, tblsize, size=sum(Lengths), dtype=np.int64)
         Scale_Bias = np.random.rand(tblsize, 2).astype(np.float32)
 
-        if normalize_by_lengths == False:
-            op = core.CreateOperator(
-                "SparseLengthsSum8BitsRowwise",
-                ["Tbl", "Indices", "Lengths", "Scale_Bias"],
-                "out",
-            )
-        else:
-            op = core.CreateOperator(
-                "SparseLengthsMean8BitsRowwise",
-                ["Tbl", "Indices", "Lengths", "Scale_Bias"],
-                "out",
-            )
-
-        self.ws.create_blob("Tbl").feed(Tbl)
-        self.ws.create_blob("Indices").feed(Indices)
-        self.ws.create_blob("Lengths").feed(Lengths)
-        self.ws.create_blob("Scale_Bias").feed(Scale_Bias)
-        self.ws.run(op)
+        op = core.CreateOperator(
+            "SparseLengths"
+            + ("Mean" if normalize_by_lengths else "Sum")
+            + "8BitsRowwise",
+            ["Tbl", "Indices", "Lengths", "Scale_Bias"],
+            "out",
+        )
 
         def sparse_lengths_sum_8BitsRowwiseOp_cpu_reg(
             Tbl, Indices, Lengths, Scale_Bias
@@ -276,17 +276,17 @@ class TestSpecializedSegmentOps(hu.HypothesisTestCase):
                 s = Scale_Bias[Indices[rptr[i] : rptr[i + 1]], 0][:, np.newaxis]
                 b = Scale_Bias[Indices[rptr[i] : rptr[i + 1]], 1][:, np.newaxis]
                 f = 1.0
-                if normalize_by_lengths == True:
+                if normalize_by_lengths and Lengths[i] != 0:
                     f = 1.0 / float(Lengths[i])
                 out[i] = (s * Tbl[Indices[rptr[i] : rptr[i + 1]]] + b).sum(axis=0) * f
-            return out
-
-        np.testing.assert_allclose(
-            self.ws.blobs[("out")].fetch(),
-            sparse_lengths_sum_8BitsRowwiseOp_cpu_reg(
-                Tbl, Indices, Lengths, Scale_Bias
-            ),
-            rtol=1e-3,
+            return [out.astype(np.float32)]
+
+        self.assertReferenceChecks(
+            gc,
+            op,
+            [Tbl, Indices, Lengths, Scale_Bias],
+            sparse_lengths_sum_8BitsRowwiseOp_cpu_reg,
+            threshold=1e-3,
             atol=atol,
         )
 
@@ -294,34 +294,29 @@ class TestSpecializedSegmentOps(hu.HypothesisTestCase):
         batchsize=st.integers(1, 20),
         blocksize=st.sampled_from([8, 16, 17, 26, 32, 64, 85, 96, 128, 148, 163]),
         normalize_by_lengths=st.booleans(),
-        **hu.gcs
+        **hu.gcs_cpu_only
     )
     def test_sparse_lengths_sum_8BitsRowwiseOp_cpu_invalid_index(
         self, batchsize, blocksize, normalize_by_lengths, gc, dc
     ):
 
         tblsize = 300
-        Tbl = np.random.randint(7, size=(tblsize, blocksize)).astype(np.uint8)
+        Tbl = np.random.randint(7, size=(tblsize, blocksize), dtype=np.uint8)
 
         # array of each row length
-        Lengths = np.random.randint(1, 30, size=batchsize).astype(np.int32)
+        Lengths = np.random.randint(1, 30, size=batchsize, dtype=np.int32)
         # flat indices
-        Indices = np.random.randint(0, tblsize, size=sum(Lengths)).astype(np.int64)
+        Indices = np.random.randint(0, tblsize, size=sum(Lengths), dtype=np.int64)
         Indices[0] += 1000
         Scale_Bias = np.random.rand(tblsize, 2).astype(np.float32)
 
-        if normalize_by_lengths == False:
-            op = core.CreateOperator(
-                "SparseLengthsSum8BitsRowwise",
-                ["Tbl", "Indices", "Lengths", "Scale_Bias"],
-                "out",
-            )
-        else:
-            op = core.CreateOperator(
-                "SparseLengthsMean8BitsRowwise",
-                ["Tbl", "Indices", "Lengths", "Scale_Bias"],
-                "out",
-            )
+        op = core.CreateOperator(
+            "SparseLengths"
+            + ("Mean" if normalize_by_lengths else "Sum")
+            + "8BitsRowwise",
+            ["Tbl", "Indices", "Lengths", "Scale_Bias"],
+            "out",
+        )
 
         self.ws.create_blob("Tbl").feed(Tbl)
         self.ws.create_blob("Indices").feed(Indices)