--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file hgemm.cpp
+ * @date 03 April 2024
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug No known bugs except for NYI items
+ * @brief This is half-precision GEMM interface
+ *
+ */
+
+#include <hgemm.h>
+#include <hgemm_kernel_4x4.h>
+#include <hgemm_kernel_4x8.h>
+#include <hgemm_kernel_8x8.h>
+#include <hgemm_kernel_pack.h>
+#include <hgemm_util.h>
+
+#define HGEMM_KERNEL_4x4 hgemm_kernel_4x4
+#define HGEMM_KERNEL_4x8 hgemm_kernel_4x8
+#define HGEMM_KERNEL_8x8 hgemm_kernel_8x8
+
+void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M,
+ unsigned int N, unsigned int K, float alpha, float beta) {
+ if (alpha == 1.F && beta == 0.F) {
+ if (M % 8 == 0 && N % 8 == 0 && K % 8 == 0) {
+ hgemm_noTrans_8x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
+ } else if (M % 4 == 0 && N % 8 == 0 && K % 4 == 0) {
+ hgemm_noTrans_4x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
+ } else {
+ hgemm_noTrans_fallback(M, N, K, A, K, B, N, C32, N, alpha, beta);
+ }
+ } else
+ hgemm_noTrans_fallback(M, N, K, A, K, B, N, C32, N, alpha, beta);
+}
+
+void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, __fp16 *C, unsigned int ldc,
+ float alpha, float beta) {
+ __fp16 *sa = alignedMalloc(M * K);
+ __fp16 *sb = alignedMalloc(K * N);
+
+ unsigned int ms, mms, ns, ks;
+ unsigned int m_min, m2_min, n_min, k_min;
+ for (ms = 0; ms < M; ms += M_BLOCKING) {
+ m_min = M - ms;
+ if (m_min > M_BLOCKING) {
+ m_min = M_BLOCKING;
+ }
+
+ for (ks = 0; ks < K; ks += k_min) {
+ k_min = K - ks;
+ if (k_min >= (K_BLOCKING << 1)) {
+ k_min = K_BLOCKING;
+ } else if (k_min > K_BLOCKING) {
+ k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+ }
+
+ n_min = N;
+ if (N >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (N > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+ }
+ packing_B4(k_min, n_min, B + ks * ldb, ldb, sb);
+
+ for (mms = ms; mms < ms + m_min; mms += m2_min) {
+ m2_min = (ms + m_min) - mms;
+ if (m2_min >= 3 * GEMM_UNROLLING_4) {
+ m2_min = 3 * GEMM_UNROLLING_4;
+ } else if (m2_min >= 2 * GEMM_UNROLLING_4) {
+ m2_min = 2 * GEMM_UNROLLING_4;
+ } else if (m2_min > GEMM_UNROLLING_4) {
+ m2_min = GEMM_UNROLLING_4;
+ }
+
+ packing_A4(m2_min, k_min, A + mms * lda + ks, lda,
+ sa + k_min * (mms - ms));
+
+ HGEMM_KERNEL_4x4(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
+ C + mms * ldc, ldc);
+ }
+
+ for (ns = n_min; ns < N; ns += n_min) {
+ n_min = N - ns;
+ if (n_min >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (n_min > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+ }
+
+ packing_B4(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+ HGEMM_KERNEL_4x4(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+ }
+ }
+ }
+
+ free(sa);
+ free(sb);
+}
+
+void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, __fp16 *C, unsigned int ldc,
+ float alpha, float beta) {
+ __fp16 *sa = alignedMalloc(M * K);
+ __fp16 *sb = alignedMalloc(K * N);
+
+ unsigned int ms, mms, ns, ks;
+ unsigned int m_min, m2_min, n_min, k_min;
+ unsigned int l1stride = 1;
+ for (ms = 0; ms < M; ms += M_BLOCKING) {
+ m_min = M - ms;
+ if (m_min > M_BLOCKING) {
+ m_min = M_BLOCKING;
+ }
+
+ for (ks = 0; ks < K; ks += k_min) {
+ k_min = K - ks;
+ if (k_min >= (K_BLOCKING << 1)) {
+ k_min = K_BLOCKING;
+ } else if (k_min > K_BLOCKING) {
+ k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+ }
+
+ n_min = N;
+ if (N >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (N > N_BLOCKING) {
+ n_min = ((n_min / 2 + GEMM_UNROLLING_8 - 1) / GEMM_UNROLLING_8) *
+ GEMM_UNROLLING_8;
+ } else {
+ l1stride = 0;
+ }
+ packing_B8(k_min, n_min, B + ks * ldb, ldb, sb);
+
+ for (mms = ms; mms < ms + m_min; mms += m2_min) {
+ m2_min = (ms + m_min) - mms;
+ if (m2_min >= 3 * GEMM_UNROLLING_4) {
+ m2_min = 3 * GEMM_UNROLLING_4;
+ } else if (m2_min >= 2 * GEMM_UNROLLING_4) {
+ m2_min = 2 * GEMM_UNROLLING_4;
+ } else if (m2_min > GEMM_UNROLLING_4) {
+ m2_min = GEMM_UNROLLING_4;
+ }
+
+ packing_A4(m2_min, k_min, A + mms * lda + ks, lda,
+ sa + k_min * (mms - ms) * l1stride);
+
+ HGEMM_KERNEL_4x8(m2_min, n_min, k_min,
+ sa + l1stride * k_min * (mms - ms), sb, C + mms * ldc,
+ ldc);
+ }
+
+ for (ns = n_min; ns < N; ns += n_min) {
+ n_min = N - ns;
+ if (n_min >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (n_min > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+ }
+
+ packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+ HGEMM_KERNEL_4x8(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+ }
+ }
+ }
+
+ free(sa);
+ free(sb);
+}
+
+void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, float *C, unsigned int ldc,
+ float alpha, float beta) {
+ __fp16 *sa = alignedMalloc(M * K);
+ __fp16 *sb = alignedMalloc(K * N);
+
+ unsigned int ms, mms, ns, ks;
+ unsigned int m_min, m2_min, n_min, k_min;
+ unsigned int l1stride = 1;
+ for (ms = 0; ms < M; ms += M_BLOCKING) {
+ m_min = M - ms;
+ if (m_min > M_BLOCKING) {
+ m_min = M_BLOCKING;
+ }
+
+ for (ks = 0; ks < K; ks += k_min) {
+ k_min = K - ks;
+ if (k_min >= (K_BLOCKING << 1)) {
+ k_min = K_BLOCKING;
+ } else if (k_min > K_BLOCKING) {
+ k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+ }
+
+ n_min = N;
+ if (N >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (N > N_BLOCKING) {
+ n_min = ((n_min / 2 + GEMM_UNROLLING_8 - 1) / GEMM_UNROLLING_8) *
+ GEMM_UNROLLING_8;
+ } else {
+ l1stride = 0;
+ }
+ packing_B8(k_min, n_min, B + ks * ldb, ldb, sb);
+
+ for (mms = ms; mms < ms + m_min; mms += m2_min) {
+ m2_min = (ms + m_min) - mms;
+ if (m2_min >= 3 * GEMM_UNROLLING_4) {
+ m2_min = 3 * GEMM_UNROLLING_4;
+ } else if (m2_min >= 2 * GEMM_UNROLLING_4) {
+ m2_min = 2 * GEMM_UNROLLING_4;
+ } else if (m2_min > GEMM_UNROLLING_4) {
+ m2_min = GEMM_UNROLLING_4;
+ }
+
+ packing_A4(m2_min, k_min, A + mms * lda + ks, lda,
+ sa + k_min * (mms - ms) * l1stride);
+
+ HGEMM_KERNEL_4x8(m2_min, n_min, k_min,
+ sa + l1stride * k_min * (mms - ms), sb, C + mms * ldc,
+ ldc);
+ }
+
+ for (ns = n_min; ns < N; ns += n_min) {
+ n_min = N - ns;
+ if (n_min >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (n_min > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+ }
+
+ packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+ HGEMM_KERNEL_4x8(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+ }
+ }
+ }
+
+ free(sa);
+ free(sb);
+}
+
+void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, __fp16 *C, unsigned int ldc,
+ float alpha, float beta) {
+
+ __fp16 *sa = alignedMalloc(M * K);
+ __fp16 *sb = alignedMalloc(K * N);
+
+ unsigned int ms, mms, ns, ks;
+ unsigned int m_min, m2_min, n_min, k_min;
+ for (ms = 0; ms < M; ms += M_BLOCKING) {
+ m_min = M - ms;
+ if (m_min > M_BLOCKING) {
+ m_min = M_BLOCKING;
+ }
+
+ for (ks = 0; ks < K; ks += k_min) {
+ k_min = K - ks;
+ if (k_min >= (K_BLOCKING << 1)) {
+ k_min = K_BLOCKING;
+ } else if (k_min > K_BLOCKING) {
+ k_min = (k_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+ }
+
+ n_min = N;
+ if (N >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (N > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+ }
+ packing_B8(k_min, n_min, B + ks * ldb, ldb, sb);
+
+ for (mms = ms; mms < ms + m_min; mms += m2_min) {
+ m2_min = (ms + m_min) - mms;
+ if (m2_min >= 3 * GEMM_UNROLLING_8) {
+ m2_min = 3 * GEMM_UNROLLING_8;
+ } else if (m2_min >= 2 * GEMM_UNROLLING_8) {
+ m2_min = 2 * GEMM_UNROLLING_8;
+ } else if (m2_min > GEMM_UNROLLING_8) {
+ m2_min = GEMM_UNROLLING_8;
+ }
+
+ packing_A8(m2_min, k_min, A + mms * lda + ks, lda,
+ sa + k_min * (mms - ms));
+
+ HGEMM_KERNEL_8x8(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
+ C + mms * ldc, ldc);
+ }
+
+ for (ns = n_min; ns < N; ns += n_min) {
+ n_min = N - ns;
+ if (n_min >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (n_min > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+ }
+
+ packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+ HGEMM_KERNEL_8x8(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+ }
+ }
+ }
+
+ free(sa);
+ free(sb);
+}
+
+void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, float *C, unsigned int ldc,
+ float alpha, float beta) {
+
+ __fp16 *sA = alignedMalloc(M * K);
+ __fp16 *sB = alignedMalloc(K * N);
+
+ unsigned int ms, ms2, ns, ks;
+ unsigned int m_min, m2_min, n_min, k_min;
+ for (ms = 0; ms < M; ms += M_BLOCKING) {
+ m_min = M - ms;
+ if (m_min > M_BLOCKING) {
+ m_min = M_BLOCKING;
+ }
+
+ for (ks = 0; ks < K; ks += k_min) {
+ k_min = K - ks;
+ if (k_min >= (K_BLOCKING << 1)) {
+ k_min = K_BLOCKING;
+ } else if (k_min > K_BLOCKING) {
+ k_min = (k_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+ }
+
+ n_min = N;
+ if (N >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (N > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+ }
+ packing_B8(k_min, n_min, B + ks * ldb, ldb, sB);
+
+ for (ms2 = ms; ms2 < ms + m_min; ms2 += m2_min) {
+ m2_min = (ms + m_min) - ms2;
+ if (m2_min >= 3 * GEMM_UNROLLING_8) {
+ m2_min = 3 * GEMM_UNROLLING_8;
+ } else if (m2_min >= 2 * GEMM_UNROLLING_8) {
+ m2_min = 2 * GEMM_UNROLLING_8;
+ } else if (m2_min > GEMM_UNROLLING_8) {
+ m2_min = GEMM_UNROLLING_8;
+ }
+
+ packing_A8(m2_min, k_min, A + ms2 * lda + ks, lda,
+ sA + k_min * (ms2 - ms));
+
+ HGEMM_KERNEL_8x8(m2_min, n_min, k_min, sA + k_min * (ms2 - ms), sB,
+ C + ms2 * ldc, ldc);
+ }
+
+ for (ns = n_min; ns < N; ns += n_min) {
+ n_min = N - ns;
+ if (n_min >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (n_min > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+ }
+
+ packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sB);
+ HGEMM_KERNEL_8x8(m_min, n_min, k_min, sA, sB, C + ms * ldc + ns, ldc);
+ }
+ }
+ }
+
+ free(sA);
+ free(sB);
+}
+
+void hgemm_noTrans_fallback(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, float *C, unsigned int ldc,
+ float alpha, float beta) {
+
+ unsigned int k = 0;
+ unsigned int N8 = (N >> 3) << 3;
+ __fp16 a[16];
+ for (; (K - k) >= 16; k += 16) {
+ for (unsigned int m = 0; m < M; m++) {
+ vst1q_f16(&a[0], vmulq_n_f16(vld1q_f16(&A[m * K + k]), alpha));
+ vst1q_f16(&a[8], vmulq_n_f16(vld1q_f16(&A[m * K + k + 8]), alpha));
+ for (unsigned int n = 0; n < N8; n += 8) {
+ float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 2) * N + n]), a[2]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 3) * N + n]), a[3]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 4) * N + n]), a[4]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 5) * N + n]), a[5]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 6) * N + n]), a[6]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 7) * N + n]), a[7]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 8) * N + n]), a[8]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 9) * N + n]), a[9]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 10) * N + n]), a[10]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 11) * N + n]), a[11]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 12) * N + n]), a[12]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 13) * N + n]), a[13]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 14) * N + n]), a[14]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 15) * N + n]), a[15]);
+
+ float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]),
+ vcvt_f32_f16(vget_low_f16(b0_7_0)));
+ float32x4_t c0_7_high_32 = vaddq_f32(
+ vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0)));
+
+ vst1q_f32(&C[m * N + n], c0_7_low_32);
+ vst1q_f32(&C[m * N + n + 4], c0_7_high_32);
+ }
+ if (N != N8) {
+ unsigned int n = N8;
+ __fp16 valsB_0[8];
+ __fp16 valsB_1[8];
+ __fp16 valsB_2[8];
+ __fp16 valsB_3[8];
+ __fp16 valsB_4[8];
+ __fp16 valsB_5[8];
+ __fp16 valsB_6[8];
+ __fp16 valsB_7[8];
+ __fp16 valsB_8[8];
+ __fp16 valsB_9[8];
+ __fp16 valsB_10[8];
+ __fp16 valsB_11[8];
+ __fp16 valsB_12[8];
+ __fp16 valsB_13[8];
+ __fp16 valsB_14[8];
+ __fp16 valsB_15[8];
+ float valsC[8];
+ for (unsigned int idx = n; idx < N; idx++) {
+ valsB_0[idx - n] = B[k * N + idx];
+ valsB_1[idx - n] = B[(k + 1) * N + idx];
+ valsB_2[idx - n] = B[(k + 2) * N + idx];
+ valsB_3[idx - n] = B[(k + 3) * N + idx];
+ valsB_4[idx - n] = B[(k + 4) * N + idx];
+ valsB_5[idx - n] = B[(k + 5) * N + idx];
+ valsB_6[idx - n] = B[(k + 6) * N + idx];
+ valsB_7[idx - n] = B[(k + 7) * N + idx];
+ valsB_8[idx - n] = B[(k + 8) * N + idx];
+ valsB_9[idx - n] = B[(k + 9) * N + idx];
+ valsB_10[idx - n] = B[(k + 10) * N + idx];
+ valsB_11[idx - n] = B[(k + 11) * N + idx];
+ valsB_12[idx - n] = B[(k + 12) * N + idx];
+ valsB_13[idx - n] = B[(k + 13) * N + idx];
+ valsB_14[idx - n] = B[(k + 14) * N + idx];
+ valsB_15[idx - n] = B[(k + 15) * N + idx];
+ valsC[idx - n] = C[m * N + idx];
+ }
+
+ float16x8_t b = vmulq_n_f16(vld1q_f16(valsB_0), a[0]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_1), a[1]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_2), a[2]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_3), a[3]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_4), a[4]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_5), a[5]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_6), a[6]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_7), a[7]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_8), a[8]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_9), a[9]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_10), a[10]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_11), a[11]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_12), a[12]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_13), a[13]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_14), a[14]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_15), a[15]);
+
+ float32x4_t c0_7_low_32 =
+ vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b)));
+
+ float32x4_t c0_7_high_32 =
+ vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b)));
+
+ vst1q_f32(valsC, c0_7_low_32);
+ vst1q_f32(valsC + 4, c0_7_high_32);
+
+ for (unsigned int idx = n; idx < N; idx++) {
+ C[m * N + idx] = valsC[idx - n];
+ }
+ }
+ }
+ }
+
+ for (; (K - k) >= 8; k += 8) {
+ for (unsigned int m = 0; m < M; m++) {
+ vst1q_f16(a, vmulq_n_f16(vld1q_f16(&A[m * K + k]), alpha));
+
+ for (unsigned int n = 0; n < N8; n += 8) {
+ float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 2) * N + n]), a[2]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 3) * N + n]), a[3]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 4) * N + n]), a[4]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 5) * N + n]), a[5]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 6) * N + n]), a[6]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 7) * N + n]), a[7]);
+
+ float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]),
+ vcvt_f32_f16(vget_low_f16(b0_7_0)));
+ float32x4_t c0_7_high_32 = vaddq_f32(
+ vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0)));
+
+ vst1q_f32(&C[m * N + n], c0_7_low_32);
+ vst1q_f32(&C[m * N + n + 4], c0_7_high_32);
+ }
+ if (N != N8) {
+ unsigned int n = N8;
+ __fp16 valsB_0[8];
+ __fp16 valsB_1[8];
+ __fp16 valsB_2[8];
+ __fp16 valsB_3[8];
+ __fp16 valsB_4[8];
+ __fp16 valsB_5[8];
+ __fp16 valsB_6[8];
+ __fp16 valsB_7[8];
+ float valsC[8];
+ for (unsigned int idx = n; idx < N; idx++) {
+ valsB_0[idx - n] = B[k * N + idx];
+ valsB_1[idx - n] = B[(k + 1) * N + idx];
+ valsB_2[idx - n] = B[(k + 2) * N + idx];
+ valsB_3[idx - n] = B[(k + 3) * N + idx];
+ valsB_4[idx - n] = B[(k + 4) * N + idx];
+ valsB_5[idx - n] = B[(k + 5) * N + idx];
+ valsB_6[idx - n] = B[(k + 6) * N + idx];
+ valsB_7[idx - n] = B[(k + 7) * N + idx];
+ valsC[idx - n] = C[m * N + idx];
+ }
+
+ float16x8_t b = vmulq_n_f16(vld1q_f16(valsB_0), a[0]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_1), a[1]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_2), a[2]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_3), a[3]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_4), a[4]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_5), a[5]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_6), a[6]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_7), a[7]);
+
+ float32x4_t c0_7_low_32 =
+ vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b)));
+
+ float32x4_t c0_7_high_32 =
+ vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b)));
+
+ vst1q_f32(valsC, c0_7_low_32);
+ vst1q_f32(valsC + 4, c0_7_high_32);
+
+ for (unsigned int idx = n; idx < N; idx++) {
+ C[m * N + idx] = valsC[idx - n];
+ }
+ }
+ }
+ }
+
+ for (; (K - k) >= 4; k += 4) {
+ for (unsigned int m = 0; m < M; m++) {
+ vst1_f16(a, vmul_n_f16(vld1_f16(&A[m * K + k]), alpha));
+
+ for (unsigned int n = 0; n < N8; n += 8) {
+
+ float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]);
+ float16x8_t b0_7_2 = vmulq_n_f16(vld1q_f16(&B[(k + 2) * N + n]), a[2]);
+ b0_7_2 = vfmaq_n_f16(b0_7_2, vld1q_f16(&B[(k + 3) * N + n]), a[3]);
+
+ float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]),
+ vcvt_f32_f16(vget_low_f16(b0_7_0)));
+ float32x4_t c0_7_high_32 = vaddq_f32(
+ vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0)));
+
+ c0_7_low_32 =
+ vaddq_f32(c0_7_low_32, vcvt_f32_f16(vget_low_f16(b0_7_2)));
+ c0_7_high_32 =
+ vaddq_f32(c0_7_high_32, vcvt_f32_f16(vget_high_f16(b0_7_2)));
+
+ vst1q_f32(&C[m * N + n], c0_7_low_32);
+ vst1q_f32(&C[m * N + n + 4], c0_7_high_32);
+ }
+ if (N != N8) {
+ unsigned int n = N8;
+ __fp16 valsB_0[8];
+ __fp16 valsB_1[8];
+ __fp16 valsB_2[8];
+ __fp16 valsB_3[8];
+ float valsC[8];
+ for (unsigned int idx = n; idx < N; idx++) {
+ valsB_0[idx - n] = B[k * N + idx];
+ valsB_1[idx - n] = B[(k + 1) * N + idx];
+ valsB_2[idx - n] = B[(k + 2) * N + idx];
+ valsB_3[idx - n] = B[(k + 3) * N + idx];
+ valsC[idx - n] = C[m * N + idx];
+ }
+
+ float16x8_t b = vmulq_n_f16(vld1q_f16(valsB_0), a[0]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_1), a[1]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_2), a[2]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_3), a[3]);
+
+ float32x4_t c0_7_low_32 =
+ vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b)));
+
+ float32x4_t c0_7_high_32 =
+ vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b)));
+
+ vst1q_f32(valsC, c0_7_low_32);
+ vst1q_f32(valsC + 4, c0_7_high_32);
+
+ for (unsigned int idx = n; idx < N; idx++) {
+ C[m * N + idx] = valsC[idx - n];
+ }
+ }
+ }
+ }
+
+ for (; k < K; k++) {
+ for (unsigned int m = 0; m < M; m++) {
+ __fp16 a0 = alpha * A[m * K + k];
+
+ for (unsigned int n = 0; n < N8; n += 8) {
+ float16x8_t b0_7 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a0);
+
+ float32x4_t c0_7_low_32 =
+ vaddq_f32(vld1q_f32(&C[m * N + n]), vcvt_f32_f16(vget_low_f16(b0_7)));
+
+ float32x4_t c0_7_high_32 = vaddq_f32(vld1q_f32(&C[m * N + n + 4]),
+ vcvt_f32_f16(vget_high_f16(b0_7)));
+
+ vst1q_f32(&C[m * N + n], c0_7_low_32);
+ vst1q_f32(&C[m * N + n + 4], c0_7_high_32);
+ }
+ if (N != N8) {
+ unsigned int n = N8;
+ __fp16 valsB[8];
+ float valsC[8];
+ for (unsigned int idx = n; idx < N; idx++) {
+ valsB[idx - n] = B[k * N + idx];
+ valsC[idx - n] = C[m * N + idx];
+ }
+
+ float16x8_t b = vmulq_n_f16(vld1q_f16(valsB), a0);
+
+ float32x4_t c0_7_low_32 =
+ vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b)));
+
+ float32x4_t c0_7_high_32 =
+ vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b)));
+
+ vst1q_f32(valsC, c0_7_low_32);
+ vst1q_f32(valsC + 4, c0_7_high_32);
+
+ for (unsigned int idx = n; idx < N; idx++) {
+ C[m * N + idx] = valsC[idx - n];
+ }
+ }
+ }
+ }
+}
*
*/
-#include <hgemm_kernel_4x4.h>
-#include <hgemm_kernel_8x8.h>
-#include <hgemm_kernel_pack.h>
-#include <hgemm_util.h>
+/**
+ * @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * @param[in] A __fp16 * for Matrix A
+ * @param[in] B __fp16 * for Matrix B
+ * @param[in] C __fp16 * for Matrix C
+ * @param[in] M number of op(A)'s and C's row
+ * @param[in] N number of op(B)'s and C's columns
+ * @param[in] K number of op(A)'s and columns and op(B)'s rows
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
+ unsigned int N, unsigned int K, float alpha = 1.F,
+ float beta = 0.F);
-#define KERNEL_4x4 hgemm_kernel_4x4
-#define KERNEL_8x8 hgemm_kernel_8x8
+/**
+ * @brief hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix C
+ * @param B input matrix B
+ * @param ldb length of the col of matrix C
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_fallback(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, float *C, unsigned int ldc,
+ float alpha = 1.F, float beta = 0.F);
/**
* @brief hgemm noTrans computation with 4x4 kernel : C = A*B,
- *
- * @param m length of the row of matrix A
- * @param n length of the col of matrix B
- * @param k length of the col of matrix A
- * @param a input matrix A
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix C
+ * @param B input matrix B
+ * @param ldb length of the col of matrix C
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, __fp16 *C, unsigned int ldc,
+ float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 8x8 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
* @param lda length of the col of matrix C
- * @param b input matrix B
+ * @param B input matrix B
* @param ldb length of the col of matrix C
- * @param c output matrix C
+ * @param C output matrix C
* @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
*/
-void hgemm_noTrans_4x4(unsigned int m, unsigned int n, unsigned int k,
- const __fp16 *a, unsigned int lda, const __fp16 *b,
- unsigned int ldb, __fp16 *c, unsigned int ldc);
+void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, __fp16 *C, unsigned int ldc,
+ float alpha = 1.F, float beta = 0.F);
/**
* @brief hgemm noTrans computation with 8x8 kernel : C = A*B,
- *
- * @param m length of the row of matrix A
- * @param n length of the col of matrix B
- * @param k length of the col of matrix A
- * @param a input matrix A
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix C
+ * @param B input matrix B
+ * @param ldb length of the col of matrix C
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, float *C, unsigned int ldc,
+ float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 4x8 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix C
+ * @param B input matrix B
+ * @param ldb length of the col of matrix C
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, __fp16 *C, unsigned int ldc,
+ float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 4x8 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
* @param lda length of the col of matrix C
- * @param b input matrix B
+ * @param B input matrix B
* @param ldb length of the col of matrix C
- * @param c output matrix C
+ * @param C output matrix C
* @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
*/
-void hgemm_noTrans_8x8(unsigned int m, unsigned int n, unsigned int k,
- const __fp16 *a, unsigned int lda, const __fp16 *b,
- unsigned int ldb, __fp16 *c, unsigned int ldc);
+void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, float *C, unsigned int ldc,
+ float alpha = 1.F, float beta = 0.F);
/**
* @brief packing function of input matrix A
- *
- * @param m length of the row of the matrix
- * @param k length of the col of the matrix
- * @param from input of original source of the matrix
+ *
+ * @param M length of the row of the matrix
+ * @param K length of the col of the matrix
+ * @param src input of original source of the matrix
* @param lda leading dimension of the matrix
- * @param to output of packed data of the matrix
+ * @param dst output of packed data of the matrix
*/
-void packA_4(unsigned int m, unsigned int k, const __fp16 *from,
- unsigned int lda, const __fp16 *to) {
+void packing_A4(unsigned int M, unsigned int K, const __fp16 *src,
+ unsigned int lda, const __fp16 *dst) {
- assert(k != 0 && m != 0 && k % 4 == 0 && m % 4 == 0);
+ assert(K != 0 && M != 0 && K % 4 == 0 && M % 4 == 0);
unsigned int i, j;
- __fp16 *a_offset, *a_offset1, *a_offset2, *a_offset3, *a_offset4;
- __fp16 *b_offset;
- __fp16 ctemp1, ctemp2, ctemp3, ctemp4;
- __fp16 ctemp5, ctemp6, ctemp7, ctemp8;
- __fp16 ctemp9, ctemp10, ctemp11, ctemp12;
- __fp16 ctemp13, ctemp14, ctemp15, ctemp16;
+ __fp16 *a_off, *a_off1, *a_off2, *a_off3, *a_off4;
+ __fp16 *b_off;
+ __fp16 c1, c2, c3, c4;
+ __fp16 c5, c6, c7, c8;
+ __fp16 c9, c10, c11, c12;
+ __fp16 c13, c14, c15, c16;
- a_offset = (__fp16 *)from;
- b_offset = (__fp16 *)to;
+ a_off = (__fp16 *)src;
+ b_off = (__fp16 *)dst;
- j = (m >> 2);
+ j = (M >> 2);
do {
- a_offset1 = a_offset;
- a_offset2 = a_offset1 + lda;
- a_offset3 = a_offset2 + lda;
- a_offset4 = a_offset3 + lda;
- a_offset += 4 * lda;
+ a_off1 = a_off;
+ a_off2 = a_off1 + lda;
+ a_off3 = a_off2 + lda;
+ a_off4 = a_off3 + lda;
+ a_off += 4 * lda;
- i = (k >> 2);
+ i = (K >> 2);
do {
- ctemp1 = *(a_offset1 + 0);
- ctemp2 = *(a_offset1 + 1);
- ctemp3 = *(a_offset1 + 2);
- ctemp4 = *(a_offset1 + 3);
-
- ctemp5 = *(a_offset2 + 0);
- ctemp6 = *(a_offset2 + 1);
- ctemp7 = *(a_offset2 + 2);
- ctemp8 = *(a_offset2 + 3);
-
- ctemp9 = *(a_offset3 + 0);
- ctemp10 = *(a_offset3 + 1);
- ctemp11 = *(a_offset3 + 2);
- ctemp12 = *(a_offset3 + 3);
-
- ctemp13 = *(a_offset4 + 0);
- ctemp14 = *(a_offset4 + 1);
- ctemp15 = *(a_offset4 + 2);
- ctemp16 = *(a_offset4 + 3);
-
- *(b_offset + 0) = ctemp1;
- *(b_offset + 1) = ctemp5;
- *(b_offset + 2) = ctemp9;
- *(b_offset + 3) = ctemp13;
-
- *(b_offset + 4) = ctemp2;
- *(b_offset + 5) = ctemp6;
- *(b_offset + 6) = ctemp10;
- *(b_offset + 7) = ctemp14;
-
- *(b_offset + 8) = ctemp3;
- *(b_offset + 9) = ctemp7;
- *(b_offset + 10) = ctemp11;
- *(b_offset + 11) = ctemp15;
-
- *(b_offset + 12) = ctemp4;
- *(b_offset + 13) = ctemp8;
- *(b_offset + 14) = ctemp12;
- *(b_offset + 15) = ctemp16;
-
- a_offset1 += 4;
- a_offset2 += 4;
- a_offset3 += 4;
- a_offset4 += 4;
-
- b_offset += 16;
+ c1 = *(a_off1 + 0);
+ c2 = *(a_off1 + 1);
+ c3 = *(a_off1 + 2);
+ c4 = *(a_off1 + 3);
+
+ c5 = *(a_off2 + 0);
+ c6 = *(a_off2 + 1);
+ c7 = *(a_off2 + 2);
+ c8 = *(a_off2 + 3);
+
+ c9 = *(a_off3 + 0);
+ c10 = *(a_off3 + 1);
+ c11 = *(a_off3 + 2);
+ c12 = *(a_off3 + 3);
+
+ c13 = *(a_off4 + 0);
+ c14 = *(a_off4 + 1);
+ c15 = *(a_off4 + 2);
+ c16 = *(a_off4 + 3);
+
+ *(b_off + 0) = c1;
+ *(b_off + 1) = c5;
+ *(b_off + 2) = c9;
+ *(b_off + 3) = c13;
+
+ *(b_off + 4) = c2;
+ *(b_off + 5) = c6;
+ *(b_off + 6) = c10;
+ *(b_off + 7) = c14;
+
+ *(b_off + 8) = c3;
+ *(b_off + 9) = c7;
+ *(b_off + 10) = c11;
+ *(b_off + 11) = c15;
+
+ *(b_off + 12) = c4;
+ *(b_off + 13) = c8;
+ *(b_off + 14) = c12;
+ *(b_off + 15) = c16;
+
+ a_off1 += 4;
+ a_off2 += 4;
+ a_off3 += 4;
+ a_off4 += 4;
+
+ b_off += 16;
i--;
} while (i > 0);
j--;
/**
* @brief packing function of input matrix A
- *
- * @param m length of the row of the matrix
- * @param k length of the col of the matrix
- * @param from input of original source of the matrix
+ *
+ * @param M length of the row of the matrix
+ * @param K length of the col of the matrix
+ * @param src input of original source of the matrix
* @param lda leading dimension of the matrix
- * @param to output of packed data of the matrix
+ * @param dst output of packed data of the matrix
*/
-void packA_8(unsigned int m, unsigned int k, const __fp16 *from,
- unsigned int lda, const __fp16 *to) {
+void packing_A8(unsigned int M, unsigned int K, const __fp16 *src,
+ unsigned int lda, const __fp16 *dst) {
- assert(k != 0 && m != 0 && k % 8 == 0 && m % 8 == 0);
+ assert(K != 0 && M != 0 && K % 8 == 0 && M % 8 == 0);
uint16x4_t msk = {0xFFFF, 0xFFFF, 0x0000, 0x0000};
uint16x4_t inv_msk = {0x0000, 0x0000, 0xFFFF, 0xFFFF};
- const __fp16 *a_offset = (__fp16 *)from;
- __fp16 *b_offset = (__fp16 *)to;
-
- for (unsigned int i = 0; i < m; i += 8) {
- const __fp16 *a_offset1 = a_offset;
- const __fp16 *a_offset2 = a_offset1 + lda;
- const __fp16 *a_offset3 = a_offset2 + lda;
- const __fp16 *a_offset4 = a_offset3 + lda;
- const __fp16 *a_offset5 = a_offset4 + lda;
- const __fp16 *a_offset6 = a_offset5 + lda;
- const __fp16 *a_offset7 = a_offset6 + lda;
- const __fp16 *a_offset8 = a_offset7 + lda;
- a_offset += 8 * lda;
-
- for (unsigned int j = 0; j < k; j += 8) {
- float16x8_t _v0 = vld1q_f16(a_offset1);
- float16x8_t _v1 = vld1q_f16(a_offset2);
- float16x8_t _v2 = vld1q_f16(a_offset3);
- float16x8_t _v3 = vld1q_f16(a_offset4);
-
- float16x8_t _v4 = vld1q_f16(a_offset5);
- float16x8_t _v5 = vld1q_f16(a_offset6);
- float16x8_t _v6 = vld1q_f16(a_offset7);
- float16x8_t _v7 = vld1q_f16(a_offset8);
-
- a_offset1 += 8;
- a_offset2 += 8;
- a_offset3 += 8;
- a_offset4 += 8;
- a_offset5 += 8;
- a_offset6 += 8;
- a_offset7 += 8;
- a_offset8 += 8;
+ const __fp16 *a_off = (__fp16 *)src;
+ __fp16 *b_off = (__fp16 *)dst;
+
+ for (unsigned int i = 0; i < M; i += 8) {
+ const __fp16 *a_off1 = a_off;
+ const __fp16 *a_off2 = a_off1 + lda;
+ const __fp16 *a_off3 = a_off2 + lda;
+ const __fp16 *a_off4 = a_off3 + lda;
+ const __fp16 *a_off5 = a_off4 + lda;
+ const __fp16 *a_off6 = a_off5 + lda;
+ const __fp16 *a_off7 = a_off6 + lda;
+ const __fp16 *a_off8 = a_off7 + lda;
+ a_off += 8 * lda;
+
+ for (unsigned int j = 0; j < K; j += 8) {
+ float16x8_t _v0 = vld1q_f16(a_off1);
+ float16x8_t _v1 = vld1q_f16(a_off2);
+ float16x8_t _v2 = vld1q_f16(a_off3);
+ float16x8_t _v3 = vld1q_f16(a_off4);
+
+ float16x8_t _v4 = vld1q_f16(a_off5);
+ float16x8_t _v5 = vld1q_f16(a_off6);
+ float16x8_t _v6 = vld1q_f16(a_off7);
+ float16x8_t _v7 = vld1q_f16(a_off8);
+
+ a_off1 += 8;
+ a_off2 += 8;
+ a_off3 += 8;
+ a_off4 += 8;
+ a_off5 += 8;
+ a_off6 += 8;
+ a_off7 += 8;
+ a_off8 += 8;
float16x8x2_t _vv0 = vtrnq_f16(_v0, _v1);
float16x8x2_t _vv1 = vtrnq_f16(_v2, _v3);
_v15 = vcombine_f16(vbsl_f16(inv_msk, tmp_high_v11, mid_v11),
vbsl_f16(inv_msk, tmp_high_v15, mid_v15));
- vst1q_f16(b_offset + 0, _v8);
- vst1q_f16(b_offset + 8, _v12);
- vst1q_f16(b_offset + 16, _v9);
- vst1q_f16(b_offset + 24, _v13);
- vst1q_f16(b_offset + 32, _v10);
- vst1q_f16(b_offset + 40, _v14);
- vst1q_f16(b_offset + 48, _v11);
- vst1q_f16(b_offset + 56, _v15);
- b_offset += 64;
+ vst1q_f16(b_off + 0, _v8);
+ vst1q_f16(b_off + 8, _v12);
+ vst1q_f16(b_off + 16, _v9);
+ vst1q_f16(b_off + 24, _v13);
+ vst1q_f16(b_off + 32, _v10);
+ vst1q_f16(b_off + 40, _v14);
+ vst1q_f16(b_off + 48, _v11);
+ vst1q_f16(b_off + 56, _v15);
+ b_off += 64;
}
}
}
/**
* @brief packing function of input matrix B
- *
- * @param m length of the row of the matrix
- * @param k length of the col of the matrix
- * @param from input of original source of the matrix
+ *
+ * @param M length of the row of the matrix
+ * @param K length of the col of the matrix
+ * @param src input of original source of the matrix
* @param ldb leading dimension of the matrix
- * @param to output of packed data of the matrix
+ * @param dst output of packed data of the matrix
*/
-void packB_4(unsigned int k, unsigned int n, const __fp16 *from,
- unsigned int ldb, const __fp16 *to) {
- assert(k != 0 && n != 0 && k % 4 == 0 && n % 4 == 0);
+void packing_B4(unsigned int K, unsigned int N, const __fp16 *src,
+ unsigned int ldb, const __fp16 *dst) {
+ assert(K != 0 && N != 0 && K % 4 == 0 && N % 4 == 0);
unsigned int i, j;
- __fp16 *a_offset, *a_offset1, *a_offset2, *a_offset3, *a_offset4;
- __fp16 *b_offset, *b_offset1;
- __fp16 ctemp1, ctemp2, ctemp3, ctemp4;
- __fp16 ctemp5, ctemp6, ctemp7, ctemp8;
- __fp16 ctemp9, ctemp10, ctemp11, ctemp12;
- __fp16 ctemp13, ctemp14, ctemp15, ctemp16;
- a_offset = (__fp16 *)from;
- b_offset = (__fp16 *)to;
+ __fp16 *a_off, *a_off1, *a_off2, *a_off3, *a_off4;
+ __fp16 *b_off, *b_off1;
+ __fp16 c1, c2, c3, c4;
+ __fp16 c5, c6, c7, c8;
+ __fp16 c9, c10, c11, c12;
+ __fp16 c13, c14, c15, c16;
+ a_off = (__fp16 *)src;
+ b_off = (__fp16 *)dst;
- j = (k >> 2);
+ j = (K >> 2);
do {
- a_offset1 = a_offset;
- a_offset2 = a_offset1 + ldb;
- a_offset3 = a_offset2 + ldb;
- a_offset4 = a_offset3 + ldb;
- a_offset += 4 * ldb;
+ a_off1 = a_off;
+ a_off2 = a_off1 + ldb;
+ a_off3 = a_off2 + ldb;
+ a_off4 = a_off3 + ldb;
+ a_off += 4 * ldb;
- b_offset1 = b_offset;
- b_offset += 16;
+ b_off1 = b_off;
+ b_off += 16;
- i = (n >> 2);
+ i = (N >> 2);
do {
- ctemp1 = *(a_offset1 + 0);
- ctemp2 = *(a_offset1 + 1);
- ctemp3 = *(a_offset1 + 2);
- ctemp4 = *(a_offset1 + 3);
-
- ctemp5 = *(a_offset2 + 0);
- ctemp6 = *(a_offset2 + 1);
- ctemp7 = *(a_offset2 + 2);
- ctemp8 = *(a_offset2 + 3);
-
- ctemp9 = *(a_offset3 + 0);
- ctemp10 = *(a_offset3 + 1);
- ctemp11 = *(a_offset3 + 2);
- ctemp12 = *(a_offset3 + 3);
-
- ctemp13 = *(a_offset4 + 0);
- ctemp14 = *(a_offset4 + 1);
- ctemp15 = *(a_offset4 + 2);
- ctemp16 = *(a_offset4 + 3);
-
- a_offset1 += 4;
- a_offset2 += 4;
- a_offset3 += 4;
- a_offset4 += 4;
-
- *(b_offset1 + 0) = ctemp1;
- *(b_offset1 + 1) = ctemp2;
- *(b_offset1 + 2) = ctemp3;
- *(b_offset1 + 3) = ctemp4;
-
- *(b_offset1 + 4) = ctemp5;
- *(b_offset1 + 5) = ctemp6;
- *(b_offset1 + 6) = ctemp7;
- *(b_offset1 + 7) = ctemp8;
-
- *(b_offset1 + 8) = ctemp9;
- *(b_offset1 + 9) = ctemp10;
- *(b_offset1 + 10) = ctemp11;
- *(b_offset1 + 11) = ctemp12;
-
- *(b_offset1 + 12) = ctemp13;
- *(b_offset1 + 13) = ctemp14;
- *(b_offset1 + 14) = ctemp15;
- *(b_offset1 + 15) = ctemp16;
-
- b_offset1 += k * 4;
+ c1 = *(a_off1 + 0);
+ c2 = *(a_off1 + 1);
+ c3 = *(a_off1 + 2);
+ c4 = *(a_off1 + 3);
+
+ c5 = *(a_off2 + 0);
+ c6 = *(a_off2 + 1);
+ c7 = *(a_off2 + 2);
+ c8 = *(a_off2 + 3);
+
+ c9 = *(a_off3 + 0);
+ c10 = *(a_off3 + 1);
+ c11 = *(a_off3 + 2);
+ c12 = *(a_off3 + 3);
+
+ c13 = *(a_off4 + 0);
+ c14 = *(a_off4 + 1);
+ c15 = *(a_off4 + 2);
+ c16 = *(a_off4 + 3);
+
+ a_off1 += 4;
+ a_off2 += 4;
+ a_off3 += 4;
+ a_off4 += 4;
+
+ *(b_off1 + 0) = c1;
+ *(b_off1 + 1) = c2;
+ *(b_off1 + 2) = c3;
+ *(b_off1 + 3) = c4;
+
+ *(b_off1 + 4) = c5;
+ *(b_off1 + 5) = c6;
+ *(b_off1 + 6) = c7;
+ *(b_off1 + 7) = c8;
+
+ *(b_off1 + 8) = c9;
+ *(b_off1 + 9) = c10;
+ *(b_off1 + 10) = c11;
+ *(b_off1 + 11) = c12;
+
+ *(b_off1 + 12) = c13;
+ *(b_off1 + 13) = c14;
+ *(b_off1 + 14) = c15;
+ *(b_off1 + 15) = c16;
+
+ b_off1 += K * 4;
i--;
} while (i > 0);
j--;
/**
* @brief packing function of input matrix B
- *
- * @param m length of the row of the matrix
- * @param k length of the col of the matrix
- * @param from input of original source of the matrix
+ *
+ * @param M length of the row of the matrix
+ * @param K length of the col of the matrix
+ * @param src input of original source of the matrix
* @param ldb leading dimension of the matrix
- * @param to output of packed data of the matrix
+ * @param dst output of packed data of the matrix
*/
-void packB_8(unsigned int k, unsigned int n, const __fp16 *from,
- unsigned int ldb, const __fp16 *to) {
- assert(k != 0 && n != 0 && n % 8 == 0);
-
- for (int i = 0; i < k; i++) {
- const __fp16 *a_offset1 = from + i * ldb;
- __fp16 *b_offset = (__fp16 *)to + i * 8;
- for (int j = 0; j < n; j += 8) {
- float16x8_t _v0 = vld1q_f16(a_offset1);
- a_offset1 += 8;
-
- vst1q_f16(b_offset, _v0);
- b_offset += 8 * k;
+void packing_B8(unsigned int K, unsigned int N, const __fp16 *src,
+ unsigned int ldb, const __fp16 *dst) {
+ assert(K != 0 && N != 0 && N % 8 == 0);
+
+ for (int i = 0; i < K; i++) {
+ const __fp16 *a_off = src + i * ldb;
+ __fp16 *b_off = (__fp16 *)dst + i * 8;
+ for (int j = 0; j < N; j += 8) {
+ float16x8_t v = vld1q_f16(a_off);
+ a_off += 8;
+
+ vst1q_f16(b_off, v);
+ b_off += 8 * K;
}
}
}