[ blas/neon ] Add NEON fp16 function for saxpy
authorDebadri Samaddar <s.debadri@samsung.com>
Thu, 3 Aug 2023 11:20:55 +0000 (16:50 +0530)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 21 Aug 2023 06:29:23 +0000 (15:29 +0900)
Enable neon saxpy function for Android (ARM) __fp16 computation

Signed-off-by: Debadri Samaddar <s.debadri@samsung.com>
nntrainer/tensor/blas_interface.cpp
nntrainer/tensor/blas_neon.cpp
nntrainer/tensor/blas_neon.h
nntrainer/tensor/meson.build

index adcce2e..1318c58 100644 (file)
     }                                               \
   } while (0);
 
+#define saxpy_loop_fp16()                                                  \
+  do {                                                                     \
+    unsigned int i;                                                        \
+    for (i = 0; i < N; ++i)                                                \
+      Y[i * incY] = Y[i * incY] + static_cast<_FP16>(alpha) * X[i * incX]; \
+  } while (0);
+
 namespace nntrainer {
 
 #ifdef ENABLE_FP16
@@ -53,8 +60,17 @@ static void saxpy_FP16(const unsigned int N, const float alpha, const _FP16 *X,
   if (incX < 0 or incY < 0)
     throw std::invalid_argument(
       "Error: negative inc not supported without cblas");
-  for (unsigned int i = 0; i < N; ++i)
-    Y[i * incY] = Y[i * incY] + static_cast<_FP16>(alpha) * X[i * incX];
+
+#ifdef USE__FP16
+  // USE__FP16 is defined when platform is android
+  if (incX == 1 && incY == 1) {
+    nntrainer::neon::saxpy_neon_fp16(N, alpha, X, Y);
+  } else {
+    saxpy_loop_fp16();
+  }
+#else
+  saxpy_loop_fp16();
+#endif
 }
 
 static void sgemv_FP16(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
index c68eda1..57d331c 100644 (file)
@@ -519,4 +519,36 @@ void sgemv_transpose_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y,
   }
 }
 
+void saxpy_neon_fp16(const unsigned int N, const float alpha, const __fp16 *X, __fp16 *Y) {
+
+  const float16x8_t v_alphaX8 = vmovq_n_f16(alpha);
+  const float16x4_t v_alphaX4 = vmov_n_f16(alpha);
+
+  unsigned int idx = 0;
+
+  // processing batch of 8
+  for(; (N - idx) >= 8 ; idx += 8){
+    float16x8_t x = vld1q_f16(&X[idx]);
+    float16x8_t y = vld1q_f16(&Y[idx]);
+
+    // alpha*X + Y -> mulacc
+    float16x8_t mulacc = vfmaq_f16(y, v_alphaX8, x);
+    vst1q_f16(&Y[idx], mulacc);
+  }
+
+  // processing remaining batch of 4
+  for(; (N - idx) >= 4 ; idx += 4){
+    float16x4_t x = vld1_f16(&X[idx]);
+    float16x4_t y = vld1_f16(&Y[idx]);
+
+    // alpha*X + Y -> mulacc
+    float16x4_t mulacc = vfma_f16(y, v_alphaX4, x);
+    vst1_f16(&Y[idx], mulacc);
+  }
+
+  // pocessing remaining values
+  for (; idx < N; idx++)
+    Y[idx] = Y[idx] + alpha * X[idx];
+}
+
 } // namespace nntrainer::neon
index 0bf51ef..c573409 100644 (file)
@@ -76,6 +76,15 @@ void sgemv_transpose_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y,
                                uint32_t rows, uint32_t cols, float alpha,
                                float beta);
 
+/**
+ * @brief     saxpy computation with neon: Y = alpha*X + Y
+ * @param[in] N number of elements in Y
+ * @param[in] alpha float number
+ * @param[in] X __fp16 * for Vector X
+ * @param[in] Y __fp16 * for Vector Y
+ */
+void saxpy_neon_fp16(const unsigned int N, const float alpha, const __fp16 *X, __fp16 *Y);
+
 } // namespace nntrainer::neon
 
 #endif /* __cplusplus */
index 40205f4..39edb3a 100644 (file)
@@ -17,6 +17,7 @@ tensor_sources = [
   'optimized_v2_planner.cpp',
   'optimized_v3_planner.cpp',
   'task_executor.cpp',
+  'blas_neon.cpp',
 ]
 
 tensor_headers = [