Y[i * incY] = Y[i * incY] + static_cast<_FP16>(alpha) * X[i * incX]; \
} while (0);
+#define sgemm_loop_fp16() \
+ do { \
+ for (unsigned int m = 0; m < M; ++m) { \
+ for (unsigned int n = 0; n < N; ++n) { \
+ _FP16 c = 0; \
+ _FP16 c_old = C[m * ldc + n]; \
+ for (unsigned int k = 0; k < K; ++k) { \
+ _FP16 a, b; \
+ a = ((TransA == CblasTrans) ? A[k * lda + m] : A[m * lda + k]); \
+ b = ((TransB == CblasTrans) ? B[n * ldb + k] : B[k * ldb + n]); \
+ c += a * b; \
+ } \
+ C[m * ldc + n] = static_cast<_FP16>(alpha) * c; \
+ if (beta != 0.0) \
+ C[m * ldc + n] += static_cast<_FP16>(beta) * c_old; \
+ } \
+ } \
+ } while (0);
+
namespace nntrainer {
#ifdef ENABLE_FP16
const unsigned int ldb, const float beta, _FP16 *C,
const unsigned int ldc) {
- for (unsigned int m = 0; m < M; ++m) {
- for (unsigned int n = 0; n < N; ++n) {
- _FP16 c = 0;
- _FP16 c_old = C[m * ldc + n];
- for (unsigned int k = 0; k < K; ++k) {
- _FP16 a, b;
- a = ((TransA == CblasTrans) ? A[k * lda + m] : A[m * lda + k]);
- b = ((TransB == CblasTrans) ? B[n * ldb + k] : B[k * ldb + n]);
- c += a * b;
- }
- C[m * ldc + n] = static_cast<_FP16>(alpha) * c;
- if (beta != 0.0)
- C[m * ldc + n] += static_cast<_FP16>(beta) * c_old;
- }
+#ifdef USE__FP16
+ if ((M % 8 == 0) && (N % 8 == 0) && (K % 8 == 0)) {
+ nntrainer::neon::sgemm_neon_fp16(A, B, C, M, N, K, alpha, beta,
+ TransA == CblasTrans,
+ TransB == CblasTrans);
+ } else {
+ sgemm_loop_fp16();
}
+#else
+ sgemm_loop_fp16();
+#endif
}
static unsigned int isamax_FP16(const unsigned int N, const _FP16 *X,
return retIdx;
}
+
+void sgemm_neon_fp16(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M,
+ uint32_t N, uint32_t K, float alpha, float beta,
+ bool TransA, bool TransB) {
+
+ float16x8_t v_alpha = vmovq_n_f16(alpha);
+ float16x8_t v_beta = vmovq_n_f16(beta);
+
+ // performing beta*C
+ for (unsigned int idx = 0; idx < (M * N); idx += 8) {
+ float16x8_t c = vld1q_f16(&C[idx]);
+ c = vmulq_f16(v_beta, c);
+ vst1q_f16(&C[idx], c);
+ }
+
+ __fp16 r[4];
+
+ if (!TransA && TransB) {
+ for (unsigned int m = 0; m < M; m++) {
+ for (unsigned int n = 0; n < N; n++) {
+ float16x8_t sum = vmovq_n_f16(0);
+
+ for (unsigned int k = 0; k < K; k += 8) {
+ float16x8_t a = vld1q_f16(&A[m * K + k]);
+ float16x8_t b = vld1q_f16(&B[n * K + k]);
+ sum = vfmaq_f16(sum, a, b);
+ }
+ sum = vmulq_f16(v_alpha, sum);
+
+ float16x4_t sum_high = vget_high_f16(sum);
+ float16x4_t sum_low = vget_low_f16(sum);
+
+ sum_low = vadd_f16(sum_high, sum_low);
+ vst1_f16(r, sum_low);
+
+ C[m * N + n] += r[0] + r[1] + r[2] + r[3];
+ }
+ }
+ } else if (TransA && !TransB) {
+ for (unsigned int k = 0; k < K; k++) {
+ for (unsigned int m = 0; m < M; m++) {
+ __fp16 a = alpha * A[k * M + m];
+
+ for (unsigned int n = 0; n < N; n += 8) {
+ float16x8_t b = vld1q_f16(&B[k * N + n]);
+
+ // load previously calculated C
+ float16x8_t c = vld1q_f16(&C[m * N + n]);
+ c = vfmaq_n_f16(c, b, a);
+ vst1q_f16(&C[m * N + n], c);
+ }
+ }
+ }
+ } else if (!TransA && !TransB) {
+ for (unsigned int k = 0; k < K; k++) {
+ for (unsigned int m = 0; m < M; m++) {
+ __fp16 a = alpha * A[m * K + k];
+
+ for (unsigned int n = 0; n < N; n += 8) {
+ float16x8_t b = vld1q_f16(&B[k * N + n]);
+
+ // load previously calculated C
+ float16x8_t c = vld1q_f16(&C[m * N + n]);
+ c = vfmaq_n_f16(c, b, a);
+ vst1q_f16(&C[m * N + n], c);
+ }
+ }
+ }
+ } else { // TransA && TransB
+ for (unsigned int m = 0; m < M; m++) {
+ for (unsigned int n = 0; n < N; n++) {
+ __fp16 sum = 0;
+ for (int k = 0; k < K; k++) {
+ __fp16 a = A[k * M + m];
+ __fp16 b = B[n * K + k];
+ sum += a * b;
+ }
+
+ sum = alpha * sum;
+ C[m * N + n] += sum;
+ }
+ }
+ }
+}
#endif
} // namespace nntrainer::neon
* @param[in] X __fp16 * for Vector X
*/
unsigned int isamax_neon_fp16(const unsigned int N, const __fp16 *X);
+
+/**
+ * @brief sgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * where op(X) is one of X or X**T
+ * @param[in] A __fp16 * for Matrix A
+ * @param[in] B __fp16 * for Matrix B
+ * @param[in] C __fp16 * for Matrix C
+ * @param[in] M number of op(A)'s and C's row
+ * @param[in] N number of op(B)'s and C's columns
+ * @param[in] K number of op(A)'s and columns and op(B)'s rows
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void sgemm_neon_fp16(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M,
+ uint32_t N, uint32_t K, float alpha, float beta,
+ bool TransA, bool TransB);
#endif
} // namespace nntrainer::neon