Rewrite a fast GEMV path for two goals:
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 9 Apr 2018 17:59:46 +0000 (10:59 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 9 Apr 2018 18:02:10 +0000 (11:02 -0700)
1. Avoid cache aliasing issues on CPUs with 4-way set associative L1 cache.
   That includes Cortex-A53.
2. Be a good basis to port to assembly.

PiperOrigin-RevId: 192152277

tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h

index 9a27461..5acf1ea 100644 (file)
@@ -554,88 +554,261 @@ inline void GEMVForLstmCellWithSymmetricRange(
   // exercises it). We just guard our assumptions about size evenness with
   // the following assertions.
   TFLITE_DCHECK(!(output_size % 4));
-  TFLITE_DCHECK(!(input_size % 8));
+  TFLITE_DCHECK(!(input_size % 64));
   const int32* bias_ptr = bias_data;
   int16* output_ptr = output_data;
   const uint8x16_t signbit = vdupq_n_u8(0x80);
   for (int in = 0; in < input_size; in += 32) {
     optimized_ops_preload_l1_keep(input_data + in);
   }
+  const int left_shift = accum_shift > 0 ? accum_shift : 0;
+  const int right_shift = accum_shift > 0 ? 0 : -accum_shift;
   for (int out = 0; out < output_size; out += 4) {
-    const uint8* weights_ptr_0 = weights_data + out * input_size;
-    const uint8* weights_ptr_1 = weights_ptr_0 + 1 * input_size;
-    const uint8* weights_ptr_2 = weights_ptr_0 + 2 * input_size;
-    const uint8* weights_ptr_3 = weights_ptr_0 + 3 * input_size;
+    // Load the bias values
+    int32x4_t bias_vec = vld1q_s32(bias_ptr);
+    bias_ptr += 4;
 
-    int32x4_t acc_0 = vdupq_n_s32(0);
-    int32x4_t acc_1 = vdupq_n_s32(0);
-    int32x4_t acc_2 = vdupq_n_s32(0);
-    int32x4_t acc_3 = vdupq_n_s32(0);
-    int in = 0;
-    const int kReadAhead = 256;
-    // Handle 16 levels of depth at a time.
-    for (; in < input_size; in += 16) {
-      int8x16_t weights_val_0 =
-          vreinterpretq_s8_u8(veorq_u8(signbit, vld1q_u8(weights_ptr_0)));
-      int8x16_t weights_val_1 =
-          vreinterpretq_s8_u8(veorq_u8(signbit, vld1q_u8(weights_ptr_1)));
-      int8x16_t weights_val_2 =
-          vreinterpretq_s8_u8(veorq_u8(signbit, vld1q_u8(weights_ptr_2)));
-      int8x16_t weights_val_3 =
-          vreinterpretq_s8_u8(veorq_u8(signbit, vld1q_u8(weights_ptr_3)));
-      int8x16_t input_val =
-          vreinterpretq_s8_u8(veorq_u8(signbit, vld1q_u8(input_data + in)));
-      int16x8_t acc16_0 =
-          vmull_s8(vget_low_s8(weights_val_0), vget_low_s8(input_val));
-      int16x8_t acc16_1 =
-          vmull_s8(vget_low_s8(weights_val_1), vget_low_s8(input_val));
-      int16x8_t acc16_2 =
-          vmull_s8(vget_low_s8(weights_val_2), vget_low_s8(input_val));
-      int16x8_t acc16_3 =
-          vmull_s8(vget_low_s8(weights_val_3), vget_low_s8(input_val));
-      acc16_0 = vmlal_s8(acc16_0, vget_high_s8(weights_val_0),
-                         vget_high_s8(input_val));
-      acc16_1 = vmlal_s8(acc16_1, vget_high_s8(weights_val_1),
-                         vget_high_s8(input_val));
-      acc16_2 = vmlal_s8(acc16_2, vget_high_s8(weights_val_2),
-                         vget_high_s8(input_val));
-      acc16_3 = vmlal_s8(acc16_3, vget_high_s8(weights_val_3),
-                         vget_high_s8(input_val));
-      acc_0 = vpadalq_s16(acc_0, acc16_0);
-      acc_1 = vpadalq_s16(acc_1, acc16_1);
-      acc_2 = vpadalq_s16(acc_2, acc16_2);
-      acc_3 = vpadalq_s16(acc_3, acc16_3);
-      weights_ptr_0 += 16;
-      weights_ptr_1 += 16;
-      weights_ptr_2 += 16;
-      weights_ptr_3 += 16;
-      optimized_ops_preload_l1_stream(weights_ptr_0 + kReadAhead);
-      optimized_ops_preload_l1_stream(weights_ptr_1 + kReadAhead);
-      optimized_ops_preload_l1_stream(weights_ptr_2 + kReadAhead);
-      optimized_ops_preload_l1_stream(weights_ptr_3 + kReadAhead);
+    // Clear accumulators. We use 2 accumulator registers per row,
+    // for 4 rows. row_accumRN is the N-th accumulator for row R.
+    int32x4_t row_accum00 = vdupq_n_s32(0);
+    int32x4_t row_accum01 = vdupq_n_s32(0);
+    int32x4_t row_accum10 = vdupq_n_s32(0);
+    int32x4_t row_accum11 = vdupq_n_s32(0);
+    int32x4_t row_accum20 = vdupq_n_s32(0);
+    int32x4_t row_accum21 = vdupq_n_s32(0);
+    int32x4_t row_accum30 = vdupq_n_s32(0);
+    int32x4_t row_accum31 = vdupq_n_s32(0);
+
+    // kReadAhead parametrizes how far ahead we prefetch weights into L1 cache.
+    const int kReadAhead = 512;
+    // Prefetch the first weights values.
+    for (int k = 0; k < kReadAhead; k += 64) {
+      optimized_ops_preload_l1_stream(weights_data + (out + 0) * input_size +
+                                      k);
+      optimized_ops_preload_l1_stream(weights_data + (out + 1) * input_size +
+                                      k);
+      optimized_ops_preload_l1_stream(weights_data + (out + 2) * input_size +
+                                      k);
+      optimized_ops_preload_l1_stream(weights_data + (out + 3) * input_size +
+                                      k);
+    }
+    // Loop along the rows, handling 64 bytes per iteration because that's
+    // cache line size on most current ARM-architecture CPUs.
+    for (int in = 0; in < input_size; in += 64) {
+      // Prefetch some future weights values.
+      optimized_ops_preload_l1_stream(weights_data + (out + 0) * input_size +
+                                      in + kReadAhead);
+      optimized_ops_preload_l1_stream(weights_data + (out + 1) * input_size +
+                                      in + kReadAhead);
+      optimized_ops_preload_l1_stream(weights_data + (out + 2) * input_size +
+                                      in + kReadAhead);
+      optimized_ops_preload_l1_stream(weights_data + (out + 3) * input_size +
+                                      in + kReadAhead);
+
+      // We will use 2 local 16-bit accumulators per row, for 2 rows.
+      // See below (*) for the rationale of processing only 2 rows at a time.
+      // local_accumRN is the N-th local accumulator for row R.
+      int16x8_t local_accum00;
+      int16x8_t local_accum01;
+      int16x8_t local_accum10;
+      int16x8_t local_accum11;
+
+      // Load 64 bytes of input activations values. Convert to signed int8
+      // by flipping the sign bit (i.e. subtracting 128, the required
+      // zero_point value).
+      int8x16_t input0 = vreinterpretq_s8_u8(
+          veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 0)));
+      int8x16_t input1 = vreinterpretq_s8_u8(
+          veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 1)));
+      int8x16_t input2 = vreinterpretq_s8_u8(
+          veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 2)));
+      int8x16_t input3 = vreinterpretq_s8_u8(
+          veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 3)));
+
+      // Beginning of the core accumulation. Notice how while we have 4
+      // rows to process, this code is taking care of only 2 rows at a time,
+      // thus being divided into two parts looking similar ("Rows 0 and 1" and
+      // "Rows 2 and 3").
+      //
+      // (*) The rationale for handling only 2 rows at a time is to avoid
+      // cache aliasing issues on 4-way set-associative L1-cache CPUs, such
+      // as Cortex-A53. With sufficiently large, power-of-two matrix dimensions,
+      // we may find ourselves in a situation where rows alias each other in
+      // the L1 cache, and moreover may also mutually alias with the input
+      // activations. If we try to load 4 rows at a time, together with the
+      // input activations, that may be 5 mutually-aliasing vectors, resulting
+      // in constant mutual eviction from L1 cache. Handling 2 rows at a time
+      // here largely mitigates these issues, and seems at least to be very
+      // effective on Cortex-A53:
+      //                          Before       After
+      // big (Cortex-A73)         2.85 ms      2.85 ms
+      // little (Cortex-A53)      11.0 ms      5.16 ms
+
+      // Rows 0 and 1:
+      // Load 64 bytes of weights values from each row. Convert to signed int8
+      // by flipping the sign bit (i.e. subtracting 128, the required
+      // zero_point value).
+      int8x16_t weights00 = vreinterpretq_s8_u8(veorq_u8(
+          signbit,
+          vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 0)));
+      int8x16_t weights01 = vreinterpretq_s8_u8(veorq_u8(
+          signbit,
+          vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 1)));
+      int8x16_t weights02 = vreinterpretq_s8_u8(veorq_u8(
+          signbit,
+          vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 2)));
+      int8x16_t weights03 = vreinterpretq_s8_u8(veorq_u8(
+          signbit,
+          vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 3)));
+      int8x16_t weights10 = vreinterpretq_s8_u8(veorq_u8(
+          signbit,
+          vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 0)));
+      int8x16_t weights11 = vreinterpretq_s8_u8(veorq_u8(
+          signbit,
+          vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 1)));
+      int8x16_t weights12 = vreinterpretq_s8_u8(veorq_u8(
+          signbit,
+          vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 2)));
+      int8x16_t weights13 = vreinterpretq_s8_u8(veorq_u8(
+          signbit,
+          vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 3)));
+      // Multiply-accumulate into local 16-bit accumulators.
+      // We can accumulate two products without overflow because weights are
+      // required to never be -128, so each product is at most 127^2 in absolute
+      // value.
+      local_accum00 = vmull_s8(vget_low_s8(weights00), vget_low_s8(input0));
+      local_accum01 = vmull_s8(vget_low_s8(weights01), vget_low_s8(input1));
+      local_accum10 = vmull_s8(vget_low_s8(weights10), vget_low_s8(input0));
+      local_accum11 = vmull_s8(vget_low_s8(weights11), vget_low_s8(input1));
+      local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights00),
+                               vget_high_s8(input0));
+      local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights01),
+                               vget_high_s8(input1));
+      local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights10),
+                               vget_high_s8(input0));
+      local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights11),
+                               vget_high_s8(input1));
+      // Pairwise add and accumulate into 32-bit accumulators
+      row_accum00 = vpadalq_s16(row_accum00, local_accum00);
+      row_accum01 = vpadalq_s16(row_accum01, local_accum01);
+      row_accum10 = vpadalq_s16(row_accum10, local_accum10);
+      row_accum11 = vpadalq_s16(row_accum11, local_accum11);
+      // Multiply-accumulate into local 16-bit accumulators.
+      // We can accumulate two products without overflow because weights are
+      // required to never be -128, so each product is at most 127^2 in absolute
+      // value.
+      local_accum00 = vmull_s8(vget_low_s8(weights02), vget_low_s8(input2));
+      local_accum01 = vmull_s8(vget_low_s8(weights03), vget_low_s8(input3));
+      local_accum10 = vmull_s8(vget_low_s8(weights12), vget_low_s8(input2));
+      local_accum11 = vmull_s8(vget_low_s8(weights13), vget_low_s8(input3));
+      local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights02),
+                               vget_high_s8(input2));
+      local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights03),
+                               vget_high_s8(input3));
+      local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights12),
+                               vget_high_s8(input2));
+      local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights13),
+                               vget_high_s8(input3));
+      // Pairwise add and accumulate into 32-bit accumulators
+      row_accum00 = vpadalq_s16(row_accum00, local_accum00);
+      row_accum01 = vpadalq_s16(row_accum01, local_accum01);
+      row_accum10 = vpadalq_s16(row_accum10, local_accum10);
+      row_accum11 = vpadalq_s16(row_accum11, local_accum11);
+
+      // Rows 2 and 3:
+      // Load 64 bytes of weights values from each row. Convert to signed int8
+      // by flipping the sign bit (i.e. subtracting 128, the required
+      // zero_point value).
+      weights00 = vreinterpretq_s8_u8(veorq_u8(
+          signbit,
+          vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 0)));
+      weights01 = vreinterpretq_s8_u8(veorq_u8(
+          signbit,
+          vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 1)));
+      weights02 = vreinterpretq_s8_u8(veorq_u8(
+          signbit,
+          vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 2)));
+      weights03 = vreinterpretq_s8_u8(veorq_u8(
+          signbit,
+          vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 3)));
+      weights10 = vreinterpretq_s8_u8(veorq_u8(
+          signbit,
+          vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 0)));
+      weights11 = vreinterpretq_s8_u8(veorq_u8(
+          signbit,
+          vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 1)));
+      weights12 = vreinterpretq_s8_u8(veorq_u8(
+          signbit,
+          vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 2)));
+      weights13 = vreinterpretq_s8_u8(veorq_u8(
+          signbit,
+          vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 3)));
+      // Multiply-accumulate into local 16-bit accumulators.
+      // We can accumulate two products without overflow because weights are
+      // required to never be -128, so each product is at most 127^2 in absolute
+      // value.
+      local_accum00 = vmull_s8(vget_low_s8(weights00), vget_low_s8(input0));
+      local_accum01 = vmull_s8(vget_low_s8(weights01), vget_low_s8(input1));
+      local_accum10 = vmull_s8(vget_low_s8(weights10), vget_low_s8(input0));
+      local_accum11 = vmull_s8(vget_low_s8(weights11), vget_low_s8(input1));
+      local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights00),
+                               vget_high_s8(input0));
+      local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights01),
+                               vget_high_s8(input1));
+      local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights10),
+                               vget_high_s8(input0));
+      local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights11),
+                               vget_high_s8(input1));
+      // Pairwise add and accumulate into 32-bit accumulators
+      row_accum20 = vpadalq_s16(row_accum20, local_accum00);
+      row_accum21 = vpadalq_s16(row_accum21, local_accum01);
+      row_accum30 = vpadalq_s16(row_accum30, local_accum10);
+      row_accum31 = vpadalq_s16(row_accum31, local_accum11);
+      // Multiply-accumulate into local 16-bit accumulators.
+      // We can accumulate two products without overflow because weights are
+      // required to never be -128, so each product is at most 127^2 in absolute
+      // value.
+      local_accum00 = vmull_s8(vget_low_s8(weights02), vget_low_s8(input2));
+      local_accum01 = vmull_s8(vget_low_s8(weights03), vget_low_s8(input3));
+      local_accum10 = vmull_s8(vget_low_s8(weights12), vget_low_s8(input2));
+      local_accum11 = vmull_s8(vget_low_s8(weights13), vget_low_s8(input3));
+      local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights02),
+                               vget_high_s8(input2));
+      local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights03),
+                               vget_high_s8(input3));
+      local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights12),
+                               vget_high_s8(input2));
+      local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights13),
+                               vget_high_s8(input3));
+      // Pairwise add and accumulate into 32-bit accumulators
+      row_accum20 = vpadalq_s16(row_accum20, local_accum00);
+      row_accum21 = vpadalq_s16(row_accum21, local_accum01);
+      row_accum30 = vpadalq_s16(row_accum30, local_accum10);
+      row_accum31 = vpadalq_s16(row_accum31, local_accum11);
     }
+
+    row_accum00 = vaddq_s32(row_accum00, row_accum01);
+    row_accum10 = vaddq_s32(row_accum10, row_accum11);
+    row_accum20 = vaddq_s32(row_accum20, row_accum21);
+    row_accum30 = vaddq_s32(row_accum30, row_accum31);
     // Horizontally reduce accumulators
     int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,
         pairwise_reduced_acc_2, pairwise_reduced_acc_3;
     pairwise_reduced_acc_0 =
-        vpadd_s32(vget_low_s32(acc_0), vget_high_s32(acc_0));
+        vpadd_s32(vget_low_s32(row_accum00), vget_high_s32(row_accum00));
     pairwise_reduced_acc_1 =
-        vpadd_s32(vget_low_s32(acc_1), vget_high_s32(acc_1));
+        vpadd_s32(vget_low_s32(row_accum10), vget_high_s32(row_accum10));
     pairwise_reduced_acc_2 =
-        vpadd_s32(vget_low_s32(acc_2), vget_high_s32(acc_2));
+        vpadd_s32(vget_low_s32(row_accum20), vget_high_s32(row_accum20));
     pairwise_reduced_acc_3 =
-        vpadd_s32(vget_low_s32(acc_3), vget_high_s32(acc_3));
+        vpadd_s32(vget_low_s32(row_accum30), vget_high_s32(row_accum30));
     const int32x2_t reduced_lo =
         vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
     const int32x2_t reduced_hi =
         vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
     int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
     // Add bias values.
-    int32x4_t bias_vec = vld1q_s32(bias_ptr);
-    bias_ptr += 4;
     reduced = vaddq_s32(reduced, bias_vec);
-    int left_shift = accum_shift > 0 ? accum_shift : 0;
-    int right_shift = accum_shift > 0 ? 0 : -accum_shift;
     reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));
     // Multiply by the fixed-point multiplier.
     reduced = vqrdmulhq_n_s32(reduced, accum_multiplier);
@@ -962,7 +1135,7 @@ inline void FullyConnected(
 #ifdef GEMMLOWP_NEON
   if (batches == 1 && input_offset == -128 && output_activation_min == -32768 &&
       output_activation_max == 32767) {
-    if (filter_offset == -128 && !(output_depth % 4) && !(accum_depth % 16)) {
+    if (filter_offset == -128 && !(output_depth % 4) && !(accum_depth % 64)) {
       GEMVForLstmCellWithSymmetricRange(input_data, input_dims, filter_data,
                                         filter_dims, bias_data_int32, bias_dims,
                                         output_multiplier, -output_shift,