[ unittest ] Add TC for K=1 hgemm case
authorskykongkong8 <ss.kong@samsung.com>
Wed, 10 Jul 2024 04:38:53 +0000 (13:38 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 30 Jul 2024 22:45:30 +0000 (07:45 +0900)
- Missing optimizations for K=1 GEMM case was recently detected.
- Add such TC accordingly.

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <ss.kong@samsung.com>
23 files changed:
nntrainer/tensor/blas_interface.cpp
nntrainer/tensor/blas_neon.cpp
nntrainer/tensor/blas_neon.h
nntrainer/tensor/hgemm/hgemm.cpp
nntrainer/tensor/hgemm/hgemm.h
nntrainer/tensor/hgemm/hgemm_common.h
nntrainer/tensor/hgemm/hgemm_kernel.h [new file with mode: 0644]
nntrainer/tensor/hgemm/hgemm_kernel_1x4.h
nntrainer/tensor/hgemm/hgemm_kernel_1x8.h
nntrainer/tensor/hgemm/hgemm_kernel_4x4.h
nntrainer/tensor/hgemm/hgemm_kernel_4x8.h
nntrainer/tensor/hgemm/hgemm_kernel_8x16.h
nntrainer/tensor/hgemm/hgemm_kernel_8x8.h
nntrainer/tensor/hgemm/hgemm_noTrans.cpp [new file with mode: 0644]
nntrainer/tensor/hgemm/hgemm_noTrans.h [new file with mode: 0644]
nntrainer/tensor/hgemm/hgemm_transA.cpp [new file with mode: 0644]
nntrainer/tensor/hgemm/hgemm_transA.h [new file with mode: 0644]
nntrainer/tensor/hgemm/hgemm_transAB.cpp [new file with mode: 0644]
nntrainer/tensor/hgemm/hgemm_transAB.h [new file with mode: 0644]
nntrainer/tensor/hgemm/hgemm_transB.cpp [new file with mode: 0644]
nntrainer/tensor/hgemm/hgemm_transB.h [new file with mode: 0644]
nntrainer/tensor/hgemm/meson.build
test/unittest/unittest_nntrainer_tensor_neon_fp16.cpp

index e04c1ce499d6ec78f1c935396342924ba65ee44f..08f31b34d0f3b22e2e4857b1e06dfd4ba6bec6d2 100644 (file)
@@ -326,7 +326,7 @@ static void sgemm_FP16(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
                        const unsigned int ldc) {
 
 #if (defined USE__FP16 && USE_NEON)
-  nntrainer::neon::hgemm(A, B, C, M, N, K, alpha, beta, TransA == CblasTrans,
+  nntrainer::neon::custom_hgemm(A, B, C, M, N, K, alpha, beta, TransA == CblasTrans,
                          TransB == CblasTrans);
 #else
   float *A_ = new float[M * K];
index 0e429a6d37879687cbb269f26782c2009e8ec309..4b6c05c72ea8a399b97e884bb7667b6de33ce696 100644 (file)
@@ -1586,55 +1586,12 @@ unsigned int isamax(const unsigned int N, const __fp16 *X) {
   return retIdx;
 }
 
-void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M, uint32_t N,
-           uint32_t K, float alpha, float beta, bool TransA, bool TransB) {
-  if (K == 1) {
-    return hgemm_K1(A, B, C, M, N, K, alpha, beta, TransA, TransB);
-  }
-  // dynamic creation to avoid reaching stack limit(causes segmentation fault)
-  float *C32 = (float *)malloc(M * N * sizeof(float));
-
-  // performing beta*C
-  unsigned int idx = 0;
-  unsigned int size = M * N;
-  unsigned int size8 = (size >> 3) << 3;
-  unsigned int size4 = (size >> 2) << 2;
-  if (std::fpclassify(beta) != FP_ZERO) {
-    for (; idx < size8; idx += 8) {
-      float16x8_t c =
-        vmulq_n_f16(vld1q_f16(&C[idx]), static_cast<__fp16>(beta));
-
-      vst1q_f32(&C32[idx], vcvt_f32_f16(vget_low_f16(c)));
-      vst1q_f32(&C32[idx + 4], vcvt_f32_f16(vget_high_f16(c)));
-    }
-    // remaining 4
-    for (; idx < size4; idx += 4) {
-      float16x4_t c = vmul_n_f16(vld1_f16(&C[idx]), static_cast<__fp16>(beta));
-
-      vst1q_f32(&C32[idx], vcvt_f32_f16(c));
-    }
-
-    // remaining values if dimensions not a multiple of 8
-    for (; idx < size; idx++) {
-      C32[idx] = C[idx] * beta;
-    }
-  } else {
-    float32x4_t zeros = vmovq_n_f32(0.F);
-    for (; idx < size4; idx += 4) {
-      vst1q_f32(&C32[idx], zeros);
-    }
-    for (; idx < size; idx++) {
-      C32[idx] = 0.F;
-    }
-  }
-
-  hgemm_ensure_divisibility(A, B, C32, M, N, K, alpha, beta, TransA, TransB);
-  
-  copy_fp32_to_fp16(M * N, C32, C);
-  free(C32);
+void custom_hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M,
+                  uint32_t N, uint32_t K, float alpha, float beta, bool TransA,
+                  bool TransB) {
+  hgemm(A, B, C, M, N, K, alpha, beta, TransA, TransB);
 }
 
-
 void ele_mul(const unsigned int N, const __fp16 *X, const __fp16 *Y, __fp16 *Z,
              float alpha, float beta) {
   unsigned int i = 0;
index 736057f0f92e59bdb034cdc2c2fc4d7520211493..b7be55d48e0fb097473dbc50c5661a26f5ef00b0 100644 (file)
@@ -327,7 +327,7 @@ unsigned int isamax(const unsigned int N, const __fp16 *X);
  * @param[in] alpha float number
  * @param[in] beta float number
  */
-void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M, uint32_t N,
+void custom_hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M, uint32_t N,
            uint32_t K, float alpha, float beta, bool TransA, bool TransB);
 
 
index c5a144ceee0537467be50885975b629e64545b3a..81d22b22c4bbf4f16577195fb7215668c2618f04 100644 (file)
  *
  */
 
-#include <cmath>
 #include <hgemm.h>
-#include <hgemm_kernel_1x4.h>
-#include <hgemm_kernel_1x8.h>
-#include <hgemm_kernel_4x4.h>
-#include <hgemm_kernel_4x8.h>
-#include <hgemm_kernel_8x16.h>
-#include <hgemm_kernel_8x8.h>
-#include <hgemm_kernel_pack.h>
+#include <hgemm_noTrans.h>
 #include <hgemm_padding.h>
+#include <hgemm_transA.h>
+#include <hgemm_transAB.h>
+#include <hgemm_transB.h>
 #include <hgemm_util.h>
-#include <limits>
-#include <matrix_transpose_neon.h>
+#include <hgemm_common.h>
+#include <cmath>
 
-#define HGEMM_KERNEL_1x4 hgemm_kernel_1x4
-#define HGEMM_KERNEL_4x4 hgemm_kernel_4x4
-#define HGEMM_KERNEL_1x8 hgemm_kernel_1x8
-#define HGEMM_KERNEL_4x8 hgemm_kernel_4x8
-#define HGEMM_KERNEL_8x8 hgemm_kernel_8x8
-#define HGEMM_KERNEL_8x16 hgemm_kernel_8x16
 
-void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M,
-                   unsigned int N, unsigned int K, float alpha, float beta) {
-  const float eps = std::numeric_limits<float>::epsilon();
-  if (std::abs(alpha - 1.F) < eps) {
-    hgemm_noTrans_strict(A, B, C32, M, N, K, alpha, beta);
-  } else {
-    hgemm_noTrans_fallback(M, N, K, A, K, B, N, C32, N, alpha, beta);
+void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M, unsigned int N,
+           unsigned int K, float alpha, float beta, bool TransA, bool TransB) {
+  if (K == 1) {
+    return hgemm_K1(A, B, C, M, N, K, alpha, beta, TransA, TransB);
   }
-}
+  // dynamic creation to avoid reaching stack limit(causes segmentation fault)
+  float *C32 = (float *)malloc(M * N * sizeof(float));
+
+  // performing beta*C
+  unsigned int idx = 0;
+  unsigned int size = M * N;
+  unsigned int size8 = (size >> 3) << 3;
+  unsigned int size4 = (size >> 2) << 2;
+
+  if (std::fpclassify(beta) != FP_ZERO) {
+    for (; idx < size8; idx += 8) {
+      float16x8_t c =
+        vmulq_n_f16(vld1q_f16(&C[idx]), static_cast<__fp16>(beta));
+
+      vst1q_f32(&C32[idx], vcvt_f32_f16(vget_low_f16(c)));
+      vst1q_f32(&C32[idx + 4], vcvt_f32_f16(vget_high_f16(c)));
+    }
+    // remaining 4
+    for (; idx < size4; idx += 4) {
+      float16x4_t c = vmul_n_f16(vld1_f16(&C[idx]), static_cast<__fp16>(beta));
+
+      vst1q_f32(&C32[idx], vcvt_f32_f16(c));
+    }
 
-void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, float *C32,
-                          unsigned int M, unsigned int N, unsigned int K,
-                          float alpha, float beta) {
-  // used bitwise operator instead of modulo for performance
-  // e.g (M % 8) is same as (M & 0x7) which will extract last 3 bits of M
-  if ((M & 0x7) == 0 && (N & 0xF) == 0 && (K & 0x7) == 0) {
-    hgemm_noTrans_8x16(M, N, K, A, K, B, N, C32, N, alpha, beta);
-  } else if ((M & 0x7) == 0 && (N & 0x7) == 0 && (K & 0x7) == 0) {
-    hgemm_noTrans_8x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
-  } else if ((M & 0x3) == 0 && (N & 0x7) == 0 && (K & 0x7) == 0) {
-    hgemm_noTrans_4x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
-  } else if ((K & 0x7) == 0 && (N & 0x7) == 0) {
-    hgemm_noTrans_1x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
-  } else if ((K & 0x7) == 0 && (N & 0x3) == 0) {
-    hgemm_noTrans_1x4(M, N, K, A, K, B, N, C32, N, alpha, beta);
+    // remaining values if dimensions not a multiple of 8
+    for (; idx < size; idx++) {
+      C32[idx] = C[idx] * beta;
+    }
   } else {
-    hgemm_noTrans_fallback(M, N, K, A, K, B, N, C32, N, alpha, beta);
+    float32x4_t zeros = vmovq_n_f32(0.F);
+    for (; idx < size4; idx += 4) {
+      vst1q_f32(&C32[idx], zeros);
+    }
+    for (; idx < size; idx++) {
+      C32[idx] = 0.F;
+    }
   }
-}
 
-void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, __fp16 *C,
-                          unsigned int M, unsigned int N, unsigned int K,
-                          float alpha, float beta) {
-  if (alpha == 1.F) {
-    // used bitwise operator instead of modulo for performance
-    // e.g (M % 8) is same as (M & 0x7) which will extract last 3 bits of M
-    if ((M & 0x7) == 0 && (N & 0xF) == 0 && (K & 0x7) == 0) {
-      hgemm_noTrans_8x16(M, N, K, A, K, B, N, C, N, alpha, beta);
-    } else if ((M & 0x7) == 0 && (N & 0x7) == 0 && (K & 0x7) == 0) {
-      hgemm_noTrans_8x8(M, N, K, A, K, B, N, C, N, alpha, beta);
-    } else if ((M & 0x3) == 0 && (N & 0x7) == 0 && (K & 0x3) == 0) {
-      hgemm_noTrans_4x8(M, N, K, A, K, B, N, C, N, alpha, beta);
-    } else if ((M & 0x3) == 0 && (N & 0x3) == 0 && (K & 0x3) == 0) {
-      hgemm_noTrans_4x4(M, N, K, A, K, B, N, C, N, alpha, beta);
-    } else if ((N & 0x7) == 0 && (K & 0x7) == 0) {
-      hgemm_noTrans_1x8(M, N, K, A, K, B, N, C, N, alpha, beta);
-    } else if ((N & 0x3) == 0 && (K & 0x7) == 0) {
-      hgemm_noTrans_1x4(M, N, K, A, K, B, N, C, N, alpha, beta);
-    }
+  hgemm_ensure_divisibility(A, B, C32, M, N, K, alpha, beta, TransA, TransB);
+  
+  unsigned int L = M*N;
+  unsigned int L8 = (L >> 3) <<3;
+
+  for (unsigned int idx = 0; idx < L8; idx += 8) {
+    float32x4_t x1 = vld1q_f32(&C32[idx]);
+    float32x4_t x2 = vld1q_f32(&C32[idx + 4]);
+
+    float16x8_t y1 = vcombine_f16(vcvt_f16_f32(x1), vcvt_f16_f32(x2));
+
+    vst1q_f16(&C[idx], y1);
+  }
+  for (unsigned int idx = L8; idx < L; ++idx) {
+    C[idx] = static_cast<__fp16>(C32[idx]);
   }
+
+  free(C32);
 }
 
 void hgemm_ensure_divisibility(const __fp16 *A, const __fp16 *B, float *C32,
@@ -149,1307 +150,14 @@ void hgemm_classify(const __fp16 *A, const __fp16 *B, float *C32,
   }
 }
 
-void hgemm_K1_noTrans(unsigned int M, unsigned int N, unsigned int K,
-                      const __fp16 *A, unsigned int lda, const __fp16 *B,
-                      unsigned int ldb, __fp16 *C, unsigned int ldc,
-                      float alpha, float beta) {
-  const float eps = std::numeric_limits<float>::epsilon();
-  float16x8_t a_vec;
-  unsigned int N8 = (N >> 3) << 3;
-  for (unsigned int m = 0; m < M; ++m) {
-    a_vec = vmovq_n_f16(alpha * A[m]);
-    if (std::fpclassify(beta) != FP_ZERO) {
-      for (unsigned int n = 0; n < N8; n += 8) {
-        vst1q_f16(&C[m * ldc + n],
-                  vaddq_f16(vmulq_f16(a_vec, vld1q_f16(&B[n])),
-                            vmulq_n_f16(vld1q_f16(&C[m * ldc + n]), beta)));
-      }
-    } else {
-      for (unsigned int n = 0; n < N8; n += 8) {
-        vst1q_f16(&C[m * ldc + n], vmulq_f16(a_vec, vld1q_f16(&B[n])));
-      }
-    }
-    for (unsigned int n = N8; n < N; ++n) {
-      C[m * ldc + n] = alpha * A[m] * B[n] + beta * C[m * ldc + n];
-    }
-  }
-}
-
-void hgemm_K1_transA(unsigned int M, unsigned int N, unsigned int K,
-                     const __fp16 *A, unsigned int lda, const __fp16 *B,
-                     unsigned int ldb, __fp16 *C, unsigned int ldc, float alpha,
-                     float beta) {
-  __fp16 *A_T = alignedMalloc(M * K);
-
-  transpose_neon<__fp16>(K, M, A, M, A_T, K);
-
-  hgemm_K1_noTrans(M, N, K, A_T, lda, B, ldb, C, ldc, alpha, beta);
-
-  free(A_T);
-}
-
-void hgemm_K1_transB(unsigned int M, unsigned int N, unsigned int K,
-                     const __fp16 *A, unsigned int lda, const __fp16 *B,
-                     unsigned int ldb, __fp16 *C, unsigned int ldc, float alpha,
-                     float beta) {
-  __fp16 *B_T = alignedMalloc(K * N);
-
-  transpose_neon<__fp16>(N, K, B, K, B_T, N);
-
-  hgemm_K1_noTrans(M, N, K, A, lda, B_T, ldb, C, ldc, alpha, beta);
-
-  free(B_T);
-}
-
-void hgemm_K1_transAB(unsigned int M, unsigned int N, unsigned int K,
-                      const __fp16 *A, unsigned int lda, const __fp16 *B,
-                      unsigned int ldb, __fp16 *C, unsigned int ldc,
-                      float alpha, float beta) {
-  __fp16 *A_T = alignedMalloc(M * K);
-  __fp16 *B_T = alignedMalloc(K * N);
-
-  transpose_neon<__fp16>(K, M, A, M, A_T, K);
-  transpose_neon<__fp16>(N, K, B, K, B_T, N);
-
-  hgemm_K1_noTrans(M, N, K, A_T, lda, B_T, ldb, C, ldc, alpha, beta);
-
-  free(A_T);
-  free(B_T);
-}
-
-void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K,
-                       const __fp16 *A, unsigned int lda, const __fp16 *B,
-                       unsigned int ldb, __fp16 *C, unsigned int ldc,
-                       float alpha, float beta) {
-  __fp16 *sa = alignedMalloc(M * K);
-  __fp16 *sb = alignedMalloc(K * N);
-
-  unsigned int ms, mms, ns, ks;
-  unsigned int m_min, m2_min, n_min, k_min;
-  for (ms = 0; ms < M; ms += M_BLOCKING) {
-    m_min = M - ms;
-    if (m_min > M_BLOCKING) {
-      m_min = M_BLOCKING;
-    }
-
-    for (ks = 0; ks < K; ks += k_min) {
-      k_min = K - ks;
-      if (k_min >= (K_BLOCKING << 1)) {
-        k_min = K_BLOCKING;
-      } else if (k_min > K_BLOCKING) {
-        k_min = (k_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
-      }
-
-      n_min = N;
-      if (N >= N_BLOCKING * 2) {
-        n_min = N_BLOCKING;
-      } else if (N > N_BLOCKING) {
-        n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
-      }
-      packing_B4(k_min, n_min, B + ks * ldb, ldb, sb);
-
-      for (mms = ms; mms < ms + m_min; mms += m2_min) {
-        m2_min = (ms + m_min) - mms;
-        if (m2_min >= 3 * GEMM_UNROLLING_1) {
-          m2_min = 3 * GEMM_UNROLLING_1;
-        } else if (m2_min >= 2 * GEMM_UNROLLING_1) {
-          m2_min = 2 * GEMM_UNROLLING_1;
-        } else if (m2_min > GEMM_UNROLLING_1) {
-          m2_min = GEMM_UNROLLING_1;
-        }
-
-        packing_A1(m2_min, k_min, A + mms * lda + ks, lda,
-                   sa + k_min * (mms - ms));
-
-        HGEMM_KERNEL_1x4(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
-                         C + mms * ldc, ldc);
-      }
-
-      for (ns = n_min; ns < N; ns += n_min) {
-        n_min = N - ns;
-        if (n_min >= N_BLOCKING * 2) {
-          n_min = N_BLOCKING;
-        } else if (n_min > N_BLOCKING) {
-          n_min = (n_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
-        }
-
-        packing_B4(k_min, n_min, B + ns + ldb * ks, ldb, sb);
-        HGEMM_KERNEL_1x4(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
-      }
-    }
-  }
-
-  free(sa);
-  free(sb);
-}
-
-void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K,
-                       const __fp16 *A, unsigned int lda, const __fp16 *B,
-                       unsigned int ldb, float *C, unsigned int ldc,
-                       float alpha, float beta) {
-  __fp16 *sa = alignedMalloc(M * K);
-  __fp16 *sb = alignedMalloc(K * N);
-
-  unsigned int ms, mms, ns, ks;
-  unsigned int m_min, m2_min, n_min, k_min;
-  for (ms = 0; ms < M; ms += M_BLOCKING) {
-    m_min = M - ms;
-    if (m_min > M_BLOCKING) {
-      m_min = M_BLOCKING;
-    }
-
-    for (ks = 0; ks < K; ks += k_min) {
-      k_min = K - ks;
-      if (k_min >= (K_BLOCKING << 1)) {
-        k_min = K_BLOCKING;
-      } else if (k_min > K_BLOCKING) {
-        k_min = (k_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
-      }
-
-      n_min = N;
-      if (N >= N_BLOCKING * 2) {
-        n_min = N_BLOCKING;
-      } else if (N > N_BLOCKING) {
-        n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
-      }
-      packing_B4(k_min, n_min, B + ks * ldb, ldb, sb);
-
-      for (mms = ms; mms < ms + m_min; mms += m2_min) {
-        m2_min = (ms + m_min) - mms;
-        if (m2_min >= 3 * GEMM_UNROLLING_1) {
-          m2_min = 3 * GEMM_UNROLLING_1;
-        } else if (m2_min >= 2 * GEMM_UNROLLING_1) {
-          m2_min = 2 * GEMM_UNROLLING_1;
-        } else if (m2_min > GEMM_UNROLLING_1) {
-          m2_min = GEMM_UNROLLING_1;
-        }
-
-        packing_A1(m2_min, k_min, A + mms * lda + ks, lda,
-                   sa + k_min * (mms - ms));
-
-        HGEMM_KERNEL_1x4(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
-                         C + mms * ldc, ldc);
-      }
-
-      for (ns = n_min; ns < N; ns += n_min) {
-        n_min = N - ns;
-        if (n_min >= N_BLOCKING * 2) {
-          n_min = N_BLOCKING;
-        } else if (n_min > N_BLOCKING) {
-          n_min = (n_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
-        }
-
-        packing_B4(k_min, n_min, B + ns + ldb * ks, ldb, sb);
-        HGEMM_KERNEL_1x4(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
-      }
-    }
-  }
-
-  free(sa);
-  free(sb);
-}
-
-void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K,
-                       const __fp16 *A, unsigned int lda, const __fp16 *B,
-                       unsigned int ldb, __fp16 *C, unsigned int ldc,
-                       float alpha, float beta) {
-  __fp16 *sa = alignedMalloc(M * K);
-  __fp16 *sb = alignedMalloc(K * N);
-
-  unsigned int ms, mms, ns, ks;
-  unsigned int m_min, m2_min, n_min, k_min;
-  for (ms = 0; ms < M; ms += M_BLOCKING) {
-    m_min = M - ms;
-    if (m_min > M_BLOCKING) {
-      m_min = M_BLOCKING;
-    }
-
-    for (ks = 0; ks < K; ks += k_min) {
-      k_min = K - ks;
-      if (k_min >= (K_BLOCKING << 1)) {
-        k_min = K_BLOCKING;
-      } else if (k_min > K_BLOCKING) {
-        k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
-      }
-
-      n_min = N;
-      if (N >= N_BLOCKING * 2) {
-        n_min = N_BLOCKING;
-      } else if (N > N_BLOCKING) {
-        n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
-      }
-      packing_B4(k_min, n_min, B + ks * ldb, ldb, sb);
-
-      for (mms = ms; mms < ms + m_min; mms += m2_min) {
-        m2_min = (ms + m_min) - mms;
-        if (m2_min >= 3 * GEMM_UNROLLING_4) {
-          m2_min = 3 * GEMM_UNROLLING_4;
-        } else if (m2_min >= 2 * GEMM_UNROLLING_4) {
-          m2_min = 2 * GEMM_UNROLLING_4;
-        } else if (m2_min > GEMM_UNROLLING_4) {
-          m2_min = GEMM_UNROLLING_4;
-        }
-
-        packing_A4(m2_min, k_min, A + mms * lda + ks, lda,
-                   sa + k_min * (mms - ms));
-
-        HGEMM_KERNEL_4x4(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
-                         C + mms * ldc, ldc);
-      }
-
-      for (ns = n_min; ns < N; ns += n_min) {
-        n_min = N - ns;
-        if (n_min >= N_BLOCKING * 2) {
-          n_min = N_BLOCKING;
-        } else if (n_min > N_BLOCKING) {
-          n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
-        }
-
-        packing_B4(k_min, n_min, B + ns + ldb * ks, ldb, sb);
-        HGEMM_KERNEL_4x4(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
-      }
-    }
-  }
-
-  free(sa);
-  free(sb);
-}
-
-void hgemm_noTrans_1x8(unsigned int M, unsigned int N, unsigned int K,
-                       const __fp16 *A, unsigned int lda, const __fp16 *B,
-                       unsigned int ldb, __fp16 *C, unsigned int ldc,
-                       float alpha, float beta) {
-  __fp16 *sa = alignedMalloc(M * K);
-  __fp16 *sb = alignedMalloc(K * N);
-
-  unsigned int ms, mms, ns, ks;
-  unsigned int m_min, m2_min, n_min, k_min;
-  unsigned int l1stride = 1;
-  for (ms = 0; ms < M; ms += M_BLOCKING) {
-    m_min = M - ms;
-    if (m_min > M_BLOCKING) {
-      m_min = M_BLOCKING;
-    }
-
-    for (ks = 0; ks < K; ks += k_min) {
-      k_min = K - ks;
-      if (k_min >= (K_BLOCKING << 1)) {
-        k_min = K_BLOCKING;
-      } else if (k_min > K_BLOCKING) {
-        k_min = (k_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
-      }
-
-      n_min = N;
-      if (N >= N_BLOCKING * 2) {
-        n_min = N_BLOCKING;
-      } else if (N > N_BLOCKING) {
-        n_min = ((n_min / 2 + GEMM_UNROLLING_8 - 1) / GEMM_UNROLLING_8) *
-                GEMM_UNROLLING_8;
-      } else {
-        l1stride = 0;
-      }
-      packing_B8(k_min, n_min, B + ks * ldb, ldb, sb);
-
-      for (mms = ms; mms < ms + m_min; mms += m2_min) {
-        m2_min = (ms + m_min) - mms;
-        if (m2_min >= 3 * GEMM_UNROLLING_1) {
-          m2_min = 3 * GEMM_UNROLLING_1;
-        } else if (m2_min >= 2 * GEMM_UNROLLING_1) {
-          m2_min = 2 * GEMM_UNROLLING_1;
-        } else if (m2_min > GEMM_UNROLLING_1) {
-          m2_min = GEMM_UNROLLING_1;
-        }
-
-        packing_A1(m2_min, k_min, A + mms * lda + ks, lda,
-                   sa + k_min * (mms - ms) * l1stride);
-
-        HGEMM_KERNEL_1x8(m2_min, n_min, k_min,
-                         sa + l1stride * k_min * (mms - ms), sb, C + mms * ldc,
-                         ldc);
-      }
-
-      for (ns = n_min; ns < N; ns += n_min) {
-        n_min = N - ns;
-        if (n_min >= N_BLOCKING * 2) {
-          n_min = N_BLOCKING;
-        } else if (n_min > N_BLOCKING) {
-          n_min = (n_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
-        }
-
-        packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sb);
-        HGEMM_KERNEL_1x8(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
-      }
-    }
-  }
-
-  free(sa);
-  free(sb);
-}
-
-void hgemm_noTrans_1x8(unsigned int M, unsigned int N, unsigned int K,
-                       const __fp16 *A, unsigned int lda, const __fp16 *B,
-                       unsigned int ldb, float *C, unsigned int ldc,
-                       float alpha, float beta) {
-  __fp16 *sa = alignedMalloc(M * K);
-  __fp16 *sb = alignedMalloc(K * N);
-
-  unsigned int ms, mms, ns, ks;
-  unsigned int m_min, m2_min, n_min, k_min;
-  unsigned int l1stride = 1;
-  for (ms = 0; ms < M; ms += M_BLOCKING) {
-    m_min = M - ms;
-    if (m_min > M_BLOCKING) {
-      m_min = M_BLOCKING;
-    }
-
-    for (ks = 0; ks < K; ks += k_min) {
-      k_min = K - ks;
-      if (k_min >= (K_BLOCKING << 1)) {
-        k_min = K_BLOCKING;
-      } else if (k_min > K_BLOCKING) {
-        k_min = (k_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
-      }
-
-      n_min = N;
-      if (N >= N_BLOCKING * 2) {
-        n_min = N_BLOCKING;
-      } else if (N > N_BLOCKING) {
-        n_min = ((n_min / 2 + GEMM_UNROLLING_8 - 1) / GEMM_UNROLLING_8) *
-                GEMM_UNROLLING_8;
-      } else {
-        l1stride = 0;
-      }
-      packing_B8(k_min, n_min, B + ks * ldb, ldb, sb);
-
-      for (mms = ms; mms < ms + m_min; mms += m2_min) {
-        m2_min = (ms + m_min) - mms;
-        if (m2_min >= 3 * GEMM_UNROLLING_1) {
-          m2_min = 3 * GEMM_UNROLLING_1;
-        } else if (m2_min >= 2 * GEMM_UNROLLING_1) {
-          m2_min = 2 * GEMM_UNROLLING_1;
-        } else if (m2_min > GEMM_UNROLLING_1) {
-          m2_min = GEMM_UNROLLING_1;
-        }
-
-        packing_A1(m2_min, k_min, A + mms * lda + ks, lda,
-                   sa + k_min * (mms - ms) * l1stride);
-
-        HGEMM_KERNEL_1x8(m2_min, n_min, k_min,
-                         sa + l1stride * k_min * (mms - ms), sb, C + mms * ldc,
-                         ldc);
-      }
-
-      for (ns = n_min; ns < N; ns += n_min) {
-        n_min = N - ns;
-        if (n_min >= N_BLOCKING * 2) {
-          n_min = N_BLOCKING;
-        } else if (n_min > N_BLOCKING) {
-          n_min = (n_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
-        }
-
-        packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sb);
-        HGEMM_KERNEL_1x8(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
-      }
-    }
-  }
-
-  free(sa);
-  free(sb);
-}
-
-void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K,
-                       const __fp16 *A, unsigned int lda, const __fp16 *B,
-                       unsigned int ldb, float *C, unsigned int ldc,
-                       float alpha, float beta) {
-  __fp16 *sa = alignedMalloc(M * K);
-  __fp16 *sb = alignedMalloc(K * N);
-
-  unsigned int ms, mms, ns, ks;
-  unsigned int m_min, m2_min, n_min, k_min;
-  for (ms = 0; ms < M; ms += M_BLOCKING) {
-    m_min = M - ms;
-    if (m_min > M_BLOCKING) {
-      m_min = M_BLOCKING;
-    }
-
-    for (ks = 0; ks < K; ks += k_min) {
-      k_min = K - ks;
-      if (k_min >= (K_BLOCKING << 1)) {
-        k_min = K_BLOCKING;
-      } else if (k_min > K_BLOCKING) {
-        k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
-      }
-
-      n_min = N;
-      if (N >= N_BLOCKING * 2) {
-        n_min = N_BLOCKING;
-      } else if (N > N_BLOCKING) {
-        n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
-      }
-      packing_B4(k_min, n_min, B + ks * ldb, ldb, sb);
-
-      for (mms = ms; mms < ms + m_min; mms += m2_min) {
-        m2_min = (ms + m_min) - mms;
-        if (m2_min >= 3 * GEMM_UNROLLING_4) {
-          m2_min = 3 * GEMM_UNROLLING_4;
-        } else if (m2_min >= 2 * GEMM_UNROLLING_4) {
-          m2_min = 2 * GEMM_UNROLLING_4;
-        } else if (m2_min > GEMM_UNROLLING_4) {
-          m2_min = GEMM_UNROLLING_4;
-        }
-
-        packing_A4(m2_min, k_min, A + mms * lda + ks, lda,
-                   sa + k_min * (mms - ms));
-
-        HGEMM_KERNEL_4x4(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
-                         C + mms * ldc, ldc);
-      }
-
-      for (ns = n_min; ns < N; ns += n_min) {
-        n_min = N - ns;
-        if (n_min >= N_BLOCKING * 2) {
-          n_min = N_BLOCKING;
-        } else if (n_min > N_BLOCKING) {
-          n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
-        }
-
-        packing_B4(k_min, n_min, B + ns + ldb * ks, ldb, sb);
-        HGEMM_KERNEL_4x4(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
-      }
-    }
-  }
-
-  free(sa);
-  free(sb);
-}
-
-void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K,
-                       const __fp16 *A, unsigned int lda, const __fp16 *B,
-                       unsigned int ldb, __fp16 *C, unsigned int ldc,
-                       float alpha, float beta) {
-  __fp16 *sa = alignedMalloc(M * K);
-  __fp16 *sb = alignedMalloc(K * N);
-
-  unsigned int ms, mms, ns, ks;
-  unsigned int m_min, m2_min, n_min, k_min;
-  unsigned int l1stride = 1;
-  for (ms = 0; ms < M; ms += M_BLOCKING) {
-    m_min = M - ms;
-    if (m_min > M_BLOCKING) {
-      m_min = M_BLOCKING;
-    }
-
-    for (ks = 0; ks < K; ks += k_min) {
-      k_min = K - ks;
-      if (k_min >= (K_BLOCKING << 1)) {
-        k_min = K_BLOCKING;
-      } else if (k_min > K_BLOCKING) {
-        k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
-      }
-
-      n_min = N;
-      if (N >= N_BLOCKING * 2) {
-        n_min = N_BLOCKING;
-      } else if (N > N_BLOCKING) {
-        n_min = ((n_min / 2 + GEMM_UNROLLING_8 - 1) / GEMM_UNROLLING_8) *
-                GEMM_UNROLLING_8;
-      } else {
-        l1stride = 0;
-      }
-      packing_B8(k_min, n_min, B + ks * ldb, ldb, sb);
-
-      for (mms = ms; mms < ms + m_min; mms += m2_min) {
-        m2_min = (ms + m_min) - mms;
-        if (m2_min >= 3 * GEMM_UNROLLING_4) {
-          m2_min = 3 * GEMM_UNROLLING_4;
-        } else if (m2_min >= 2 * GEMM_UNROLLING_4) {
-          m2_min = 2 * GEMM_UNROLLING_4;
-        } else if (m2_min > GEMM_UNROLLING_4) {
-          m2_min = GEMM_UNROLLING_4;
-        }
-
-        packing_A4(m2_min, k_min, A + mms * lda + ks, lda,
-                   sa + k_min * (mms - ms) * l1stride);
-
-        HGEMM_KERNEL_4x8(m2_min, n_min, k_min,
-                         sa + l1stride * k_min * (mms - ms), sb, C + mms * ldc,
-                         ldc);
-      }
-
-      for (ns = n_min; ns < N; ns += n_min) {
-        n_min = N - ns;
-        if (n_min >= N_BLOCKING * 2) {
-          n_min = N_BLOCKING;
-        } else if (n_min > N_BLOCKING) {
-          n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
-        }
-
-        packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sb);
-        HGEMM_KERNEL_4x8(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
-      }
-    }
-  }
-
-  free(sa);
-  free(sb);
-}
-
-void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K,
-                       const __fp16 *A, unsigned int lda, const __fp16 *B,
-                       unsigned int ldb, float *C, unsigned int ldc,
-                       float alpha, float beta) {
-  __fp16 *sa = alignedMalloc(M * K);
-  __fp16 *sb = alignedMalloc(K * N);
-
-  unsigned int ms, mms, ns, ks;
-  unsigned int m_min, m2_min, n_min, k_min;
-  unsigned int l1stride = 1;
-  for (ms = 0; ms < M; ms += M_BLOCKING) {
-    m_min = M - ms;
-    if (m_min > M_BLOCKING) {
-      m_min = M_BLOCKING;
-    }
-
-    for (ks = 0; ks < K; ks += k_min) {
-      k_min = K - ks;
-      if (k_min >= (K_BLOCKING << 1)) {
-        k_min = K_BLOCKING;
-      } else if (k_min > K_BLOCKING) {
-        k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
-      }
-
-      n_min = N;
-      if (N >= N_BLOCKING * 2) {
-        n_min = N_BLOCKING;
-      } else if (N > N_BLOCKING) {
-        n_min = ((n_min / 2 + GEMM_UNROLLING_8 - 1) / GEMM_UNROLLING_8) *
-                GEMM_UNROLLING_8;
-      } else {
-        l1stride = 0;
-      }
-      packing_B8(k_min, n_min, B + ks * ldb, ldb, sb);
-
-      for (mms = ms; mms < ms + m_min; mms += m2_min) {
-        m2_min = (ms + m_min) - mms;
-        if (m2_min >= 3 * GEMM_UNROLLING_4) {
-          m2_min = 3 * GEMM_UNROLLING_4;
-        } else if (m2_min >= 2 * GEMM_UNROLLING_4) {
-          m2_min = 2 * GEMM_UNROLLING_4;
-        } else if (m2_min > GEMM_UNROLLING_4) {
-          m2_min = GEMM_UNROLLING_4;
-        }
-
-        packing_A4(m2_min, k_min, A + mms * lda + ks, lda,
-                   sa + k_min * (mms - ms) * l1stride);
-
-        HGEMM_KERNEL_4x8(m2_min, n_min, k_min,
-                         sa + l1stride * k_min * (mms - ms), sb, C + mms * ldc,
-                         ldc);
-      }
-
-      for (ns = n_min; ns < N; ns += n_min) {
-        n_min = N - ns;
-        if (n_min >= N_BLOCKING * 2) {
-          n_min = N_BLOCKING;
-        } else if (n_min > N_BLOCKING) {
-          n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
-        }
-
-        packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sb);
-        HGEMM_KERNEL_4x8(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
-      }
-    }
-  }
-
-  free(sa);
-  free(sb);
-}
-
-void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K,
-                       const __fp16 *A, unsigned int lda, const __fp16 *B,
-                       unsigned int ldb, __fp16 *C, unsigned int ldc,
-                       float alpha, float beta) {
-
-  __fp16 *sa = alignedMalloc(M * K);
-  __fp16 *sb = alignedMalloc(K * N);
-
-  unsigned int ms, mms, ns, ks;
-  unsigned int m_min, m2_min, n_min, k_min;
-  for (ms = 0; ms < M; ms += M_BLOCKING) {
-    m_min = M - ms;
-    if (m_min > M_BLOCKING) {
-      m_min = M_BLOCKING;
-    }
-
-    for (ks = 0; ks < K; ks += k_min) {
-      k_min = K - ks;
-      if (k_min >= (K_BLOCKING << 1)) {
-        k_min = K_BLOCKING;
-      } else if (k_min > K_BLOCKING) {
-        k_min = (k_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
-      }
-
-      n_min = N;
-      if (N >= N_BLOCKING * 2) {
-        n_min = N_BLOCKING;
-      } else if (N > N_BLOCKING) {
-        n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
-      }
-      packing_B8(k_min, n_min, B + ks * ldb, ldb, sb);
-
-      for (mms = ms; mms < ms + m_min; mms += m2_min) {
-        m2_min = (ms + m_min) - mms;
-        if (m2_min >= 3 * GEMM_UNROLLING_8) {
-          m2_min = 3 * GEMM_UNROLLING_8;
-        } else if (m2_min >= 2 * GEMM_UNROLLING_8) {
-          m2_min = 2 * GEMM_UNROLLING_8;
-        } else if (m2_min > GEMM_UNROLLING_8) {
-          m2_min = GEMM_UNROLLING_8;
-        }
-
-        packing_A8(m2_min, k_min, A + mms * lda + ks, lda,
-                   sa + k_min * (mms - ms));
-
-        HGEMM_KERNEL_8x8(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
-                         C + mms * ldc, ldc);
-      }
-
-      for (ns = n_min; ns < N; ns += n_min) {
-        n_min = N - ns;
-        if (n_min >= N_BLOCKING * 2) {
-          n_min = N_BLOCKING;
-        } else if (n_min > N_BLOCKING) {
-          n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
-        }
-
-        packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sb);
-        HGEMM_KERNEL_8x8(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
-      }
-    }
-  }
-
-  free(sa);
-  free(sb);
-}
-
-void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K,
-                       const __fp16 *A, unsigned int lda, const __fp16 *B,
-                       unsigned int ldb, float *C, unsigned int ldc,
-                       float alpha, float beta) {
-
-  __fp16 *sA = alignedMalloc(M * K);
-  __fp16 *sB = alignedMalloc(K * N);
-
-  unsigned int ms, ms2, ns, ks;
-  unsigned int m_min, m2_min, n_min, k_min;
-  for (ms = 0; ms < M; ms += M_BLOCKING) {
-    m_min = M - ms;
-    if (m_min > M_BLOCKING) {
-      m_min = M_BLOCKING;
-    }
-
-    for (ks = 0; ks < K; ks += k_min) {
-      k_min = K - ks;
-      if (k_min >= (K_BLOCKING << 1)) {
-        k_min = K_BLOCKING;
-      } else if (k_min > K_BLOCKING) {
-        k_min = (k_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
-      }
-
-      n_min = N;
-      if (N >= N_BLOCKING * 2) {
-        n_min = N_BLOCKING;
-      } else if (N > N_BLOCKING) {
-        n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
-      }
-      packing_B8(k_min, n_min, B + ks * ldb, ldb, sB);
-
-      for (ms2 = ms; ms2 < ms + m_min; ms2 += m2_min) {
-        m2_min = (ms + m_min) - ms2;
-        if (m2_min >= 3 * GEMM_UNROLLING_8) {
-          m2_min = 3 * GEMM_UNROLLING_8;
-        } else if (m2_min >= 2 * GEMM_UNROLLING_8) {
-          m2_min = 2 * GEMM_UNROLLING_8;
-        } else if (m2_min > GEMM_UNROLLING_8) {
-          m2_min = GEMM_UNROLLING_8;
-        }
-
-        packing_A8(m2_min, k_min, A + ms2 * lda + ks, lda,
-                   sA + k_min * (ms2 - ms));
-
-        HGEMM_KERNEL_8x8(m2_min, n_min, k_min, sA + k_min * (ms2 - ms), sB,
-                         C + ms2 * ldc, ldc);
-      }
-
-      for (ns = n_min; ns < N; ns += n_min) {
-        n_min = N - ns;
-        if (n_min >= N_BLOCKING * 2) {
-          n_min = N_BLOCKING;
-        } else if (n_min > N_BLOCKING) {
-          n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
-        }
-
-        packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sB);
-        HGEMM_KERNEL_8x8(m_min, n_min, k_min, sA, sB, C + ms * ldc + ns, ldc);
-      }
-    }
-  }
-
-  free(sA);
-  free(sB);
-}
-
-void hgemm_noTrans_8x16(unsigned int M, unsigned int N, unsigned int K,
-                        const __fp16 *A, unsigned int lda, const __fp16 *B,
-                        unsigned int ldb, __fp16 *C, unsigned int ldc,
-                        float alpha, float beta) {
-  __fp16 *sA = alignedMalloc(M * K);
-  __fp16 *sB = alignedMalloc(K * N);
-
-  unsigned int ms, ms2, ns, ks;
-  unsigned int m_min, m2_min, n_min, k_min;
-  unsigned int stride_l1 = 1;
-
-  for (ms = 0; ms < M; ms += M_BLOCKING) {
-    m_min = M - ms;
-    if (m_min > M_BLOCKING) {
-      m_min = M_BLOCKING;
-    }
-
-    for (ks = 0; ks < K; ks += k_min) {
-      k_min = K - ks;
-      if (k_min >= (K_BLOCKING << 1)) {
-        k_min = K_BLOCKING;
-      } else if (k_min > K_BLOCKING) {
-        k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
-      }
-
-      n_min = N;
-      if (N >= N_BLOCKING * 2) {
-        n_min = N_BLOCKING;
-      } else if (N > N_BLOCKING) {
-        n_min = ((n_min / 2 + GEMM_UNROLLING_16 - 1) / GEMM_UNROLLING_16) *
-                GEMM_UNROLLING_16;
-      } else {
-        stride_l1 = 0;
-      }
-      packing_B16(k_min, n_min, B + ks * ldb, ldb, sB);
-
-      for (ms2 = ms; ms2 < ms + m_min; ms2 += m2_min) {
-        m2_min = (ms + m_min) - ms2;
-        if (m2_min >= 3 * GEMM_UNROLLING_8) {
-          m2_min = 3 * GEMM_UNROLLING_8;
-        } else if (m2_min >= 2 * GEMM_UNROLLING_8) {
-          m2_min = 2 * GEMM_UNROLLING_8;
-        } else if (m2_min > GEMM_UNROLLING_8) {
-          m2_min = GEMM_UNROLLING_8;
-        }
-
-        packing_A8(m2_min, k_min, A + ms2 * lda + ks, lda,
-                   sA + k_min * (ms2 - ms) * stride_l1);
-
-        HGEMM_KERNEL_8x16(m2_min, n_min, k_min,
-                          sA + stride_l1 * k_min * (ms2 - ms), sB,
-                          C + ms2 * ldc, ldc);
-      }
-
-      for (ns = n_min; ns < N; ns += n_min) {
-        n_min = N - ns;
-        if (n_min >= N_BLOCKING * 2) {
-          n_min = N_BLOCKING;
-        } else if (n_min > N_BLOCKING) {
-          n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
-        }
-
-        packing_B16(k_min, n_min, B + ns + ldb * ks, ldb, sB);
-        HGEMM_KERNEL_8x16(m_min, n_min, k_min, sA, sB, C + ms * ldc + ns, ldc);
-      }
-    }
-  }
-
-  free(sA);
-  free(sB);
-}
-
-void hgemm_noTrans_8x16(unsigned int M, unsigned int N, unsigned int K,
-                        const __fp16 *A, unsigned int lda, const __fp16 *B,
-                        unsigned int ldb, float *C, unsigned int ldc,
-                        float alpha, float beta) {
-  __fp16 *sA = alignedMalloc(M * K);
-  __fp16 *sB = alignedMalloc(K * N);
-
-  unsigned int ms, ms2, ns, ks;
-  unsigned int m_min, m2_min, n_min, k_min;
-  unsigned int stride_l1 = 1;
-
-  for (ms = 0; ms < M; ms += M_BLOCKING) {
-    m_min = M - ms;
-    if (m_min > M_BLOCKING) {
-      m_min = M_BLOCKING;
-    }
-
-    for (ks = 0; ks < K; ks += k_min) {
-      k_min = K - ks;
-      if (k_min >= (K_BLOCKING << 1)) {
-        k_min = K_BLOCKING;
-      } else if (k_min > K_BLOCKING) {
-        k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
-      }
-
-      n_min = N;
-      if (N >= N_BLOCKING * 2) {
-        n_min = N_BLOCKING;
-      } else if (N > N_BLOCKING) {
-        n_min = ((n_min / 2 + GEMM_UNROLLING_16 - 1) / GEMM_UNROLLING_16) *
-                GEMM_UNROLLING_16;
-      } else {
-        stride_l1 = 0;
-      }
-      packing_B16(k_min, n_min, B + ks * ldb, ldb, sB);
-
-      for (ms2 = ms; ms2 < ms + m_min; ms2 += m2_min) {
-        m2_min = (ms + m_min) - ms2;
-        if (m2_min >= 3 * GEMM_UNROLLING_8) {
-          m2_min = 3 * GEMM_UNROLLING_8;
-        } else if (m2_min >= 2 * GEMM_UNROLLING_8) {
-          m2_min = 2 * GEMM_UNROLLING_8;
-        } else if (m2_min > GEMM_UNROLLING_8) {
-          m2_min = GEMM_UNROLLING_8;
-        }
-
-        packing_A8(m2_min, k_min, A + ms2 * lda + ks, lda,
-                   sA + k_min * (ms2 - ms));
-
-        HGEMM_KERNEL_8x16(m2_min, n_min, k_min, sA + k_min * (ms2 - ms), sB,
-                          C + ms2 * ldc, ldc);
-      }
-
-      for (ns = n_min; ns < N; ns += n_min) {
-        n_min = N - ns;
-        if (n_min >= N_BLOCKING * 2) {
-          n_min = N_BLOCKING;
-        } else if (n_min > N_BLOCKING) {
-          n_min =
-            (n_min / 2 + GEMM_UNROLLING_16 - 1) & ~(GEMM_UNROLLING_16 - 1);
-
-        }
-
-        packing_B16(k_min, n_min, B + ns + ldb * ks, ldb, sB);
-
-        HGEMM_KERNEL_8x16(m_min, n_min, k_min, sA, sB, C + ms * ldc + ns, ldc);
-      }
-    }
-  }
-
-  free(sA);
-  free(sB);
-}
-
-void hgemm_noTrans_fallback(unsigned int M, unsigned int N, unsigned int K,
-                            const __fp16 *A, unsigned int lda, const __fp16 *B,
-                            unsigned int ldb, float *C, unsigned int ldc,
-                            float alpha, float beta) {
-
-  unsigned int k = 0;
-  unsigned int N8 = (N >> 3) << 3;
-  __fp16 a[16];
-  for (; (K - k) >= 16; k += 16) {
-    for (unsigned int m = 0; m < M; m++) {
-      vst1q_f16(&a[0], vmulq_n_f16(vld1q_f16(&A[m * K + k]), alpha));
-      vst1q_f16(&a[8], vmulq_n_f16(vld1q_f16(&A[m * K + k + 8]), alpha));
-      for (unsigned int n = 0; n < N8; n += 8) {
-        float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]);
-        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]);
-        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 2) * N + n]), a[2]);
-        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 3) * N + n]), a[3]);
-        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 4) * N + n]), a[4]);
-        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 5) * N + n]), a[5]);
-        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 6) * N + n]), a[6]);
-        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 7) * N + n]), a[7]);
-        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 8) * N + n]), a[8]);
-        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 9) * N + n]), a[9]);
-        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 10) * N + n]), a[10]);
-        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 11) * N + n]), a[11]);
-        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 12) * N + n]), a[12]);
-        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 13) * N + n]), a[13]);
-        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 14) * N + n]), a[14]);
-        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 15) * N + n]), a[15]);
-
-        float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]),
-                                            vcvt_f32_f16(vget_low_f16(b0_7_0)));
-        float32x4_t c0_7_high_32 = vaddq_f32(
-          vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0)));
-
-        vst1q_f32(&C[m * N + n], c0_7_low_32);
-        vst1q_f32(&C[m * N + n + 4], c0_7_high_32);
-      }
-      if (N != N8) {
-        unsigned int n = N8;
-        __fp16 valsB_0[8];
-        __fp16 valsB_1[8];
-        __fp16 valsB_2[8];
-        __fp16 valsB_3[8];
-        __fp16 valsB_4[8];
-        __fp16 valsB_5[8];
-        __fp16 valsB_6[8];
-        __fp16 valsB_7[8];
-        __fp16 valsB_8[8];
-        __fp16 valsB_9[8];
-        __fp16 valsB_10[8];
-        __fp16 valsB_11[8];
-        __fp16 valsB_12[8];
-        __fp16 valsB_13[8];
-        __fp16 valsB_14[8];
-        __fp16 valsB_15[8];
-        float valsC[8];
-        for (unsigned int idx = n; idx < N; idx++) {
-          valsB_0[idx - n] = B[k * N + idx];
-          valsB_1[idx - n] = B[(k + 1) * N + idx];
-          valsB_2[idx - n] = B[(k + 2) * N + idx];
-          valsB_3[idx - n] = B[(k + 3) * N + idx];
-          valsB_4[idx - n] = B[(k + 4) * N + idx];
-          valsB_5[idx - n] = B[(k + 5) * N + idx];
-          valsB_6[idx - n] = B[(k + 6) * N + idx];
-          valsB_7[idx - n] = B[(k + 7) * N + idx];
-          valsB_8[idx - n] = B[(k + 8) * N + idx];
-          valsB_9[idx - n] = B[(k + 9) * N + idx];
-          valsB_10[idx - n] = B[(k + 10) * N + idx];
-          valsB_11[idx - n] = B[(k + 11) * N + idx];
-          valsB_12[idx - n] = B[(k + 12) * N + idx];
-          valsB_13[idx - n] = B[(k + 13) * N + idx];
-          valsB_14[idx - n] = B[(k + 14) * N + idx];
-          valsB_15[idx - n] = B[(k + 15) * N + idx];
-          valsC[idx - n] = C[m * N + idx];
-        }
-
-        float16x8_t b = vmulq_n_f16(vld1q_f16(valsB_0), a[0]);
-        b = vfmaq_n_f16(b, vld1q_f16(valsB_1), a[1]);
-        b = vfmaq_n_f16(b, vld1q_f16(valsB_2), a[2]);
-        b = vfmaq_n_f16(b, vld1q_f16(valsB_3), a[3]);
-        b = vfmaq_n_f16(b, vld1q_f16(valsB_4), a[4]);
-        b = vfmaq_n_f16(b, vld1q_f16(valsB_5), a[5]);
-        b = vfmaq_n_f16(b, vld1q_f16(valsB_6), a[6]);
-        b = vfmaq_n_f16(b, vld1q_f16(valsB_7), a[7]);
-        b = vfmaq_n_f16(b, vld1q_f16(valsB_8), a[8]);
-        b = vfmaq_n_f16(b, vld1q_f16(valsB_9), a[9]);
-        b = vfmaq_n_f16(b, vld1q_f16(valsB_10), a[10]);
-        b = vfmaq_n_f16(b, vld1q_f16(valsB_11), a[11]);
-        b = vfmaq_n_f16(b, vld1q_f16(valsB_12), a[12]);
-        b = vfmaq_n_f16(b, vld1q_f16(valsB_13), a[13]);
-        b = vfmaq_n_f16(b, vld1q_f16(valsB_14), a[14]);
-        b = vfmaq_n_f16(b, vld1q_f16(valsB_15), a[15]);
-
-        float32x4_t c0_7_low_32 =
-          vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b)));
-
-        float32x4_t c0_7_high_32 =
-          vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b)));
-
-        vst1q_f32(valsC, c0_7_low_32);
-        vst1q_f32(valsC + 4, c0_7_high_32);
-
-        for (unsigned int idx = n; idx < N; idx++) {
-          C[m * N + idx] = valsC[idx - n];
-        }
-      }
-    }
-  }
-
-  for (; (K - k) >= 8; k += 8) {
-    for (unsigned int m = 0; m < M; m++) {
-      vst1q_f16(a, vmulq_n_f16(vld1q_f16(&A[m * K + k]), alpha));
-
-      for (unsigned int n = 0; n < N8; n += 8) {
-        float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]);
-        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]);
-        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 2) * N + n]), a[2]);
-        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 3) * N + n]), a[3]);
-        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 4) * N + n]), a[4]);
-        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 5) * N + n]), a[5]);
-        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 6) * N + n]), a[6]);
-        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 7) * N + n]), a[7]);
-
-        float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]),
-                                            vcvt_f32_f16(vget_low_f16(b0_7_0)));
-        float32x4_t c0_7_high_32 = vaddq_f32(
-          vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0)));
-
-        vst1q_f32(&C[m * N + n], c0_7_low_32);
-        vst1q_f32(&C[m * N + n + 4], c0_7_high_32);
-      }
-      if (N != N8) {
-        unsigned int n = N8;
-        __fp16 valsB_0[8];
-        __fp16 valsB_1[8];
-        __fp16 valsB_2[8];
-        __fp16 valsB_3[8];
-        __fp16 valsB_4[8];
-        __fp16 valsB_5[8];
-        __fp16 valsB_6[8];
-        __fp16 valsB_7[8];
-        float valsC[8];
-        for (unsigned int idx = n; idx < N; idx++) {
-          valsB_0[idx - n] = B[k * N + idx];
-          valsB_1[idx - n] = B[(k + 1) * N + idx];
-          valsB_2[idx - n] = B[(k + 2) * N + idx];
-          valsB_3[idx - n] = B[(k + 3) * N + idx];
-          valsB_4[idx - n] = B[(k + 4) * N + idx];
-          valsB_5[idx - n] = B[(k + 5) * N + idx];
-          valsB_6[idx - n] = B[(k + 6) * N + idx];
-          valsB_7[idx - n] = B[(k + 7) * N + idx];
-          valsC[idx - n] = C[m * N + idx];
-        }
-
-        float16x8_t b = vmulq_n_f16(vld1q_f16(valsB_0), a[0]);
-        b = vfmaq_n_f16(b, vld1q_f16(valsB_1), a[1]);
-        b = vfmaq_n_f16(b, vld1q_f16(valsB_2), a[2]);
-        b = vfmaq_n_f16(b, vld1q_f16(valsB_3), a[3]);
-        b = vfmaq_n_f16(b, vld1q_f16(valsB_4), a[4]);
-        b = vfmaq_n_f16(b, vld1q_f16(valsB_5), a[5]);
-        b = vfmaq_n_f16(b, vld1q_f16(valsB_6), a[6]);
-        b = vfmaq_n_f16(b, vld1q_f16(valsB_7), a[7]);
-
-        float32x4_t c0_7_low_32 =
-          vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b)));
-
-        float32x4_t c0_7_high_32 =
-          vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b)));
-
-        vst1q_f32(valsC, c0_7_low_32);
-        vst1q_f32(valsC + 4, c0_7_high_32);
-
-        for (unsigned int idx = n; idx < N; idx++) {
-          C[m * N + idx] = valsC[idx - n];
-        }
-      }
-    }
-  }
-
-  for (; (K - k) >= 4; k += 4) {
-    for (unsigned int m = 0; m < M; m++) {
-      vst1_f16(a, vmul_n_f16(vld1_f16(&A[m * K + k]), alpha));
-
-      for (unsigned int n = 0; n < N8; n += 8) {
-
-        float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]);
-        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]);
-        float16x8_t b0_7_2 = vmulq_n_f16(vld1q_f16(&B[(k + 2) * N + n]), a[2]);
-        b0_7_2 = vfmaq_n_f16(b0_7_2, vld1q_f16(&B[(k + 3) * N + n]), a[3]);
-
-        float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]),
-                                            vcvt_f32_f16(vget_low_f16(b0_7_0)));
-        float32x4_t c0_7_high_32 = vaddq_f32(
-          vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0)));
-
-        c0_7_low_32 =
-          vaddq_f32(c0_7_low_32, vcvt_f32_f16(vget_low_f16(b0_7_2)));
-        c0_7_high_32 =
-          vaddq_f32(c0_7_high_32, vcvt_f32_f16(vget_high_f16(b0_7_2)));
-
-        vst1q_f32(&C[m * N + n], c0_7_low_32);
-        vst1q_f32(&C[m * N + n + 4], c0_7_high_32);
-      }
-      if (N != N8) {
-        unsigned int n = N8;
-        __fp16 valsB_0[8];
-        __fp16 valsB_1[8];
-        __fp16 valsB_2[8];
-        __fp16 valsB_3[8];
-        float valsC[8];
-        for (unsigned int idx = n; idx < N; idx++) {
-          valsB_0[idx - n] = B[k * N + idx];
-          valsB_1[idx - n] = B[(k + 1) * N + idx];
-          valsB_2[idx - n] = B[(k + 2) * N + idx];
-          valsB_3[idx - n] = B[(k + 3) * N + idx];
-          valsC[idx - n] = C[m * N + idx];
-        }
-
-        float16x8_t b = vmulq_n_f16(vld1q_f16(valsB_0), a[0]);
-        b = vfmaq_n_f16(b, vld1q_f16(valsB_1), a[1]);
-        b = vfmaq_n_f16(b, vld1q_f16(valsB_2), a[2]);
-        b = vfmaq_n_f16(b, vld1q_f16(valsB_3), a[3]);
-
-        float32x4_t c0_7_low_32 =
-          vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b)));
-
-        float32x4_t c0_7_high_32 =
-          vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b)));
-
-        vst1q_f32(valsC, c0_7_low_32);
-        vst1q_f32(valsC + 4, c0_7_high_32);
-
-        for (unsigned int idx = n; idx < N; idx++) {
-          C[m * N + idx] = valsC[idx - n];
-        }
-      }
-    }
-  }
-
-  for (; k < K; k++) {
-    for (unsigned int m = 0; m < M; m++) {
-      __fp16 a0 = alpha * A[m * K + k];
-
-      for (unsigned int n = 0; n < N8; n += 8) {
-        float16x8_t b0_7 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a0);
-
-        float32x4_t c0_7_low_32 =
-          vaddq_f32(vld1q_f32(&C[m * N + n]), vcvt_f32_f16(vget_low_f16(b0_7)));
-
-        float32x4_t c0_7_high_32 = vaddq_f32(vld1q_f32(&C[m * N + n + 4]),
-                                             vcvt_f32_f16(vget_high_f16(b0_7)));
-
-        vst1q_f32(&C[m * N + n], c0_7_low_32);
-        vst1q_f32(&C[m * N + n + 4], c0_7_high_32);
-      }
-      if (N != N8) {
-        unsigned int n = N8;
-        __fp16 valsB[8];
-        float valsC[8];
-        for (unsigned int idx = n; idx < N; idx++) {
-          valsB[idx - n] = B[k * N + idx];
-          valsC[idx - n] = C[m * N + idx];
-        }
-
-        float16x8_t b = vmulq_n_f16(vld1q_f16(valsB), a0);
-
-        float32x4_t c0_7_low_32 =
-          vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b)));
-
-        float32x4_t c0_7_high_32 =
-          vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b)));
-
-        vst1q_f32(valsC, c0_7_low_32);
-        vst1q_f32(valsC + 4, c0_7_high_32);
-
-        for (unsigned int idx = n; idx < N; idx++) {
-          C[m * N + idx] = valsC[idx - n];
-        }
-      }
-    }
-  }
-}
-
-void hgemm_transB_8x16(unsigned int M, unsigned int N, unsigned int K,
-                       const __fp16 *A, unsigned int lda, const __fp16 *B,
-                       unsigned int ldb, float *C, unsigned int ldc,
-                       float alpha, float beta) {
-
-  __fp16 *sA = alignedMalloc(M * K);
-  __fp16 *sB = alignedMalloc(K * N);
-
-  unsigned int ms, ms2, ns, ks;
-  unsigned int m_min, m2_min, n_min, k_min;
-  unsigned int stride_l1 = 1;
-
-  for (ms = 0; ms < M; ms += M_BLOCKING) {
-    m_min = M - ms;
-    if (m_min > M_BLOCKING) {
-      m_min = M_BLOCKING;
-    }
-
-    for (ks = 0; ks < K; ks += k_min) {
-      k_min = K - ks;
-      if (k_min >= (K_BLOCKING << 1)) {
-        k_min = K_BLOCKING;
-      } else if (k_min > K_BLOCKING) {
-        k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
-      }
-
-      n_min = N;
-      if (N >= N_BLOCKING * 2) {
-        n_min = N_BLOCKING;
-      } else if (N > N_BLOCKING) {
-        n_min = ((n_min / 2 + GEMM_UNROLLING_16 - 1) / GEMM_UNROLLING_16) *
-                GEMM_UNROLLING_16;
-      } else {
-        stride_l1 = 0;
-      }
-      packing_transB16(k_min, n_min, B + (ks), ldb, sB);
-
-      for (ms2 = ms; ms2 < ms + m_min; ms2 += m2_min) {
-        m2_min = (ms + m_min) - ms2;
-        if (m2_min >= 3 * GEMM_UNROLLING_8) {
-          m2_min = 3 * GEMM_UNROLLING_8;
-        } else if (m2_min >= 2 * GEMM_UNROLLING_8) {
-          m2_min = 2 * GEMM_UNROLLING_8;
-        } else if (m2_min > GEMM_UNROLLING_8) {
-          m2_min = GEMM_UNROLLING_8;
-        }
-        packing_A8(m2_min, k_min, A + ms2 * lda + ks, lda,
-                   sA + k_min * (ms2 - ms) * stride_l1);
-        HGEMM_KERNEL_8x16(m2_min, n_min, k_min,
-                          sA + k_min * (ms2 - ms) * stride_l1, sB,
-                          C + ms2 * ldc, ldc);
-      }
-
-      for (ns = n_min; ns < N; ns += n_min) {
-        n_min = N - ns;
-        if (n_min >= N_BLOCKING * 2) {
-          n_min = N_BLOCKING;
-        } else if (n_min > N_BLOCKING) {
-          n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
-        }
-        packing_transB16(k_min, n_min, B + ks + ldb * ns, ldb, sB);
-        HGEMM_KERNEL_8x16(m_min, n_min, k_min, sA, sB, C + ms * ldc + ns, ldc);
-      }
-    }
-  }
-
-  free(sA);
-  free(sB);
-}
-
-void hgemm_transB(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
-                  unsigned int N, unsigned int K, float alpha, float beta) {
-  if (((M & 0x7) == 0 && (N & 0xF) == 0 && (K & 0x7) == 0)) {
-    return hgemm_transB_8x16(M, N, K, A, K, B, K, C, N, alpha, beta);
-  } else {
-    return hgemm_transB_fallback(A, B, C, M, N, K, alpha, beta);
-  }
-}
-
-void hgemm_transB_fallback(const __fp16 *A, const __fp16 *B, float *C,
-                           unsigned int M, unsigned int N, unsigned int K,
-                           float alpha, float beta) {
-  __fp16 *B_T = alignedMalloc(K * N);
-
-  transpose_neon<__fp16>(N, K, B, K, B_T, N);
-
-  hgemm_noTrans(A, B_T, C, M, N, K, alpha, beta);
-
-  free(B_T);
-}
-
-void hgemm_transA(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
-                  unsigned int N, unsigned int K, float alpha, float beta) {
-  __fp16 *A_T = alignedMalloc(M * K);
-
-  transpose_neon<__fp16>(K, M, A, M, A_T, K);
-
-  hgemm_noTrans(A_T, B, C, M, N, K, alpha, beta);
-
-  free(A_T);
-}
-
-void hgemm_transAB(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
-                   unsigned int N, unsigned int K, float alpha, float beta) {
-  __fp16 *A_T = alignedMalloc(M * K);
-  __fp16 *B_T = alignedMalloc(K * N);
-
-  transpose_neon<__fp16>(K, M, A, M, A_T, K);
-  transpose_neon<__fp16>(N, K, B, K, B_T, N);
-
-  hgemm_noTrans(A_T, B_T, C, M, N, K, alpha, beta);
-
-  free(A_T);
-  free(B_T);
-}
-
-void hgemm_K1(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M,
-              uint32_t N, uint32_t K, float alpha, float beta, bool TransA,
+void hgemm_K1(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
+              unsigned int N, unsigned int K, float alpha, float beta, bool TransA,
               bool TransB) {
   unsigned int lda = (TransA) ? M : K;
   unsigned int ldb = (TransB) ? K : N;
+  
+  return hgemm_K1_noTrans(M, N, K, A, lda, B, ldb, C, N, alpha, beta);
+
   if (!TransA && TransB) {
     hgemm_K1_transB(M, N, K, A, lda, B, ldb, C, N, alpha, beta);
   } else if (TransA && !TransB) {
index 45595e035f6de045e673d4dc2a208b99a62a08f9..d2dd28941c25a8b9a2f8e5bada087cc379bf70be 100644 (file)
 
 /**
  * @brief     hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
- * @param[in] A __fp16 * for Matrix A
- * @param[in] B __fp16 * for Matrix B
- * @param[in] C float * for Matrix C
- * @param[in] M number of op(A)'s and C's row
- * @param[in] N number of op(B)'s and C's columns
- * @param[in] K number of op(A)'s and columns and op(B)'s rows
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
-                   unsigned int N, unsigned int K, float alpha = 1.F,
-                   float beta = 0.F);
-
-/**
- * @brief     hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
- * @param[in] A __fp16 * for Matrix A
- * @param[in] B __fp16 * for Matrix B
- * @param[in] C __fp16 * for Matrix C
- * @param[in] M number of op(A)'s and C's row
- * @param[in] N number of op(B)'s and C's columns
- * @param[in] K number of op(A)'s and columns and op(B)'s rows
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, __fp16 *C,
-                          unsigned int M, unsigned int N, unsigned int K,
-                          float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief     hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * where op(X) is one of X or X**T
  * @param[in] A __fp16 * for Matrix A
  * @param[in] B __fp16 * for Matrix B
  * @param[in] C __fp16 * for Matrix C
@@ -53,9 +24,8 @@ void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, __fp16 *C,
  * @param[in] alpha float number
  * @param[in] beta float number
  */
-void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, float *C,
-                          unsigned int M, unsigned int N, unsigned int K,
-                          float alpha = 1.F, float beta = 0.F);
+void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M, unsigned int N,
+           unsigned int K, float alpha, float beta, bool TransA, bool TransB);
 
 /**
  * @brief     hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
@@ -87,402 +57,6 @@ void hgemm_ensure_divisibility(const __fp16 *A, const __fp16 *B, float *C32,
 void hgemm_classify(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M,
                     unsigned int N, unsigned int K, float alpha = 1.F, float beta = 0.F,
                     bool TransA = false, bool TransB = false);
-
-/**
- * @brief     hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans_fallback(unsigned int M, unsigned int N, unsigned int K,
-                            const __fp16 *A, unsigned int lda, const __fp16 *B,
-                            unsigned int ldb, float *C, unsigned int ldc,
-                            float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief     hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_K1_noTrans(unsigned int M, unsigned int N, unsigned int K,
-                      const __fp16 *A, unsigned int lda, const __fp16 *B,
-                      unsigned int ldb, __fp16 *C, unsigned int ldc,
-                      float alpha = 1.F, float beta = 0.F);
-/**
- * @brief     hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_K1_transA(unsigned int M, unsigned int N, unsigned int K,
-                     const __fp16 *A, unsigned int lda, const __fp16 *B,
-                     unsigned int ldb, __fp16 *C, unsigned int ldc,
-                     float alpha = 1.F, float beta = 0.F);
-/**
- * @brief     hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_K1_transB(unsigned int M, unsigned int N, unsigned int K,
-                     const __fp16 *A, unsigned int lda, const __fp16 *B,
-                     unsigned int ldb, __fp16 *C, unsigned int ldc,
-                     float alpha = 1.F, float beta = 0.F);
-/**
- * @brief     hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_K1_transAB(unsigned int M, unsigned int N, unsigned int K,
-                      const __fp16 *A, unsigned int lda, const __fp16 *B,
-                      unsigned int ldb, __fp16 *C, unsigned int ldc,
-                      float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief hgemm noTrans computation with 1x4 kernel : C = A*B,
- *
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K,
-                       const __fp16 *A, unsigned int lda, const __fp16 *B,
-                       unsigned int ldb, __fp16 *C, unsigned int ldc,
-                       float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief hgemm noTrans computation with 1x4 kernel : C = A*B,
- *
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K,
-                       const __fp16 *A, unsigned int lda, const __fp16 *B,
-                       unsigned int ldb, float *C, unsigned int ldc,
-                       float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief hgemm noTrans computation with 4x4 kernel : C = A*B,
- *
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K,
-                       const __fp16 *A, unsigned int lda, const __fp16 *B,
-                       unsigned int ldb, __fp16 *C, unsigned int ldc,
-                       float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief hgemm noTrans computation with 1x8 kernel : C = A*B,
- *
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans_1x8(unsigned int M, unsigned int N, unsigned int K,
-                       const __fp16 *A, unsigned int lda, const __fp16 *B,
-                       unsigned int ldb, __fp16 *C, unsigned int ldc,
-                       float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief hgemm noTrans computation with 1x8 kernel : C = A*B,
- *
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans_1x8(unsigned int M, unsigned int N, unsigned int K,
-                       const __fp16 *A, unsigned int lda, const __fp16 *B,
-                       unsigned int ldb, float *C, unsigned int ldc,
-                       float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief hgemm noTrans computation with 8x8 kernel : C = A*B,
- *
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K,
-                       const __fp16 *A, unsigned int lda, const __fp16 *B,
-                       unsigned int ldb, __fp16 *C, unsigned int ldc,
-                       float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief hgemm noTrans computation with 4x4 kernel : C = A*B,
- *
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K,
-                       const __fp16 *A, unsigned int lda, const __fp16 *B,
-                       unsigned int ldb, float *C, unsigned int ldc,
-                       float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief hgemm noTrans computation with 8x8 kernel : C = A*B,
- *
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K,
-                       const __fp16 *A, unsigned int lda, const __fp16 *B,
-                       unsigned int ldb, float *C, unsigned int ldc,
-                       float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief hgemm noTrans computation with 4x8 kernel : C = A*B,
- *
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K,
-                       const __fp16 *A, unsigned int lda, const __fp16 *B,
-                       unsigned int ldb, __fp16 *C, unsigned int ldc,
-                       float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief hgemm noTrans computation with 4x8 kernel : C = A*B,
- *
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K,
-                       const __fp16 *A, unsigned int lda, const __fp16 *B,
-                       unsigned int ldb, float *C, unsigned int ldc,
-                       float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief hgemm noTrans computation with 8x16 kernel : C = A*B,
- *
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans_8x16(unsigned int M, unsigned int N, unsigned int K,
-                        const __fp16 *A, unsigned int lda, const __fp16 *B,
-                        unsigned int ldb, __fp16 *C, unsigned int ldc,
-                        float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief hgemm noTrans computation with 8x16 kernel : C = A*B,
- *
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans_8x16(unsigned int M, unsigned int N, unsigned int K,
-                        const __fp16 *A, unsigned int lda, const __fp16 *B,
-                        unsigned int ldb, float *C, unsigned int ldc,
-                        float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief     hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
- * where op(X) is one of X or X**T
- * @param[in] A __fp16 * for Matrix A
- * @param[in] B __fp16 * for Matrix B
- * @param[in] C __fp16 * for Matrix C
- * @param[in] M number of op(A)'s and C's row
- * @param[in] N number of op(B)'s and C's columns
- * @param[in] K number of op(A)'s and columns and op(B)'s rows
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_transA(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
-                  unsigned int N, unsigned int K, float alpha, float beta);
-/**
- * @brief     hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
- * where op(X) is one of X or X**T
- * @param[in] A __fp16 * for Matrix A
- * @param[in] B __fp16 * for Matrix B
- * @param[in] C __fp16 * for Matrix C
- * @param[in] M number of op(A)'s and C's row
- * @param[in] N number of op(B)'s and C's columns
- * @param[in] K number of op(A)'s and columns and op(B)'s rows
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_transB(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
-                  unsigned int N, unsigned int K, float alpha, float beta);
-
-void hgemm_transB_fallback(const __fp16 *A, const __fp16 *B, float *C,
-                           unsigned int M, unsigned int N, unsigned int K,
-                           float alpha, float beta);
-
-/**
- * @brief     hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
- * where op(X) is one of X or X**T
- * @param[in] A __fp16 * for Matrix A
- * @param[in] B __fp16 * for Matrix B
- * @param[in] C __fp16 * for Matrix C
- * @param[in] M number of op(A)'s and C's row
- * @param[in] N number of op(B)'s and C's columns
- * @param[in] K number of op(A)'s and columns and op(B)'s rows
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_transB_8x16(unsigned int M, unsigned int N, unsigned int K,
-                       const __fp16 *A, unsigned int lda, const __fp16 *B,
-                       unsigned int ldb, float *C, unsigned int ldc,
-                       float alpha = 1.F, float beta = 0.F);
-/**
- * @brief     hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
- * where op(X) is one of X or X**T
- * @param[in] A __fp16 * for Matrix A
- * @param[in] B __fp16 * for Matrix B
- * @param[in] C __fp16 * for Matrix C
- * @param[in] M number of op(A)'s and C's row
- * @param[in] N number of op(B)'s and C's columns
- * @param[in] K number of op(A)'s and columns and op(B)'s rows
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_transAB(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
-                   unsigned int N, unsigned int K, float alpha, float beta);
 /**
  * @brief     hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
  * where op(X) is one of X or X**T
@@ -495,6 +69,6 @@ void hgemm_transAB(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
  * @param[in] alpha float number
  * @param[in] beta float number
  */
-void hgemm_K1(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M,
-              uint32_t N, uint32_t K, float alpha, float beta, bool TransA,
+void hgemm_K1(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
+              unsigned int N, unsigned int K, float alpha, float beta, bool TransA,
               bool TransB);
index d5c965c477fb18056501d8172659a49af8930a76..bdf9bcbcace1f883a41a8658c84f400b79649dd9 100644 (file)
@@ -13,6 +13,7 @@
 #include <arm_neon.h>
 #include <assert.h>
 
+
 #define A(i, j) a[(i) * lda + (j)]
 #define B(i, j) b[(i) * ldb + (j)]
 #define C(i, j) c[(i) * ldc + (j)]
@@ -27,6 +28,8 @@
 #define VL_FP16 (8)
 #define VL_FP16_HALF (4)
 
+
+
 /**
  * @todo Add macro for instructions in other CPU architectures
  */
diff --git a/nntrainer/tensor/hgemm/hgemm_kernel.h b/nntrainer/tensor/hgemm/hgemm_kernel.h
new file mode 100644 (file)
index 0000000..4bcea0f
--- /dev/null
@@ -0,0 +1,13 @@
+// #include <hgemm_kernel_1x4.h>
+// #include <hgemm_kernel_1x8.h>
+// #include <hgemm_kernel_4x4.h>
+// #include <hgemm_kernel_4x8.h>
+// #include <hgemm_kernel_8x16.h>
+// #include <hgemm_kernel_8x8.h>
+
+// #define HGEMM_KERNEL_1x4 hgemm_kernel_1x4
+// #define HGEMM_KERNEL_4x4 hgemm_kernel_4x4
+// #define HGEMM_KERNEL_1x8 hgemm_kernel_1x8
+// #define HGEMM_KERNEL_4x8 hgemm_kernel_4x8
+// #define HGEMM_KERNEL_8x8 hgemm_kernel_8x8
+// #define HGEMM_KERNEL_8x16 hgemm_kernel_8x16
\ No newline at end of file
index c189f63603e9ecd270f6dede42d7c35c58b88022..d0018876215b497d3e29e5d3d16d1bdab37452ed 100644 (file)
@@ -11,8 +11,9 @@
  *
  */
 
-#include <hgemm_common.h>
 #include <stdlib.h>
+#include <arm_neon.h>
+#include <assert.h>
 
 /**
  * @brief hgemm 1x4 kernel sc = sa * sb
@@ -33,11 +34,11 @@ void hgemm_kernel_1x4(unsigned int M, unsigned int N, unsigned int K,
   __fp16 *a = sa, *b = sb, *c = sc;
   unsigned int i, j, l;
   for (i = 0; i < M; i++) {
-    for (j = 0; j < N; j += VL_FP16_HALF) {
+    for (j = 0; j < N; j += 4) {
       __builtin_prefetch(b, 0, 3);
       __builtin_prefetch(a, 0, 3);
 
-      for (l = 0; l < K; l += VL_FP16_HALF) {
+      for (l = 0; l < K; l += 4) {
         float16x4_t v24 = {0.F};
         float16x4_t v0 = vld1_f16(b);
         float16_t v16 = *a;
@@ -99,11 +100,11 @@ void hgemm_kernel_1x4(unsigned int M, unsigned int N, unsigned int K,
   float *c = sc;
   unsigned int i, j, l;
   for (i = 0; i < M; i++) {
-    for (j = 0; j < N; j += VL_FP16_HALF) {
+    for (j = 0; j < N; j += 4) {
       __builtin_prefetch(b, 0, 3);
       __builtin_prefetch(a, 0, 3);
 
-      for (l = 0; l < K; l += VL_FP16_HALF) {
+      for (l = 0; l < K; l += 4) {
         float16x4_t v24 = {0.F};
         float16x4_t v0 = vld1_f16(b);
         float16_t v16 = *a;
index 5503dd7c3e760a73d45cee8c4db71742c263ed3b..3114ca32a221d47ec0a9ffff457d143f125c2dcb 100644 (file)
@@ -12,7 +12,8 @@
  *
  */
 
-#include <hgemm_common.h>
+#include <arm_neon.h>
+#include <assert.h>
 #include <stdlib.h>
 
 // 1. Partial sum 64 digits : worst accuracy, best latency
index 8a837bb19b5ef05913be61666e45d55439030aa7..18e86ccca116533fcc2fda272adb6372fcfb9612 100644 (file)
@@ -11,7 +11,8 @@
  *
  */
 
-#include <hgemm_common.h>
+#include <arm_neon.h>
+#include <assert.h>
 #include <stdlib.h>
 
 #define INIT_KERNEL_4x4()  \
@@ -230,8 +231,8 @@ void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K,
 
   __fp16 *a = sa, *b = sb, *c = sc;
   unsigned int i, j, l;
-  for (i = 0; i < M; i += VL_FP16_HALF) {
-    for (j = 0; j < N; j += VL_FP16_HALF) {
+  for (i = 0; i < M; i += 4) {
+    for (j = 0; j < N; j += 4) {
       __builtin_prefetch(b, 0, 3);
       __builtin_prefetch(a, 0, 3);
 
@@ -241,7 +242,7 @@ void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K,
       float16x4_t v27;
       INIT_KERNEL_4x4();
 
-      for (l = 0; l < K; l += VL_FP16_HALF) {
+      for (l = 0; l < K; l += 4) {
         float16x4_t v0 = vld1_f16(b);
         float16x4_t v16 = vld1_f16(a);
 
@@ -322,8 +323,8 @@ void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K,
   unsigned int i, j, l;
   unsigned int K16 = (K >> 4) << 4;
   unsigned int K8 = (K >> 3) << 3;
-  for (i = 0; i < M; i += VL_FP16_HALF) {
-    for (j = 0; j < N; j += VL_FP16_HALF) {
+  for (i = 0; i < M; i += 4) {
+    for (j = 0; j < N; j += 4) {
       __builtin_prefetch(b, 0, 3);
       __builtin_prefetch(a, 0, 3);
 
index aa9183e75a267513f664b04b63a4ad351f12b5f2..b1757bb5526a5aa32f0ff2cf72499bfa184e03e4 100644 (file)
@@ -11,7 +11,8 @@
  *
  */
 
-#include <hgemm_common.h>
+#include <arm_neon.h>
+#include <assert.h>
 #include <stdlib.h>
 
 #define INIT_KERNEL_4X8()  \
index 63886cbb0fbc667a0bae2ae2c9af70b2a8d04d53..d29cbfc2af814b13da1c2b04dc6de2f2126cf256 100644 (file)
@@ -11,7 +11,8 @@
  *
  */
 
-#include <hgemm_common.h>
+#include <arm_neon.h>
+#include <assert.h>
 #include <iostream>
 #include <stdlib.h>
 
@@ -809,8 +810,6 @@ void hgemm_kernel_8x16(unsigned int M, unsigned int N, unsigned int K,
   assert(M > 0 && N > 0 && K > 0);
   assert(M % 8 == 0 && N % 16 == 0 && K % 4 == 0);
 
-  // std::cout << " m : " << M << " , n : " << N << " , k : " << K << std::endl;
-
   __fp16 *a = sa, *b = sb;
   float *c = sc;
   unsigned int i, j, l;
index 374b2f5cd2680bf9aae2ac21d99dbe63d35e0932..2e3eb6a75b0078dd844cb2e340280a9ee1c7b5a5 100644 (file)
@@ -11,7 +11,8 @@
  *
  */
 
-#include <hgemm_common.h>
+#include <arm_neon.h>
+#include <assert.h>
 #include <stdlib.h>
 
 #define INIT_KERNEL_8x8()   \
@@ -416,8 +417,8 @@ void hgemm_kernel_8x8(unsigned int M, unsigned int N, unsigned int K,
 
   __fp16 *a = sa, *b = sb, *c = sc;
   unsigned int i, j, l;
-  for (i = 0; i < M; i += VL_FP16) {
-    for (j = 0; j < N; j += VL_FP16) {
+  for (i = 0; i < M; i += 8) {
+    for (j = 0; j < N; j += 8) {
       __builtin_prefetch(b, 0, 3);
       __builtin_prefetch(a, 0, 3);
 
@@ -469,8 +470,8 @@ void hgemm_kernel_8x8(unsigned int M, unsigned int N, unsigned int K,
   unsigned int K4 = (K >> 2) << 2;
   unsigned int K8 = (K >> 3) << 3;
   unsigned int K16 = (K >> 4) << 4;
-  for (i = 0; i < M; i += VL_FP16) {
-    for (j = 0; j < N; j += VL_FP16) {
+  for (i = 0; i < M; i += 8) {
+    for (j = 0; j < N; j += 8) {
       __builtin_prefetch(b, 0, 3);
       __builtin_prefetch(a, 0, 3);
 
diff --git a/nntrainer/tensor/hgemm/hgemm_noTrans.cpp b/nntrainer/tensor/hgemm/hgemm_noTrans.cpp
new file mode 100644 (file)
index 0000000..64a32b3
--- /dev/null
@@ -0,0 +1,1221 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file   hgemm_noTrans.cpp
+ * @date   10 July 2024
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is half-precision GEMM interface of non-transposed case
+ *
+ */
+
+#include <cmath>
+
+#include <hgemm_kernel_pack.h>
+#include <hgemm_noTrans.h>
+#include <hgemm_util.h>
+#include <limits>
+// #include <hgemm_kernel.h>
+
+#include <matrix_transpose_neon.h>
+#include <hgemm_common.h>
+
+#include <hgemm_kernel_1x4.h>
+#include <hgemm_kernel_1x8.h>
+#include <hgemm_kernel_4x4.h>
+#include <hgemm_kernel_4x8.h>
+#include <hgemm_kernel_8x16.h>
+#include <hgemm_kernel_8x8.h>
+
+
+
+void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M,
+                   unsigned int N, unsigned int K, float alpha, float beta) {
+  const float eps = std::numeric_limits<float>::epsilon();
+  if (std::abs(alpha - 1.F) < eps) {
+    hgemm_noTrans_strict(A, B, C32, M, N, K, alpha, beta);
+  } else {
+    hgemm_noTrans_fallback(M, N, K, A, K, B, N, C32, N, alpha, beta);
+  }
+}
+
+void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, float *C32,
+                          unsigned int M, unsigned int N, unsigned int K,
+                          float alpha, float beta) {
+  // used bitwise operator instead of modulo for performance
+  // e.g (M % 8) is same as (M & 0x7) which will extract last 3 bits of M
+  if ((M & 0x7) == 0 && (N & 0xF) == 0 && (K & 0x7) == 0) {
+    hgemm_noTrans_8x16(M, N, K, A, K, B, N, C32, N, alpha, beta);
+  } else if ((M & 0x7) == 0 && (N & 0x7) == 0 && (K & 0x7) == 0) {
+    hgemm_noTrans_8x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
+  } else if ((M & 0x3) == 0 && (N & 0x7) == 0 && (K & 0x7) == 0) {
+    hgemm_noTrans_4x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
+  } else if ((K & 0x7) == 0 && (N & 0x7) == 0) {
+    hgemm_noTrans_1x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
+  } else if ((K & 0x7) == 0 && (N & 0x3) == 0) {
+    hgemm_noTrans_1x4(M, N, K, A, K, B, N, C32, N, alpha, beta);
+  } else {
+    hgemm_noTrans_fallback(M, N, K, A, K, B, N, C32, N, alpha, beta);
+  }
+}
+
+void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, __fp16 *C,
+                          unsigned int M, unsigned int N, unsigned int K,
+                          float alpha, float beta) {
+  if (alpha == 1.F) {
+    // used bitwise operator instead of modulo for performance
+    // e.g (M % 8) is same as (M & 0x7) which will extract last 3 bits of M
+    if ((M & 0x7) == 0 && (N & 0xF) == 0 && (K & 0x7) == 0) {
+      hgemm_noTrans_8x16(M, N, K, A, K, B, N, C, N, alpha, beta);
+    } else if ((M & 0x7) == 0 && (N & 0x7) == 0 && (K & 0x7) == 0) {
+      hgemm_noTrans_8x8(M, N, K, A, K, B, N, C, N, alpha, beta);
+    } else if ((M & 0x3) == 0 && (N & 0x7) == 0 && (K & 0x3) == 0) {
+      hgemm_noTrans_4x8(M, N, K, A, K, B, N, C, N, alpha, beta);
+    } else if ((M & 0x3) == 0 && (N & 0x3) == 0 && (K & 0x3) == 0) {
+      hgemm_noTrans_4x4(M, N, K, A, K, B, N, C, N, alpha, beta);
+    } else if ((N & 0x7) == 0 && (K & 0x7) == 0) {
+      hgemm_noTrans_1x8(M, N, K, A, K, B, N, C, N, alpha, beta);
+    } else if ((N & 0x3) == 0 && (K & 0x7) == 0) {
+      hgemm_noTrans_1x4(M, N, K, A, K, B, N, C, N, alpha, beta);
+    }
+  }
+}
+
+void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, __fp16 *C, unsigned int ldc,
+                       float alpha, float beta) {
+  __fp16 *sa = alignedMalloc(M * K);
+  __fp16 *sb = alignedMalloc(K * N);
+
+  unsigned int ms, mms, ns, ks;
+  unsigned int m_min, m2_min, n_min, k_min;
+  for (ms = 0; ms < M; ms += M_BLOCKING) {
+    m_min = M - ms;
+    if (m_min > M_BLOCKING) {
+      m_min = M_BLOCKING;
+    }
+
+    for (ks = 0; ks < K; ks += k_min) {
+      k_min = K - ks;
+      if (k_min >= (K_BLOCKING << 1)) {
+        k_min = K_BLOCKING;
+      } else if (k_min > K_BLOCKING) {
+        k_min = (k_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
+      }
+
+      n_min = N;
+      if (N >= N_BLOCKING * 2) {
+        n_min = N_BLOCKING;
+      } else if (N > N_BLOCKING) {
+        n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+      }
+      packing_B4(k_min, n_min, B + ks * ldb, ldb, sb);
+
+      for (mms = ms; mms < ms + m_min; mms += m2_min) {
+        m2_min = (ms + m_min) - mms;
+        if (m2_min >= 3 * GEMM_UNROLLING_1) {
+          m2_min = 3 * GEMM_UNROLLING_1;
+        } else if (m2_min >= 2 * GEMM_UNROLLING_1) {
+          m2_min = 2 * GEMM_UNROLLING_1;
+        } else if (m2_min > GEMM_UNROLLING_1) {
+          m2_min = GEMM_UNROLLING_1;
+        }
+
+        packing_A1(m2_min, k_min, A + mms * lda + ks, lda,
+                   sa + k_min * (mms - ms));
+
+        hgemm_kernel_1x4(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
+                         C + mms * ldc, ldc);
+      }
+
+      for (ns = n_min; ns < N; ns += n_min) {
+        n_min = N - ns;
+        if (n_min >= N_BLOCKING * 2) {
+          n_min = N_BLOCKING;
+        } else if (n_min > N_BLOCKING) {
+          n_min = (n_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
+        }
+
+        packing_B4(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+        hgemm_kernel_1x4(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+      }
+    }
+  }
+
+  free(sa);
+  free(sb);
+}
+
+void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, float *C, unsigned int ldc,
+                       float alpha, float beta) {
+  __fp16 *sa = alignedMalloc(M * K);
+  __fp16 *sb = alignedMalloc(K * N);
+
+  unsigned int ms, mms, ns, ks;
+  unsigned int m_min, m2_min, n_min, k_min;
+  for (ms = 0; ms < M; ms += M_BLOCKING) {
+    m_min = M - ms;
+    if (m_min > M_BLOCKING) {
+      m_min = M_BLOCKING;
+    }
+
+    for (ks = 0; ks < K; ks += k_min) {
+      k_min = K - ks;
+      if (k_min >= (K_BLOCKING << 1)) {
+        k_min = K_BLOCKING;
+      } else if (k_min > K_BLOCKING) {
+        k_min = (k_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
+      }
+
+      n_min = N;
+      if (N >= N_BLOCKING * 2) {
+        n_min = N_BLOCKING;
+      } else if (N > N_BLOCKING) {
+        n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+      }
+      packing_B4(k_min, n_min, B + ks * ldb, ldb, sb);
+
+      for (mms = ms; mms < ms + m_min; mms += m2_min) {
+        m2_min = (ms + m_min) - mms;
+        if (m2_min >= 3 * GEMM_UNROLLING_1) {
+          m2_min = 3 * GEMM_UNROLLING_1;
+        } else if (m2_min >= 2 * GEMM_UNROLLING_1) {
+          m2_min = 2 * GEMM_UNROLLING_1;
+        } else if (m2_min > GEMM_UNROLLING_1) {
+          m2_min = GEMM_UNROLLING_1;
+        }
+
+        packing_A1(m2_min, k_min, A + mms * lda + ks, lda,
+                   sa + k_min * (mms - ms));
+
+        hgemm_kernel_1x4(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
+                         C + mms * ldc, ldc);
+      }
+
+      for (ns = n_min; ns < N; ns += n_min) {
+        n_min = N - ns;
+        if (n_min >= N_BLOCKING * 2) {
+          n_min = N_BLOCKING;
+        } else if (n_min > N_BLOCKING) {
+          n_min = (n_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
+        }
+
+        packing_B4(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+        hgemm_kernel_1x4(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+      }
+    }
+  }
+
+  free(sa);
+  free(sb);
+}
+
+void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, __fp16 *C, unsigned int ldc,
+                       float alpha, float beta) {
+  __fp16 *sa = alignedMalloc(M * K);
+  __fp16 *sb = alignedMalloc(K * N);
+
+  unsigned int ms, mms, ns, ks;
+  unsigned int m_min, m2_min, n_min, k_min;
+  for (ms = 0; ms < M; ms += M_BLOCKING) {
+    m_min = M - ms;
+    if (m_min > M_BLOCKING) {
+      m_min = M_BLOCKING;
+    }
+
+    for (ks = 0; ks < K; ks += k_min) {
+      k_min = K - ks;
+      if (k_min >= (K_BLOCKING << 1)) {
+        k_min = K_BLOCKING;
+      } else if (k_min > K_BLOCKING) {
+        k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+      }
+
+      n_min = N;
+      if (N >= N_BLOCKING * 2) {
+        n_min = N_BLOCKING;
+      } else if (N > N_BLOCKING) {
+        n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+      }
+      packing_B4(k_min, n_min, B + ks * ldb, ldb, sb);
+
+      for (mms = ms; mms < ms + m_min; mms += m2_min) {
+        m2_min = (ms + m_min) - mms;
+        if (m2_min >= 3 * GEMM_UNROLLING_4) {
+          m2_min = 3 * GEMM_UNROLLING_4;
+        } else if (m2_min >= 2 * GEMM_UNROLLING_4) {
+          m2_min = 2 * GEMM_UNROLLING_4;
+        } else if (m2_min > GEMM_UNROLLING_4) {
+          m2_min = GEMM_UNROLLING_4;
+        }
+
+        packing_A4(m2_min, k_min, A + mms * lda + ks, lda,
+                   sa + k_min * (mms - ms));
+
+        hgemm_kernel_4x4(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
+                         C + mms * ldc, ldc);
+      }
+
+      for (ns = n_min; ns < N; ns += n_min) {
+        n_min = N - ns;
+        if (n_min >= N_BLOCKING * 2) {
+          n_min = N_BLOCKING;
+        } else if (n_min > N_BLOCKING) {
+          n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+        }
+
+        packing_B4(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+        hgemm_kernel_4x4(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+      }
+    }
+  }
+
+  free(sa);
+  free(sb);
+}
+
+void hgemm_noTrans_1x8(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, __fp16 *C, unsigned int ldc,
+                       float alpha, float beta) {
+  __fp16 *sa = alignedMalloc(M * K);
+  __fp16 *sb = alignedMalloc(K * N);
+
+  unsigned int ms, mms, ns, ks;
+  unsigned int m_min, m2_min, n_min, k_min;
+  unsigned int l1stride = 1;
+  for (ms = 0; ms < M; ms += M_BLOCKING) {
+    m_min = M - ms;
+    if (m_min > M_BLOCKING) {
+      m_min = M_BLOCKING;
+    }
+
+    for (ks = 0; ks < K; ks += k_min) {
+      k_min = K - ks;
+      if (k_min >= (K_BLOCKING << 1)) {
+        k_min = K_BLOCKING;
+      } else if (k_min > K_BLOCKING) {
+        k_min = (k_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
+      }
+
+      n_min = N;
+      if (N >= N_BLOCKING * 2) {
+        n_min = N_BLOCKING;
+      } else if (N > N_BLOCKING) {
+        n_min = ((n_min / 2 + GEMM_UNROLLING_8 - 1) / GEMM_UNROLLING_8) *
+                GEMM_UNROLLING_8;
+      } else {
+        l1stride = 0;
+      }
+      packing_B8(k_min, n_min, B + ks * ldb, ldb, sb);
+
+      for (mms = ms; mms < ms + m_min; mms += m2_min) {
+        m2_min = (ms + m_min) - mms;
+        if (m2_min >= 3 * GEMM_UNROLLING_1) {
+          m2_min = 3 * GEMM_UNROLLING_1;
+        } else if (m2_min >= 2 * GEMM_UNROLLING_1) {
+          m2_min = 2 * GEMM_UNROLLING_1;
+        } else if (m2_min > GEMM_UNROLLING_1) {
+          m2_min = GEMM_UNROLLING_1;
+        }
+
+        packing_A1(m2_min, k_min, A + mms * lda + ks, lda,
+                   sa + k_min * (mms - ms) * l1stride);
+
+        hgemm_kernel_1x8(m2_min, n_min, k_min,
+                         sa + l1stride * k_min * (mms - ms), sb, C + mms * ldc,
+                         ldc);
+      }
+
+      for (ns = n_min; ns < N; ns += n_min) {
+        n_min = N - ns;
+        if (n_min >= N_BLOCKING * 2) {
+          n_min = N_BLOCKING;
+        } else if (n_min > N_BLOCKING) {
+          n_min = (n_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
+        }
+
+        packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+        hgemm_kernel_1x8(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+      }
+    }
+  }
+
+  free(sa);
+  free(sb);
+}
+
+void hgemm_noTrans_1x8(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, float *C, unsigned int ldc,
+                       float alpha, float beta) {
+  __fp16 *sa = alignedMalloc(M * K);
+  __fp16 *sb = alignedMalloc(K * N);
+
+  unsigned int ms, mms, ns, ks;
+  unsigned int m_min, m2_min, n_min, k_min;
+  unsigned int l1stride = 1;
+  for (ms = 0; ms < M; ms += M_BLOCKING) {
+    m_min = M - ms;
+    if (m_min > M_BLOCKING) {
+      m_min = M_BLOCKING;
+    }
+
+    for (ks = 0; ks < K; ks += k_min) {
+      k_min = K - ks;
+      if (k_min >= (K_BLOCKING << 1)) {
+        k_min = K_BLOCKING;
+      } else if (k_min > K_BLOCKING) {
+        k_min = (k_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
+      }
+
+      n_min = N;
+      if (N >= N_BLOCKING * 2) {
+        n_min = N_BLOCKING;
+      } else if (N > N_BLOCKING) {
+        n_min = ((n_min / 2 + GEMM_UNROLLING_8 - 1) / GEMM_UNROLLING_8) *
+                GEMM_UNROLLING_8;
+      } else {
+        l1stride = 0;
+      }
+      packing_B8(k_min, n_min, B + ks * ldb, ldb, sb);
+
+      for (mms = ms; mms < ms + m_min; mms += m2_min) {
+        m2_min = (ms + m_min) - mms;
+        if (m2_min >= 3 * GEMM_UNROLLING_1) {
+          m2_min = 3 * GEMM_UNROLLING_1;
+        } else if (m2_min >= 2 * GEMM_UNROLLING_1) {
+          m2_min = 2 * GEMM_UNROLLING_1;
+        } else if (m2_min > GEMM_UNROLLING_1) {
+          m2_min = GEMM_UNROLLING_1;
+        }
+
+        packing_A1(m2_min, k_min, A + mms * lda + ks, lda,
+                   sa + k_min * (mms - ms) * l1stride);
+
+        hgemm_kernel_1x8(m2_min, n_min, k_min,
+                         sa + l1stride * k_min * (mms - ms), sb, C + mms * ldc,
+                         ldc);
+      }
+
+      for (ns = n_min; ns < N; ns += n_min) {
+        n_min = N - ns;
+        if (n_min >= N_BLOCKING * 2) {
+          n_min = N_BLOCKING;
+        } else if (n_min > N_BLOCKING) {
+          n_min = (n_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
+        }
+
+        packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+        hgemm_kernel_1x8(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+      }
+    }
+  }
+
+  free(sa);
+  free(sb);
+}
+
+void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, float *C, unsigned int ldc,
+                       float alpha, float beta) {
+  __fp16 *sa = alignedMalloc(M * K);
+  __fp16 *sb = alignedMalloc(K * N);
+
+  unsigned int ms, mms, ns, ks;
+  unsigned int m_min, m2_min, n_min, k_min;
+  for (ms = 0; ms < M; ms += M_BLOCKING) {
+    m_min = M - ms;
+    if (m_min > M_BLOCKING) {
+      m_min = M_BLOCKING;
+    }
+
+    for (ks = 0; ks < K; ks += k_min) {
+      k_min = K - ks;
+      if (k_min >= (K_BLOCKING << 1)) {
+        k_min = K_BLOCKING;
+      } else if (k_min > K_BLOCKING) {
+        k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+      }
+
+      n_min = N;
+      if (N >= N_BLOCKING * 2) {
+        n_min = N_BLOCKING;
+      } else if (N > N_BLOCKING) {
+        n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+      }
+      packing_B4(k_min, n_min, B + ks * ldb, ldb, sb);
+
+      for (mms = ms; mms < ms + m_min; mms += m2_min) {
+        m2_min = (ms + m_min) - mms;
+        if (m2_min >= 3 * GEMM_UNROLLING_4) {
+          m2_min = 3 * GEMM_UNROLLING_4;
+        } else if (m2_min >= 2 * GEMM_UNROLLING_4) {
+          m2_min = 2 * GEMM_UNROLLING_4;
+        } else if (m2_min > GEMM_UNROLLING_4) {
+          m2_min = GEMM_UNROLLING_4;
+        }
+
+        packing_A4(m2_min, k_min, A + mms * lda + ks, lda,
+                   sa + k_min * (mms - ms));
+
+        hgemm_kernel_4x4(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
+                         C + mms * ldc, ldc);
+      }
+
+      for (ns = n_min; ns < N; ns += n_min) {
+        n_min = N - ns;
+        if (n_min >= N_BLOCKING * 2) {
+          n_min = N_BLOCKING;
+        } else if (n_min > N_BLOCKING) {
+          n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+        }
+
+        packing_B4(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+        hgemm_kernel_4x4(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+      }
+    }
+  }
+
+  free(sa);
+  free(sb);
+}
+
+void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, __fp16 *C, unsigned int ldc,
+                       float alpha, float beta) {
+  __fp16 *sa = alignedMalloc(M * K);
+  __fp16 *sb = alignedMalloc(K * N);
+
+  unsigned int ms, mms, ns, ks;
+  unsigned int m_min, m2_min, n_min, k_min;
+  unsigned int l1stride = 1;
+  for (ms = 0; ms < M; ms += M_BLOCKING) {
+    m_min = M - ms;
+    if (m_min > M_BLOCKING) {
+      m_min = M_BLOCKING;
+    }
+
+    for (ks = 0; ks < K; ks += k_min) {
+      k_min = K - ks;
+      if (k_min >= (K_BLOCKING << 1)) {
+        k_min = K_BLOCKING;
+      } else if (k_min > K_BLOCKING) {
+        k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+      }
+
+      n_min = N;
+      if (N >= N_BLOCKING * 2) {
+        n_min = N_BLOCKING;
+      } else if (N > N_BLOCKING) {
+        n_min = ((n_min / 2 + GEMM_UNROLLING_8 - 1) / GEMM_UNROLLING_8) *
+                GEMM_UNROLLING_8;
+      } else {
+        l1stride = 0;
+      }
+      packing_B8(k_min, n_min, B + ks * ldb, ldb, sb);
+
+      for (mms = ms; mms < ms + m_min; mms += m2_min) {
+        m2_min = (ms + m_min) - mms;
+        if (m2_min >= 3 * GEMM_UNROLLING_4) {
+          m2_min = 3 * GEMM_UNROLLING_4;
+        } else if (m2_min >= 2 * GEMM_UNROLLING_4) {
+          m2_min = 2 * GEMM_UNROLLING_4;
+        } else if (m2_min > GEMM_UNROLLING_4) {
+          m2_min = GEMM_UNROLLING_4;
+        }
+
+        packing_A4(m2_min, k_min, A + mms * lda + ks, lda,
+                   sa + k_min * (mms - ms) * l1stride);
+
+        hgemm_kernel_4x8(m2_min, n_min, k_min,
+                         sa + l1stride * k_min * (mms - ms), sb, C + mms * ldc,
+                         ldc);
+      }
+
+      for (ns = n_min; ns < N; ns += n_min) {
+        n_min = N - ns;
+        if (n_min >= N_BLOCKING * 2) {
+          n_min = N_BLOCKING;
+        } else if (n_min > N_BLOCKING) {
+          n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+        }
+
+        packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+        hgemm_kernel_4x8(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+      }
+    }
+  }
+
+  free(sa);
+  free(sb);
+}
+
+void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, float *C, unsigned int ldc,
+                       float alpha, float beta) {
+  __fp16 *sa = alignedMalloc(M * K);
+  __fp16 *sb = alignedMalloc(K * N);
+
+  unsigned int ms, mms, ns, ks;
+  unsigned int m_min, m2_min, n_min, k_min;
+  unsigned int l1stride = 1;
+  for (ms = 0; ms < M; ms += M_BLOCKING) {
+    m_min = M - ms;
+    if (m_min > M_BLOCKING) {
+      m_min = M_BLOCKING;
+    }
+
+    for (ks = 0; ks < K; ks += k_min) {
+      k_min = K - ks;
+      if (k_min >= (K_BLOCKING << 1)) {
+        k_min = K_BLOCKING;
+      } else if (k_min > K_BLOCKING) {
+        k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+      }
+
+      n_min = N;
+      if (N >= N_BLOCKING * 2) {
+        n_min = N_BLOCKING;
+      } else if (N > N_BLOCKING) {
+        n_min = ((n_min / 2 + GEMM_UNROLLING_8 - 1) / GEMM_UNROLLING_8) *
+                GEMM_UNROLLING_8;
+      } else {
+        l1stride = 0;
+      }
+      packing_B8(k_min, n_min, B + ks * ldb, ldb, sb);
+
+      for (mms = ms; mms < ms + m_min; mms += m2_min) {
+        m2_min = (ms + m_min) - mms;
+        if (m2_min >= 3 * GEMM_UNROLLING_4) {
+          m2_min = 3 * GEMM_UNROLLING_4;
+        } else if (m2_min >= 2 * GEMM_UNROLLING_4) {
+          m2_min = 2 * GEMM_UNROLLING_4;
+        } else if (m2_min > GEMM_UNROLLING_4) {
+          m2_min = GEMM_UNROLLING_4;
+        }
+
+        packing_A4(m2_min, k_min, A + mms * lda + ks, lda,
+                   sa + k_min * (mms - ms) * l1stride);
+
+        hgemm_kernel_4x8(m2_min, n_min, k_min,
+                         sa + l1stride * k_min * (mms - ms), sb, C + mms * ldc,
+                         ldc);
+      }
+
+      for (ns = n_min; ns < N; ns += n_min) {
+        n_min = N - ns;
+        if (n_min >= N_BLOCKING * 2) {
+          n_min = N_BLOCKING;
+        } else if (n_min > N_BLOCKING) {
+          n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+        }
+
+        packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+        hgemm_kernel_4x8(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+      }
+    }
+  }
+
+  free(sa);
+  free(sb);
+}
+
+void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, __fp16 *C, unsigned int ldc,
+                       float alpha, float beta) {
+
+  __fp16 *sa = alignedMalloc(M * K);
+  __fp16 *sb = alignedMalloc(K * N);
+
+  unsigned int ms, mms, ns, ks;
+  unsigned int m_min, m2_min, n_min, k_min;
+  for (ms = 0; ms < M; ms += M_BLOCKING) {
+    m_min = M - ms;
+    if (m_min > M_BLOCKING) {
+      m_min = M_BLOCKING;
+    }
+
+    for (ks = 0; ks < K; ks += k_min) {
+      k_min = K - ks;
+      if (k_min >= (K_BLOCKING << 1)) {
+        k_min = K_BLOCKING;
+      } else if (k_min > K_BLOCKING) {
+        k_min = (k_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+      }
+
+      n_min = N;
+      if (N >= N_BLOCKING * 2) {
+        n_min = N_BLOCKING;
+      } else if (N > N_BLOCKING) {
+        n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+      }
+      packing_B8(k_min, n_min, B + ks * ldb, ldb, sb);
+
+      for (mms = ms; mms < ms + m_min; mms += m2_min) {
+        m2_min = (ms + m_min) - mms;
+        if (m2_min >= 3 * GEMM_UNROLLING_8) {
+          m2_min = 3 * GEMM_UNROLLING_8;
+        } else if (m2_min >= 2 * GEMM_UNROLLING_8) {
+          m2_min = 2 * GEMM_UNROLLING_8;
+        } else if (m2_min > GEMM_UNROLLING_8) {
+          m2_min = GEMM_UNROLLING_8;
+        }
+
+        packing_A8(m2_min, k_min, A + mms * lda + ks, lda,
+                   sa + k_min * (mms - ms));
+
+        hgemm_kernel_8x8(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
+                         C + mms * ldc, ldc);
+      }
+
+      for (ns = n_min; ns < N; ns += n_min) {
+        n_min = N - ns;
+        if (n_min >= N_BLOCKING * 2) {
+          n_min = N_BLOCKING;
+        } else if (n_min > N_BLOCKING) {
+          n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+        }
+
+        packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+        hgemm_kernel_8x8(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+      }
+    }
+  }
+
+  free(sa);
+  free(sb);
+}
+
+void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, float *C, unsigned int ldc,
+                       float alpha, float beta) {
+
+  __fp16 *sA = alignedMalloc(M * K);
+  __fp16 *sB = alignedMalloc(K * N);
+
+  unsigned int ms, ms2, ns, ks;
+  unsigned int m_min, m2_min, n_min, k_min;
+  for (ms = 0; ms < M; ms += M_BLOCKING) {
+    m_min = M - ms;
+    if (m_min > M_BLOCKING) {
+      m_min = M_BLOCKING;
+    }
+
+    for (ks = 0; ks < K; ks += k_min) {
+      k_min = K - ks;
+      if (k_min >= (K_BLOCKING << 1)) {
+        k_min = K_BLOCKING;
+      } else if (k_min > K_BLOCKING) {
+        k_min = (k_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+      }
+
+      n_min = N;
+      if (N >= N_BLOCKING * 2) {
+        n_min = N_BLOCKING;
+      } else if (N > N_BLOCKING) {
+        n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+      }
+      packing_B8(k_min, n_min, B + ks * ldb, ldb, sB);
+
+      for (ms2 = ms; ms2 < ms + m_min; ms2 += m2_min) {
+        m2_min = (ms + m_min) - ms2;
+        if (m2_min >= 3 * GEMM_UNROLLING_8) {
+          m2_min = 3 * GEMM_UNROLLING_8;
+        } else if (m2_min >= 2 * GEMM_UNROLLING_8) {
+          m2_min = 2 * GEMM_UNROLLING_8;
+        } else if (m2_min > GEMM_UNROLLING_8) {
+          m2_min = GEMM_UNROLLING_8;
+        }
+
+        packing_A8(m2_min, k_min, A + ms2 * lda + ks, lda,
+                   sA + k_min * (ms2 - ms));
+
+        hgemm_kernel_8x8(m2_min, n_min, k_min, sA + k_min * (ms2 - ms), sB,
+                         C + ms2 * ldc, ldc);
+      }
+
+      for (ns = n_min; ns < N; ns += n_min) {
+        n_min = N - ns;
+        if (n_min >= N_BLOCKING * 2) {
+          n_min = N_BLOCKING;
+        } else if (n_min > N_BLOCKING) {
+          n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+        }
+
+        packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sB);
+        hgemm_kernel_8x8(m_min, n_min, k_min, sA, sB, C + ms * ldc + ns, ldc);
+      }
+    }
+  }
+
+  free(sA);
+  free(sB);
+}
+
+void hgemm_noTrans_8x16(unsigned int M, unsigned int N, unsigned int K,
+                        const __fp16 *A, unsigned int lda, const __fp16 *B,
+                        unsigned int ldb, __fp16 *C, unsigned int ldc,
+                        float alpha, float beta) {
+  __fp16 *sA = alignedMalloc(M * K);
+  __fp16 *sB = alignedMalloc(K * N);
+
+  unsigned int ms, ms2, ns, ks;
+  unsigned int m_min, m2_min, n_min, k_min;
+  unsigned int stride_l1 = 1;
+
+  for (ms = 0; ms < M; ms += M_BLOCKING) {
+    m_min = M - ms;
+    if (m_min > M_BLOCKING) {
+      m_min = M_BLOCKING;
+    }
+
+    for (ks = 0; ks < K; ks += k_min) {
+      k_min = K - ks;
+      if (k_min >= (K_BLOCKING << 1)) {
+        k_min = K_BLOCKING;
+      } else if (k_min > K_BLOCKING) {
+        k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+      }
+
+      n_min = N;
+      if (N >= N_BLOCKING * 2) {
+        n_min = N_BLOCKING;
+      } else if (N > N_BLOCKING) {
+        n_min = ((n_min / 2 + GEMM_UNROLLING_16 - 1) / GEMM_UNROLLING_16) *
+                GEMM_UNROLLING_16;
+      } else {
+        stride_l1 = 0;
+      }
+      packing_B16(k_min, n_min, B + ks * ldb, ldb, sB);
+
+      for (ms2 = ms; ms2 < ms + m_min; ms2 += m2_min) {
+        m2_min = (ms + m_min) - ms2;
+        if (m2_min >= 3 * GEMM_UNROLLING_8) {
+          m2_min = 3 * GEMM_UNROLLING_8;
+        } else if (m2_min >= 2 * GEMM_UNROLLING_8) {
+          m2_min = 2 * GEMM_UNROLLING_8;
+        } else if (m2_min > GEMM_UNROLLING_8) {
+          m2_min = GEMM_UNROLLING_8;
+        }
+
+        packing_A8(m2_min, k_min, A + ms2 * lda + ks, lda,
+                   sA + k_min * (ms2 - ms) * stride_l1);
+
+        hgemm_kernel_8x16(m2_min, n_min, k_min,
+                          sA + stride_l1 * k_min * (ms2 - ms), sB,
+                          C + ms2 * ldc, ldc);
+      }
+
+      for (ns = n_min; ns < N; ns += n_min) {
+        n_min = N - ns;
+        if (n_min >= N_BLOCKING * 2) {
+          n_min = N_BLOCKING;
+        } else if (n_min > N_BLOCKING) {
+          n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+        }
+
+        packing_B16(k_min, n_min, B + ns + ldb * ks, ldb, sB);
+        hgemm_kernel_8x16(m_min, n_min, k_min, sA, sB, C + ms * ldc + ns, ldc);
+      }
+    }
+  }
+
+  free(sA);
+  free(sB);
+}
+
+void hgemm_noTrans_8x16(unsigned int M, unsigned int N, unsigned int K,
+                        const __fp16 *A, unsigned int lda, const __fp16 *B,
+                        unsigned int ldb, float *C, unsigned int ldc,
+                        float alpha, float beta) {
+  __fp16 *sA = alignedMalloc(M * K);
+  __fp16 *sB = alignedMalloc(K * N);
+
+  unsigned int ms, ms2, ns, ks;
+  unsigned int m_min, m2_min, n_min, k_min;
+  unsigned int stride_l1 = 1;
+
+  for (ms = 0; ms < M; ms += M_BLOCKING) {
+    m_min = M - ms;
+    if (m_min > M_BLOCKING) {
+      m_min = M_BLOCKING;
+    }
+
+    for (ks = 0; ks < K; ks += k_min) {
+      k_min = K - ks;
+      if (k_min >= (K_BLOCKING << 1)) {
+        k_min = K_BLOCKING;
+      } else if (k_min > K_BLOCKING) {
+        k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+      }
+
+      n_min = N;
+      if (N >= N_BLOCKING * 2) {
+        n_min = N_BLOCKING;
+      } else if (N > N_BLOCKING) {
+        n_min = ((n_min / 2 + GEMM_UNROLLING_16 - 1) / GEMM_UNROLLING_16) *
+                GEMM_UNROLLING_16;
+      } else {
+        stride_l1 = 0;
+      }
+      packing_B16(k_min, n_min, B + ks * ldb, ldb, sB);
+
+      for (ms2 = ms; ms2 < ms + m_min; ms2 += m2_min) {
+        m2_min = (ms + m_min) - ms2;
+        if (m2_min >= 3 * GEMM_UNROLLING_8) {
+          m2_min = 3 * GEMM_UNROLLING_8;
+        } else if (m2_min >= 2 * GEMM_UNROLLING_8) {
+          m2_min = 2 * GEMM_UNROLLING_8;
+        } else if (m2_min > GEMM_UNROLLING_8) {
+          m2_min = GEMM_UNROLLING_8;
+        }
+
+        packing_A8(m2_min, k_min, A + ms2 * lda + ks, lda,
+                   sA + k_min * (ms2 - ms));
+
+        hgemm_kernel_8x16(m2_min, n_min, k_min, sA + k_min * (ms2 - ms), sB,
+                          C + ms2 * ldc, ldc);
+      }
+
+      for (ns = n_min; ns < N; ns += n_min) {
+        n_min = N - ns;
+        if (n_min >= N_BLOCKING * 2) {
+          n_min = N_BLOCKING;
+        } else if (n_min > N_BLOCKING) {
+          n_min =
+            (n_min / 2 + GEMM_UNROLLING_16 - 1) & ~(GEMM_UNROLLING_16 - 1);
+        }
+
+        packing_B16(k_min, n_min, B + ns + ldb * ks, ldb, sB);
+
+        hgemm_kernel_8x16(m_min, n_min, k_min, sA, sB, C + ms * ldc + ns, ldc);
+      }
+    }
+  }
+
+  free(sA);
+  free(sB);
+}
+
+void hgemm_noTrans_fallback(unsigned int M, unsigned int N, unsigned int K,
+                            const __fp16 *A, unsigned int lda, const __fp16 *B,
+                            unsigned int ldb, float *C, unsigned int ldc,
+                            float alpha, float beta) {
+
+  unsigned int k = 0;
+  unsigned int N8 = (N >> 3) << 3;
+  __fp16 a[16];
+  for (; (K - k) >= 16; k += 16) {
+    for (unsigned int m = 0; m < M; m++) {
+      vst1q_f16(&a[0], vmulq_n_f16(vld1q_f16(&A[m * K + k]), alpha));
+      vst1q_f16(&a[8], vmulq_n_f16(vld1q_f16(&A[m * K + k + 8]), alpha));
+      for (unsigned int n = 0; n < N8; n += 8) {
+        float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 2) * N + n]), a[2]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 3) * N + n]), a[3]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 4) * N + n]), a[4]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 5) * N + n]), a[5]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 6) * N + n]), a[6]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 7) * N + n]), a[7]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 8) * N + n]), a[8]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 9) * N + n]), a[9]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 10) * N + n]), a[10]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 11) * N + n]), a[11]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 12) * N + n]), a[12]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 13) * N + n]), a[13]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 14) * N + n]), a[14]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 15) * N + n]), a[15]);
+
+        float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]),
+                                            vcvt_f32_f16(vget_low_f16(b0_7_0)));
+        float32x4_t c0_7_high_32 = vaddq_f32(
+          vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0)));
+
+        vst1q_f32(&C[m * N + n], c0_7_low_32);
+        vst1q_f32(&C[m * N + n + 4], c0_7_high_32);
+      }
+      if (N != N8) {
+        unsigned int n = N8;
+        __fp16 valsB_0[8];
+        __fp16 valsB_1[8];
+        __fp16 valsB_2[8];
+        __fp16 valsB_3[8];
+        __fp16 valsB_4[8];
+        __fp16 valsB_5[8];
+        __fp16 valsB_6[8];
+        __fp16 valsB_7[8];
+        __fp16 valsB_8[8];
+        __fp16 valsB_9[8];
+        __fp16 valsB_10[8];
+        __fp16 valsB_11[8];
+        __fp16 valsB_12[8];
+        __fp16 valsB_13[8];
+        __fp16 valsB_14[8];
+        __fp16 valsB_15[8];
+        float valsC[8];
+        for (unsigned int idx = n; idx < N; idx++) {
+          valsB_0[idx - n] = B[k * N + idx];
+          valsB_1[idx - n] = B[(k + 1) * N + idx];
+          valsB_2[idx - n] = B[(k + 2) * N + idx];
+          valsB_3[idx - n] = B[(k + 3) * N + idx];
+          valsB_4[idx - n] = B[(k + 4) * N + idx];
+          valsB_5[idx - n] = B[(k + 5) * N + idx];
+          valsB_6[idx - n] = B[(k + 6) * N + idx];
+          valsB_7[idx - n] = B[(k + 7) * N + idx];
+          valsB_8[idx - n] = B[(k + 8) * N + idx];
+          valsB_9[idx - n] = B[(k + 9) * N + idx];
+          valsB_10[idx - n] = B[(k + 10) * N + idx];
+          valsB_11[idx - n] = B[(k + 11) * N + idx];
+          valsB_12[idx - n] = B[(k + 12) * N + idx];
+          valsB_13[idx - n] = B[(k + 13) * N + idx];
+          valsB_14[idx - n] = B[(k + 14) * N + idx];
+          valsB_15[idx - n] = B[(k + 15) * N + idx];
+          valsC[idx - n] = C[m * N + idx];
+        }
+
+        float16x8_t b = vmulq_n_f16(vld1q_f16(valsB_0), a[0]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_1), a[1]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_2), a[2]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_3), a[3]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_4), a[4]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_5), a[5]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_6), a[6]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_7), a[7]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_8), a[8]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_9), a[9]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_10), a[10]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_11), a[11]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_12), a[12]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_13), a[13]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_14), a[14]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_15), a[15]);
+
+        float32x4_t c0_7_low_32 =
+          vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b)));
+
+        float32x4_t c0_7_high_32 =
+          vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b)));
+
+        vst1q_f32(valsC, c0_7_low_32);
+        vst1q_f32(valsC + 4, c0_7_high_32);
+
+        for (unsigned int idx = n; idx < N; idx++) {
+          C[m * N + idx] = valsC[idx - n];
+        }
+      }
+    }
+  }
+
+  for (; (K - k) >= 8; k += 8) {
+    for (unsigned int m = 0; m < M; m++) {
+      vst1q_f16(a, vmulq_n_f16(vld1q_f16(&A[m * K + k]), alpha));
+
+      for (unsigned int n = 0; n < N8; n += 8) {
+        float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 2) * N + n]), a[2]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 3) * N + n]), a[3]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 4) * N + n]), a[4]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 5) * N + n]), a[5]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 6) * N + n]), a[6]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 7) * N + n]), a[7]);
+
+        float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]),
+                                            vcvt_f32_f16(vget_low_f16(b0_7_0)));
+        float32x4_t c0_7_high_32 = vaddq_f32(
+          vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0)));
+
+        vst1q_f32(&C[m * N + n], c0_7_low_32);
+        vst1q_f32(&C[m * N + n + 4], c0_7_high_32);
+      }
+      if (N != N8) {
+        unsigned int n = N8;
+        __fp16 valsB_0[8];
+        __fp16 valsB_1[8];
+        __fp16 valsB_2[8];
+        __fp16 valsB_3[8];
+        __fp16 valsB_4[8];
+        __fp16 valsB_5[8];
+        __fp16 valsB_6[8];
+        __fp16 valsB_7[8];
+        float valsC[8];
+        for (unsigned int idx = n; idx < N; idx++) {
+          valsB_0[idx - n] = B[k * N + idx];
+          valsB_1[idx - n] = B[(k + 1) * N + idx];
+          valsB_2[idx - n] = B[(k + 2) * N + idx];
+          valsB_3[idx - n] = B[(k + 3) * N + idx];
+          valsB_4[idx - n] = B[(k + 4) * N + idx];
+          valsB_5[idx - n] = B[(k + 5) * N + idx];
+          valsB_6[idx - n] = B[(k + 6) * N + idx];
+          valsB_7[idx - n] = B[(k + 7) * N + idx];
+          valsC[idx - n] = C[m * N + idx];
+        }
+
+        float16x8_t b = vmulq_n_f16(vld1q_f16(valsB_0), a[0]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_1), a[1]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_2), a[2]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_3), a[3]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_4), a[4]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_5), a[5]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_6), a[6]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_7), a[7]);
+
+        float32x4_t c0_7_low_32 =
+          vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b)));
+
+        float32x4_t c0_7_high_32 =
+          vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b)));
+
+        vst1q_f32(valsC, c0_7_low_32);
+        vst1q_f32(valsC + 4, c0_7_high_32);
+
+        for (unsigned int idx = n; idx < N; idx++) {
+          C[m * N + idx] = valsC[idx - n];
+        }
+      }
+    }
+  }
+
+  for (; (K - k) >= 4; k += 4) {
+    for (unsigned int m = 0; m < M; m++) {
+      vst1_f16(a, vmul_n_f16(vld1_f16(&A[m * K + k]), alpha));
+
+      for (unsigned int n = 0; n < N8; n += 8) {
+
+        float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]);
+        float16x8_t b0_7_2 = vmulq_n_f16(vld1q_f16(&B[(k + 2) * N + n]), a[2]);
+        b0_7_2 = vfmaq_n_f16(b0_7_2, vld1q_f16(&B[(k + 3) * N + n]), a[3]);
+
+        float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]),
+                                            vcvt_f32_f16(vget_low_f16(b0_7_0)));
+        float32x4_t c0_7_high_32 = vaddq_f32(
+          vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0)));
+
+        c0_7_low_32 =
+          vaddq_f32(c0_7_low_32, vcvt_f32_f16(vget_low_f16(b0_7_2)));
+        c0_7_high_32 =
+          vaddq_f32(c0_7_high_32, vcvt_f32_f16(vget_high_f16(b0_7_2)));
+
+        vst1q_f32(&C[m * N + n], c0_7_low_32);
+        vst1q_f32(&C[m * N + n + 4], c0_7_high_32);
+      }
+      if (N != N8) {
+        unsigned int n = N8;
+        __fp16 valsB_0[8];
+        __fp16 valsB_1[8];
+        __fp16 valsB_2[8];
+        __fp16 valsB_3[8];
+        float valsC[8];
+        for (unsigned int idx = n; idx < N; idx++) {
+          valsB_0[idx - n] = B[k * N + idx];
+          valsB_1[idx - n] = B[(k + 1) * N + idx];
+          valsB_2[idx - n] = B[(k + 2) * N + idx];
+          valsB_3[idx - n] = B[(k + 3) * N + idx];
+          valsC[idx - n] = C[m * N + idx];
+        }
+
+        float16x8_t b = vmulq_n_f16(vld1q_f16(valsB_0), a[0]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_1), a[1]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_2), a[2]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_3), a[3]);
+
+        float32x4_t c0_7_low_32 =
+          vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b)));
+
+        float32x4_t c0_7_high_32 =
+          vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b)));
+
+        vst1q_f32(valsC, c0_7_low_32);
+        vst1q_f32(valsC + 4, c0_7_high_32);
+
+        for (unsigned int idx = n; idx < N; idx++) {
+          C[m * N + idx] = valsC[idx - n];
+        }
+      }
+    }
+  }
+
+  for (; k < K; k++) {
+    for (unsigned int m = 0; m < M; m++) {
+      __fp16 a0 = alpha * A[m * K + k];
+
+      for (unsigned int n = 0; n < N8; n += 8) {
+        float16x8_t b0_7 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a0);
+
+        float32x4_t c0_7_low_32 =
+          vaddq_f32(vld1q_f32(&C[m * N + n]), vcvt_f32_f16(vget_low_f16(b0_7)));
+
+        float32x4_t c0_7_high_32 = vaddq_f32(vld1q_f32(&C[m * N + n + 4]),
+                                             vcvt_f32_f16(vget_high_f16(b0_7)));
+
+        vst1q_f32(&C[m * N + n], c0_7_low_32);
+        vst1q_f32(&C[m * N + n + 4], c0_7_high_32);
+      }
+      if (N != N8) {
+        unsigned int n = N8;
+        __fp16 valsB[8];
+        float valsC[8];
+        for (unsigned int idx = n; idx < N; idx++) {
+          valsB[idx - n] = B[k * N + idx];
+          valsC[idx - n] = C[m * N + idx];
+        }
+
+        float16x8_t b = vmulq_n_f16(vld1q_f16(valsB), a0);
+
+        float32x4_t c0_7_low_32 =
+          vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b)));
+
+        float32x4_t c0_7_high_32 =
+          vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b)));
+
+        vst1q_f32(valsC, c0_7_low_32);
+        vst1q_f32(valsC + 4, c0_7_high_32);
+
+        for (unsigned int idx = n; idx < N; idx++) {
+          C[m * N + idx] = valsC[idx - n];
+        }
+      }
+    }
+  }
+}
+
+void hgemm_K1_noTrans(unsigned int M, unsigned int N, unsigned int K,
+                      const __fp16 *A, unsigned int lda, const __fp16 *B,
+                      unsigned int ldb, __fp16 *C, unsigned int ldc,
+                      float alpha, float beta) {
+  const float eps = std::numeric_limits<float>::epsilon();
+  float16x8_t a_vec;
+  unsigned int N8 = (N >> 3) << 3;
+  for (unsigned int m = 0; m < M; ++m) {
+    a_vec = vmovq_n_f16(alpha * A[m]);
+    if (std::fpclassify(beta) != FP_ZERO) {
+      for (unsigned int n = 0; n < N8; n += 8) {
+        vst1q_f16(&C[m * ldc + n],
+                  vaddq_f16(vmulq_f16(a_vec, vld1q_f16(&B[n])),
+                            vmulq_n_f16(vld1q_f16(&C[m * ldc + n]), beta)));
+      }
+    } else {
+      for (unsigned int n = 0; n < N8; n += 8) {
+        vst1q_f16(&C[m * ldc + n], vmulq_f16(a_vec, vld1q_f16(&B[n])));
+      }
+    }
+    for (unsigned int n = N8; n < N; ++n) {
+      C[m * ldc + n] = alpha * A[m] * B[n] + beta * C[m * ldc + n];
+    }
+  }
+}
diff --git a/nntrainer/tensor/hgemm/hgemm_noTrans.h b/nntrainer/tensor/hgemm/hgemm_noTrans.h
new file mode 100644 (file)
index 0000000..1270f37
--- /dev/null
@@ -0,0 +1,334 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file   hgemm_noTrans.h
+ * @date   10 July 2024
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is half-precision GEMM interface of non-transposed case
+ *
+ */
+
+/**
+ * @brief hgemm noTrans computation with 1x4 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, __fp16 *C, unsigned int ldc,
+                       float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 1x4 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, float *C, unsigned int ldc,
+                       float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 4x4 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, __fp16 *C, unsigned int ldc,
+                       float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 1x8 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_1x8(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, __fp16 *C, unsigned int ldc,
+                       float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 1x8 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_1x8(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, float *C, unsigned int ldc,
+                       float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 8x8 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, __fp16 *C, unsigned int ldc,
+                       float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 4x4 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, float *C, unsigned int ldc,
+                       float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 8x8 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, float *C, unsigned int ldc,
+                       float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 4x8 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, __fp16 *C, unsigned int ldc,
+                       float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 4x8 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, float *C, unsigned int ldc,
+                       float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 8x16 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_8x16(unsigned int M, unsigned int N, unsigned int K,
+                        const __fp16 *A, unsigned int lda, const __fp16 *B,
+                        unsigned int ldb, __fp16 *C, unsigned int ldc,
+                        float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 8x16 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_8x16(unsigned int M, unsigned int N, unsigned int K,
+                        const __fp16 *A, unsigned int lda, const __fp16 *B,
+                        unsigned int ldb, float *C, unsigned int ldc,
+                        float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief     hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_K1_noTrans(unsigned int M, unsigned int N, unsigned int K,
+                      const __fp16 *A, unsigned int lda, const __fp16 *B,
+                      unsigned int ldb, __fp16 *C, unsigned int ldc,
+                      float alpha = 1.F, float beta = 0.F);
+/**
+ * @brief     hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_fallback(unsigned int M, unsigned int N, unsigned int K,
+                            const __fp16 *A, unsigned int lda, const __fp16 *B,
+                            unsigned int ldb, float *C, unsigned int ldc,
+                            float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief     hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * @param[in] A __fp16 * for Matrix A
+ * @param[in] B __fp16 * for Matrix B
+ * @param[in] C float * for Matrix C
+ * @param[in] M number of op(A)'s and C's row
+ * @param[in] N number of op(B)'s and C's columns
+ * @param[in] K number of op(A)'s and columns and op(B)'s rows
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
+                   unsigned int N, unsigned int K, float alpha = 1.F,
+                   float beta = 0.F);
+
+/**
+ * @brief     hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * @param[in] A __fp16 * for Matrix A
+ * @param[in] B __fp16 * for Matrix B
+ * @param[in] C __fp16 * for Matrix C
+ * @param[in] M number of op(A)'s and C's row
+ * @param[in] N number of op(B)'s and C's columns
+ * @param[in] K number of op(A)'s and columns and op(B)'s rows
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, __fp16 *C,
+                          unsigned int M, unsigned int N, unsigned int K,
+                          float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief     hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * @param[in] A __fp16 * for Matrix A
+ * @param[in] B __fp16 * for Matrix B
+ * @param[in] C __fp16 * for Matrix C
+ * @param[in] M number of op(A)'s and C's row
+ * @param[in] N number of op(B)'s and C's columns
+ * @param[in] K number of op(A)'s and columns and op(B)'s rows
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, float *C,
+                          unsigned int M, unsigned int N, unsigned int K,
+                          float alpha = 1.F, float beta = 0.F);
diff --git a/nntrainer/tensor/hgemm/hgemm_transA.cpp b/nntrainer/tensor/hgemm/hgemm_transA.cpp
new file mode 100644 (file)
index 0000000..b510a34
--- /dev/null
@@ -0,0 +1,41 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file   hgemm_transA.cpp
+ * @date   10 July 2024
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is half-precision GEMM interface of transposed A case
+ *
+ */
+
+#include <hgemm_noTrans.h>
+#include <hgemm_transA.h>
+#include <hgemm_util.h>
+#include <matrix_transpose_neon.h>
+
+void hgemm_transA(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
+                  unsigned int N, unsigned int K, float alpha, float beta) {
+  __fp16 *A_T = alignedMalloc(M * K);
+
+  transpose_neon<__fp16>(K, M, A, M, A_T, K);
+
+  hgemm_noTrans(A_T, B, C, M, N, K, alpha, beta);
+
+  free(A_T);
+}
+
+void hgemm_K1_transA(unsigned int M, unsigned int N, unsigned int K,
+                     const __fp16 *A, unsigned int lda, const __fp16 *B,
+                     unsigned int ldb, __fp16 *C, unsigned int ldc, float alpha,
+                     float beta) {
+  __fp16 *A_T = alignedMalloc(M * K);
+
+  transpose_neon<__fp16>(K, M, A, M, A_T, K);
+
+  hgemm_K1_noTrans(M, N, K, A_T, lda, B, ldb, C, ldc, alpha, beta);
+
+  free(A_T);
+}
diff --git a/nntrainer/tensor/hgemm/hgemm_transA.h b/nntrainer/tensor/hgemm/hgemm_transA.h
new file mode 100644 (file)
index 0000000..68272bd
--- /dev/null
@@ -0,0 +1,45 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file   hgemm_transA.h
+ * @date   10 July 2024
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is half-precision GEMM interface of transposed A case
+ *
+ */
+
+/**
+ * @brief     hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * where op(X) is one of X or X**T
+ * @param[in] A __fp16 * for Matrix A
+ * @param[in] B __fp16 * for Matrix B
+ * @param[in] C __fp16 * for Matrix C
+ * @param[in] M number of op(A)'s and C's row
+ * @param[in] N number of op(B)'s and C's columns
+ * @param[in] K number of op(A)'s and columns and op(B)'s rows
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_transA(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
+                  unsigned int N, unsigned int K, float alpha, float beta);
+/**
+ * @brief     hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_K1_transA(unsigned int M, unsigned int N, unsigned int K,
+                     const __fp16 *A, unsigned int lda, const __fp16 *B,
+                     unsigned int ldb, __fp16 *C, unsigned int ldc,
+                     float alpha = 1.F, float beta = 0.F);
diff --git a/nntrainer/tensor/hgemm/hgemm_transAB.cpp b/nntrainer/tensor/hgemm/hgemm_transAB.cpp
new file mode 100644 (file)
index 0000000..0ab9708
--- /dev/null
@@ -0,0 +1,47 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file   hgemm_transAB.cpp
+ * @date   10 July 2024
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is half-precision GEMM interface of transposed  AB case
+ *
+ */
+
+#include <hgemm_noTrans.h>
+#include <hgemm_transAB.h>
+#include <hgemm_util.h>
+#include <matrix_transpose_neon.h>
+
+void hgemm_transAB(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
+                   unsigned int N, unsigned int K, float alpha, float beta) {
+  __fp16 *A_T = alignedMalloc(M * K);
+  __fp16 *B_T = alignedMalloc(K * N);
+
+  transpose_neon<__fp16>(K, M, A, M, A_T, K);
+  transpose_neon<__fp16>(N, K, B, K, B_T, N);
+
+  hgemm_noTrans(A_T, B_T, C, M, N, K, alpha, beta);
+
+  free(A_T);
+  free(B_T);
+}
+
+void hgemm_K1_transAB(unsigned int M, unsigned int N, unsigned int K,
+                      const __fp16 *A, unsigned int lda, const __fp16 *B,
+                      unsigned int ldb, __fp16 *C, unsigned int ldc,
+                      float alpha, float beta) {
+  __fp16 *A_T = alignedMalloc(M * K);
+  __fp16 *B_T = alignedMalloc(K * N);
+
+  transpose_neon<__fp16>(K, M, A, M, A_T, K);
+  transpose_neon<__fp16>(N, K, B, K, B_T, N);
+
+  hgemm_K1_noTrans(M, N, K, A_T, lda, B_T, ldb, C, ldc, alpha, beta);
+
+  free(A_T);
+  free(B_T);
+}
diff --git a/nntrainer/tensor/hgemm/hgemm_transAB.h b/nntrainer/tensor/hgemm/hgemm_transAB.h
new file mode 100644 (file)
index 0000000..08e131d
--- /dev/null
@@ -0,0 +1,45 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file   hgemm_transAB.h
+ * @date   10 July 2024
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is half-precision GEMM interface of transposed  AB case
+ *
+ */
+
+/**
+ * @brief     hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * where op(X) is one of X or X**T
+ * @param[in] A __fp16 * for Matrix A
+ * @param[in] B __fp16 * for Matrix B
+ * @param[in] C __fp16 * for Matrix C
+ * @param[in] M number of op(A)'s and C's row
+ * @param[in] N number of op(B)'s and C's columns
+ * @param[in] K number of op(A)'s and columns and op(B)'s rows
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_transAB(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
+                   unsigned int N, unsigned int K, float alpha, float beta);
+/**
+ * @brief     hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_K1_transAB(unsigned int M, unsigned int N, unsigned int K,
+                      const __fp16 *A, unsigned int lda, const __fp16 *B,
+                      unsigned int ldb, __fp16 *C, unsigned int ldc,
+                      float alpha = 1.F, float beta = 0.F);
\ No newline at end of file
diff --git a/nntrainer/tensor/hgemm/hgemm_transB.cpp b/nntrainer/tensor/hgemm/hgemm_transB.cpp
new file mode 100644 (file)
index 0000000..adc7907
--- /dev/null
@@ -0,0 +1,134 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file   hgemm_transB.cpp
+ * @date   10 July 2024
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is half-precision GEMM interface of transposed B case
+ *
+ */
+
+#include <cmath>
+#include <hgemm_kernel_8x16.h>
+#include <hgemm_common.h>
+// #include <hgemm_kernel.h>
+#include <hgemm_kernel_pack.h>
+#include <hgemm_noTrans.h>
+#include <hgemm_transB.h>
+#include <hgemm_util.h>
+#include <limits>
+#include <matrix_transpose_neon.h>
+
+// #define HGEMM_KERNEL_8x16 hgemm_kernel_8x16 /// @todo change to macro kernel
+// #if !defined(HGEMM_KERNEL_8x16) hgemm_kernel_8x16
+// #endif
+
+void hgemm_transB_8x16(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, float *C, unsigned int ldc,
+                       float alpha, float beta) {
+
+  __fp16 *sA = alignedMalloc(M * K);
+  __fp16 *sB = alignedMalloc(K * N);
+
+  unsigned int ms, ms2, ns, ks;
+  unsigned int m_min, m2_min, n_min, k_min;
+  unsigned int stride_l1 = 1;
+
+  for (ms = 0; ms < M; ms += M_BLOCKING) {
+    m_min = M - ms;
+    if (m_min > M_BLOCKING) {
+      m_min = M_BLOCKING;
+    }
+
+    for (ks = 0; ks < K; ks += k_min) {
+      k_min = K - ks;
+      if (k_min >= (K_BLOCKING << 1)) {
+        k_min = K_BLOCKING;
+      } else if (k_min > K_BLOCKING) {
+        k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+      }
+
+      n_min = N;
+      if (N >= N_BLOCKING * 2) {
+        n_min = N_BLOCKING;
+      } else if (N > N_BLOCKING) {
+        n_min = ((n_min / 2 + GEMM_UNROLLING_16 - 1) / GEMM_UNROLLING_16) *
+                GEMM_UNROLLING_16;
+      } else {
+        stride_l1 = 0;
+      }
+      packing_transB16(k_min, n_min, B + (ks), ldb, sB);
+
+      for (ms2 = ms; ms2 < ms + m_min; ms2 += m2_min) {
+        m2_min = (ms + m_min) - ms2;
+        if (m2_min >= 3 * GEMM_UNROLLING_8) {
+          m2_min = 3 * GEMM_UNROLLING_8;
+        } else if (m2_min >= 2 * GEMM_UNROLLING_8) {
+          m2_min = 2 * GEMM_UNROLLING_8;
+        } else if (m2_min > GEMM_UNROLLING_8) {
+          m2_min = GEMM_UNROLLING_8;
+        }
+        packing_A8(m2_min, k_min, A + ms2 * lda + ks, lda,
+                   sA + k_min * (ms2 - ms) * stride_l1);
+        hgemm_kernel_8x16(m2_min, n_min, k_min,
+                          sA + k_min * (ms2 - ms) * stride_l1, sB,
+                          C + ms2 * ldc, ldc);
+      }
+
+      for (ns = n_min; ns < N; ns += n_min) {
+        n_min = N - ns;
+        if (n_min >= N_BLOCKING * 2) {
+          n_min = N_BLOCKING;
+        } else if (n_min > N_BLOCKING) {
+          n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+        }
+        packing_transB16(k_min, n_min, B + ks + ldb * ns, ldb, sB);
+        hgemm_kernel_8x16(m_min, n_min, k_min, sA, sB, C + ms * ldc + ns,
+        ldc);
+      }
+    }
+  }
+
+  free(sA);
+  free(sB);
+}
+
+void hgemm_transB(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
+                  unsigned int N, unsigned int K, float alpha, float beta) {
+  const float eps = std::numeric_limits<float>::epsilon();
+  if (((M & 0x7) == 0 && (N & 0xF) == 0 && (K & 0x7) == 0 &&
+       (std::abs(alpha - 1.F) < eps))) {
+    return hgemm_transB_8x16(M, N, K, A, K, B, K, C, N, alpha, beta);
+  } else {
+    return hgemm_transB_fallback(A, B, C, M, N, K, alpha, beta);
+  }
+}
+
+void hgemm_transB_fallback(const __fp16 *A, const __fp16 *B, float *C,
+                           unsigned int M, unsigned int N, unsigned int K,
+                           float alpha, float beta) {
+  __fp16 *B_T = alignedMalloc(K * N);
+
+  transpose_neon<__fp16>(N, K, B, K, B_T, N);
+
+  hgemm_noTrans(A, B_T, C, M, N, K, alpha, beta);
+
+  free(B_T);
+}
+
+void hgemm_K1_transB(unsigned int M, unsigned int N, unsigned int K,
+                     const __fp16 *A, unsigned int lda, const __fp16 *B,
+                     unsigned int ldb, __fp16 *C, unsigned int ldc, float alpha,
+                     float beta) {
+  __fp16 *B_T = alignedMalloc(K * N);
+
+  transpose_neon<__fp16>(N, K, B, K, B_T, N);
+
+  hgemm_K1_noTrans(M, N, K, A, lda, B_T, ldb, C, ldc, alpha, beta);
+
+  free(B_T);
+}
diff --git a/nntrainer/tensor/hgemm/hgemm_transB.h b/nntrainer/tensor/hgemm/hgemm_transB.h
new file mode 100644 (file)
index 0000000..cda6422
--- /dev/null
@@ -0,0 +1,66 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file   hgemm_transB.h
+ * @date   10 July 2024
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is half-precision GEMM interface of transposed B case
+ *
+ */
+
+/**
+ * @brief     hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * where op(X) is one of X or X**T
+ * @param[in] A __fp16 * for Matrix A
+ * @param[in] B __fp16 * for Matrix B
+ * @param[in] C __fp16 * for Matrix C
+ * @param[in] M number of op(A)'s and C's row
+ * @param[in] N number of op(B)'s and C's columns
+ * @param[in] K number of op(A)'s and columns and op(B)'s rows
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_transB(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
+                  unsigned int N, unsigned int K, float alpha, float beta);
+
+void hgemm_transB_fallback(const __fp16 *A, const __fp16 *B, float *C,
+                           unsigned int M, unsigned int N, unsigned int K,
+                           float alpha, float beta);
+
+/**
+ * @brief     hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * where op(X) is one of X or X**T
+ * @param[in] A __fp16 * for Matrix A
+ * @param[in] B __fp16 * for Matrix B
+ * @param[in] C __fp16 * for Matrix C
+ * @param[in] M number of op(A)'s and C's row
+ * @param[in] N number of op(B)'s and C's columns
+ * @param[in] K number of op(A)'s and columns and op(B)'s rows
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_transB_8x16(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, float *C, unsigned int ldc,
+                       float alpha = 1.F, float beta = 0.F);
+/**
+ * @brief     hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_K1_transB(unsigned int M, unsigned int N, unsigned int K,
+                     const __fp16 *A, unsigned int lda, const __fp16 *B,
+                     unsigned int ldb, __fp16 *C, unsigned int ldc,
+                     float alpha = 1.F, float beta = 0.F);
index f5aa25da914d725390c037d7dfdcebc344aedee5..ec5af07bba8dda44ed3ad0591e49be2af532668b 100644 (file)
@@ -13,6 +13,10 @@ hgemm_sources = [
     'hgemm_padding_a.cpp',
     'hgemm_padding_b.cpp',
     'hgemm_kernel_pack.cpp',
+    'hgemm_noTrans.cpp',
+    'hgemm_transA.cpp',
+    'hgemm_transB.cpp',
+    'hgemm_transAB.cpp',
 ]
 
 foreach s : hgemm_sources
index 2c81bdcbd40f2b56f1714e299e48eb5e4ee3ad11..7bf0bd24b7c64b37a15bfbe9676353dd2416fd4f 100644 (file)
@@ -1071,6 +1071,67 @@ TEST(nntrainer_Tensor, dot_gemm_50_768_516) {
   EXPECT_IN_RANGE((float)cosSimNeon, 0.99, 1);
 }
 
+TEST(nntrainer_Tensor, dot_gemm_K1) {
+  /// @note GEMM : A X B = C
+  int batch = 1;
+  int channel = 1;
+  int height = 56;
+  int width = 1;
+
+  int height_b = 1;
+  int width_b = 516;
+
+  bool transA = false;
+  bool transB = false;
+
+  nntrainer::TensorDim::TensorType t_type_nchw_fp16 = {
+    nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16};
+
+  nntrainer::TensorDim::TensorType t_type_nchw_fp32 = {
+    nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32};
+
+  nntrainer::Tensor A(batch, channel, height, width, t_type_nchw_fp16);
+  nntrainer::Tensor B(batch, channel, height_b, width_b, t_type_nchw_fp16);
+
+  nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32);
+  nntrainer::Tensor B_fp32(batch, channel, height_b, width_b, t_type_nchw_fp32);
+
+  const float alpha = 1e-1;
+  const int MOD = 10;
+
+  GEN_TEST_INPUT(A, ((i * (batch * height * channel) + j * (batch * height) +
+                      k * (width) + l + 1) %
+                     MOD) *
+                      alpha);
+  GEN_TEST_INPUT_B(B, ((i * (batch * height_b * channel) +
+                        j * (batch * height_b) + k * (width_b) + l + 1) %
+                       MOD) *
+                        alpha);
+  GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) +
+                           j * (batch * height) + k * (width) + l + 1) %
+                          MOD) *
+                           alpha);
+  GEN_TEST_INPUT_B(B_fp32, ((i * (batch * height_b * channel) +
+                             j * (batch * height_b) + k * (width_b) + l + 1) %
+                            MOD) *
+                             alpha);
+
+  nntrainer::Tensor C = A.dot(B, transA, transB);
+
+  nntrainer::Tensor C_fp32 = A_fp32.dot(B_fp32, transA, transB);
+
+  float mseErrorNeon =
+    mse<__fp16>(C.getData<__fp16>(), C_fp32.getData<float>(), C.size());
+
+  double cosSimNeon = cosine_similarity<__fp16>(
+    C.getData<__fp16>(), C_fp32.getData<float>(), C.size());
+
+  const float epsilon = 1e-3 * width;
+
+  EXPECT_IN_RANGE(mseErrorNeon, 0, epsilon);
+  EXPECT_IN_RANGE((float)cosSimNeon, 0.99, 1);
+}
+
 TEST(nntrainer_Tensor, dot_gemv_768_96000) {
   /// @note GEMV : A X B = C
   int batch = 1;