*/
#include <hgemm.h>
+#include <hgemm_kernel_1x4.h>
#include <hgemm_kernel_1x8.h>
#include <hgemm_kernel_4x4.h>
#include <hgemm_kernel_4x8.h>
#include <hgemm_kernel_pack.h>
#include <hgemm_util.h>
+#define HGEMM_KERNEL_1x4 hgemm_kernel_1x4
#define HGEMM_KERNEL_4x4 hgemm_kernel_4x4
#define HGEMM_KERNEL_1x8 hgemm_kernel_1x8
#define HGEMM_KERNEL_4x8 hgemm_kernel_4x8
hgemm_noTrans_4x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
} else if (N % 8 == 0) {
hgemm_noTrans_1x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
+ } else if (N % 4 == 0) {
+ hgemm_noTrans_1x4(M, N, K, A, K, B, N, C32, N, alpha, beta);
} else {
hgemm_noTrans_fallback(M, N, K, A, K, B, N, C32, N, alpha, beta);
}
hgemm_noTrans_1x8(M, N, K, A, K, B, N, C, N, alpha, beta);
} else if (M % 4 == 0 && N % 4 == 0 && K % 4 == 0) {
hgemm_noTrans_4x4(M, N, K, A, K, B, N, C, N, alpha, beta);
+ } else if (N % 4 == 0) {
+ hgemm_noTrans_1x4(M, N, K, A, K, B, N, C, N, alpha, beta);
}
}
}
+void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, __fp16 *C, unsigned int ldc,
+ float alpha, float beta) {
+ __fp16 *sa = alignedMalloc(M * K);
+ __fp16 *sb = alignedMalloc(K * N);
+
+ unsigned int ms, mms, ns, ks;
+ unsigned int m_min, m2_min, n_min, k_min;
+ for (ms = 0; ms < M; ms += M_BLOCKING) {
+ m_min = M - ms;
+ if (m_min > M_BLOCKING) {
+ m_min = M_BLOCKING;
+ }
+
+ for (ks = 0; ks < K; ks += k_min) {
+ k_min = K - ks;
+ if (k_min >= (K_BLOCKING << 1)) {
+ k_min = K_BLOCKING;
+ } else if (k_min > K_BLOCKING) {
+ k_min = (k_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
+ }
+
+ n_min = N;
+ if (N >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (N > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+ }
+ packing_B4(k_min, n_min, B + ks * ldb, ldb, sb);
+
+ for (mms = ms; mms < ms + m_min; mms += m2_min) {
+ m2_min = (ms + m_min) - mms;
+ if (m2_min >= 3 * GEMM_UNROLLING_1) {
+ m2_min = 3 * GEMM_UNROLLING_1;
+ } else if (m2_min >= 2 * GEMM_UNROLLING_1) {
+ m2_min = 2 * GEMM_UNROLLING_1;
+ } else if (m2_min > GEMM_UNROLLING_1) {
+ m2_min = GEMM_UNROLLING_1;
+ }
+
+ packing_A1(m2_min, k_min, A + mms * lda + ks, lda,
+ sa + k_min * (mms - ms));
+
+ HGEMM_KERNEL_1x4(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
+ C + mms * ldc, ldc);
+ }
+
+ for (ns = n_min; ns < N; ns += n_min) {
+ n_min = N - ns;
+ if (n_min >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (n_min > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
+ }
+
+ packing_B4(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+ HGEMM_KERNEL_1x4(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+ }
+ }
+ }
+
+ free(sa);
+ free(sb);
+}
+
+void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, float *C, unsigned int ldc,
+ float alpha, float beta) {
+ __fp16 *sa = alignedMalloc(M * K);
+ __fp16 *sb = alignedMalloc(K * N);
+
+ unsigned int ms, mms, ns, ks;
+ unsigned int m_min, m2_min, n_min, k_min;
+ for (ms = 0; ms < M; ms += M_BLOCKING) {
+ m_min = M - ms;
+ if (m_min > M_BLOCKING) {
+ m_min = M_BLOCKING;
+ }
+
+ for (ks = 0; ks < K; ks += k_min) {
+ k_min = K - ks;
+ if (k_min >= (K_BLOCKING << 1)) {
+ k_min = K_BLOCKING;
+ } else if (k_min > K_BLOCKING) {
+ k_min = (k_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
+ }
+
+ n_min = N;
+ if (N >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (N > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+ }
+ packing_B4(k_min, n_min, B + ks * ldb, ldb, sb);
+
+ for (mms = ms; mms < ms + m_min; mms += m2_min) {
+ m2_min = (ms + m_min) - mms;
+ if (m2_min >= 3 * GEMM_UNROLLING_1) {
+ m2_min = 3 * GEMM_UNROLLING_1;
+ } else if (m2_min >= 2 * GEMM_UNROLLING_1) {
+ m2_min = 2 * GEMM_UNROLLING_1;
+ } else if (m2_min > GEMM_UNROLLING_1) {
+ m2_min = GEMM_UNROLLING_1;
+ }
+
+ packing_A1(m2_min, k_min, A + mms * lda + ks, lda,
+ sa + k_min * (mms - ms));
+
+ HGEMM_KERNEL_1x4(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
+ C + mms * ldc, ldc);
+ }
+
+ for (ns = n_min; ns < N; ns += n_min) {
+ n_min = N - ns;
+ if (n_min >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (n_min > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
+ }
+
+ packing_B4(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+ HGEMM_KERNEL_1x4(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+ }
+ }
+ }
+
+ free(sa);
+ free(sb);
+}
+
void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K,
const __fp16 *A, unsigned int lda, const __fp16 *B,
unsigned int ldb, __fp16 *C, unsigned int ldc,
unsigned int ldb, float *C, unsigned int ldc,
float alpha = 1.F, float beta = 0.F);
+/**
+ * @brief hgemm noTrans computation with 1x4 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix C
+ * @param B input matrix B
+ * @param ldb length of the col of matrix C
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, __fp16 *C, unsigned int ldc,
+ float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 1x4 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix C
+ * @param B input matrix B
+ * @param ldb length of the col of matrix C
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, float *C, unsigned int ldc,
+ float alpha = 1.F, float beta = 0.F);
+
/**
* @brief hgemm noTrans computation with 4x4 kernel : C = A*B,
*
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Debadri Samaddar <s.debadri@samsung.com>
+ *
+ * @file hgemm_kernel_1x4.h
+ * @date 23 April 2024
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Debadri Samaddar <s.debadri@samsung.com>
+ * @bug No known bugs except for NYI items
+ * @brief This is half-precision GEMM 1x4 kernel
+ *
+ */
+
+#include <hgemm_common.h>
+#include <stdlib.h>
+
+/**
+ * @brief hgemm 1x4 kernel sc = sa * sb
+ *
+ * @param m length of the row of matrix A
+ * @param n length of the col of matrix B
+ * @param k length of the col of matrix A
+ * @param sa sub-matrix of input matrix A
+ * @param sb sub-matrix of input matrix B
+ * @param sc sub-matrix of output matrix C
+ * @param ldc leading dimension of matrix C
+ */
+void hgemm_kernel_1x4(unsigned int M, unsigned int N, unsigned int K,
+ __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) {
+ assert(M > 0 && N > 0 && K > 0);
+ assert(N % 4 == 0);
+
+ __fp16 *a = sa, *b = sb, *c = sc;
+ unsigned int i, j, l;
+ for (i = 0; i < M; i++) {
+ for (j = 0; j < N; j += VL_FP16_HALF) {
+ __builtin_prefetch(b, 0, 3);
+ __builtin_prefetch(a, 0, 3);
+
+ for (l = 0; l < K; l += VL_FP16_HALF) {
+ float16x4_t v24 = {0.F};
+ float16x4_t v0 = vld1_f16(b);
+ float16_t v16 = *a;
+
+ v24 = vfma_n_f16(v24, v0, v16);
+
+ float16x4_t v1 = vld1_f16(b + 4);
+ float16_t v17 = *(a + 1);
+
+ v24 = vfma_n_f16(v24, v1, v17);
+
+ float16x4_t v2 = vld1_f16(b + 8);
+ float16_t v18 = *(a + 2);
+
+ v24 = vfma_n_f16(v24, v2, v18);
+
+ float16x4_t v3 = vld1_f16(b + 12);
+ float16_t v19 = *(a + 3);
+
+ v24 = vfma_n_f16(v24, v3, v19);
+
+ __builtin_prefetch(b + 16, 0, 3);
+ __builtin_prefetch(a + 4, 0, 3);
+
+ b += 16;
+ a += 4;
+
+ v24 = vadd_f16(vld1_f16(c), v24);
+
+ vst1_f16(c, v24);
+ }
+ c += 4;
+ a -= K;
+ }
+ sc += ldc;
+ c = sc;
+ a += K;
+ b = sb;
+ }
+}
+
+/**
+ * @brief hgemm 1x4 kernel sc = sa * sb
+ *
+ * @param m length of the row of matrix A
+ * @param n length of the col of matrix B
+ * @param k length of the col of matrix A
+ * @param sa sub-matrix of input matrix A
+ * @param sb sub-matrix of input matrix B
+ * @param sc sub-matrix of output matrix C
+ * @param ldc leading dimension of matrix C
+ */
+void hgemm_kernel_1x4(unsigned int M, unsigned int N, unsigned int K,
+ __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
+ assert(M > 0 && N > 0 && K > 0);
+ assert(N % 4 == 0);
+
+ __fp16 *a = sa, *b = sb;
+ float *c = sc;
+ unsigned int i, j, l;
+ for (i = 0; i < M; i++) {
+ for (j = 0; j < N; j += VL_FP16_HALF) {
+ __builtin_prefetch(b, 0, 3);
+ __builtin_prefetch(a, 0, 3);
+
+ for (l = 0; l < K; l += VL_FP16_HALF) {
+ float16x4_t v24 = {0.F};
+ float16x4_t v0 = vld1_f16(b);
+ float16_t v16 = *a;
+
+ v24 = vfma_n_f16(v24, v0, v16);
+
+ float16x4_t v1 = vld1_f16(b + 4);
+ float16_t v17 = *(a + 1);
+
+ v24 = vfma_n_f16(v24, v1, v17);
+
+ float16x4_t v2 = vld1_f16(b + 8);
+ float16_t v18 = *(a + 2);
+
+ v24 = vfma_n_f16(v24, v2, v18);
+
+ float16x4_t v3 = vld1_f16(b + 12);
+ float16_t v19 = *(a + 3);
+
+ v24 = vfma_n_f16(v24, v3, v19);
+
+ __builtin_prefetch(b + 16, 0, 3);
+ __builtin_prefetch(a + 4, 0, 3);
+
+ b += 16;
+ a += 4;
+
+ vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(v24)));
+ }
+ c += 4;
+ a -= K;
+ }
+ sc += ldc;
+ c = sc;
+ a += K;
+ b = sb;
+ }
+}
EXPECT_IN_RANGE((float)cosSimNeon, 0.99, 1);
}
+TEST(nntrainer_Tensor, dot_gemm_50_768_516) {
+ /// @note GEMM : A X B = C
+ int batch = 1;
+ int channel = 1;
+ int height = 50;
+ int width = 768;
+
+ int height_b = 768;
+ int width_b = 516;
+
+ bool transA = false;
+ bool transB = false;
+
+ nntrainer::TensorDim::TensorType t_type_nchw_fp16 = {
+ nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16};
+
+ nntrainer::TensorDim::TensorType t_type_nchw_fp32 = {
+ nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32};
+
+ nntrainer::Tensor A(batch, channel, height, width, t_type_nchw_fp16);
+ nntrainer::Tensor B(batch, channel, height_b, width_b, t_type_nchw_fp16);
+
+ nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32);
+ nntrainer::Tensor B_fp32(batch, channel, height_b, width_b, t_type_nchw_fp32);
+
+ const float alpha = 1e-1;
+ const int MOD = 10;
+
+ GEN_TEST_INPUT(A, ((i * (batch * height * channel) + j * (batch * height) +
+ k * (width) + l + 1) %
+ MOD) *
+ alpha);
+ GEN_TEST_INPUT_B(B, ((i * (batch * height_b * channel) +
+ j * (batch * height_b) + k * (width_b) + l + 1) %
+ MOD) *
+ alpha);
+ GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) +
+ j * (batch * height) + k * (width) + l + 1) %
+ MOD) *
+ alpha);
+ GEN_TEST_INPUT_B(B_fp32, ((i * (batch * height_b * channel) +
+ j * (batch * height_b) + k * (width_b) + l + 1) %
+ MOD) *
+ alpha);
+
+ nntrainer::Tensor C = A.dot(B, transA, transB);
+
+ nntrainer::Tensor C_fp32 = A_fp32.dot(B_fp32, transA, transB);
+
+ float mseErrorNeon =
+ mse<__fp16>(C.getData<__fp16>(), C_fp32.getData<float>(), C.size());
+
+ double cosSimNeon = cosine_similarity<__fp16>(
+ C.getData<__fp16>(), C_fp32.getData<float>(), C.size());
+
+ const float epsilon = 1e-3 * width;
+
+ EXPECT_IN_RANGE(mseErrorNeon, 0, epsilon);
+ EXPECT_IN_RANGE((float)cosSimNeon, 0.99, 1);
+}
+
TEST(nntrainer_Tensor, dot_gemv_768_96000) {
/// @note GEMV : A X B = C
int batch = 1;