arm_compute v17.04
[platform/upstream/armcl.git] / src / core / NEON / kernels / NEGEMMMatrixMultiplyKernel.cpp
index 3a51cc2..46430fc 100644 (file)
@@ -99,7 +99,8 @@ void vector_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, IT
         int i = 0;
         for(; i <= (num_elems_vec_a - 4); i += 4)
         {
-            const float32x4_t a0 = vld1q_f32(&vec_a[i]);
+            const float32x2_t a0l = vld1_f32(&vec_a[i]);
+            const float32x2_t a0h = vld1_f32(&vec_a[i] + 2);
 
             const float32x4_t b00 = vld1q_f32(&matrix_b[0 + (i + 0) * in_b_stride]);
             const float32x4_t b01 = vld1q_f32(&matrix_b[4 + (i + 0) * in_b_stride]);
@@ -121,25 +122,25 @@ void vector_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, IT
             const float32x4_t b32 = vld1q_f32(&matrix_b[8 + (i + 3) * in_b_stride]);
             const float32x4_t b33 = vld1q_f32(&matrix_b[12 + (i + 3) * in_b_stride]);
 
-            acc0 = vmlaq_lane_f32(acc0, b00, vget_low_f32(a0), 0);
-            acc1 = vmlaq_lane_f32(acc1, b01, vget_low_f32(a0), 0);
-            acc2 = vmlaq_lane_f32(acc2, b02, vget_low_f32(a0), 0);
-            acc3 = vmlaq_lane_f32(acc3, b03, vget_low_f32(a0), 0);
-
-            acc0 = vmlaq_lane_f32(acc0, b10, vget_low_f32(a0), 1);
-            acc1 = vmlaq_lane_f32(acc1, b11, vget_low_f32(a0), 1);
-            acc2 = vmlaq_lane_f32(acc2, b12, vget_low_f32(a0), 1);
-            acc3 = vmlaq_lane_f32(acc3, b13, vget_low_f32(a0), 1);
-
-            acc0 = vmlaq_lane_f32(acc0, b20, vget_high_f32(a0), 0);
-            acc1 = vmlaq_lane_f32(acc1, b21, vget_high_f32(a0), 0);
-            acc2 = vmlaq_lane_f32(acc2, b22, vget_high_f32(a0), 0);
-            acc3 = vmlaq_lane_f32(acc3, b23, vget_high_f32(a0), 0);
-
-            acc0 = vmlaq_lane_f32(acc0, b30, vget_high_f32(a0), 1);
-            acc1 = vmlaq_lane_f32(acc1, b31, vget_high_f32(a0), 1);
-            acc2 = vmlaq_lane_f32(acc2, b32, vget_high_f32(a0), 1);
-            acc3 = vmlaq_lane_f32(acc3, b33, vget_high_f32(a0), 1);
+            acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
+            acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
+            acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
+            acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
+
+            acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
+            acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
+            acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
+            acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
+
+            acc0 = vmlaq_lane_f32(acc0, b20, a0h, 0);
+            acc1 = vmlaq_lane_f32(acc1, b21, a0h, 0);
+            acc2 = vmlaq_lane_f32(acc2, b22, a0h, 0);
+            acc3 = vmlaq_lane_f32(acc3, b23, a0h, 0);
+
+            acc0 = vmlaq_lane_f32(acc0, b30, a0h, 1);
+            acc1 = vmlaq_lane_f32(acc1, b31, a0h, 1);
+            acc2 = vmlaq_lane_f32(acc2, b32, a0h, 1);
+            acc3 = vmlaq_lane_f32(acc3, b33, a0h, 1);
         }
 
         for(; i < num_elems_vec_a; i++)
@@ -181,7 +182,9 @@ template <bool multiply_alpha>
 void matrix_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
 {
     const size_t in_b_stride          = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
-    const size_t out_stride           = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
+    const size_t out_stride1          = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
+    const size_t out_stride2          = out_stride1 * 2;
+    const size_t out_stride3          = out_stride1 * 3;
     const int    num_elems_matrix_b_x = input1->info()->dimension(0);
 
     // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
@@ -197,51 +200,130 @@ void matrix_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, IT
         win_b = window;
     }
     // Set step_x and step_y for matrix B. Scale by a factor of 4 the X range as the input transposed matrix A has 4 times less the cols of the output matrix
-    win_b.set(Window::DimX, Window::Dimension(window.x().start() / 4, window.x().end() / 4, in_b_stride));
+    // The step along the x direction is 4 times the in_b_stride because for each iteration we compute 4 blocks of size 4x4
+    win_b.set(Window::DimX, Window::Dimension(window.x().start() / 4, window.x().end() / 4, 4 * in_b_stride));
     win_b.set(Window::DimY, Window::Dimension(0, 1, 0));
 
     Iterator ina(input0, win_a);
     Iterator inb(input1, win_b);
     Iterator out(output, window);
 
+    // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
+    // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
+    // All the values needed for computing a single 4x4 block will be read from consecutive memory positions
     execute_window_loop(window, [&](const Coordinates & id)
     {
-        const auto mtx_a0 = reinterpret_cast<const float *>(ina.ptr());
-        const auto mtx_b0 = reinterpret_cast<const float *>(inb.ptr());
-
-        float32x4_t acc0 = vdupq_n_f32(0.f);
-        float32x4_t acc1 = vdupq_n_f32(0.f);
-        float32x4_t acc2 = vdupq_n_f32(0.f);
-        float32x4_t acc3 = vdupq_n_f32(0.f);
+        auto mtx_a0 = reinterpret_cast<const float *>(ina.ptr());
+        auto mtx_b0 = reinterpret_cast<const float *>(inb.ptr());
+        auto mtx_b1 = mtx_b0 + in_b_stride;
+        auto mtx_b2 = mtx_b1 + in_b_stride;
+        auto mtx_b3 = mtx_b2 + in_b_stride;
+
+        float32x4_t acc00 = vdupq_n_f32(0.f);
+        float32x4_t acc10 = vdupq_n_f32(0.f);
+        float32x4_t acc20 = vdupq_n_f32(0.f);
+        float32x4_t acc30 = vdupq_n_f32(0.f);
+
+        float32x4_t acc01 = vdupq_n_f32(0.f);
+        float32x4_t acc11 = vdupq_n_f32(0.f);
+        float32x4_t acc21 = vdupq_n_f32(0.f);
+        float32x4_t acc31 = vdupq_n_f32(0.f);
+
+        float32x4_t acc02 = vdupq_n_f32(0.f);
+        float32x4_t acc12 = vdupq_n_f32(0.f);
+        float32x4_t acc22 = vdupq_n_f32(0.f);
+        float32x4_t acc32 = vdupq_n_f32(0.f);
+
+        float32x4_t acc03 = vdupq_n_f32(0.f);
+        float32x4_t acc13 = vdupq_n_f32(0.f);
+        float32x4_t acc23 = vdupq_n_f32(0.f);
+        float32x4_t acc33 = vdupq_n_f32(0.f);
 
         for(int k = 0; k < num_elems_matrix_b_x; k += 4)
         {
-            const float32x4_t a00 = vld1q_f32(mtx_a0 + k);
-            const float32x4_t b00 = vld1q_f32(mtx_b0 + k);
-
-            // Accumulation 0
-            acc0 = vmlaq_lane_f32(acc0, b00, vget_low_f32(a00), 0);
-            acc1 = vmlaq_lane_f32(acc1, b00, vget_low_f32(a00), 1);
-            acc2 = vmlaq_lane_f32(acc2, b00, vget_high_f32(a00), 0);
-            acc3 = vmlaq_lane_f32(acc3, b00, vget_high_f32(a00), 1);
+            const float32x4_t a    = vld1q_f32(mtx_a0);
+            const float32x2_t a00l = vget_low_f32(a);
+            const float32x2_t a00h = vget_high_f32(a);
+            const float32x4_t b00  = vld1q_f32(mtx_b0);
+            const float32x4_t b10  = vld1q_f32(mtx_b1);
+            const float32x4_t b20  = vld1q_f32(mtx_b2);
+            const float32x4_t b30  = vld1q_f32(mtx_b3);
+
+            // 4x4 block 0
+            acc00 = vmlaq_lane_f32(acc00, b00, a00l, 0);
+            acc10 = vmlaq_lane_f32(acc10, b00, a00l, 1);
+            acc20 = vmlaq_lane_f32(acc20, b00, a00h, 0);
+            acc30 = vmlaq_lane_f32(acc30, b00, a00h, 1);
+
+            // 4x4 block 1
+            acc01 = vmlaq_lane_f32(acc01, b10, a00l, 0);
+            acc11 = vmlaq_lane_f32(acc11, b10, a00l, 1);
+            acc21 = vmlaq_lane_f32(acc21, b10, a00h, 0);
+            acc31 = vmlaq_lane_f32(acc31, b10, a00h, 1);
+
+            // 4x4 block 2
+            acc02 = vmlaq_lane_f32(acc02, b20, a00l, 0);
+            acc12 = vmlaq_lane_f32(acc12, b20, a00l, 1);
+            acc22 = vmlaq_lane_f32(acc22, b20, a00h, 0);
+            acc32 = vmlaq_lane_f32(acc32, b20, a00h, 1);
+
+            // 4x4 block 3
+            acc03 = vmlaq_lane_f32(acc03, b30, a00l, 0);
+            acc13 = vmlaq_lane_f32(acc13, b30, a00l, 1);
+            acc23 = vmlaq_lane_f32(acc23, b30, a00h, 0);
+            acc33 = vmlaq_lane_f32(acc33, b30, a00h, 1);
+
+            mtx_a0 += 4;
+            mtx_b0 += 4;
+            mtx_b1 += 4;
+            mtx_b2 += 4;
+            mtx_b3 += 4;
         }
 
         // Multiply by the weight of matrix product (alpha)
         if(multiply_alpha)
         {
             const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
-            acc0                        = vmulq_f32(acc0, alpha_f32);
-            acc1                        = vmulq_f32(acc1, alpha_f32);
-            acc2                        = vmulq_f32(acc2, alpha_f32);
-            acc3                        = vmulq_f32(acc3, alpha_f32);
+            acc00                       = vmulq_f32(acc00, alpha_f32);
+            acc10                       = vmulq_f32(acc10, alpha_f32);
+            acc20                       = vmulq_f32(acc20, alpha_f32);
+            acc30                       = vmulq_f32(acc30, alpha_f32);
+            acc01                       = vmulq_f32(acc01, alpha_f32);
+            acc11                       = vmulq_f32(acc11, alpha_f32);
+            acc21                       = vmulq_f32(acc21, alpha_f32);
+            acc31                       = vmulq_f32(acc31, alpha_f32);
+            acc02                       = vmulq_f32(acc02, alpha_f32);
+            acc12                       = vmulq_f32(acc12, alpha_f32);
+            acc22                       = vmulq_f32(acc22, alpha_f32);
+            acc32                       = vmulq_f32(acc32, alpha_f32);
+            acc03                       = vmulq_f32(acc03, alpha_f32);
+            acc13                       = vmulq_f32(acc13, alpha_f32);
+            acc23                       = vmulq_f32(acc23, alpha_f32);
+            acc33                       = vmulq_f32(acc33, alpha_f32);
         }
 
-        const auto mtx_out = reinterpret_cast<float *>(out.ptr());
-
-        vst1q_f32(mtx_out + 0 * out_stride, acc0);
-        vst1q_f32(mtx_out + 1 * out_stride, acc1);
-        vst1q_f32(mtx_out + 2 * out_stride, acc2);
-        vst1q_f32(mtx_out + 3 * out_stride, acc3);
+        const auto mtx_out0 = reinterpret_cast<float *>(out.ptr());
+        const auto mtx_out1 = mtx_out0 + 4;
+        const auto mtx_out2 = mtx_out1 + 4;
+        const auto mtx_out3 = mtx_out2 + 4;
+
+        // Store the 4 blocks
+        vst1q_f32(mtx_out0, acc00);
+        vst1q_f32(mtx_out1, acc01);
+        vst1q_f32(mtx_out2, acc02);
+        vst1q_f32(mtx_out3, acc03);
+        vst1q_f32(mtx_out0 + out_stride1, acc10);
+        vst1q_f32(mtx_out1 + out_stride1, acc11);
+        vst1q_f32(mtx_out2 + out_stride1, acc12);
+        vst1q_f32(mtx_out3 + out_stride1, acc13);
+        vst1q_f32(mtx_out0 + out_stride2, acc20);
+        vst1q_f32(mtx_out1 + out_stride2, acc21);
+        vst1q_f32(mtx_out2 + out_stride2, acc22);
+        vst1q_f32(mtx_out3 + out_stride2, acc23);
+        vst1q_f32(mtx_out0 + out_stride3, acc30);
+        vst1q_f32(mtx_out1 + out_stride3, acc31);
+        vst1q_f32(mtx_out2 + out_stride3, acc32);
+        vst1q_f32(mtx_out3 + out_stride3, acc33);
     },
     ina, inb, out);
 }
@@ -425,7 +507,7 @@ void NEGEMMMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor
             }
             case DataType::F32:
             {
-                num_elems_processed_per_iteration_x = 4;
+                num_elems_processed_per_iteration_x = 16;
                 break;
             }
             default: