void sgemv_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y, uint32_t rows,
uint32_t cols, float alpha, float beta) {
const __fp16 *__restrict x;
- const float32x4_t v_beta_32 = vmovq_n_f32(beta);
float Y32[rows];
unsigned int idx = 0;
for (; rows - idx >= 8; idx += 8) {
float32x4_t y0_3 = vcvt_f32_f16(vld1_f16(&Y[idx]));
float32x4_t y4_7 = vcvt_f32_f16(vld1_f16(&Y[idx + 4]));
- y0_3 = vmulq_f32(y0_3, v_beta_32);
- y4_7 = vmulq_f32(y4_7, v_beta_32);
+ y0_3 = vmulq_n_f32(y0_3, beta);
+ y4_7 = vmulq_n_f32(y4_7, beta);
vst1q_f32(&Y32[idx], y0_3);
vst1q_f32(&Y32[idx + 4], y4_7);
}
for (; rows - idx >= 4; idx += 4) {
float32x4_t y0_3_32 = vcvt_f32_f16(vld1_f16(&Y[idx]));
- y0_3_32 = vmulq_f32(y0_3_32, v_beta_32);
+ y0_3_32 = vmulq_n_f32(y0_3_32, beta);
vst1q_f32(&Y32[idx], y0_3_32);
}
}
idx = 0;
- for (; cols - idx >= 64; idx += 64) {
+ for (; cols - idx >= 120; idx += 120) {
float16x8_t x0_7 = vld1q_f16(&X[idx]);
float16x8_t x8_15 = vld1q_f16(&X[idx + 8]);
float16x8_t x16_23 = vld1q_f16(&X[idx + 16]);
float16x8_t x48_55 = vld1q_f16(&X[idx + 48]);
float16x8_t x56_63 = vld1q_f16(&X[idx + 56]);
+ float16x8_t x64_71 = vld1q_f16(&X[idx + 64]);
+ float16x8_t x72_79 = vld1q_f16(&X[idx + 72]);
+ float16x8_t x80_87 = vld1q_f16(&X[idx + 80]);
+
+ float16x8_t x88_95 = vld1q_f16(&X[idx + 88]);
+ float16x8_t x96_103 = vld1q_f16(&X[idx + 96]);
+ float16x8_t x104_111 = vld1q_f16(&X[idx + 104]);
+ float16x8_t x112_120 = vld1q_f16(&X[idx + 112]);
+
if (alpha != 1.0) {
x0_7 = vmulq_n_f16(x0_7, alpha);
x8_15 = vmulq_n_f16(x8_15, alpha);
x40_47 = vmulq_n_f16(x40_47, alpha);
x48_55 = vmulq_n_f16(x48_55, alpha);
x56_63 = vmulq_n_f16(x56_63, alpha);
+
+ x64_71 = vmulq_n_f16(x64_71, alpha);
+ x72_79 = vmulq_n_f16(x72_79, alpha);
+ x80_87 = vmulq_n_f16(x80_87, alpha);
+ x88_95 = vmulq_n_f16(x88_95, alpha);
+ x96_103 = vmulq_n_f16(x96_103, alpha);
+ x104_111 = vmulq_n_f16(x104_111, alpha);
+ x112_120 = vmulq_n_f16(x112_120, alpha);
}
const __fp16 *__restrict w;
- float yVal_low;
- float yVal_high;
+ for (unsigned int j = 0; j < rows; ++j) {
+ w = &A[j * cols + idx];
+ float16x8_t y = vmulq_f16(vld1q_f16(&w[0]), x0_7);
+ y = vfmaq_f16(y, vld1q_f16(&w[8]), x8_15);
+ y = vfmaq_f16(y, vld1q_f16(&w[16]), x16_23);
+ y = vfmaq_f16(y, vld1q_f16(&w[24]), x24_31);
+
+ y = vfmaq_f16(y, vld1q_f16(&w[32]), x32_39);
+ y = vfmaq_f16(y, vld1q_f16(&w[40]), x40_47);
+ y = vfmaq_f16(y, vld1q_f16(&w[48]), x48_55);
+ y = vfmaq_f16(y, vld1q_f16(&w[56]), x56_63);
+
+ y = vfmaq_f16(y, vld1q_f16(&w[64]), x64_71);
+ y = vfmaq_f16(y, vld1q_f16(&w[72]), x72_79);
+ y = vfmaq_f16(y, vld1q_f16(&w[80]), x80_87);
+
+ y = vfmaq_f16(y, vld1q_f16(&w[88]), x88_95);
+ y = vfmaq_f16(y, vld1q_f16(&w[96]), x96_103);
+ y = vfmaq_f16(y, vld1q_f16(&w[104]), x104_111);
+ y = vfmaq_f16(y, vld1q_f16(&w[112]), x112_120);
+
+ Y32[j] += vaddvq_f32(vcvt_f32_f16(vget_low_f16(y))) +
+ vaddvq_f32(vcvt_f32_f16(vget_high_f16(y)));
+ }
+ }
+ for (; cols - idx >= 64; idx += 64) {
+ float16x8_t x0_7 = vld1q_f16(&X[idx]);
+ float16x8_t x8_15 = vld1q_f16(&X[idx + 8]);
+ float16x8_t x16_23 = vld1q_f16(&X[idx + 16]);
+ float16x8_t x24_31 = vld1q_f16(&X[idx + 24]);
+
+ float16x8_t x32_39 = vld1q_f16(&X[idx + 32]);
+ float16x8_t x40_47 = vld1q_f16(&X[idx + 40]);
+ float16x8_t x48_55 = vld1q_f16(&X[idx + 48]);
+ float16x8_t x56_63 = vld1q_f16(&X[idx + 56]);
+
+ if (alpha != 1.0) {
+ x0_7 = vmulq_n_f16(x0_7, alpha);
+ x8_15 = vmulq_n_f16(x8_15, alpha);
+ x16_23 = vmulq_n_f16(x16_23, alpha);
+ x24_31 = vmulq_n_f16(x24_31, alpha);
+ x32_39 = vmulq_n_f16(x32_39, alpha);
+ x40_47 = vmulq_n_f16(x40_47, alpha);
+ x48_55 = vmulq_n_f16(x48_55, alpha);
+ x56_63 = vmulq_n_f16(x56_63, alpha);
+ }
+
+ const __fp16 *__restrict w;
for (unsigned int j = 0; j < rows; ++j) {
w = &A[j * cols + idx];
y = vfmaq_f16(y, wvec48_55, x48_55);
y = vfmaq_f16(y, wvec56_63, x56_63);
- yVal_low = vaddvq_f32(vcvt_f32_f16(vget_low_f16(y)));
- yVal_high = vaddvq_f32(vcvt_f32_f16(vget_high_f16(y)));
-
- Y32[j] += yVal_low + yVal_high;
+ Y32[j] += vaddvq_f32(vcvt_f32_f16(vget_low_f16(y))) +
+ vaddvq_f32(vcvt_f32_f16(vget_high_f16(y)));
}
}
for (; cols - idx >= 32; idx += 32) {
const __fp16 *__restrict w;
- float yVal_low;
- float yVal_high;
-
for (unsigned int j = 0; j < rows; ++j) {
w = &A[j * cols + idx];
float16x8_t wvec0_7 = vld1q_f16(&w[0]);
y = vfmaq_f16(y, wvec16_23, x16_23);
y = vfmaq_f16(y, wvec24_31, x24_31);
- yVal_low = vaddvq_f32(vcvt_f32_f16(vget_low_f16(y)));
- yVal_high = vaddvq_f32(vcvt_f32_f16(vget_high_f16(y)));
-
- Y32[j] += yVal_low + yVal_high;
+ Y32[j] += vaddvq_f32(vcvt_f32_f16(vget_low_f16(y))) +
+ vaddvq_f32(vcvt_f32_f16(vget_high_f16(y)));
}
}
for (; cols - idx >= 16; idx += 16) {
}
const __fp16 *__restrict w;
- float yVal_low;
- float yVal_high;
for (unsigned int j = 0; j < rows; ++j) {
w = &A[j * cols + idx];
float16x8_t wvec0_7 = vld1q_f16(&w[0]);
float16x8_t y = vmulq_f16(wvec0_7, x0_7);
y = vfmaq_f16(y, wvec8_15, x8_15);
- yVal_low = vaddvq_f32(vcvt_f32_f16(vget_low_f16(y)));
- yVal_high = vaddvq_f32(vcvt_f32_f16(vget_high_f16(y)));
-
- Y32[j] += yVal_low + yVal_high;
+ Y32[j] += vaddvq_f32(vcvt_f32_f16(vget_low_f16(y))) +
+ vaddvq_f32(vcvt_f32_f16(vget_high_f16(y)));
}
}
for (; cols - idx >= 8; idx += 8) {
const __fp16 *__restrict w;
- float yVal_low;
- float yVal_high;
-
for (unsigned int j = 0; j < rows; ++j) {
w = &A[j * cols + idx];
float16x8_t wvec0_7 = vld1q_f16(&w[0]);
float16x8_t y = vmulq_f16(wvec0_7, x0_7);
- yVal_low = vaddvq_f32(vcvt_f32_f16(vget_low_f16(y)));
- yVal_high = vaddvq_f32(vcvt_f32_f16(vget_high_f16(y)));
-
- Y32[j] += yVal_low + yVal_high;
+ Y32[j] += vaddvq_f32(vcvt_f32_f16(vget_low_f16(y))) +
+ vaddvq_f32(vcvt_f32_f16(vget_high_f16(y)));
}
}
for (; cols - idx >= 4; idx += 4) {
vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 10) * cols + idx]), x11);
w2vec0_7_f16 =
vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 11) * cols + idx]), x12);
- w2vec0_7_f16 =
+ w2vec0_7_f16 =
vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 12) * cols + idx]), x13);
- w2vec0_7_f16 =
+ w2vec0_7_f16 =
vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 13) * cols + idx]), x14);
- w2vec0_7_f16 =
+ w2vec0_7_f16 =
vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 14) * cols + idx]), x15);
- w2vec0_7_f16 =
+ w2vec0_7_f16 =
vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 15) * cols + idx]), x16);
-
float32x4_t y0_3 = vaddq_f32(vld1q_f32(&Y32[idx]),
vcvt_f32_f16(vget_low_f16(wvec0_7_f16)));
y0_3 = vaddq_f32(y0_3, vcvt_f32_f16(vget_low_f16(w2vec0_7_f16)));
vst1q_f32(&Y32[idx], y0_3);
}
}
- } else
- if (rows % 8 == 0) {
+ } else if (rows % 8 == 0) {
for (unsigned int i = 0; i < rows; i += 8) {
__fp16 x = alpha * (X[i]);
__fp16 x2 = alpha * (X[i + 1]);