const unsigned int ldc) {
#ifdef USE__FP16
- if ((M % 8 == 0) && (N % 8 == 0) && (K % 8 == 0)) {
+ if ((N % 8 == 0) && (K % 8 == 0)) {
nntrainer::neon::sgemm_neon_fp16(A, B, C, M, N, K, alpha, beta,
TransA == CblasTrans,
TransB == CblasTrans);
float16x8_t v_beta = vmovq_n_f16(beta);
// performing beta*C
- for (unsigned int idx = 0; idx < (M * N); idx += 8) {
+ unsigned int idx = 0;
+ unsigned int size = M * N;
+ for (; idx < (size - idx) >= 8; idx += 8) {
float16x8_t c = vld1q_f16(&C[idx]);
c = vmulq_f16(v_beta, c);
vst1q_f16(&C[idx], c);
}
+ // remaining values if dimensions not a multiple of 8
+ for (; idx < size; idx++) {
+ C[idx] *= beta;
+ }
+
__fp16 r[4];
if (!TransA && TransB) {
for (unsigned int k = 0; k < K; k++) {
__fp16 b = alpha * B[n * K + k];
- for (unsigned int m = 0; m < M; m += 8) {
+ unsigned int m = 0;
+ for (; (M - m) >= 8; m += 8) {
float16x8_t a = vld1q_f16(&A[k * M + m]);
a = vmulq_n_f16(a, b);
vst1q_f16(vals, a);
for (unsigned int idx = m; idx < m + 8; idx++)
C[idx * N + n] += vals[idx - m];
}
+
+ // remaining when M is not a multiple of 8
+ if (m < M) {
+ for (idx = m; idx < M; idx++) {
+ vals[idx - m] = A[k * M + idx];
+ }
+
+ float16x8_t a = vld1q_f16(vals);
+ a = vmulq_n_f16(a, b);
+ vst1q_f16(vals, a);
+
+ // calculations for all remaining M values
+ for (idx = m; idx < M; idx++)
+ C[idx * N + n] += vals[idx - m];
+ }
}
}
}