[ hgemm ] Use optimized hgemm if possible
authorskykongkong8 <ss.kong@samsung.com>
Wed, 3 Apr 2024 04:23:42 +0000 (13:23 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 3 Apr 2024 11:48:34 +0000 (20:48 +0900)
- We can use optimized version of hgemm with following condition:
1. noTrans hgemm
2. M, N, K is divisible with 4 or 8
3. Row Major GEMM
4. alpha = 1.0, beta = 0.0 (will be patched soon)
- Otherwise, use previous version as a fallback.
- Note that there are a few optimization strategy is left for optimal hgemm.

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <ss.kong@samsung.com>
nntrainer/tensor/hgemm/hgemm.cpp [new file with mode: 0644]
nntrainer/tensor/hgemm/hgemm.h
nntrainer/tensor/hgemm/hgemm_common.h
nntrainer/tensor/hgemm/hgemm_kernel_4x4.h
nntrainer/tensor/hgemm/hgemm_kernel_pack.h
nntrainer/tensor/hgemm/hgemm_util.h
nntrainer/tensor/hgemm/meson.build [new file with mode: 0644]

diff --git a/nntrainer/tensor/hgemm/hgemm.cpp b/nntrainer/tensor/hgemm/hgemm.cpp
new file mode 100644 (file)
index 0000000..97f074c
--- /dev/null
@@ -0,0 +1,663 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file   hgemm.cpp
+ * @date   03 April 2024
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is half-precision GEMM interface
+ *
+ */
+
+#include <hgemm.h>
+#include <hgemm_kernel_4x4.h>
+#include <hgemm_kernel_4x8.h>
+#include <hgemm_kernel_8x8.h>
+#include <hgemm_kernel_pack.h>
+#include <hgemm_util.h>
+
+#define HGEMM_KERNEL_4x4 hgemm_kernel_4x4
+#define HGEMM_KERNEL_4x8 hgemm_kernel_4x8
+#define HGEMM_KERNEL_8x8 hgemm_kernel_8x8
+
+void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M,
+                   unsigned int N, unsigned int K, float alpha, float beta) {
+  if (alpha == 1.F && beta == 0.F) {
+    if (M % 8 == 0 && N % 8 == 0 && K % 8 == 0) {
+      hgemm_noTrans_8x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
+    } else if (M % 4 == 0 && N % 8 == 0 && K % 4 == 0) {
+      hgemm_noTrans_4x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
+    } else {
+      hgemm_noTrans_fallback(M, N, K, A, K, B, N, C32, N, alpha, beta);
+    }
+  } else
+    hgemm_noTrans_fallback(M, N, K, A, K, B, N, C32, N, alpha, beta);
+}
+
+void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, __fp16 *C, unsigned int ldc,
+                       float alpha, float beta) {
+  __fp16 *sa = alignedMalloc(M * K);
+  __fp16 *sb = alignedMalloc(K * N);
+
+  unsigned int ms, mms, ns, ks;
+  unsigned int m_min, m2_min, n_min, k_min;
+  for (ms = 0; ms < M; ms += M_BLOCKING) {
+    m_min = M - ms;
+    if (m_min > M_BLOCKING) {
+      m_min = M_BLOCKING;
+    }
+
+    for (ks = 0; ks < K; ks += k_min) {
+      k_min = K - ks;
+      if (k_min >= (K_BLOCKING << 1)) {
+        k_min = K_BLOCKING;
+      } else if (k_min > K_BLOCKING) {
+        k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+      }
+
+      n_min = N;
+      if (N >= N_BLOCKING * 2) {
+        n_min = N_BLOCKING;
+      } else if (N > N_BLOCKING) {
+        n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+      }
+      packing_B4(k_min, n_min, B + ks * ldb, ldb, sb);
+
+      for (mms = ms; mms < ms + m_min; mms += m2_min) {
+        m2_min = (ms + m_min) - mms;
+        if (m2_min >= 3 * GEMM_UNROLLING_4) {
+          m2_min = 3 * GEMM_UNROLLING_4;
+        } else if (m2_min >= 2 * GEMM_UNROLLING_4) {
+          m2_min = 2 * GEMM_UNROLLING_4;
+        } else if (m2_min > GEMM_UNROLLING_4) {
+          m2_min = GEMM_UNROLLING_4;
+        }
+
+        packing_A4(m2_min, k_min, A + mms * lda + ks, lda,
+                   sa + k_min * (mms - ms));
+
+        HGEMM_KERNEL_4x4(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
+                         C + mms * ldc, ldc);
+      }
+
+      for (ns = n_min; ns < N; ns += n_min) {
+        n_min = N - ns;
+        if (n_min >= N_BLOCKING * 2) {
+          n_min = N_BLOCKING;
+        } else if (n_min > N_BLOCKING) {
+          n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+        }
+
+        packing_B4(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+        HGEMM_KERNEL_4x4(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+      }
+    }
+  }
+
+  free(sa);
+  free(sb);
+}
+
+void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, __fp16 *C, unsigned int ldc,
+                       float alpha, float beta) {
+  __fp16 *sa = alignedMalloc(M * K);
+  __fp16 *sb = alignedMalloc(K * N);
+
+  unsigned int ms, mms, ns, ks;
+  unsigned int m_min, m2_min, n_min, k_min;
+  unsigned int l1stride = 1;
+  for (ms = 0; ms < M; ms += M_BLOCKING) {
+    m_min = M - ms;
+    if (m_min > M_BLOCKING) {
+      m_min = M_BLOCKING;
+    }
+
+    for (ks = 0; ks < K; ks += k_min) {
+      k_min = K - ks;
+      if (k_min >= (K_BLOCKING << 1)) {
+        k_min = K_BLOCKING;
+      } else if (k_min > K_BLOCKING) {
+        k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+      }
+
+      n_min = N;
+      if (N >= N_BLOCKING * 2) {
+        n_min = N_BLOCKING;
+      } else if (N > N_BLOCKING) {
+        n_min = ((n_min / 2 + GEMM_UNROLLING_8 - 1) / GEMM_UNROLLING_8) *
+                GEMM_UNROLLING_8;
+      } else {
+        l1stride = 0;
+      }
+      packing_B8(k_min, n_min, B + ks * ldb, ldb, sb);
+
+      for (mms = ms; mms < ms + m_min; mms += m2_min) {
+        m2_min = (ms + m_min) - mms;
+        if (m2_min >= 3 * GEMM_UNROLLING_4) {
+          m2_min = 3 * GEMM_UNROLLING_4;
+        } else if (m2_min >= 2 * GEMM_UNROLLING_4) {
+          m2_min = 2 * GEMM_UNROLLING_4;
+        } else if (m2_min > GEMM_UNROLLING_4) {
+          m2_min = GEMM_UNROLLING_4;
+        }
+
+        packing_A4(m2_min, k_min, A + mms * lda + ks, lda,
+                   sa + k_min * (mms - ms) * l1stride);
+
+        HGEMM_KERNEL_4x8(m2_min, n_min, k_min,
+                         sa + l1stride * k_min * (mms - ms), sb, C + mms * ldc,
+                         ldc);
+      }
+
+      for (ns = n_min; ns < N; ns += n_min) {
+        n_min = N - ns;
+        if (n_min >= N_BLOCKING * 2) {
+          n_min = N_BLOCKING;
+        } else if (n_min > N_BLOCKING) {
+          n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+        }
+
+        packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+        HGEMM_KERNEL_4x8(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+      }
+    }
+  }
+
+  free(sa);
+  free(sb);
+}
+
+void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, float *C, unsigned int ldc,
+                       float alpha, float beta) {
+  __fp16 *sa = alignedMalloc(M * K);
+  __fp16 *sb = alignedMalloc(K * N);
+
+  unsigned int ms, mms, ns, ks;
+  unsigned int m_min, m2_min, n_min, k_min;
+  unsigned int l1stride = 1;
+  for (ms = 0; ms < M; ms += M_BLOCKING) {
+    m_min = M - ms;
+    if (m_min > M_BLOCKING) {
+      m_min = M_BLOCKING;
+    }
+
+    for (ks = 0; ks < K; ks += k_min) {
+      k_min = K - ks;
+      if (k_min >= (K_BLOCKING << 1)) {
+        k_min = K_BLOCKING;
+      } else if (k_min > K_BLOCKING) {
+        k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+      }
+
+      n_min = N;
+      if (N >= N_BLOCKING * 2) {
+        n_min = N_BLOCKING;
+      } else if (N > N_BLOCKING) {
+        n_min = ((n_min / 2 + GEMM_UNROLLING_8 - 1) / GEMM_UNROLLING_8) *
+                GEMM_UNROLLING_8;
+      } else {
+        l1stride = 0;
+      }
+      packing_B8(k_min, n_min, B + ks * ldb, ldb, sb);
+
+      for (mms = ms; mms < ms + m_min; mms += m2_min) {
+        m2_min = (ms + m_min) - mms;
+        if (m2_min >= 3 * GEMM_UNROLLING_4) {
+          m2_min = 3 * GEMM_UNROLLING_4;
+        } else if (m2_min >= 2 * GEMM_UNROLLING_4) {
+          m2_min = 2 * GEMM_UNROLLING_4;
+        } else if (m2_min > GEMM_UNROLLING_4) {
+          m2_min = GEMM_UNROLLING_4;
+        }
+
+        packing_A4(m2_min, k_min, A + mms * lda + ks, lda,
+                   sa + k_min * (mms - ms) * l1stride);
+
+        HGEMM_KERNEL_4x8(m2_min, n_min, k_min,
+                         sa + l1stride * k_min * (mms - ms), sb, C + mms * ldc,
+                         ldc);
+      }
+
+      for (ns = n_min; ns < N; ns += n_min) {
+        n_min = N - ns;
+        if (n_min >= N_BLOCKING * 2) {
+          n_min = N_BLOCKING;
+        } else if (n_min > N_BLOCKING) {
+          n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+        }
+
+        packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+        HGEMM_KERNEL_4x8(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+      }
+    }
+  }
+
+  free(sa);
+  free(sb);
+}
+
+void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, __fp16 *C, unsigned int ldc,
+                       float alpha, float beta) {
+
+  __fp16 *sa = alignedMalloc(M * K);
+  __fp16 *sb = alignedMalloc(K * N);
+
+  unsigned int ms, mms, ns, ks;
+  unsigned int m_min, m2_min, n_min, k_min;
+  for (ms = 0; ms < M; ms += M_BLOCKING) {
+    m_min = M - ms;
+    if (m_min > M_BLOCKING) {
+      m_min = M_BLOCKING;
+    }
+
+    for (ks = 0; ks < K; ks += k_min) {
+      k_min = K - ks;
+      if (k_min >= (K_BLOCKING << 1)) {
+        k_min = K_BLOCKING;
+      } else if (k_min > K_BLOCKING) {
+        k_min = (k_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+      }
+
+      n_min = N;
+      if (N >= N_BLOCKING * 2) {
+        n_min = N_BLOCKING;
+      } else if (N > N_BLOCKING) {
+        n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+      }
+      packing_B8(k_min, n_min, B + ks * ldb, ldb, sb);
+
+      for (mms = ms; mms < ms + m_min; mms += m2_min) {
+        m2_min = (ms + m_min) - mms;
+        if (m2_min >= 3 * GEMM_UNROLLING_8) {
+          m2_min = 3 * GEMM_UNROLLING_8;
+        } else if (m2_min >= 2 * GEMM_UNROLLING_8) {
+          m2_min = 2 * GEMM_UNROLLING_8;
+        } else if (m2_min > GEMM_UNROLLING_8) {
+          m2_min = GEMM_UNROLLING_8;
+        }
+
+        packing_A8(m2_min, k_min, A + mms * lda + ks, lda,
+                   sa + k_min * (mms - ms));
+
+        HGEMM_KERNEL_8x8(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
+                         C + mms * ldc, ldc);
+      }
+
+      for (ns = n_min; ns < N; ns += n_min) {
+        n_min = N - ns;
+        if (n_min >= N_BLOCKING * 2) {
+          n_min = N_BLOCKING;
+        } else if (n_min > N_BLOCKING) {
+          n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+        }
+
+        packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+        HGEMM_KERNEL_8x8(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+      }
+    }
+  }
+
+  free(sa);
+  free(sb);
+}
+
+void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, float *C, unsigned int ldc,
+                       float alpha, float beta) {
+
+  __fp16 *sA = alignedMalloc(M * K);
+  __fp16 *sB = alignedMalloc(K * N);
+
+  unsigned int ms, ms2, ns, ks;
+  unsigned int m_min, m2_min, n_min, k_min;
+  for (ms = 0; ms < M; ms += M_BLOCKING) {
+    m_min = M - ms;
+    if (m_min > M_BLOCKING) {
+      m_min = M_BLOCKING;
+    }
+
+    for (ks = 0; ks < K; ks += k_min) {
+      k_min = K - ks;
+      if (k_min >= (K_BLOCKING << 1)) {
+        k_min = K_BLOCKING;
+      } else if (k_min > K_BLOCKING) {
+        k_min = (k_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+      }
+
+      n_min = N;
+      if (N >= N_BLOCKING * 2) {
+        n_min = N_BLOCKING;
+      } else if (N > N_BLOCKING) {
+        n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+      }
+      packing_B8(k_min, n_min, B + ks * ldb, ldb, sB);
+
+      for (ms2 = ms; ms2 < ms + m_min; ms2 += m2_min) {
+        m2_min = (ms + m_min) - ms2;
+        if (m2_min >= 3 * GEMM_UNROLLING_8) {
+          m2_min = 3 * GEMM_UNROLLING_8;
+        } else if (m2_min >= 2 * GEMM_UNROLLING_8) {
+          m2_min = 2 * GEMM_UNROLLING_8;
+        } else if (m2_min > GEMM_UNROLLING_8) {
+          m2_min = GEMM_UNROLLING_8;
+        }
+
+        packing_A8(m2_min, k_min, A + ms2 * lda + ks, lda,
+                   sA + k_min * (ms2 - ms));
+
+        HGEMM_KERNEL_8x8(m2_min, n_min, k_min, sA + k_min * (ms2 - ms), sB,
+                         C + ms2 * ldc, ldc);
+      }
+
+      for (ns = n_min; ns < N; ns += n_min) {
+        n_min = N - ns;
+        if (n_min >= N_BLOCKING * 2) {
+          n_min = N_BLOCKING;
+        } else if (n_min > N_BLOCKING) {
+          n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+        }
+
+        packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sB);
+        HGEMM_KERNEL_8x8(m_min, n_min, k_min, sA, sB, C + ms * ldc + ns, ldc);
+      }
+    }
+  }
+
+  free(sA);
+  free(sB);
+}
+
+void hgemm_noTrans_fallback(unsigned int M, unsigned int N, unsigned int K,
+                            const __fp16 *A, unsigned int lda, const __fp16 *B,
+                            unsigned int ldb, float *C, unsigned int ldc,
+                            float alpha, float beta) {
+
+  unsigned int k = 0;
+  unsigned int N8 = (N >> 3) << 3;
+  __fp16 a[16];
+  for (; (K - k) >= 16; k += 16) {
+    for (unsigned int m = 0; m < M; m++) {
+      vst1q_f16(&a[0], vmulq_n_f16(vld1q_f16(&A[m * K + k]), alpha));
+      vst1q_f16(&a[8], vmulq_n_f16(vld1q_f16(&A[m * K + k + 8]), alpha));
+      for (unsigned int n = 0; n < N8; n += 8) {
+        float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 2) * N + n]), a[2]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 3) * N + n]), a[3]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 4) * N + n]), a[4]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 5) * N + n]), a[5]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 6) * N + n]), a[6]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 7) * N + n]), a[7]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 8) * N + n]), a[8]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 9) * N + n]), a[9]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 10) * N + n]), a[10]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 11) * N + n]), a[11]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 12) * N + n]), a[12]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 13) * N + n]), a[13]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 14) * N + n]), a[14]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 15) * N + n]), a[15]);
+
+        float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]),
+                                            vcvt_f32_f16(vget_low_f16(b0_7_0)));
+        float32x4_t c0_7_high_32 = vaddq_f32(
+          vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0)));
+
+        vst1q_f32(&C[m * N + n], c0_7_low_32);
+        vst1q_f32(&C[m * N + n + 4], c0_7_high_32);
+      }
+      if (N != N8) {
+        unsigned int n = N8;
+        __fp16 valsB_0[8];
+        __fp16 valsB_1[8];
+        __fp16 valsB_2[8];
+        __fp16 valsB_3[8];
+        __fp16 valsB_4[8];
+        __fp16 valsB_5[8];
+        __fp16 valsB_6[8];
+        __fp16 valsB_7[8];
+        __fp16 valsB_8[8];
+        __fp16 valsB_9[8];
+        __fp16 valsB_10[8];
+        __fp16 valsB_11[8];
+        __fp16 valsB_12[8];
+        __fp16 valsB_13[8];
+        __fp16 valsB_14[8];
+        __fp16 valsB_15[8];
+        float valsC[8];
+        for (unsigned int idx = n; idx < N; idx++) {
+          valsB_0[idx - n] = B[k * N + idx];
+          valsB_1[idx - n] = B[(k + 1) * N + idx];
+          valsB_2[idx - n] = B[(k + 2) * N + idx];
+          valsB_3[idx - n] = B[(k + 3) * N + idx];
+          valsB_4[idx - n] = B[(k + 4) * N + idx];
+          valsB_5[idx - n] = B[(k + 5) * N + idx];
+          valsB_6[idx - n] = B[(k + 6) * N + idx];
+          valsB_7[idx - n] = B[(k + 7) * N + idx];
+          valsB_8[idx - n] = B[(k + 8) * N + idx];
+          valsB_9[idx - n] = B[(k + 9) * N + idx];
+          valsB_10[idx - n] = B[(k + 10) * N + idx];
+          valsB_11[idx - n] = B[(k + 11) * N + idx];
+          valsB_12[idx - n] = B[(k + 12) * N + idx];
+          valsB_13[idx - n] = B[(k + 13) * N + idx];
+          valsB_14[idx - n] = B[(k + 14) * N + idx];
+          valsB_15[idx - n] = B[(k + 15) * N + idx];
+          valsC[idx - n] = C[m * N + idx];
+        }
+
+        float16x8_t b = vmulq_n_f16(vld1q_f16(valsB_0), a[0]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_1), a[1]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_2), a[2]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_3), a[3]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_4), a[4]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_5), a[5]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_6), a[6]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_7), a[7]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_8), a[8]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_9), a[9]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_10), a[10]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_11), a[11]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_12), a[12]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_13), a[13]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_14), a[14]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_15), a[15]);
+
+        float32x4_t c0_7_low_32 =
+          vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b)));
+
+        float32x4_t c0_7_high_32 =
+          vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b)));
+
+        vst1q_f32(valsC, c0_7_low_32);
+        vst1q_f32(valsC + 4, c0_7_high_32);
+
+        for (unsigned int idx = n; idx < N; idx++) {
+          C[m * N + idx] = valsC[idx - n];
+        }
+      }
+    }
+  }
+
+  for (; (K - k) >= 8; k += 8) {
+    for (unsigned int m = 0; m < M; m++) {
+      vst1q_f16(a, vmulq_n_f16(vld1q_f16(&A[m * K + k]), alpha));
+
+      for (unsigned int n = 0; n < N8; n += 8) {
+        float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 2) * N + n]), a[2]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 3) * N + n]), a[3]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 4) * N + n]), a[4]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 5) * N + n]), a[5]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 6) * N + n]), a[6]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 7) * N + n]), a[7]);
+
+        float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]),
+                                            vcvt_f32_f16(vget_low_f16(b0_7_0)));
+        float32x4_t c0_7_high_32 = vaddq_f32(
+          vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0)));
+
+        vst1q_f32(&C[m * N + n], c0_7_low_32);
+        vst1q_f32(&C[m * N + n + 4], c0_7_high_32);
+      }
+      if (N != N8) {
+        unsigned int n = N8;
+        __fp16 valsB_0[8];
+        __fp16 valsB_1[8];
+        __fp16 valsB_2[8];
+        __fp16 valsB_3[8];
+        __fp16 valsB_4[8];
+        __fp16 valsB_5[8];
+        __fp16 valsB_6[8];
+        __fp16 valsB_7[8];
+        float valsC[8];
+        for (unsigned int idx = n; idx < N; idx++) {
+          valsB_0[idx - n] = B[k * N + idx];
+          valsB_1[idx - n] = B[(k + 1) * N + idx];
+          valsB_2[idx - n] = B[(k + 2) * N + idx];
+          valsB_3[idx - n] = B[(k + 3) * N + idx];
+          valsB_4[idx - n] = B[(k + 4) * N + idx];
+          valsB_5[idx - n] = B[(k + 5) * N + idx];
+          valsB_6[idx - n] = B[(k + 6) * N + idx];
+          valsB_7[idx - n] = B[(k + 7) * N + idx];
+          valsC[idx - n] = C[m * N + idx];
+        }
+
+        float16x8_t b = vmulq_n_f16(vld1q_f16(valsB_0), a[0]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_1), a[1]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_2), a[2]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_3), a[3]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_4), a[4]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_5), a[5]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_6), a[6]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_7), a[7]);
+
+        float32x4_t c0_7_low_32 =
+          vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b)));
+
+        float32x4_t c0_7_high_32 =
+          vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b)));
+
+        vst1q_f32(valsC, c0_7_low_32);
+        vst1q_f32(valsC + 4, c0_7_high_32);
+
+        for (unsigned int idx = n; idx < N; idx++) {
+          C[m * N + idx] = valsC[idx - n];
+        }
+      }
+    }
+  }
+
+  for (; (K - k) >= 4; k += 4) {
+    for (unsigned int m = 0; m < M; m++) {
+      vst1_f16(a, vmul_n_f16(vld1_f16(&A[m * K + k]), alpha));
+
+      for (unsigned int n = 0; n < N8; n += 8) {
+
+        float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]);
+        b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]);
+        float16x8_t b0_7_2 = vmulq_n_f16(vld1q_f16(&B[(k + 2) * N + n]), a[2]);
+        b0_7_2 = vfmaq_n_f16(b0_7_2, vld1q_f16(&B[(k + 3) * N + n]), a[3]);
+
+        float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]),
+                                            vcvt_f32_f16(vget_low_f16(b0_7_0)));
+        float32x4_t c0_7_high_32 = vaddq_f32(
+          vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0)));
+
+        c0_7_low_32 =
+          vaddq_f32(c0_7_low_32, vcvt_f32_f16(vget_low_f16(b0_7_2)));
+        c0_7_high_32 =
+          vaddq_f32(c0_7_high_32, vcvt_f32_f16(vget_high_f16(b0_7_2)));
+
+        vst1q_f32(&C[m * N + n], c0_7_low_32);
+        vst1q_f32(&C[m * N + n + 4], c0_7_high_32);
+      }
+      if (N != N8) {
+        unsigned int n = N8;
+        __fp16 valsB_0[8];
+        __fp16 valsB_1[8];
+        __fp16 valsB_2[8];
+        __fp16 valsB_3[8];
+        float valsC[8];
+        for (unsigned int idx = n; idx < N; idx++) {
+          valsB_0[idx - n] = B[k * N + idx];
+          valsB_1[idx - n] = B[(k + 1) * N + idx];
+          valsB_2[idx - n] = B[(k + 2) * N + idx];
+          valsB_3[idx - n] = B[(k + 3) * N + idx];
+          valsC[idx - n] = C[m * N + idx];
+        }
+
+        float16x8_t b = vmulq_n_f16(vld1q_f16(valsB_0), a[0]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_1), a[1]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_2), a[2]);
+        b = vfmaq_n_f16(b, vld1q_f16(valsB_3), a[3]);
+
+        float32x4_t c0_7_low_32 =
+          vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b)));
+
+        float32x4_t c0_7_high_32 =
+          vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b)));
+
+        vst1q_f32(valsC, c0_7_low_32);
+        vst1q_f32(valsC + 4, c0_7_high_32);
+
+        for (unsigned int idx = n; idx < N; idx++) {
+          C[m * N + idx] = valsC[idx - n];
+        }
+      }
+    }
+  }
+
+  for (; k < K; k++) {
+    for (unsigned int m = 0; m < M; m++) {
+      __fp16 a0 = alpha * A[m * K + k];
+
+      for (unsigned int n = 0; n < N8; n += 8) {
+        float16x8_t b0_7 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a0);
+
+        float32x4_t c0_7_low_32 =
+          vaddq_f32(vld1q_f32(&C[m * N + n]), vcvt_f32_f16(vget_low_f16(b0_7)));
+
+        float32x4_t c0_7_high_32 = vaddq_f32(vld1q_f32(&C[m * N + n + 4]),
+                                             vcvt_f32_f16(vget_high_f16(b0_7)));
+
+        vst1q_f32(&C[m * N + n], c0_7_low_32);
+        vst1q_f32(&C[m * N + n + 4], c0_7_high_32);
+      }
+      if (N != N8) {
+        unsigned int n = N8;
+        __fp16 valsB[8];
+        float valsC[8];
+        for (unsigned int idx = n; idx < N; idx++) {
+          valsB[idx - n] = B[k * N + idx];
+          valsC[idx - n] = C[m * N + idx];
+        }
+
+        float16x8_t b = vmulq_n_f16(vld1q_f16(valsB), a0);
+
+        float32x4_t c0_7_low_32 =
+          vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b)));
+
+        float32x4_t c0_7_high_32 =
+          vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b)));
+
+        vst1q_f32(valsC, c0_7_low_32);
+        vst1q_f32(valsC + 4, c0_7_high_32);
+
+        for (unsigned int idx = n; idx < N; idx++) {
+          C[m * N + idx] = valsC[idx - n];
+        }
+      }
+    }
+  }
+}
index bcd74b68d2aa72cf2b4285aec64f9e94fa01f683..fde7112a60e1f2dae4a521a3e7b6a7dff29bc684 100644 (file)
  *
  */
 
-#include <hgemm_kernel_4x4.h>
-#include <hgemm_kernel_8x8.h>
-#include <hgemm_kernel_pack.h>
-#include <hgemm_util.h>
+/**
+ * @brief     hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * @param[in] A __fp16 * for Matrix A
+ * @param[in] B __fp16 * for Matrix B
+ * @param[in] C __fp16 * for Matrix C
+ * @param[in] M number of op(A)'s and C's row
+ * @param[in] N number of op(B)'s and C's columns
+ * @param[in] K number of op(A)'s and columns and op(B)'s rows
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
+                   unsigned int N, unsigned int K, float alpha = 1.F,
+                   float beta = 0.F);
 
-#define KERNEL_4x4 hgemm_kernel_4x4
-#define KERNEL_8x8 hgemm_kernel_8x8
+/**
+ * @brief     hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix C
+ * @param B input matrix B
+ * @param ldb length of the col of matrix C
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_fallback(unsigned int M, unsigned int N, unsigned int K,
+                            const __fp16 *A, unsigned int lda, const __fp16 *B,
+                            unsigned int ldb, float *C, unsigned int ldc,
+                            float alpha = 1.F, float beta = 0.F);
 
 /**
  * @brief hgemm noTrans computation with 4x4 kernel : C = A*B,
- * 
- * @param m length of the row of matrix A
- * @param n length of the col of matrix B 
- * @param k length of the col of matrix A
- * @param a input matrix A
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix C
+ * @param B input matrix B
+ * @param ldb length of the col of matrix C
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, __fp16 *C, unsigned int ldc,
+                       float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 8x8 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
  * @param lda length of the col of matrix C
- * @param b input matrix B
+ * @param B input matrix B
  * @param ldb length of the col of matrix C
- * @param c output matrix C
+ * @param C output matrix C
  * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
  */
-void hgemm_noTrans_4x4(unsigned int m, unsigned int n, unsigned int k,
-                       const __fp16 *a, unsigned int lda, const __fp16 *b,
-                       unsigned int ldb, __fp16 *c, unsigned int ldc);
+void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, __fp16 *C, unsigned int ldc,
+                       float alpha = 1.F, float beta = 0.F);
 
 /**
  * @brief hgemm noTrans computation with 8x8 kernel : C = A*B,
- * 
- * @param m length of the row of matrix A
- * @param n length of the col of matrix B 
- * @param k length of the col of matrix A
- * @param a input matrix A
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix C
+ * @param B input matrix B
+ * @param ldb length of the col of matrix C
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, float *C, unsigned int ldc,
+                       float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 4x8 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix C
+ * @param B input matrix B
+ * @param ldb length of the col of matrix C
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, __fp16 *C, unsigned int ldc,
+                       float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 4x8 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
  * @param lda length of the col of matrix C
- * @param b input matrix B
+ * @param B input matrix B
  * @param ldb length of the col of matrix C
- * @param c output matrix C
+ * @param C output matrix C
  * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
  */
-void hgemm_noTrans_8x8(unsigned int m, unsigned int n, unsigned int k,
-                       const __fp16 *a, unsigned int lda, const __fp16 *b,
-                       unsigned int ldb, __fp16 *c, unsigned int ldc);
+void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, float *C, unsigned int ldc,
+                       float alpha = 1.F, float beta = 0.F);
index f5bb3fa36b27e6142890eee8226b66d90c1c555b..68728102423a459b05aad20ce33b6cfb8f6c8ddb 100644 (file)
  * @brief  This is common settings for hgemm
  *
  */
+#include <arm_neon.h>
+#include <assert.h>
 
 #define A(i, j) a[(i)*lda + (j)]
 #define B(i, j) b[(i)*ldb + (j)]
 #define C(i, j) c[(i)*ldc + (j)]
 
-#define GEMM_N (384)
-#define GEMM_K (256)
-#define GEMM_M (4096)
+#define N_BLOCKING (384)
+#define K_BLOCKING (256)
+#define M_BLOCKING (4096)
 #define GEMM_UNROLLING_8 (8)
 #define GEMM_UNROLLING_4 (4)
 #define VL_FP16 (8)
index fbb2f7447011c35b10640f4ac3827ba880ae347a..6166b9407dca3dc5653c87dbd06c81d19403d8db 100644 (file)
  *
  */
 
-#include <cmath>
 #include <hgemm_common.h>
-#include <math.h>
 #include <stdlib.h>
 
 /**
  * @brief hgemm 4x4 kernel sc = sa * sb
- * 
+ *
  * @param m length of the row of matrix A
- * @param n length of the col of matrix B 
+ * @param n length of the col of matrix B
  * @param k length of the col of matrix A
  * @param sa sub-matrix of input matrix A
  * @param sb sub-matrix of input matrix B
  * @param sc sub-matrix of output matrix C
  * @param ldc leading dimension of matrix C
  */
-void hgemm_kernel_4x4(unsigned int m, unsigned int n, unsigned int k,
+void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K,
                       __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) {
-  assert(m > 0 && n > 0 && k > 0);
-  assert(m % 4 == 0 && n % 4 == 0 && k % 4 == 0);
+  assert(M > 0 && N > 0 && K > 0);
+  assert(M % 4 == 0 && N % 4 == 0 && K % 4 == 0);
 
   __fp16 *a = sa, *b = sb, *c = sc;
   unsigned int i, j, l;
-  for (i = 0; i < m; i += VL_FP16_HALF) {
-    for (j = 0; j < n; j += VL_FP16_HALF) {
+  for (i = 0; i < M; i += VL_FP16_HALF) {
+    for (j = 0; j < N; j += VL_FP16_HALF) {
       __builtin_prefetch(b, 0, 3);
       __builtin_prefetch(a, 0, 3);
 
@@ -44,7 +42,7 @@ void hgemm_kernel_4x4(unsigned int m, unsigned int n, unsigned int k,
       float16x4_t v26 = {0};
       float16x4_t v27 = {0};
 
-      for (l = 0; l < k; l += VL_FP16_HALF) {
+      for (l = 0; l < K; l += VL_FP16_HALF) {
         float16x4_t v0 = vld1_f16(b);
         float16x4_t v16 = vld1_f16(a);
 
@@ -95,12 +93,11 @@ void hgemm_kernel_4x4(unsigned int m, unsigned int n, unsigned int k,
       vst1_f16(c + 3 * ldc, v27);
 
       c += 4;
-      a -= 4 * k;
+      a -= 4 * K;
     }
     sc += ldc * 4;
     c = sc;
-    a += 4 * k;
+    a += 4 * K;
     b = sb;
   }
 }
-
index 60881f15d269c596f9fcd97d1dcaf65a9b85e806..14c112c1d144f766d81fcd55fe8cb3bf7a29a8a6 100644 (file)
 
 /**
  * @brief packing function of input matrix A
- * 
- * @param m length of the row of the matrix
- * @param k length of the col of the matrix
- * @param from input of original source of the matrix
+ *
+ * @param M length of the row of the matrix
+ * @param K length of the col of the matrix
+ * @param src input of original source of the matrix
  * @param lda leading dimension of the matrix
- * @param to output of packed data of the matrix
+ * @param dst output of packed data of the matrix
  */
-void packA_4(unsigned int m, unsigned int k, const __fp16 *from,
-             unsigned int lda, const __fp16 *to) {
+void packing_A4(unsigned int M, unsigned int K, const __fp16 *src,
+                unsigned int lda, const __fp16 *dst) {
 
-  assert(k != 0 && m != 0 && k % 4 == 0 && m % 4 == 0);
+  assert(K != 0 && M != 0 && K % 4 == 0 && M % 4 == 0);
   unsigned int i, j;
 
-  __fp16 *a_offset, *a_offset1, *a_offset2, *a_offset3, *a_offset4;
-  __fp16 *b_offset;
-  __fp16 ctemp1, ctemp2, ctemp3, ctemp4;
-  __fp16 ctemp5, ctemp6, ctemp7, ctemp8;
-  __fp16 ctemp9, ctemp10, ctemp11, ctemp12;
-  __fp16 ctemp13, ctemp14, ctemp15, ctemp16;
+  __fp16 *a_off, *a_off1, *a_off2, *a_off3, *a_off4;
+  __fp16 *b_off;
+  __fp16 c1, c2, c3, c4;
+  __fp16 c5, c6, c7, c8;
+  __fp16 c9, c10, c11, c12;
+  __fp16 c13, c14, c15, c16;
 
-  a_offset = (__fp16 *)from;
-  b_offset = (__fp16 *)to;
+  a_off = (__fp16 *)src;
+  b_off = (__fp16 *)dst;
 
-  j = (m >> 2);
+  j = (M >> 2);
   do {
-    a_offset1 = a_offset;
-    a_offset2 = a_offset1 + lda;
-    a_offset3 = a_offset2 + lda;
-    a_offset4 = a_offset3 + lda;
-    a_offset += 4 * lda;
+    a_off1 = a_off;
+    a_off2 = a_off1 + lda;
+    a_off3 = a_off2 + lda;
+    a_off4 = a_off3 + lda;
+    a_off += 4 * lda;
 
-    i = (k >> 2);
+    i = (K >> 2);
     do {
-      ctemp1 = *(a_offset1 + 0);
-      ctemp2 = *(a_offset1 + 1);
-      ctemp3 = *(a_offset1 + 2);
-      ctemp4 = *(a_offset1 + 3);
-
-      ctemp5 = *(a_offset2 + 0);
-      ctemp6 = *(a_offset2 + 1);
-      ctemp7 = *(a_offset2 + 2);
-      ctemp8 = *(a_offset2 + 3);
-
-      ctemp9 = *(a_offset3 + 0);
-      ctemp10 = *(a_offset3 + 1);
-      ctemp11 = *(a_offset3 + 2);
-      ctemp12 = *(a_offset3 + 3);
-
-      ctemp13 = *(a_offset4 + 0);
-      ctemp14 = *(a_offset4 + 1);
-      ctemp15 = *(a_offset4 + 2);
-      ctemp16 = *(a_offset4 + 3);
-
-      *(b_offset + 0) = ctemp1;
-      *(b_offset + 1) = ctemp5;
-      *(b_offset + 2) = ctemp9;
-      *(b_offset + 3) = ctemp13;
-
-      *(b_offset + 4) = ctemp2;
-      *(b_offset + 5) = ctemp6;
-      *(b_offset + 6) = ctemp10;
-      *(b_offset + 7) = ctemp14;
-
-      *(b_offset + 8) = ctemp3;
-      *(b_offset + 9) = ctemp7;
-      *(b_offset + 10) = ctemp11;
-      *(b_offset + 11) = ctemp15;
-
-      *(b_offset + 12) = ctemp4;
-      *(b_offset + 13) = ctemp8;
-      *(b_offset + 14) = ctemp12;
-      *(b_offset + 15) = ctemp16;
-
-      a_offset1 += 4;
-      a_offset2 += 4;
-      a_offset3 += 4;
-      a_offset4 += 4;
-
-      b_offset += 16;
+      c1 = *(a_off1 + 0);
+      c2 = *(a_off1 + 1);
+      c3 = *(a_off1 + 2);
+      c4 = *(a_off1 + 3);
+
+      c5 = *(a_off2 + 0);
+      c6 = *(a_off2 + 1);
+      c7 = *(a_off2 + 2);
+      c8 = *(a_off2 + 3);
+
+      c9 = *(a_off3 + 0);
+      c10 = *(a_off3 + 1);
+      c11 = *(a_off3 + 2);
+      c12 = *(a_off3 + 3);
+
+      c13 = *(a_off4 + 0);
+      c14 = *(a_off4 + 1);
+      c15 = *(a_off4 + 2);
+      c16 = *(a_off4 + 3);
+
+      *(b_off + 0) = c1;
+      *(b_off + 1) = c5;
+      *(b_off + 2) = c9;
+      *(b_off + 3) = c13;
+
+      *(b_off + 4) = c2;
+      *(b_off + 5) = c6;
+      *(b_off + 6) = c10;
+      *(b_off + 7) = c14;
+
+      *(b_off + 8) = c3;
+      *(b_off + 9) = c7;
+      *(b_off + 10) = c11;
+      *(b_off + 11) = c15;
+
+      *(b_off + 12) = c4;
+      *(b_off + 13) = c8;
+      *(b_off + 14) = c12;
+      *(b_off + 15) = c16;
+
+      a_off1 += 4;
+      a_off2 += 4;
+      a_off3 += 4;
+      a_off4 += 4;
+
+      b_off += 16;
       i--;
     } while (i > 0);
     j--;
@@ -102,54 +102,54 @@ void packA_4(unsigned int m, unsigned int k, const __fp16 *from,
 
 /**
  * @brief packing function of input matrix A
- * 
- * @param m length of the row of the matrix
- * @param k length of the col of the matrix
- * @param from input of original source of the matrix
+ *
+ * @param M length of the row of the matrix
+ * @param K length of the col of the matrix
+ * @param src input of original source of the matrix
  * @param lda leading dimension of the matrix
- * @param to output of packed data of the matrix
+ * @param dst output of packed data of the matrix
  */
-void packA_8(unsigned int m, unsigned int k, const __fp16 *from,
-             unsigned int lda, const __fp16 *to) {
+void packing_A8(unsigned int M, unsigned int K, const __fp16 *src,
+                unsigned int lda, const __fp16 *dst) {
 
-  assert(k != 0 && m != 0 && k % 8 == 0 && m % 8 == 0);
+  assert(K != 0 && M != 0 && K % 8 == 0 && M % 8 == 0);
 
   uint16x4_t msk = {0xFFFF, 0xFFFF, 0x0000, 0x0000};
   uint16x4_t inv_msk = {0x0000, 0x0000, 0xFFFF, 0xFFFF};
 
-  const __fp16 *a_offset = (__fp16 *)from;
-  __fp16 *b_offset = (__fp16 *)to;
-
-  for (unsigned int i = 0; i < m; i += 8) {
-    const __fp16 *a_offset1 = a_offset;
-    const __fp16 *a_offset2 = a_offset1 + lda;
-    const __fp16 *a_offset3 = a_offset2 + lda;
-    const __fp16 *a_offset4 = a_offset3 + lda;
-    const __fp16 *a_offset5 = a_offset4 + lda;
-    const __fp16 *a_offset6 = a_offset5 + lda;
-    const __fp16 *a_offset7 = a_offset6 + lda;
-    const __fp16 *a_offset8 = a_offset7 + lda;
-    a_offset += 8 * lda;
-
-    for (unsigned int j = 0; j < k; j += 8) {
-      float16x8_t _v0 = vld1q_f16(a_offset1);
-      float16x8_t _v1 = vld1q_f16(a_offset2);
-      float16x8_t _v2 = vld1q_f16(a_offset3);
-      float16x8_t _v3 = vld1q_f16(a_offset4);
-
-      float16x8_t _v4 = vld1q_f16(a_offset5);
-      float16x8_t _v5 = vld1q_f16(a_offset6);
-      float16x8_t _v6 = vld1q_f16(a_offset7);
-      float16x8_t _v7 = vld1q_f16(a_offset8);
-
-      a_offset1 += 8;
-      a_offset2 += 8;
-      a_offset3 += 8;
-      a_offset4 += 8;
-      a_offset5 += 8;
-      a_offset6 += 8;
-      a_offset7 += 8;
-      a_offset8 += 8;
+  const __fp16 *a_off = (__fp16 *)src;
+  __fp16 *b_off = (__fp16 *)dst;
+
+  for (unsigned int i = 0; i < M; i += 8) {
+    const __fp16 *a_off1 = a_off;
+    const __fp16 *a_off2 = a_off1 + lda;
+    const __fp16 *a_off3 = a_off2 + lda;
+    const __fp16 *a_off4 = a_off3 + lda;
+    const __fp16 *a_off5 = a_off4 + lda;
+    const __fp16 *a_off6 = a_off5 + lda;
+    const __fp16 *a_off7 = a_off6 + lda;
+    const __fp16 *a_off8 = a_off7 + lda;
+    a_off += 8 * lda;
+
+    for (unsigned int j = 0; j < K; j += 8) {
+      float16x8_t _v0 = vld1q_f16(a_off1);
+      float16x8_t _v1 = vld1q_f16(a_off2);
+      float16x8_t _v2 = vld1q_f16(a_off3);
+      float16x8_t _v3 = vld1q_f16(a_off4);
+
+      float16x8_t _v4 = vld1q_f16(a_off5);
+      float16x8_t _v5 = vld1q_f16(a_off6);
+      float16x8_t _v6 = vld1q_f16(a_off7);
+      float16x8_t _v7 = vld1q_f16(a_off8);
+
+      a_off1 += 8;
+      a_off2 += 8;
+      a_off3 += 8;
+      a_off4 += 8;
+      a_off5 += 8;
+      a_off6 += 8;
+      a_off7 += 8;
+      a_off8 += 8;
 
       float16x8x2_t _vv0 = vtrnq_f16(_v0, _v1);
       float16x8x2_t _vv1 = vtrnq_f16(_v2, _v3);
@@ -224,101 +224,101 @@ void packA_8(unsigned int m, unsigned int k, const __fp16 *from,
       _v15 = vcombine_f16(vbsl_f16(inv_msk, tmp_high_v11, mid_v11),
                           vbsl_f16(inv_msk, tmp_high_v15, mid_v15));
 
-      vst1q_f16(b_offset + 0, _v8);
-      vst1q_f16(b_offset + 8, _v12);
-      vst1q_f16(b_offset + 16, _v9);
-      vst1q_f16(b_offset + 24, _v13);
-      vst1q_f16(b_offset + 32, _v10);
-      vst1q_f16(b_offset + 40, _v14);
-      vst1q_f16(b_offset + 48, _v11);
-      vst1q_f16(b_offset + 56, _v15);
-      b_offset += 64;
+      vst1q_f16(b_off + 0, _v8);
+      vst1q_f16(b_off + 8, _v12);
+      vst1q_f16(b_off + 16, _v9);
+      vst1q_f16(b_off + 24, _v13);
+      vst1q_f16(b_off + 32, _v10);
+      vst1q_f16(b_off + 40, _v14);
+      vst1q_f16(b_off + 48, _v11);
+      vst1q_f16(b_off + 56, _v15);
+      b_off += 64;
     }
   }
 }
 
 /**
  * @brief packing function of input matrix B
- * 
- * @param m length of the row of the matrix
- * @param k length of the col of the matrix
- * @param from input of original source of the matrix
+ *
+ * @param M length of the row of the matrix
+ * @param K length of the col of the matrix
+ * @param src input of original source of the matrix
  * @param ldb leading dimension of the matrix
- * @param to output of packed data of the matrix
+ * @param dst output of packed data of the matrix
  */
-void packB_4(unsigned int k, unsigned int n, const __fp16 *from,
-             unsigned int ldb, const __fp16 *to) {
-  assert(k != 0 && n != 0 && k % 4 == 0 && n % 4 == 0);
+void packing_B4(unsigned int K, unsigned int N, const __fp16 *src,
+                unsigned int ldb, const __fp16 *dst) {
+  assert(K != 0 && N != 0 && K % 4 == 0 && N % 4 == 0);
   unsigned int i, j;
 
-  __fp16 *a_offset, *a_offset1, *a_offset2, *a_offset3, *a_offset4;
-  __fp16 *b_offset, *b_offset1;
-  __fp16 ctemp1, ctemp2, ctemp3, ctemp4;
-  __fp16 ctemp5, ctemp6, ctemp7, ctemp8;
-  __fp16 ctemp9, ctemp10, ctemp11, ctemp12;
-  __fp16 ctemp13, ctemp14, ctemp15, ctemp16;
-  a_offset = (__fp16 *)from;
-  b_offset = (__fp16 *)to;
+  __fp16 *a_off, *a_off1, *a_off2, *a_off3, *a_off4;
+  __fp16 *b_off, *b_off1;
+  __fp16 c1, c2, c3, c4;
+  __fp16 c5, c6, c7, c8;
+  __fp16 c9, c10, c11, c12;
+  __fp16 c13, c14, c15, c16;
+  a_off = (__fp16 *)src;
+  b_off = (__fp16 *)dst;
 
-  j = (k >> 2);
+  j = (K >> 2);
   do {
-    a_offset1 = a_offset;
-    a_offset2 = a_offset1 + ldb;
-    a_offset3 = a_offset2 + ldb;
-    a_offset4 = a_offset3 + ldb;
-    a_offset += 4 * ldb;
+    a_off1 = a_off;
+    a_off2 = a_off1 + ldb;
+    a_off3 = a_off2 + ldb;
+    a_off4 = a_off3 + ldb;
+    a_off += 4 * ldb;
 
-    b_offset1 = b_offset;
-    b_offset += 16;
+    b_off1 = b_off;
+    b_off += 16;
 
-    i = (n >> 2);
+    i = (N >> 2);
     do {
-      ctemp1 = *(a_offset1 + 0);
-      ctemp2 = *(a_offset1 + 1);
-      ctemp3 = *(a_offset1 + 2);
-      ctemp4 = *(a_offset1 + 3);
-
-      ctemp5 = *(a_offset2 + 0);
-      ctemp6 = *(a_offset2 + 1);
-      ctemp7 = *(a_offset2 + 2);
-      ctemp8 = *(a_offset2 + 3);
-
-      ctemp9 = *(a_offset3 + 0);
-      ctemp10 = *(a_offset3 + 1);
-      ctemp11 = *(a_offset3 + 2);
-      ctemp12 = *(a_offset3 + 3);
-
-      ctemp13 = *(a_offset4 + 0);
-      ctemp14 = *(a_offset4 + 1);
-      ctemp15 = *(a_offset4 + 2);
-      ctemp16 = *(a_offset4 + 3);
-
-      a_offset1 += 4;
-      a_offset2 += 4;
-      a_offset3 += 4;
-      a_offset4 += 4;
-
-      *(b_offset1 + 0) = ctemp1;
-      *(b_offset1 + 1) = ctemp2;
-      *(b_offset1 + 2) = ctemp3;
-      *(b_offset1 + 3) = ctemp4;
-
-      *(b_offset1 + 4) = ctemp5;
-      *(b_offset1 + 5) = ctemp6;
-      *(b_offset1 + 6) = ctemp7;
-      *(b_offset1 + 7) = ctemp8;
-
-      *(b_offset1 + 8) = ctemp9;
-      *(b_offset1 + 9) = ctemp10;
-      *(b_offset1 + 10) = ctemp11;
-      *(b_offset1 + 11) = ctemp12;
-
-      *(b_offset1 + 12) = ctemp13;
-      *(b_offset1 + 13) = ctemp14;
-      *(b_offset1 + 14) = ctemp15;
-      *(b_offset1 + 15) = ctemp16;
-
-      b_offset1 += k * 4;
+      c1 = *(a_off1 + 0);
+      c2 = *(a_off1 + 1);
+      c3 = *(a_off1 + 2);
+      c4 = *(a_off1 + 3);
+
+      c5 = *(a_off2 + 0);
+      c6 = *(a_off2 + 1);
+      c7 = *(a_off2 + 2);
+      c8 = *(a_off2 + 3);
+
+      c9 = *(a_off3 + 0);
+      c10 = *(a_off3 + 1);
+      c11 = *(a_off3 + 2);
+      c12 = *(a_off3 + 3);
+
+      c13 = *(a_off4 + 0);
+      c14 = *(a_off4 + 1);
+      c15 = *(a_off4 + 2);
+      c16 = *(a_off4 + 3);
+
+      a_off1 += 4;
+      a_off2 += 4;
+      a_off3 += 4;
+      a_off4 += 4;
+
+      *(b_off1 + 0) = c1;
+      *(b_off1 + 1) = c2;
+      *(b_off1 + 2) = c3;
+      *(b_off1 + 3) = c4;
+
+      *(b_off1 + 4) = c5;
+      *(b_off1 + 5) = c6;
+      *(b_off1 + 6) = c7;
+      *(b_off1 + 7) = c8;
+
+      *(b_off1 + 8) = c9;
+      *(b_off1 + 9) = c10;
+      *(b_off1 + 10) = c11;
+      *(b_off1 + 11) = c12;
+
+      *(b_off1 + 12) = c13;
+      *(b_off1 + 13) = c14;
+      *(b_off1 + 14) = c15;
+      *(b_off1 + 15) = c16;
+
+      b_off1 += K * 4;
       i--;
     } while (i > 0);
     j--;
@@ -327,26 +327,26 @@ void packB_4(unsigned int k, unsigned int n, const __fp16 *from,
 
 /**
  * @brief packing function of input matrix B
- * 
- * @param m length of the row of the matrix
- * @param k length of the col of the matrix
- * @param from input of original source of the matrix
+ *
+ * @param M length of the row of the matrix
+ * @param K length of the col of the matrix
+ * @param src input of original source of the matrix
  * @param ldb leading dimension of the matrix
- * @param to output of packed data of the matrix
+ * @param dst output of packed data of the matrix
  */
-void packB_8(unsigned int k, unsigned int n, const __fp16 *from,
-             unsigned int ldb, const __fp16 *to) {
-  assert(k != 0 && n != 0 && n % 8 == 0);
-
-  for (int i = 0; i < k; i++) {
-    const __fp16 *a_offset1 = from + i * ldb;
-    __fp16 *b_offset = (__fp16 *)to + i * 8;
-    for (int j = 0; j < n; j += 8) {
-      float16x8_t _v0 = vld1q_f16(a_offset1);
-      a_offset1 += 8;
-
-      vst1q_f16(b_offset, _v0);
-      b_offset += 8 * k;
+void packing_B8(unsigned int K, unsigned int N, const __fp16 *src,
+                unsigned int ldb, const __fp16 *dst) {
+  assert(K != 0 && N != 0 && N % 8 == 0);
+
+  for (int i = 0; i < K; i++) {
+    const __fp16 *a_off = src + i * ldb;
+    __fp16 *b_off = (__fp16 *)dst + i * 8;
+    for (int j = 0; j < N; j += 8) {
+      float16x8_t v = vld1q_f16(a_off);
+      a_off += 8;
+
+      vst1q_f16(b_off, v);
+      b_off += 8 * K;
     }
   }
 }
index 1996a59e659c891fd383cbd6d02e7e68ba85ed37..4c71d0a8be77c2900d26eee2d3e201f1ae4e7583 100644 (file)
 
 /**
  * @brief aligned dynamic allocation function
- * 
- * @param size amount of data to allocate
+ *
+ * @param sz amount of data to allocate
  * @return __fp16* addr of allocated memory
  */
-static inline __fp16 *alignedMalloc(int size) {
-  void *ptr = 0;
-  int iRet = posix_memalign(&ptr, 64, size * sizeof(__fp16));
+static inline __fp16 *alignedMalloc(int sz) {
+  void *addr = 0;
+  int iRet = posix_memalign(&addr, 64, sz * sizeof(__fp16));
   assert(0 == iRet);
-  return (__fp16 *)ptr;
+  return (__fp16 *)addr;
 }
diff --git a/nntrainer/tensor/hgemm/meson.build b/nntrainer/tensor/hgemm/meson.build
new file mode 100644 (file)
index 0000000..cf9efc3
--- /dev/null
@@ -0,0 +1,21 @@
+hgemm_headers = [
+  'hgemm.h',
+  'hgemm_util.h',
+  'hgemm_kernel_pack.h',
+  'hgemm_kernel_4x4.h',
+  'hgemm_kernel_4x8.h',
+  'hgemm_kernel_8x8.h',
+]
+
+hgemm_sources = [
+    'hgemm.cpp'
+]
+
+foreach s : hgemm_sources
+  nntrainer_sources += meson.current_source_dir() / s
+endforeach
+
+foreach h : hgemm_headers
+  nntrainer_headers += meson.current_source_dir() / h
+endforeach
+