int i = 0;
for(; i <= (num_elems_vec_a - 4); i += 4)
{
- const float32x4_t a0 = vld1q_f32(&vec_a[i]);
+ const float32x2_t a0l = vld1_f32(&vec_a[i]);
+ const float32x2_t a0h = vld1_f32(&vec_a[i] + 2);
const float32x4_t b00 = vld1q_f32(&matrix_b[0 + (i + 0) * in_b_stride]);
const float32x4_t b01 = vld1q_f32(&matrix_b[4 + (i + 0) * in_b_stride]);
const float32x4_t b32 = vld1q_f32(&matrix_b[8 + (i + 3) * in_b_stride]);
const float32x4_t b33 = vld1q_f32(&matrix_b[12 + (i + 3) * in_b_stride]);
- acc0 = vmlaq_lane_f32(acc0, b00, vget_low_f32(a0), 0);
- acc1 = vmlaq_lane_f32(acc1, b01, vget_low_f32(a0), 0);
- acc2 = vmlaq_lane_f32(acc2, b02, vget_low_f32(a0), 0);
- acc3 = vmlaq_lane_f32(acc3, b03, vget_low_f32(a0), 0);
-
- acc0 = vmlaq_lane_f32(acc0, b10, vget_low_f32(a0), 1);
- acc1 = vmlaq_lane_f32(acc1, b11, vget_low_f32(a0), 1);
- acc2 = vmlaq_lane_f32(acc2, b12, vget_low_f32(a0), 1);
- acc3 = vmlaq_lane_f32(acc3, b13, vget_low_f32(a0), 1);
-
- acc0 = vmlaq_lane_f32(acc0, b20, vget_high_f32(a0), 0);
- acc1 = vmlaq_lane_f32(acc1, b21, vget_high_f32(a0), 0);
- acc2 = vmlaq_lane_f32(acc2, b22, vget_high_f32(a0), 0);
- acc3 = vmlaq_lane_f32(acc3, b23, vget_high_f32(a0), 0);
-
- acc0 = vmlaq_lane_f32(acc0, b30, vget_high_f32(a0), 1);
- acc1 = vmlaq_lane_f32(acc1, b31, vget_high_f32(a0), 1);
- acc2 = vmlaq_lane_f32(acc2, b32, vget_high_f32(a0), 1);
- acc3 = vmlaq_lane_f32(acc3, b33, vget_high_f32(a0), 1);
+ acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
+ acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
+ acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
+ acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
+
+ acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
+ acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
+ acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
+ acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
+
+ acc0 = vmlaq_lane_f32(acc0, b20, a0h, 0);
+ acc1 = vmlaq_lane_f32(acc1, b21, a0h, 0);
+ acc2 = vmlaq_lane_f32(acc2, b22, a0h, 0);
+ acc3 = vmlaq_lane_f32(acc3, b23, a0h, 0);
+
+ acc0 = vmlaq_lane_f32(acc0, b30, a0h, 1);
+ acc1 = vmlaq_lane_f32(acc1, b31, a0h, 1);
+ acc2 = vmlaq_lane_f32(acc2, b32, a0h, 1);
+ acc3 = vmlaq_lane_f32(acc3, b33, a0h, 1);
}
for(; i < num_elems_vec_a; i++)
void matrix_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
{
const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
- const size_t out_stride = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
+ const size_t out_stride1 = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
+ const size_t out_stride2 = out_stride1 * 2;
+ const size_t out_stride3 = out_stride1 * 3;
const int num_elems_matrix_b_x = input1->info()->dimension(0);
// Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
win_b = window;
}
// Set step_x and step_y for matrix B. Scale by a factor of 4 the X range as the input transposed matrix A has 4 times less the cols of the output matrix
- win_b.set(Window::DimX, Window::Dimension(window.x().start() / 4, window.x().end() / 4, in_b_stride));
+ // The step along the x direction is 4 times the in_b_stride because for each iteration we compute 4 blocks of size 4x4
+ win_b.set(Window::DimX, Window::Dimension(window.x().start() / 4, window.x().end() / 4, 4 * in_b_stride));
win_b.set(Window::DimY, Window::Dimension(0, 1, 0));
Iterator ina(input0, win_a);
Iterator inb(input1, win_b);
Iterator out(output, window);
+ // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
+ // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
+ // All the values needed for computing a single 4x4 block will be read from consecutive memory positions
execute_window_loop(window, [&](const Coordinates & id)
{
- const auto mtx_a0 = reinterpret_cast<const float *>(ina.ptr());
- const auto mtx_b0 = reinterpret_cast<const float *>(inb.ptr());
-
- float32x4_t acc0 = vdupq_n_f32(0.f);
- float32x4_t acc1 = vdupq_n_f32(0.f);
- float32x4_t acc2 = vdupq_n_f32(0.f);
- float32x4_t acc3 = vdupq_n_f32(0.f);
+ auto mtx_a0 = reinterpret_cast<const float *>(ina.ptr());
+ auto mtx_b0 = reinterpret_cast<const float *>(inb.ptr());
+ auto mtx_b1 = mtx_b0 + in_b_stride;
+ auto mtx_b2 = mtx_b1 + in_b_stride;
+ auto mtx_b3 = mtx_b2 + in_b_stride;
+
+ float32x4_t acc00 = vdupq_n_f32(0.f);
+ float32x4_t acc10 = vdupq_n_f32(0.f);
+ float32x4_t acc20 = vdupq_n_f32(0.f);
+ float32x4_t acc30 = vdupq_n_f32(0.f);
+
+ float32x4_t acc01 = vdupq_n_f32(0.f);
+ float32x4_t acc11 = vdupq_n_f32(0.f);
+ float32x4_t acc21 = vdupq_n_f32(0.f);
+ float32x4_t acc31 = vdupq_n_f32(0.f);
+
+ float32x4_t acc02 = vdupq_n_f32(0.f);
+ float32x4_t acc12 = vdupq_n_f32(0.f);
+ float32x4_t acc22 = vdupq_n_f32(0.f);
+ float32x4_t acc32 = vdupq_n_f32(0.f);
+
+ float32x4_t acc03 = vdupq_n_f32(0.f);
+ float32x4_t acc13 = vdupq_n_f32(0.f);
+ float32x4_t acc23 = vdupq_n_f32(0.f);
+ float32x4_t acc33 = vdupq_n_f32(0.f);
for(int k = 0; k < num_elems_matrix_b_x; k += 4)
{
- const float32x4_t a00 = vld1q_f32(mtx_a0 + k);
- const float32x4_t b00 = vld1q_f32(mtx_b0 + k);
-
- // Accumulation 0
- acc0 = vmlaq_lane_f32(acc0, b00, vget_low_f32(a00), 0);
- acc1 = vmlaq_lane_f32(acc1, b00, vget_low_f32(a00), 1);
- acc2 = vmlaq_lane_f32(acc2, b00, vget_high_f32(a00), 0);
- acc3 = vmlaq_lane_f32(acc3, b00, vget_high_f32(a00), 1);
+ const float32x4_t a = vld1q_f32(mtx_a0);
+ const float32x2_t a00l = vget_low_f32(a);
+ const float32x2_t a00h = vget_high_f32(a);
+ const float32x4_t b00 = vld1q_f32(mtx_b0);
+ const float32x4_t b10 = vld1q_f32(mtx_b1);
+ const float32x4_t b20 = vld1q_f32(mtx_b2);
+ const float32x4_t b30 = vld1q_f32(mtx_b3);
+
+ // 4x4 block 0
+ acc00 = vmlaq_lane_f32(acc00, b00, a00l, 0);
+ acc10 = vmlaq_lane_f32(acc10, b00, a00l, 1);
+ acc20 = vmlaq_lane_f32(acc20, b00, a00h, 0);
+ acc30 = vmlaq_lane_f32(acc30, b00, a00h, 1);
+
+ // 4x4 block 1
+ acc01 = vmlaq_lane_f32(acc01, b10, a00l, 0);
+ acc11 = vmlaq_lane_f32(acc11, b10, a00l, 1);
+ acc21 = vmlaq_lane_f32(acc21, b10, a00h, 0);
+ acc31 = vmlaq_lane_f32(acc31, b10, a00h, 1);
+
+ // 4x4 block 2
+ acc02 = vmlaq_lane_f32(acc02, b20, a00l, 0);
+ acc12 = vmlaq_lane_f32(acc12, b20, a00l, 1);
+ acc22 = vmlaq_lane_f32(acc22, b20, a00h, 0);
+ acc32 = vmlaq_lane_f32(acc32, b20, a00h, 1);
+
+ // 4x4 block 3
+ acc03 = vmlaq_lane_f32(acc03, b30, a00l, 0);
+ acc13 = vmlaq_lane_f32(acc13, b30, a00l, 1);
+ acc23 = vmlaq_lane_f32(acc23, b30, a00h, 0);
+ acc33 = vmlaq_lane_f32(acc33, b30, a00h, 1);
+
+ mtx_a0 += 4;
+ mtx_b0 += 4;
+ mtx_b1 += 4;
+ mtx_b2 += 4;
+ mtx_b3 += 4;
}
// Multiply by the weight of matrix product (alpha)
if(multiply_alpha)
{
const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
- acc0 = vmulq_f32(acc0, alpha_f32);
- acc1 = vmulq_f32(acc1, alpha_f32);
- acc2 = vmulq_f32(acc2, alpha_f32);
- acc3 = vmulq_f32(acc3, alpha_f32);
+ acc00 = vmulq_f32(acc00, alpha_f32);
+ acc10 = vmulq_f32(acc10, alpha_f32);
+ acc20 = vmulq_f32(acc20, alpha_f32);
+ acc30 = vmulq_f32(acc30, alpha_f32);
+ acc01 = vmulq_f32(acc01, alpha_f32);
+ acc11 = vmulq_f32(acc11, alpha_f32);
+ acc21 = vmulq_f32(acc21, alpha_f32);
+ acc31 = vmulq_f32(acc31, alpha_f32);
+ acc02 = vmulq_f32(acc02, alpha_f32);
+ acc12 = vmulq_f32(acc12, alpha_f32);
+ acc22 = vmulq_f32(acc22, alpha_f32);
+ acc32 = vmulq_f32(acc32, alpha_f32);
+ acc03 = vmulq_f32(acc03, alpha_f32);
+ acc13 = vmulq_f32(acc13, alpha_f32);
+ acc23 = vmulq_f32(acc23, alpha_f32);
+ acc33 = vmulq_f32(acc33, alpha_f32);
}
- const auto mtx_out = reinterpret_cast<float *>(out.ptr());
-
- vst1q_f32(mtx_out + 0 * out_stride, acc0);
- vst1q_f32(mtx_out + 1 * out_stride, acc1);
- vst1q_f32(mtx_out + 2 * out_stride, acc2);
- vst1q_f32(mtx_out + 3 * out_stride, acc3);
+ const auto mtx_out0 = reinterpret_cast<float *>(out.ptr());
+ const auto mtx_out1 = mtx_out0 + 4;
+ const auto mtx_out2 = mtx_out1 + 4;
+ const auto mtx_out3 = mtx_out2 + 4;
+
+ // Store the 4 blocks
+ vst1q_f32(mtx_out0, acc00);
+ vst1q_f32(mtx_out1, acc01);
+ vst1q_f32(mtx_out2, acc02);
+ vst1q_f32(mtx_out3, acc03);
+ vst1q_f32(mtx_out0 + out_stride1, acc10);
+ vst1q_f32(mtx_out1 + out_stride1, acc11);
+ vst1q_f32(mtx_out2 + out_stride1, acc12);
+ vst1q_f32(mtx_out3 + out_stride1, acc13);
+ vst1q_f32(mtx_out0 + out_stride2, acc20);
+ vst1q_f32(mtx_out1 + out_stride2, acc21);
+ vst1q_f32(mtx_out2 + out_stride2, acc22);
+ vst1q_f32(mtx_out3 + out_stride2, acc23);
+ vst1q_f32(mtx_out0 + out_stride3, acc30);
+ vst1q_f32(mtx_out1 + out_stride3, acc31);
+ vst1q_f32(mtx_out2 + out_stride3, acc32);
+ vst1q_f32(mtx_out3 + out_stride3, acc33);
},
ina, inb, out);
}
}
case DataType::F32:
{
- num_elems_processed_per_iteration_x = 4;
+ num_elems_processed_per_iteration_x = 16;
break;
}
default: