From: skykongkong8 Date: Wed, 3 Apr 2024 04:23:42 +0000 (+0900) Subject: [ hgemm ] Use optimized hgemm if possible X-Git-Tag: accepted/tizen/7.0/unified/20240830.164841~220 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=a936c13b2fcc348d4313f220d8eaa1fe0b2d85a0;p=platform%2Fcore%2Fml%2Fnntrainer.git [ hgemm ] Use optimized hgemm if possible - We can use optimized version of hgemm with following condition: 1. noTrans hgemm 2. M, N, K is divisible with 4 or 8 3. Row Major GEMM 4. alpha = 1.0, beta = 0.0 (will be patched soon) - Otherwise, use previous version as a fallback. - Note that there are a few optimization strategy is left for optimal hgemm. **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: skykongkong8 --- diff --git a/nntrainer/tensor/hgemm/hgemm.cpp b/nntrainer/tensor/hgemm/hgemm.cpp new file mode 100644 index 00000000..97f074ce --- /dev/null +++ b/nntrainer/tensor/hgemm/hgemm.cpp @@ -0,0 +1,663 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Sungsik Kong + * + * @file hgemm.cpp + * @date 03 April 2024 + * @see https://github.com/nnstreamer/nntrainer + * @author Sungsik Kong + * @bug No known bugs except for NYI items + * @brief This is half-precision GEMM interface + * + */ + +#include +#include +#include +#include +#include +#include + +#define HGEMM_KERNEL_4x4 hgemm_kernel_4x4 +#define HGEMM_KERNEL_4x8 hgemm_kernel_4x8 +#define HGEMM_KERNEL_8x8 hgemm_kernel_8x8 + +void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M, + unsigned int N, unsigned int K, float alpha, float beta) { + if (alpha == 1.F && beta == 0.F) { + if (M % 8 == 0 && N % 8 == 0 && K % 8 == 0) { + hgemm_noTrans_8x8(M, N, K, A, K, B, N, C32, N, alpha, beta); + } else if (M % 4 == 0 && N % 8 == 0 && K % 4 == 0) { + hgemm_noTrans_4x8(M, N, K, A, K, B, N, C32, N, alpha, beta); + } else { + hgemm_noTrans_fallback(M, N, K, A, K, B, N, C32, N, alpha, beta); + } + } else + hgemm_noTrans_fallback(M, N, K, A, K, B, N, C32, N, alpha, beta); +} + +void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K, + const __fp16 *A, unsigned int lda, const __fp16 *B, + unsigned int ldb, __fp16 *C, unsigned int ldc, + float alpha, float beta) { + __fp16 *sa = alignedMalloc(M * K); + __fp16 *sb = alignedMalloc(K * N); + + unsigned int ms, mms, ns, ks; + unsigned int m_min, m2_min, n_min, k_min; + for (ms = 0; ms < M; ms += M_BLOCKING) { + m_min = M - ms; + if (m_min > M_BLOCKING) { + m_min = M_BLOCKING; + } + + for (ks = 0; ks < K; ks += k_min) { + k_min = K - ks; + if (k_min >= (K_BLOCKING << 1)) { + k_min = K_BLOCKING; + } else if (k_min > K_BLOCKING) { + k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1); + } + + n_min = N; + if (N >= N_BLOCKING * 2) { + n_min = N_BLOCKING; + } else if (N > N_BLOCKING) { + n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1); + } + packing_B4(k_min, n_min, B + ks * ldb, ldb, sb); + + for (mms = ms; mms < ms + m_min; mms += m2_min) { + m2_min = (ms + m_min) - mms; + if (m2_min >= 3 * GEMM_UNROLLING_4) { + m2_min = 3 * GEMM_UNROLLING_4; + } else if (m2_min >= 2 * GEMM_UNROLLING_4) { + m2_min = 2 * GEMM_UNROLLING_4; + } else if (m2_min > GEMM_UNROLLING_4) { + m2_min = GEMM_UNROLLING_4; + } + + packing_A4(m2_min, k_min, A + mms * lda + ks, lda, + sa + k_min * (mms - ms)); + + HGEMM_KERNEL_4x4(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb, + C + mms * ldc, ldc); + } + + for (ns = n_min; ns < N; ns += n_min) { + n_min = N - ns; + if (n_min >= N_BLOCKING * 2) { + n_min = N_BLOCKING; + } else if (n_min > N_BLOCKING) { + n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1); + } + + packing_B4(k_min, n_min, B + ns + ldb * ks, ldb, sb); + HGEMM_KERNEL_4x4(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc); + } + } + } + + free(sa); + free(sb); +} + +void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K, + const __fp16 *A, unsigned int lda, const __fp16 *B, + unsigned int ldb, __fp16 *C, unsigned int ldc, + float alpha, float beta) { + __fp16 *sa = alignedMalloc(M * K); + __fp16 *sb = alignedMalloc(K * N); + + unsigned int ms, mms, ns, ks; + unsigned int m_min, m2_min, n_min, k_min; + unsigned int l1stride = 1; + for (ms = 0; ms < M; ms += M_BLOCKING) { + m_min = M - ms; + if (m_min > M_BLOCKING) { + m_min = M_BLOCKING; + } + + for (ks = 0; ks < K; ks += k_min) { + k_min = K - ks; + if (k_min >= (K_BLOCKING << 1)) { + k_min = K_BLOCKING; + } else if (k_min > K_BLOCKING) { + k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1); + } + + n_min = N; + if (N >= N_BLOCKING * 2) { + n_min = N_BLOCKING; + } else if (N > N_BLOCKING) { + n_min = ((n_min / 2 + GEMM_UNROLLING_8 - 1) / GEMM_UNROLLING_8) * + GEMM_UNROLLING_8; + } else { + l1stride = 0; + } + packing_B8(k_min, n_min, B + ks * ldb, ldb, sb); + + for (mms = ms; mms < ms + m_min; mms += m2_min) { + m2_min = (ms + m_min) - mms; + if (m2_min >= 3 * GEMM_UNROLLING_4) { + m2_min = 3 * GEMM_UNROLLING_4; + } else if (m2_min >= 2 * GEMM_UNROLLING_4) { + m2_min = 2 * GEMM_UNROLLING_4; + } else if (m2_min > GEMM_UNROLLING_4) { + m2_min = GEMM_UNROLLING_4; + } + + packing_A4(m2_min, k_min, A + mms * lda + ks, lda, + sa + k_min * (mms - ms) * l1stride); + + HGEMM_KERNEL_4x8(m2_min, n_min, k_min, + sa + l1stride * k_min * (mms - ms), sb, C + mms * ldc, + ldc); + } + + for (ns = n_min; ns < N; ns += n_min) { + n_min = N - ns; + if (n_min >= N_BLOCKING * 2) { + n_min = N_BLOCKING; + } else if (n_min > N_BLOCKING) { + n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1); + } + + packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sb); + HGEMM_KERNEL_4x8(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc); + } + } + } + + free(sa); + free(sb); +} + +void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K, + const __fp16 *A, unsigned int lda, const __fp16 *B, + unsigned int ldb, float *C, unsigned int ldc, + float alpha, float beta) { + __fp16 *sa = alignedMalloc(M * K); + __fp16 *sb = alignedMalloc(K * N); + + unsigned int ms, mms, ns, ks; + unsigned int m_min, m2_min, n_min, k_min; + unsigned int l1stride = 1; + for (ms = 0; ms < M; ms += M_BLOCKING) { + m_min = M - ms; + if (m_min > M_BLOCKING) { + m_min = M_BLOCKING; + } + + for (ks = 0; ks < K; ks += k_min) { + k_min = K - ks; + if (k_min >= (K_BLOCKING << 1)) { + k_min = K_BLOCKING; + } else if (k_min > K_BLOCKING) { + k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1); + } + + n_min = N; + if (N >= N_BLOCKING * 2) { + n_min = N_BLOCKING; + } else if (N > N_BLOCKING) { + n_min = ((n_min / 2 + GEMM_UNROLLING_8 - 1) / GEMM_UNROLLING_8) * + GEMM_UNROLLING_8; + } else { + l1stride = 0; + } + packing_B8(k_min, n_min, B + ks * ldb, ldb, sb); + + for (mms = ms; mms < ms + m_min; mms += m2_min) { + m2_min = (ms + m_min) - mms; + if (m2_min >= 3 * GEMM_UNROLLING_4) { + m2_min = 3 * GEMM_UNROLLING_4; + } else if (m2_min >= 2 * GEMM_UNROLLING_4) { + m2_min = 2 * GEMM_UNROLLING_4; + } else if (m2_min > GEMM_UNROLLING_4) { + m2_min = GEMM_UNROLLING_4; + } + + packing_A4(m2_min, k_min, A + mms * lda + ks, lda, + sa + k_min * (mms - ms) * l1stride); + + HGEMM_KERNEL_4x8(m2_min, n_min, k_min, + sa + l1stride * k_min * (mms - ms), sb, C + mms * ldc, + ldc); + } + + for (ns = n_min; ns < N; ns += n_min) { + n_min = N - ns; + if (n_min >= N_BLOCKING * 2) { + n_min = N_BLOCKING; + } else if (n_min > N_BLOCKING) { + n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1); + } + + packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sb); + HGEMM_KERNEL_4x8(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc); + } + } + } + + free(sa); + free(sb); +} + +void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K, + const __fp16 *A, unsigned int lda, const __fp16 *B, + unsigned int ldb, __fp16 *C, unsigned int ldc, + float alpha, float beta) { + + __fp16 *sa = alignedMalloc(M * K); + __fp16 *sb = alignedMalloc(K * N); + + unsigned int ms, mms, ns, ks; + unsigned int m_min, m2_min, n_min, k_min; + for (ms = 0; ms < M; ms += M_BLOCKING) { + m_min = M - ms; + if (m_min > M_BLOCKING) { + m_min = M_BLOCKING; + } + + for (ks = 0; ks < K; ks += k_min) { + k_min = K - ks; + if (k_min >= (K_BLOCKING << 1)) { + k_min = K_BLOCKING; + } else if (k_min > K_BLOCKING) { + k_min = (k_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1); + } + + n_min = N; + if (N >= N_BLOCKING * 2) { + n_min = N_BLOCKING; + } else if (N > N_BLOCKING) { + n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1); + } + packing_B8(k_min, n_min, B + ks * ldb, ldb, sb); + + for (mms = ms; mms < ms + m_min; mms += m2_min) { + m2_min = (ms + m_min) - mms; + if (m2_min >= 3 * GEMM_UNROLLING_8) { + m2_min = 3 * GEMM_UNROLLING_8; + } else if (m2_min >= 2 * GEMM_UNROLLING_8) { + m2_min = 2 * GEMM_UNROLLING_8; + } else if (m2_min > GEMM_UNROLLING_8) { + m2_min = GEMM_UNROLLING_8; + } + + packing_A8(m2_min, k_min, A + mms * lda + ks, lda, + sa + k_min * (mms - ms)); + + HGEMM_KERNEL_8x8(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb, + C + mms * ldc, ldc); + } + + for (ns = n_min; ns < N; ns += n_min) { + n_min = N - ns; + if (n_min >= N_BLOCKING * 2) { + n_min = N_BLOCKING; + } else if (n_min > N_BLOCKING) { + n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1); + } + + packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sb); + HGEMM_KERNEL_8x8(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc); + } + } + } + + free(sa); + free(sb); +} + +void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K, + const __fp16 *A, unsigned int lda, const __fp16 *B, + unsigned int ldb, float *C, unsigned int ldc, + float alpha, float beta) { + + __fp16 *sA = alignedMalloc(M * K); + __fp16 *sB = alignedMalloc(K * N); + + unsigned int ms, ms2, ns, ks; + unsigned int m_min, m2_min, n_min, k_min; + for (ms = 0; ms < M; ms += M_BLOCKING) { + m_min = M - ms; + if (m_min > M_BLOCKING) { + m_min = M_BLOCKING; + } + + for (ks = 0; ks < K; ks += k_min) { + k_min = K - ks; + if (k_min >= (K_BLOCKING << 1)) { + k_min = K_BLOCKING; + } else if (k_min > K_BLOCKING) { + k_min = (k_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1); + } + + n_min = N; + if (N >= N_BLOCKING * 2) { + n_min = N_BLOCKING; + } else if (N > N_BLOCKING) { + n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1); + } + packing_B8(k_min, n_min, B + ks * ldb, ldb, sB); + + for (ms2 = ms; ms2 < ms + m_min; ms2 += m2_min) { + m2_min = (ms + m_min) - ms2; + if (m2_min >= 3 * GEMM_UNROLLING_8) { + m2_min = 3 * GEMM_UNROLLING_8; + } else if (m2_min >= 2 * GEMM_UNROLLING_8) { + m2_min = 2 * GEMM_UNROLLING_8; + } else if (m2_min > GEMM_UNROLLING_8) { + m2_min = GEMM_UNROLLING_8; + } + + packing_A8(m2_min, k_min, A + ms2 * lda + ks, lda, + sA + k_min * (ms2 - ms)); + + HGEMM_KERNEL_8x8(m2_min, n_min, k_min, sA + k_min * (ms2 - ms), sB, + C + ms2 * ldc, ldc); + } + + for (ns = n_min; ns < N; ns += n_min) { + n_min = N - ns; + if (n_min >= N_BLOCKING * 2) { + n_min = N_BLOCKING; + } else if (n_min > N_BLOCKING) { + n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1); + } + + packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sB); + HGEMM_KERNEL_8x8(m_min, n_min, k_min, sA, sB, C + ms * ldc + ns, ldc); + } + } + } + + free(sA); + free(sB); +} + +void hgemm_noTrans_fallback(unsigned int M, unsigned int N, unsigned int K, + const __fp16 *A, unsigned int lda, const __fp16 *B, + unsigned int ldb, float *C, unsigned int ldc, + float alpha, float beta) { + + unsigned int k = 0; + unsigned int N8 = (N >> 3) << 3; + __fp16 a[16]; + for (; (K - k) >= 16; k += 16) { + for (unsigned int m = 0; m < M; m++) { + vst1q_f16(&a[0], vmulq_n_f16(vld1q_f16(&A[m * K + k]), alpha)); + vst1q_f16(&a[8], vmulq_n_f16(vld1q_f16(&A[m * K + k + 8]), alpha)); + for (unsigned int n = 0; n < N8; n += 8) { + float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 2) * N + n]), a[2]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 3) * N + n]), a[3]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 4) * N + n]), a[4]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 5) * N + n]), a[5]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 6) * N + n]), a[6]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 7) * N + n]), a[7]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 8) * N + n]), a[8]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 9) * N + n]), a[9]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 10) * N + n]), a[10]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 11) * N + n]), a[11]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 12) * N + n]), a[12]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 13) * N + n]), a[13]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 14) * N + n]), a[14]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 15) * N + n]), a[15]); + + float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]), + vcvt_f32_f16(vget_low_f16(b0_7_0))); + float32x4_t c0_7_high_32 = vaddq_f32( + vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0))); + + vst1q_f32(&C[m * N + n], c0_7_low_32); + vst1q_f32(&C[m * N + n + 4], c0_7_high_32); + } + if (N != N8) { + unsigned int n = N8; + __fp16 valsB_0[8]; + __fp16 valsB_1[8]; + __fp16 valsB_2[8]; + __fp16 valsB_3[8]; + __fp16 valsB_4[8]; + __fp16 valsB_5[8]; + __fp16 valsB_6[8]; + __fp16 valsB_7[8]; + __fp16 valsB_8[8]; + __fp16 valsB_9[8]; + __fp16 valsB_10[8]; + __fp16 valsB_11[8]; + __fp16 valsB_12[8]; + __fp16 valsB_13[8]; + __fp16 valsB_14[8]; + __fp16 valsB_15[8]; + float valsC[8]; + for (unsigned int idx = n; idx < N; idx++) { + valsB_0[idx - n] = B[k * N + idx]; + valsB_1[idx - n] = B[(k + 1) * N + idx]; + valsB_2[idx - n] = B[(k + 2) * N + idx]; + valsB_3[idx - n] = B[(k + 3) * N + idx]; + valsB_4[idx - n] = B[(k + 4) * N + idx]; + valsB_5[idx - n] = B[(k + 5) * N + idx]; + valsB_6[idx - n] = B[(k + 6) * N + idx]; + valsB_7[idx - n] = B[(k + 7) * N + idx]; + valsB_8[idx - n] = B[(k + 8) * N + idx]; + valsB_9[idx - n] = B[(k + 9) * N + idx]; + valsB_10[idx - n] = B[(k + 10) * N + idx]; + valsB_11[idx - n] = B[(k + 11) * N + idx]; + valsB_12[idx - n] = B[(k + 12) * N + idx]; + valsB_13[idx - n] = B[(k + 13) * N + idx]; + valsB_14[idx - n] = B[(k + 14) * N + idx]; + valsB_15[idx - n] = B[(k + 15) * N + idx]; + valsC[idx - n] = C[m * N + idx]; + } + + float16x8_t b = vmulq_n_f16(vld1q_f16(valsB_0), a[0]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_1), a[1]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_2), a[2]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_3), a[3]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_4), a[4]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_5), a[5]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_6), a[6]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_7), a[7]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_8), a[8]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_9), a[9]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_10), a[10]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_11), a[11]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_12), a[12]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_13), a[13]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_14), a[14]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_15), a[15]); + + float32x4_t c0_7_low_32 = + vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b))); + + float32x4_t c0_7_high_32 = + vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b))); + + vst1q_f32(valsC, c0_7_low_32); + vst1q_f32(valsC + 4, c0_7_high_32); + + for (unsigned int idx = n; idx < N; idx++) { + C[m * N + idx] = valsC[idx - n]; + } + } + } + } + + for (; (K - k) >= 8; k += 8) { + for (unsigned int m = 0; m < M; m++) { + vst1q_f16(a, vmulq_n_f16(vld1q_f16(&A[m * K + k]), alpha)); + + for (unsigned int n = 0; n < N8; n += 8) { + float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 2) * N + n]), a[2]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 3) * N + n]), a[3]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 4) * N + n]), a[4]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 5) * N + n]), a[5]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 6) * N + n]), a[6]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 7) * N + n]), a[7]); + + float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]), + vcvt_f32_f16(vget_low_f16(b0_7_0))); + float32x4_t c0_7_high_32 = vaddq_f32( + vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0))); + + vst1q_f32(&C[m * N + n], c0_7_low_32); + vst1q_f32(&C[m * N + n + 4], c0_7_high_32); + } + if (N != N8) { + unsigned int n = N8; + __fp16 valsB_0[8]; + __fp16 valsB_1[8]; + __fp16 valsB_2[8]; + __fp16 valsB_3[8]; + __fp16 valsB_4[8]; + __fp16 valsB_5[8]; + __fp16 valsB_6[8]; + __fp16 valsB_7[8]; + float valsC[8]; + for (unsigned int idx = n; idx < N; idx++) { + valsB_0[idx - n] = B[k * N + idx]; + valsB_1[idx - n] = B[(k + 1) * N + idx]; + valsB_2[idx - n] = B[(k + 2) * N + idx]; + valsB_3[idx - n] = B[(k + 3) * N + idx]; + valsB_4[idx - n] = B[(k + 4) * N + idx]; + valsB_5[idx - n] = B[(k + 5) * N + idx]; + valsB_6[idx - n] = B[(k + 6) * N + idx]; + valsB_7[idx - n] = B[(k + 7) * N + idx]; + valsC[idx - n] = C[m * N + idx]; + } + + float16x8_t b = vmulq_n_f16(vld1q_f16(valsB_0), a[0]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_1), a[1]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_2), a[2]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_3), a[3]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_4), a[4]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_5), a[5]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_6), a[6]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_7), a[7]); + + float32x4_t c0_7_low_32 = + vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b))); + + float32x4_t c0_7_high_32 = + vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b))); + + vst1q_f32(valsC, c0_7_low_32); + vst1q_f32(valsC + 4, c0_7_high_32); + + for (unsigned int idx = n; idx < N; idx++) { + C[m * N + idx] = valsC[idx - n]; + } + } + } + } + + for (; (K - k) >= 4; k += 4) { + for (unsigned int m = 0; m < M; m++) { + vst1_f16(a, vmul_n_f16(vld1_f16(&A[m * K + k]), alpha)); + + for (unsigned int n = 0; n < N8; n += 8) { + + float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]); + float16x8_t b0_7_2 = vmulq_n_f16(vld1q_f16(&B[(k + 2) * N + n]), a[2]); + b0_7_2 = vfmaq_n_f16(b0_7_2, vld1q_f16(&B[(k + 3) * N + n]), a[3]); + + float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]), + vcvt_f32_f16(vget_low_f16(b0_7_0))); + float32x4_t c0_7_high_32 = vaddq_f32( + vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0))); + + c0_7_low_32 = + vaddq_f32(c0_7_low_32, vcvt_f32_f16(vget_low_f16(b0_7_2))); + c0_7_high_32 = + vaddq_f32(c0_7_high_32, vcvt_f32_f16(vget_high_f16(b0_7_2))); + + vst1q_f32(&C[m * N + n], c0_7_low_32); + vst1q_f32(&C[m * N + n + 4], c0_7_high_32); + } + if (N != N8) { + unsigned int n = N8; + __fp16 valsB_0[8]; + __fp16 valsB_1[8]; + __fp16 valsB_2[8]; + __fp16 valsB_3[8]; + float valsC[8]; + for (unsigned int idx = n; idx < N; idx++) { + valsB_0[idx - n] = B[k * N + idx]; + valsB_1[idx - n] = B[(k + 1) * N + idx]; + valsB_2[idx - n] = B[(k + 2) * N + idx]; + valsB_3[idx - n] = B[(k + 3) * N + idx]; + valsC[idx - n] = C[m * N + idx]; + } + + float16x8_t b = vmulq_n_f16(vld1q_f16(valsB_0), a[0]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_1), a[1]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_2), a[2]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_3), a[3]); + + float32x4_t c0_7_low_32 = + vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b))); + + float32x4_t c0_7_high_32 = + vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b))); + + vst1q_f32(valsC, c0_7_low_32); + vst1q_f32(valsC + 4, c0_7_high_32); + + for (unsigned int idx = n; idx < N; idx++) { + C[m * N + idx] = valsC[idx - n]; + } + } + } + } + + for (; k < K; k++) { + for (unsigned int m = 0; m < M; m++) { + __fp16 a0 = alpha * A[m * K + k]; + + for (unsigned int n = 0; n < N8; n += 8) { + float16x8_t b0_7 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a0); + + float32x4_t c0_7_low_32 = + vaddq_f32(vld1q_f32(&C[m * N + n]), vcvt_f32_f16(vget_low_f16(b0_7))); + + float32x4_t c0_7_high_32 = vaddq_f32(vld1q_f32(&C[m * N + n + 4]), + vcvt_f32_f16(vget_high_f16(b0_7))); + + vst1q_f32(&C[m * N + n], c0_7_low_32); + vst1q_f32(&C[m * N + n + 4], c0_7_high_32); + } + if (N != N8) { + unsigned int n = N8; + __fp16 valsB[8]; + float valsC[8]; + for (unsigned int idx = n; idx < N; idx++) { + valsB[idx - n] = B[k * N + idx]; + valsC[idx - n] = C[m * N + idx]; + } + + float16x8_t b = vmulq_n_f16(vld1q_f16(valsB), a0); + + float32x4_t c0_7_low_32 = + vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b))); + + float32x4_t c0_7_high_32 = + vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b))); + + vst1q_f32(valsC, c0_7_low_32); + vst1q_f32(valsC + 4, c0_7_high_32); + + for (unsigned int idx = n; idx < N; idx++) { + C[m * N + idx] = valsC[idx - n]; + } + } + } + } +} diff --git a/nntrainer/tensor/hgemm/hgemm.h b/nntrainer/tensor/hgemm/hgemm.h index bcd74b68..fde7112a 100644 --- a/nntrainer/tensor/hgemm/hgemm.h +++ b/nntrainer/tensor/hgemm/hgemm.h @@ -11,44 +11,136 @@ * */ -#include -#include -#include -#include +/** + * @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C, + * @param[in] A __fp16 * for Matrix A + * @param[in] B __fp16 * for Matrix B + * @param[in] C __fp16 * for Matrix C + * @param[in] M number of op(A)'s and C's row + * @param[in] N number of op(B)'s and C's columns + * @param[in] K number of op(A)'s and columns and op(B)'s rows + * @param[in] alpha float number + * @param[in] beta float number + */ +void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C, unsigned int M, + unsigned int N, unsigned int K, float alpha = 1.F, + float beta = 0.F); -#define KERNEL_4x4 hgemm_kernel_4x4 -#define KERNEL_8x8 hgemm_kernel_8x8 +/** + * @brief hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C, + * @param M length of the row of matrix A + * @param N length of the col of matrix B + * @param K length of the col of matrix A + * @param A input matrix A + * @param lda length of the col of matrix C + * @param B input matrix B + * @param ldb length of the col of matrix C + * @param C output matrix C + * @param ldc length of the col of matrix C + * @param[in] alpha float number + * @param[in] beta float number + */ +void hgemm_noTrans_fallback(unsigned int M, unsigned int N, unsigned int K, + const __fp16 *A, unsigned int lda, const __fp16 *B, + unsigned int ldb, float *C, unsigned int ldc, + float alpha = 1.F, float beta = 0.F); /** * @brief hgemm noTrans computation with 4x4 kernel : C = A*B, - * - * @param m length of the row of matrix A - * @param n length of the col of matrix B - * @param k length of the col of matrix A - * @param a input matrix A + * + * @param M length of the row of matrix A + * @param N length of the col of matrix B + * @param K length of the col of matrix A + * @param A input matrix A + * @param lda length of the col of matrix C + * @param B input matrix B + * @param ldb length of the col of matrix C + * @param C output matrix C + * @param ldc length of the col of matrix C + * @param[in] alpha float number + * @param[in] beta float number + */ +void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K, + const __fp16 *A, unsigned int lda, const __fp16 *B, + unsigned int ldb, __fp16 *C, unsigned int ldc, + float alpha = 1.F, float beta = 0.F); + +/** + * @brief hgemm noTrans computation with 8x8 kernel : C = A*B, + * + * @param M length of the row of matrix A + * @param N length of the col of matrix B + * @param K length of the col of matrix A + * @param A input matrix A * @param lda length of the col of matrix C - * @param b input matrix B + * @param B input matrix B * @param ldb length of the col of matrix C - * @param c output matrix C + * @param C output matrix C * @param ldc length of the col of matrix C + * @param[in] alpha float number + * @param[in] beta float number */ -void hgemm_noTrans_4x4(unsigned int m, unsigned int n, unsigned int k, - const __fp16 *a, unsigned int lda, const __fp16 *b, - unsigned int ldb, __fp16 *c, unsigned int ldc); +void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K, + const __fp16 *A, unsigned int lda, const __fp16 *B, + unsigned int ldb, __fp16 *C, unsigned int ldc, + float alpha = 1.F, float beta = 0.F); /** * @brief hgemm noTrans computation with 8x8 kernel : C = A*B, - * - * @param m length of the row of matrix A - * @param n length of the col of matrix B - * @param k length of the col of matrix A - * @param a input matrix A + * + * @param M length of the row of matrix A + * @param N length of the col of matrix B + * @param K length of the col of matrix A + * @param A input matrix A + * @param lda length of the col of matrix C + * @param B input matrix B + * @param ldb length of the col of matrix C + * @param C output matrix C + * @param ldc length of the col of matrix C + * @param[in] alpha float number + * @param[in] beta float number + */ +void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K, + const __fp16 *A, unsigned int lda, const __fp16 *B, + unsigned int ldb, float *C, unsigned int ldc, + float alpha = 1.F, float beta = 0.F); + +/** + * @brief hgemm noTrans computation with 4x8 kernel : C = A*B, + * + * @param M length of the row of matrix A + * @param N length of the col of matrix B + * @param K length of the col of matrix A + * @param A input matrix A + * @param lda length of the col of matrix C + * @param B input matrix B + * @param ldb length of the col of matrix C + * @param C output matrix C + * @param ldc length of the col of matrix C + * @param[in] alpha float number + * @param[in] beta float number + */ +void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K, + const __fp16 *A, unsigned int lda, const __fp16 *B, + unsigned int ldb, __fp16 *C, unsigned int ldc, + float alpha = 1.F, float beta = 0.F); + +/** + * @brief hgemm noTrans computation with 4x8 kernel : C = A*B, + * + * @param M length of the row of matrix A + * @param N length of the col of matrix B + * @param K length of the col of matrix A + * @param A input matrix A * @param lda length of the col of matrix C - * @param b input matrix B + * @param B input matrix B * @param ldb length of the col of matrix C - * @param c output matrix C + * @param C output matrix C * @param ldc length of the col of matrix C + * @param[in] alpha float number + * @param[in] beta float number */ -void hgemm_noTrans_8x8(unsigned int m, unsigned int n, unsigned int k, - const __fp16 *a, unsigned int lda, const __fp16 *b, - unsigned int ldb, __fp16 *c, unsigned int ldc); +void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K, + const __fp16 *A, unsigned int lda, const __fp16 *B, + unsigned int ldb, float *C, unsigned int ldc, + float alpha = 1.F, float beta = 0.F); diff --git a/nntrainer/tensor/hgemm/hgemm_common.h b/nntrainer/tensor/hgemm/hgemm_common.h index f5bb3fa3..68728102 100644 --- a/nntrainer/tensor/hgemm/hgemm_common.h +++ b/nntrainer/tensor/hgemm/hgemm_common.h @@ -10,14 +10,16 @@ * @brief This is common settings for hgemm * */ +#include +#include #define A(i, j) a[(i)*lda + (j)] #define B(i, j) b[(i)*ldb + (j)] #define C(i, j) c[(i)*ldc + (j)] -#define GEMM_N (384) -#define GEMM_K (256) -#define GEMM_M (4096) +#define N_BLOCKING (384) +#define K_BLOCKING (256) +#define M_BLOCKING (4096) #define GEMM_UNROLLING_8 (8) #define GEMM_UNROLLING_4 (4) #define VL_FP16 (8) diff --git a/nntrainer/tensor/hgemm/hgemm_kernel_4x4.h b/nntrainer/tensor/hgemm/hgemm_kernel_4x4.h index fbb2f744..6166b940 100644 --- a/nntrainer/tensor/hgemm/hgemm_kernel_4x4.h +++ b/nntrainer/tensor/hgemm/hgemm_kernel_4x4.h @@ -11,31 +11,29 @@ * */ -#include #include -#include #include /** * @brief hgemm 4x4 kernel sc = sa * sb - * + * * @param m length of the row of matrix A - * @param n length of the col of matrix B + * @param n length of the col of matrix B * @param k length of the col of matrix A * @param sa sub-matrix of input matrix A * @param sb sub-matrix of input matrix B * @param sc sub-matrix of output matrix C * @param ldc leading dimension of matrix C */ -void hgemm_kernel_4x4(unsigned int m, unsigned int n, unsigned int k, +void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K, __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) { - assert(m > 0 && n > 0 && k > 0); - assert(m % 4 == 0 && n % 4 == 0 && k % 4 == 0); + assert(M > 0 && N > 0 && K > 0); + assert(M % 4 == 0 && N % 4 == 0 && K % 4 == 0); __fp16 *a = sa, *b = sb, *c = sc; unsigned int i, j, l; - for (i = 0; i < m; i += VL_FP16_HALF) { - for (j = 0; j < n; j += VL_FP16_HALF) { + for (i = 0; i < M; i += VL_FP16_HALF) { + for (j = 0; j < N; j += VL_FP16_HALF) { __builtin_prefetch(b, 0, 3); __builtin_prefetch(a, 0, 3); @@ -44,7 +42,7 @@ void hgemm_kernel_4x4(unsigned int m, unsigned int n, unsigned int k, float16x4_t v26 = {0}; float16x4_t v27 = {0}; - for (l = 0; l < k; l += VL_FP16_HALF) { + for (l = 0; l < K; l += VL_FP16_HALF) { float16x4_t v0 = vld1_f16(b); float16x4_t v16 = vld1_f16(a); @@ -95,12 +93,11 @@ void hgemm_kernel_4x4(unsigned int m, unsigned int n, unsigned int k, vst1_f16(c + 3 * ldc, v27); c += 4; - a -= 4 * k; + a -= 4 * K; } sc += ldc * 4; c = sc; - a += 4 * k; + a += 4 * K; b = sb; } } - diff --git a/nntrainer/tensor/hgemm/hgemm_kernel_pack.h b/nntrainer/tensor/hgemm/hgemm_kernel_pack.h index 60881f15..14c112c1 100644 --- a/nntrainer/tensor/hgemm/hgemm_kernel_pack.h +++ b/nntrainer/tensor/hgemm/hgemm_kernel_pack.h @@ -15,85 +15,85 @@ /** * @brief packing function of input matrix A - * - * @param m length of the row of the matrix - * @param k length of the col of the matrix - * @param from input of original source of the matrix + * + * @param M length of the row of the matrix + * @param K length of the col of the matrix + * @param src input of original source of the matrix * @param lda leading dimension of the matrix - * @param to output of packed data of the matrix + * @param dst output of packed data of the matrix */ -void packA_4(unsigned int m, unsigned int k, const __fp16 *from, - unsigned int lda, const __fp16 *to) { +void packing_A4(unsigned int M, unsigned int K, const __fp16 *src, + unsigned int lda, const __fp16 *dst) { - assert(k != 0 && m != 0 && k % 4 == 0 && m % 4 == 0); + assert(K != 0 && M != 0 && K % 4 == 0 && M % 4 == 0); unsigned int i, j; - __fp16 *a_offset, *a_offset1, *a_offset2, *a_offset3, *a_offset4; - __fp16 *b_offset; - __fp16 ctemp1, ctemp2, ctemp3, ctemp4; - __fp16 ctemp5, ctemp6, ctemp7, ctemp8; - __fp16 ctemp9, ctemp10, ctemp11, ctemp12; - __fp16 ctemp13, ctemp14, ctemp15, ctemp16; + __fp16 *a_off, *a_off1, *a_off2, *a_off3, *a_off4; + __fp16 *b_off; + __fp16 c1, c2, c3, c4; + __fp16 c5, c6, c7, c8; + __fp16 c9, c10, c11, c12; + __fp16 c13, c14, c15, c16; - a_offset = (__fp16 *)from; - b_offset = (__fp16 *)to; + a_off = (__fp16 *)src; + b_off = (__fp16 *)dst; - j = (m >> 2); + j = (M >> 2); do { - a_offset1 = a_offset; - a_offset2 = a_offset1 + lda; - a_offset3 = a_offset2 + lda; - a_offset4 = a_offset3 + lda; - a_offset += 4 * lda; + a_off1 = a_off; + a_off2 = a_off1 + lda; + a_off3 = a_off2 + lda; + a_off4 = a_off3 + lda; + a_off += 4 * lda; - i = (k >> 2); + i = (K >> 2); do { - ctemp1 = *(a_offset1 + 0); - ctemp2 = *(a_offset1 + 1); - ctemp3 = *(a_offset1 + 2); - ctemp4 = *(a_offset1 + 3); - - ctemp5 = *(a_offset2 + 0); - ctemp6 = *(a_offset2 + 1); - ctemp7 = *(a_offset2 + 2); - ctemp8 = *(a_offset2 + 3); - - ctemp9 = *(a_offset3 + 0); - ctemp10 = *(a_offset3 + 1); - ctemp11 = *(a_offset3 + 2); - ctemp12 = *(a_offset3 + 3); - - ctemp13 = *(a_offset4 + 0); - ctemp14 = *(a_offset4 + 1); - ctemp15 = *(a_offset4 + 2); - ctemp16 = *(a_offset4 + 3); - - *(b_offset + 0) = ctemp1; - *(b_offset + 1) = ctemp5; - *(b_offset + 2) = ctemp9; - *(b_offset + 3) = ctemp13; - - *(b_offset + 4) = ctemp2; - *(b_offset + 5) = ctemp6; - *(b_offset + 6) = ctemp10; - *(b_offset + 7) = ctemp14; - - *(b_offset + 8) = ctemp3; - *(b_offset + 9) = ctemp7; - *(b_offset + 10) = ctemp11; - *(b_offset + 11) = ctemp15; - - *(b_offset + 12) = ctemp4; - *(b_offset + 13) = ctemp8; - *(b_offset + 14) = ctemp12; - *(b_offset + 15) = ctemp16; - - a_offset1 += 4; - a_offset2 += 4; - a_offset3 += 4; - a_offset4 += 4; - - b_offset += 16; + c1 = *(a_off1 + 0); + c2 = *(a_off1 + 1); + c3 = *(a_off1 + 2); + c4 = *(a_off1 + 3); + + c5 = *(a_off2 + 0); + c6 = *(a_off2 + 1); + c7 = *(a_off2 + 2); + c8 = *(a_off2 + 3); + + c9 = *(a_off3 + 0); + c10 = *(a_off3 + 1); + c11 = *(a_off3 + 2); + c12 = *(a_off3 + 3); + + c13 = *(a_off4 + 0); + c14 = *(a_off4 + 1); + c15 = *(a_off4 + 2); + c16 = *(a_off4 + 3); + + *(b_off + 0) = c1; + *(b_off + 1) = c5; + *(b_off + 2) = c9; + *(b_off + 3) = c13; + + *(b_off + 4) = c2; + *(b_off + 5) = c6; + *(b_off + 6) = c10; + *(b_off + 7) = c14; + + *(b_off + 8) = c3; + *(b_off + 9) = c7; + *(b_off + 10) = c11; + *(b_off + 11) = c15; + + *(b_off + 12) = c4; + *(b_off + 13) = c8; + *(b_off + 14) = c12; + *(b_off + 15) = c16; + + a_off1 += 4; + a_off2 += 4; + a_off3 += 4; + a_off4 += 4; + + b_off += 16; i--; } while (i > 0); j--; @@ -102,54 +102,54 @@ void packA_4(unsigned int m, unsigned int k, const __fp16 *from, /** * @brief packing function of input matrix A - * - * @param m length of the row of the matrix - * @param k length of the col of the matrix - * @param from input of original source of the matrix + * + * @param M length of the row of the matrix + * @param K length of the col of the matrix + * @param src input of original source of the matrix * @param lda leading dimension of the matrix - * @param to output of packed data of the matrix + * @param dst output of packed data of the matrix */ -void packA_8(unsigned int m, unsigned int k, const __fp16 *from, - unsigned int lda, const __fp16 *to) { +void packing_A8(unsigned int M, unsigned int K, const __fp16 *src, + unsigned int lda, const __fp16 *dst) { - assert(k != 0 && m != 0 && k % 8 == 0 && m % 8 == 0); + assert(K != 0 && M != 0 && K % 8 == 0 && M % 8 == 0); uint16x4_t msk = {0xFFFF, 0xFFFF, 0x0000, 0x0000}; uint16x4_t inv_msk = {0x0000, 0x0000, 0xFFFF, 0xFFFF}; - const __fp16 *a_offset = (__fp16 *)from; - __fp16 *b_offset = (__fp16 *)to; - - for (unsigned int i = 0; i < m; i += 8) { - const __fp16 *a_offset1 = a_offset; - const __fp16 *a_offset2 = a_offset1 + lda; - const __fp16 *a_offset3 = a_offset2 + lda; - const __fp16 *a_offset4 = a_offset3 + lda; - const __fp16 *a_offset5 = a_offset4 + lda; - const __fp16 *a_offset6 = a_offset5 + lda; - const __fp16 *a_offset7 = a_offset6 + lda; - const __fp16 *a_offset8 = a_offset7 + lda; - a_offset += 8 * lda; - - for (unsigned int j = 0; j < k; j += 8) { - float16x8_t _v0 = vld1q_f16(a_offset1); - float16x8_t _v1 = vld1q_f16(a_offset2); - float16x8_t _v2 = vld1q_f16(a_offset3); - float16x8_t _v3 = vld1q_f16(a_offset4); - - float16x8_t _v4 = vld1q_f16(a_offset5); - float16x8_t _v5 = vld1q_f16(a_offset6); - float16x8_t _v6 = vld1q_f16(a_offset7); - float16x8_t _v7 = vld1q_f16(a_offset8); - - a_offset1 += 8; - a_offset2 += 8; - a_offset3 += 8; - a_offset4 += 8; - a_offset5 += 8; - a_offset6 += 8; - a_offset7 += 8; - a_offset8 += 8; + const __fp16 *a_off = (__fp16 *)src; + __fp16 *b_off = (__fp16 *)dst; + + for (unsigned int i = 0; i < M; i += 8) { + const __fp16 *a_off1 = a_off; + const __fp16 *a_off2 = a_off1 + lda; + const __fp16 *a_off3 = a_off2 + lda; + const __fp16 *a_off4 = a_off3 + lda; + const __fp16 *a_off5 = a_off4 + lda; + const __fp16 *a_off6 = a_off5 + lda; + const __fp16 *a_off7 = a_off6 + lda; + const __fp16 *a_off8 = a_off7 + lda; + a_off += 8 * lda; + + for (unsigned int j = 0; j < K; j += 8) { + float16x8_t _v0 = vld1q_f16(a_off1); + float16x8_t _v1 = vld1q_f16(a_off2); + float16x8_t _v2 = vld1q_f16(a_off3); + float16x8_t _v3 = vld1q_f16(a_off4); + + float16x8_t _v4 = vld1q_f16(a_off5); + float16x8_t _v5 = vld1q_f16(a_off6); + float16x8_t _v6 = vld1q_f16(a_off7); + float16x8_t _v7 = vld1q_f16(a_off8); + + a_off1 += 8; + a_off2 += 8; + a_off3 += 8; + a_off4 += 8; + a_off5 += 8; + a_off6 += 8; + a_off7 += 8; + a_off8 += 8; float16x8x2_t _vv0 = vtrnq_f16(_v0, _v1); float16x8x2_t _vv1 = vtrnq_f16(_v2, _v3); @@ -224,101 +224,101 @@ void packA_8(unsigned int m, unsigned int k, const __fp16 *from, _v15 = vcombine_f16(vbsl_f16(inv_msk, tmp_high_v11, mid_v11), vbsl_f16(inv_msk, tmp_high_v15, mid_v15)); - vst1q_f16(b_offset + 0, _v8); - vst1q_f16(b_offset + 8, _v12); - vst1q_f16(b_offset + 16, _v9); - vst1q_f16(b_offset + 24, _v13); - vst1q_f16(b_offset + 32, _v10); - vst1q_f16(b_offset + 40, _v14); - vst1q_f16(b_offset + 48, _v11); - vst1q_f16(b_offset + 56, _v15); - b_offset += 64; + vst1q_f16(b_off + 0, _v8); + vst1q_f16(b_off + 8, _v12); + vst1q_f16(b_off + 16, _v9); + vst1q_f16(b_off + 24, _v13); + vst1q_f16(b_off + 32, _v10); + vst1q_f16(b_off + 40, _v14); + vst1q_f16(b_off + 48, _v11); + vst1q_f16(b_off + 56, _v15); + b_off += 64; } } } /** * @brief packing function of input matrix B - * - * @param m length of the row of the matrix - * @param k length of the col of the matrix - * @param from input of original source of the matrix + * + * @param M length of the row of the matrix + * @param K length of the col of the matrix + * @param src input of original source of the matrix * @param ldb leading dimension of the matrix - * @param to output of packed data of the matrix + * @param dst output of packed data of the matrix */ -void packB_4(unsigned int k, unsigned int n, const __fp16 *from, - unsigned int ldb, const __fp16 *to) { - assert(k != 0 && n != 0 && k % 4 == 0 && n % 4 == 0); +void packing_B4(unsigned int K, unsigned int N, const __fp16 *src, + unsigned int ldb, const __fp16 *dst) { + assert(K != 0 && N != 0 && K % 4 == 0 && N % 4 == 0); unsigned int i, j; - __fp16 *a_offset, *a_offset1, *a_offset2, *a_offset3, *a_offset4; - __fp16 *b_offset, *b_offset1; - __fp16 ctemp1, ctemp2, ctemp3, ctemp4; - __fp16 ctemp5, ctemp6, ctemp7, ctemp8; - __fp16 ctemp9, ctemp10, ctemp11, ctemp12; - __fp16 ctemp13, ctemp14, ctemp15, ctemp16; - a_offset = (__fp16 *)from; - b_offset = (__fp16 *)to; + __fp16 *a_off, *a_off1, *a_off2, *a_off3, *a_off4; + __fp16 *b_off, *b_off1; + __fp16 c1, c2, c3, c4; + __fp16 c5, c6, c7, c8; + __fp16 c9, c10, c11, c12; + __fp16 c13, c14, c15, c16; + a_off = (__fp16 *)src; + b_off = (__fp16 *)dst; - j = (k >> 2); + j = (K >> 2); do { - a_offset1 = a_offset; - a_offset2 = a_offset1 + ldb; - a_offset3 = a_offset2 + ldb; - a_offset4 = a_offset3 + ldb; - a_offset += 4 * ldb; + a_off1 = a_off; + a_off2 = a_off1 + ldb; + a_off3 = a_off2 + ldb; + a_off4 = a_off3 + ldb; + a_off += 4 * ldb; - b_offset1 = b_offset; - b_offset += 16; + b_off1 = b_off; + b_off += 16; - i = (n >> 2); + i = (N >> 2); do { - ctemp1 = *(a_offset1 + 0); - ctemp2 = *(a_offset1 + 1); - ctemp3 = *(a_offset1 + 2); - ctemp4 = *(a_offset1 + 3); - - ctemp5 = *(a_offset2 + 0); - ctemp6 = *(a_offset2 + 1); - ctemp7 = *(a_offset2 + 2); - ctemp8 = *(a_offset2 + 3); - - ctemp9 = *(a_offset3 + 0); - ctemp10 = *(a_offset3 + 1); - ctemp11 = *(a_offset3 + 2); - ctemp12 = *(a_offset3 + 3); - - ctemp13 = *(a_offset4 + 0); - ctemp14 = *(a_offset4 + 1); - ctemp15 = *(a_offset4 + 2); - ctemp16 = *(a_offset4 + 3); - - a_offset1 += 4; - a_offset2 += 4; - a_offset3 += 4; - a_offset4 += 4; - - *(b_offset1 + 0) = ctemp1; - *(b_offset1 + 1) = ctemp2; - *(b_offset1 + 2) = ctemp3; - *(b_offset1 + 3) = ctemp4; - - *(b_offset1 + 4) = ctemp5; - *(b_offset1 + 5) = ctemp6; - *(b_offset1 + 6) = ctemp7; - *(b_offset1 + 7) = ctemp8; - - *(b_offset1 + 8) = ctemp9; - *(b_offset1 + 9) = ctemp10; - *(b_offset1 + 10) = ctemp11; - *(b_offset1 + 11) = ctemp12; - - *(b_offset1 + 12) = ctemp13; - *(b_offset1 + 13) = ctemp14; - *(b_offset1 + 14) = ctemp15; - *(b_offset1 + 15) = ctemp16; - - b_offset1 += k * 4; + c1 = *(a_off1 + 0); + c2 = *(a_off1 + 1); + c3 = *(a_off1 + 2); + c4 = *(a_off1 + 3); + + c5 = *(a_off2 + 0); + c6 = *(a_off2 + 1); + c7 = *(a_off2 + 2); + c8 = *(a_off2 + 3); + + c9 = *(a_off3 + 0); + c10 = *(a_off3 + 1); + c11 = *(a_off3 + 2); + c12 = *(a_off3 + 3); + + c13 = *(a_off4 + 0); + c14 = *(a_off4 + 1); + c15 = *(a_off4 + 2); + c16 = *(a_off4 + 3); + + a_off1 += 4; + a_off2 += 4; + a_off3 += 4; + a_off4 += 4; + + *(b_off1 + 0) = c1; + *(b_off1 + 1) = c2; + *(b_off1 + 2) = c3; + *(b_off1 + 3) = c4; + + *(b_off1 + 4) = c5; + *(b_off1 + 5) = c6; + *(b_off1 + 6) = c7; + *(b_off1 + 7) = c8; + + *(b_off1 + 8) = c9; + *(b_off1 + 9) = c10; + *(b_off1 + 10) = c11; + *(b_off1 + 11) = c12; + + *(b_off1 + 12) = c13; + *(b_off1 + 13) = c14; + *(b_off1 + 14) = c15; + *(b_off1 + 15) = c16; + + b_off1 += K * 4; i--; } while (i > 0); j--; @@ -327,26 +327,26 @@ void packB_4(unsigned int k, unsigned int n, const __fp16 *from, /** * @brief packing function of input matrix B - * - * @param m length of the row of the matrix - * @param k length of the col of the matrix - * @param from input of original source of the matrix + * + * @param M length of the row of the matrix + * @param K length of the col of the matrix + * @param src input of original source of the matrix * @param ldb leading dimension of the matrix - * @param to output of packed data of the matrix + * @param dst output of packed data of the matrix */ -void packB_8(unsigned int k, unsigned int n, const __fp16 *from, - unsigned int ldb, const __fp16 *to) { - assert(k != 0 && n != 0 && n % 8 == 0); - - for (int i = 0; i < k; i++) { - const __fp16 *a_offset1 = from + i * ldb; - __fp16 *b_offset = (__fp16 *)to + i * 8; - for (int j = 0; j < n; j += 8) { - float16x8_t _v0 = vld1q_f16(a_offset1); - a_offset1 += 8; - - vst1q_f16(b_offset, _v0); - b_offset += 8 * k; +void packing_B8(unsigned int K, unsigned int N, const __fp16 *src, + unsigned int ldb, const __fp16 *dst) { + assert(K != 0 && N != 0 && N % 8 == 0); + + for (int i = 0; i < K; i++) { + const __fp16 *a_off = src + i * ldb; + __fp16 *b_off = (__fp16 *)dst + i * 8; + for (int j = 0; j < N; j += 8) { + float16x8_t v = vld1q_f16(a_off); + a_off += 8; + + vst1q_f16(b_off, v); + b_off += 8 * K; } } } diff --git a/nntrainer/tensor/hgemm/hgemm_util.h b/nntrainer/tensor/hgemm/hgemm_util.h index 1996a59e..4c71d0a8 100644 --- a/nntrainer/tensor/hgemm/hgemm_util.h +++ b/nntrainer/tensor/hgemm/hgemm_util.h @@ -14,13 +14,13 @@ /** * @brief aligned dynamic allocation function - * - * @param size amount of data to allocate + * + * @param sz amount of data to allocate * @return __fp16* addr of allocated memory */ -static inline __fp16 *alignedMalloc(int size) { - void *ptr = 0; - int iRet = posix_memalign(&ptr, 64, size * sizeof(__fp16)); +static inline __fp16 *alignedMalloc(int sz) { + void *addr = 0; + int iRet = posix_memalign(&addr, 64, sz * sizeof(__fp16)); assert(0 == iRet); - return (__fp16 *)ptr; + return (__fp16 *)addr; } diff --git a/nntrainer/tensor/hgemm/meson.build b/nntrainer/tensor/hgemm/meson.build new file mode 100644 index 00000000..cf9efc33 --- /dev/null +++ b/nntrainer/tensor/hgemm/meson.build @@ -0,0 +1,21 @@ +hgemm_headers = [ + 'hgemm.h', + 'hgemm_util.h', + 'hgemm_kernel_pack.h', + 'hgemm_kernel_4x4.h', + 'hgemm_kernel_4x8.h', + 'hgemm_kernel_8x8.h', +] + +hgemm_sources = [ + 'hgemm.cpp' +] + +foreach s : hgemm_sources + nntrainer_sources += meson.current_source_dir() / s +endforeach + +foreach h : hgemm_headers + nntrainer_headers += meson.current_source_dir() / h +endforeach +