[blas/neon] isamax improvement for larger input length
authorDebadri Samaddar <s.debadri@samsung.com>
Wed, 22 May 2024 07:22:02 +0000 (12:52 +0530)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 4 Jun 2024 09:57:23 +0000 (18:57 +0900)
Used uint32_t operations to process indices larger than 65535.
Added unittest of shape(1,1,768,768) for max_abs which calls isamax

Signed-off-by: Debadri Samaddar <s.debadri@samsung.com>
nntrainer/tensor/blas_neon.cpp
test/unittest/unittest_nntrainer_tensor_neon_fp16.cpp

index 6f02978e1f4a604a624555eb816a61d39d28f933..32c8eb9349f9b4a54a70df80f4b7d90872129369 100644 (file)
@@ -1508,7 +1508,8 @@ unsigned int isamax(const unsigned int N, const __fp16 *X) {
   unsigned int idx = 8;
 
   // processing batch of 8
-  for (; (N - idx) >= 8; idx += 8) {
+  // stop before idx reaches UNIT16_MAX
+  for (; ((N - idx) >= 8) && ((idx + 8) <= UINT16_MAX); idx += 8) {
     float16x8_t values = vld1q_f16(&X[idx]);
     curr_index = vaddq_u16(curr_index, stride);
 
@@ -1527,11 +1528,50 @@ unsigned int isamax(const unsigned int N, const __fp16 *X) {
 
   // getting the index of the maxima
   maxNum = maxVal[0];
-  retIdx = max_index[0];
+  retIdx = indices[0];
   for (unsigned int i = 1; i < 8; i++) {
     if (maxVal[i] > maxNum) {
       maxNum = maxVal[i];
-      retIdx = max_index[i];
+      retIdx = indices[i];
+    }
+  }
+
+  // if idx is more than UNIT16_MAX
+  if ((N > UINT16_MAX) && (N - idx) >= 4) {
+    uint32_t indices_u32[] = {idx, idx + 1, idx + 2, idx + 3};
+    uint32x4_t stride_u32 = vmovq_n_u32(4);
+    float16x4_t batch_4 = vld1_f16(&X[0]);
+    uint32x4_t curr_index_u32 = vld1q_u32(indices_u32);
+    uint32x4_t max_index_u32 = curr_index_u32;
+
+    idx += 4;
+    // processing batch of 4
+    for (; (N - idx) >= 4; idx += 4) {
+      float16x4_t values_4 = vld1_f16(&X[idx]);
+      curr_index_u32 = vaddq_u32(curr_index_u32, stride_u32);
+
+      // comparison
+      uint16x4_t mask_4 = vcgt_f16(batch_4, values_4);
+
+      // converting to u32 mask as required by vbslq_u32
+      uint32x4_t mask_4_u32 = vmovl_u16(mask_4);
+
+      // blend values and indices based on the mask
+      batch_4 = vbsl_f16(mask_4, batch_4, values_4);
+      max_index_u32 = vbslq_u32(mask_4_u32, max_index_u32, curr_index_u32);
+    }
+
+    // storing indices and max values
+    __fp16 maxVal_4[4];
+    vst1_f16(maxVal_4, batch_4);
+    vst1q_u32(indices_u32, max_index_u32);
+
+    // getting the index of the maxima
+    for (unsigned int i = 0; i < 4; i++) {
+      if (maxVal_4[i] > maxNum) {
+        maxNum = maxVal_4[i];
+        retIdx = indices_u32[i];
+      }
     }
   }
 
index d62be87aaf3e7352b9b11a5a9258dc1f8c3ba466..ed4a04a71b055b814a52b0b8c5ffaea03a714f2f 100644 (file)
@@ -384,6 +384,47 @@ TEST(nntrainer_Tensor, max_abs) {
   EXPECT_NEAR(result_neon, result_fp32, epsilon);
 }
 
+TEST(nntrainer_Tensor, max_abs_768) {
+
+  nntrainer::TensorDim::TensorType t_type_nchw_fp16 = {
+    nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16};
+
+  nntrainer::TensorDim::TensorType t_type_nchw_fp32 = {
+    nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32};
+
+  size_t batch = 1;
+  size_t channel = 1;
+  size_t height = 768;
+  size_t width = 768;
+
+  nntrainer::Tensor input(
+    nntrainer::TensorDim(batch, channel, height, width, t_type_nchw_fp16));
+
+  nntrainer::Tensor input_fp32(
+    nntrainer::TensorDim(batch, channel, height, width, t_type_nchw_fp32));
+
+  const float alpha = 1e-1;
+  const int MOD = 10;
+
+  GEN_TEST_INPUT(input, ((k * l * (batch * height * channel) +
+                          l * (batch * height) + k * (width) + l + 1) %
+                         MOD) *
+                          alpha);
+  GEN_TEST_INPUT(input_fp32, ((k * l * (batch * height * channel) +
+                               l * (batch * height) + k * (width) + l + 1) %
+                              MOD) *
+                               alpha);
+
+  __fp16 result_neon = input.max_abs();
+  float result_fp32 = input_fp32.max_abs();
+
+  float absErrorNeon = std::abs(result_neon - result_fp32);
+
+  const float epsilon = 1e-3;
+
+  EXPECT_IN_RANGE(absErrorNeon, 0, epsilon);
+}
+
 TEST(nntrainer_Tensor, sum_gemv_transpose_2_10) {
   int batch = 3;
   int channel = 2;