From: A. Unique TensorFlower Date: Mon, 9 Apr 2018 17:59:46 +0000 (-0700) Subject: Rewrite a fast GEMV path for two goals: X-Git-Tag: tflite-v0.1.7~16^2^2~32 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=9d1bf2bd4723fd3d0a012891bc54cc9db54bd9cd;p=platform%2Fupstream%2Ftensorflow.git Rewrite a fast GEMV path for two goals: 1. Avoid cache aliasing issues on CPUs with 4-way set associative L1 cache. That includes Cortex-A53. 2. Be a good basis to port to assembly. PiperOrigin-RevId: 192152277 --- diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index 9a27461..5acf1ea 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -554,88 +554,261 @@ inline void GEMVForLstmCellWithSymmetricRange( // exercises it). We just guard our assumptions about size evenness with // the following assertions. TFLITE_DCHECK(!(output_size % 4)); - TFLITE_DCHECK(!(input_size % 8)); + TFLITE_DCHECK(!(input_size % 64)); const int32* bias_ptr = bias_data; int16* output_ptr = output_data; const uint8x16_t signbit = vdupq_n_u8(0x80); for (int in = 0; in < input_size; in += 32) { optimized_ops_preload_l1_keep(input_data + in); } + const int left_shift = accum_shift > 0 ? accum_shift : 0; + const int right_shift = accum_shift > 0 ? 0 : -accum_shift; for (int out = 0; out < output_size; out += 4) { - const uint8* weights_ptr_0 = weights_data + out * input_size; - const uint8* weights_ptr_1 = weights_ptr_0 + 1 * input_size; - const uint8* weights_ptr_2 = weights_ptr_0 + 2 * input_size; - const uint8* weights_ptr_3 = weights_ptr_0 + 3 * input_size; + // Load the bias values + int32x4_t bias_vec = vld1q_s32(bias_ptr); + bias_ptr += 4; - int32x4_t acc_0 = vdupq_n_s32(0); - int32x4_t acc_1 = vdupq_n_s32(0); - int32x4_t acc_2 = vdupq_n_s32(0); - int32x4_t acc_3 = vdupq_n_s32(0); - int in = 0; - const int kReadAhead = 256; - // Handle 16 levels of depth at a time. - for (; in < input_size; in += 16) { - int8x16_t weights_val_0 = - vreinterpretq_s8_u8(veorq_u8(signbit, vld1q_u8(weights_ptr_0))); - int8x16_t weights_val_1 = - vreinterpretq_s8_u8(veorq_u8(signbit, vld1q_u8(weights_ptr_1))); - int8x16_t weights_val_2 = - vreinterpretq_s8_u8(veorq_u8(signbit, vld1q_u8(weights_ptr_2))); - int8x16_t weights_val_3 = - vreinterpretq_s8_u8(veorq_u8(signbit, vld1q_u8(weights_ptr_3))); - int8x16_t input_val = - vreinterpretq_s8_u8(veorq_u8(signbit, vld1q_u8(input_data + in))); - int16x8_t acc16_0 = - vmull_s8(vget_low_s8(weights_val_0), vget_low_s8(input_val)); - int16x8_t acc16_1 = - vmull_s8(vget_low_s8(weights_val_1), vget_low_s8(input_val)); - int16x8_t acc16_2 = - vmull_s8(vget_low_s8(weights_val_2), vget_low_s8(input_val)); - int16x8_t acc16_3 = - vmull_s8(vget_low_s8(weights_val_3), vget_low_s8(input_val)); - acc16_0 = vmlal_s8(acc16_0, vget_high_s8(weights_val_0), - vget_high_s8(input_val)); - acc16_1 = vmlal_s8(acc16_1, vget_high_s8(weights_val_1), - vget_high_s8(input_val)); - acc16_2 = vmlal_s8(acc16_2, vget_high_s8(weights_val_2), - vget_high_s8(input_val)); - acc16_3 = vmlal_s8(acc16_3, vget_high_s8(weights_val_3), - vget_high_s8(input_val)); - acc_0 = vpadalq_s16(acc_0, acc16_0); - acc_1 = vpadalq_s16(acc_1, acc16_1); - acc_2 = vpadalq_s16(acc_2, acc16_2); - acc_3 = vpadalq_s16(acc_3, acc16_3); - weights_ptr_0 += 16; - weights_ptr_1 += 16; - weights_ptr_2 += 16; - weights_ptr_3 += 16; - optimized_ops_preload_l1_stream(weights_ptr_0 + kReadAhead); - optimized_ops_preload_l1_stream(weights_ptr_1 + kReadAhead); - optimized_ops_preload_l1_stream(weights_ptr_2 + kReadAhead); - optimized_ops_preload_l1_stream(weights_ptr_3 + kReadAhead); + // Clear accumulators. We use 2 accumulator registers per row, + // for 4 rows. row_accumRN is the N-th accumulator for row R. + int32x4_t row_accum00 = vdupq_n_s32(0); + int32x4_t row_accum01 = vdupq_n_s32(0); + int32x4_t row_accum10 = vdupq_n_s32(0); + int32x4_t row_accum11 = vdupq_n_s32(0); + int32x4_t row_accum20 = vdupq_n_s32(0); + int32x4_t row_accum21 = vdupq_n_s32(0); + int32x4_t row_accum30 = vdupq_n_s32(0); + int32x4_t row_accum31 = vdupq_n_s32(0); + + // kReadAhead parametrizes how far ahead we prefetch weights into L1 cache. + const int kReadAhead = 512; + // Prefetch the first weights values. + for (int k = 0; k < kReadAhead; k += 64) { + optimized_ops_preload_l1_stream(weights_data + (out + 0) * input_size + + k); + optimized_ops_preload_l1_stream(weights_data + (out + 1) * input_size + + k); + optimized_ops_preload_l1_stream(weights_data + (out + 2) * input_size + + k); + optimized_ops_preload_l1_stream(weights_data + (out + 3) * input_size + + k); + } + // Loop along the rows, handling 64 bytes per iteration because that's + // cache line size on most current ARM-architecture CPUs. + for (int in = 0; in < input_size; in += 64) { + // Prefetch some future weights values. + optimized_ops_preload_l1_stream(weights_data + (out + 0) * input_size + + in + kReadAhead); + optimized_ops_preload_l1_stream(weights_data + (out + 1) * input_size + + in + kReadAhead); + optimized_ops_preload_l1_stream(weights_data + (out + 2) * input_size + + in + kReadAhead); + optimized_ops_preload_l1_stream(weights_data + (out + 3) * input_size + + in + kReadAhead); + + // We will use 2 local 16-bit accumulators per row, for 2 rows. + // See below (*) for the rationale of processing only 2 rows at a time. + // local_accumRN is the N-th local accumulator for row R. + int16x8_t local_accum00; + int16x8_t local_accum01; + int16x8_t local_accum10; + int16x8_t local_accum11; + + // Load 64 bytes of input activations values. Convert to signed int8 + // by flipping the sign bit (i.e. subtracting 128, the required + // zero_point value). + int8x16_t input0 = vreinterpretq_s8_u8( + veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 0))); + int8x16_t input1 = vreinterpretq_s8_u8( + veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 1))); + int8x16_t input2 = vreinterpretq_s8_u8( + veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 2))); + int8x16_t input3 = vreinterpretq_s8_u8( + veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 3))); + + // Beginning of the core accumulation. Notice how while we have 4 + // rows to process, this code is taking care of only 2 rows at a time, + // thus being divided into two parts looking similar ("Rows 0 and 1" and + // "Rows 2 and 3"). + // + // (*) The rationale for handling only 2 rows at a time is to avoid + // cache aliasing issues on 4-way set-associative L1-cache CPUs, such + // as Cortex-A53. With sufficiently large, power-of-two matrix dimensions, + // we may find ourselves in a situation where rows alias each other in + // the L1 cache, and moreover may also mutually alias with the input + // activations. If we try to load 4 rows at a time, together with the + // input activations, that may be 5 mutually-aliasing vectors, resulting + // in constant mutual eviction from L1 cache. Handling 2 rows at a time + // here largely mitigates these issues, and seems at least to be very + // effective on Cortex-A53: + // Before After + // big (Cortex-A73) 2.85 ms 2.85 ms + // little (Cortex-A53) 11.0 ms 5.16 ms + + // Rows 0 and 1: + // Load 64 bytes of weights values from each row. Convert to signed int8 + // by flipping the sign bit (i.e. subtracting 128, the required + // zero_point value). + int8x16_t weights00 = vreinterpretq_s8_u8(veorq_u8( + signbit, + vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 0))); + int8x16_t weights01 = vreinterpretq_s8_u8(veorq_u8( + signbit, + vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 1))); + int8x16_t weights02 = vreinterpretq_s8_u8(veorq_u8( + signbit, + vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 2))); + int8x16_t weights03 = vreinterpretq_s8_u8(veorq_u8( + signbit, + vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 3))); + int8x16_t weights10 = vreinterpretq_s8_u8(veorq_u8( + signbit, + vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 0))); + int8x16_t weights11 = vreinterpretq_s8_u8(veorq_u8( + signbit, + vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 1))); + int8x16_t weights12 = vreinterpretq_s8_u8(veorq_u8( + signbit, + vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 2))); + int8x16_t weights13 = vreinterpretq_s8_u8(veorq_u8( + signbit, + vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 3))); + // Multiply-accumulate into local 16-bit accumulators. + // We can accumulate two products without overflow because weights are + // required to never be -128, so each product is at most 127^2 in absolute + // value. + local_accum00 = vmull_s8(vget_low_s8(weights00), vget_low_s8(input0)); + local_accum01 = vmull_s8(vget_low_s8(weights01), vget_low_s8(input1)); + local_accum10 = vmull_s8(vget_low_s8(weights10), vget_low_s8(input0)); + local_accum11 = vmull_s8(vget_low_s8(weights11), vget_low_s8(input1)); + local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights00), + vget_high_s8(input0)); + local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights01), + vget_high_s8(input1)); + local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights10), + vget_high_s8(input0)); + local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights11), + vget_high_s8(input1)); + // Pairwise add and accumulate into 32-bit accumulators + row_accum00 = vpadalq_s16(row_accum00, local_accum00); + row_accum01 = vpadalq_s16(row_accum01, local_accum01); + row_accum10 = vpadalq_s16(row_accum10, local_accum10); + row_accum11 = vpadalq_s16(row_accum11, local_accum11); + // Multiply-accumulate into local 16-bit accumulators. + // We can accumulate two products without overflow because weights are + // required to never be -128, so each product is at most 127^2 in absolute + // value. + local_accum00 = vmull_s8(vget_low_s8(weights02), vget_low_s8(input2)); + local_accum01 = vmull_s8(vget_low_s8(weights03), vget_low_s8(input3)); + local_accum10 = vmull_s8(vget_low_s8(weights12), vget_low_s8(input2)); + local_accum11 = vmull_s8(vget_low_s8(weights13), vget_low_s8(input3)); + local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights02), + vget_high_s8(input2)); + local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights03), + vget_high_s8(input3)); + local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights12), + vget_high_s8(input2)); + local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights13), + vget_high_s8(input3)); + // Pairwise add and accumulate into 32-bit accumulators + row_accum00 = vpadalq_s16(row_accum00, local_accum00); + row_accum01 = vpadalq_s16(row_accum01, local_accum01); + row_accum10 = vpadalq_s16(row_accum10, local_accum10); + row_accum11 = vpadalq_s16(row_accum11, local_accum11); + + // Rows 2 and 3: + // Load 64 bytes of weights values from each row. Convert to signed int8 + // by flipping the sign bit (i.e. subtracting 128, the required + // zero_point value). + weights00 = vreinterpretq_s8_u8(veorq_u8( + signbit, + vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 0))); + weights01 = vreinterpretq_s8_u8(veorq_u8( + signbit, + vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 1))); + weights02 = vreinterpretq_s8_u8(veorq_u8( + signbit, + vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 2))); + weights03 = vreinterpretq_s8_u8(veorq_u8( + signbit, + vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 3))); + weights10 = vreinterpretq_s8_u8(veorq_u8( + signbit, + vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 0))); + weights11 = vreinterpretq_s8_u8(veorq_u8( + signbit, + vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 1))); + weights12 = vreinterpretq_s8_u8(veorq_u8( + signbit, + vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 2))); + weights13 = vreinterpretq_s8_u8(veorq_u8( + signbit, + vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 3))); + // Multiply-accumulate into local 16-bit accumulators. + // We can accumulate two products without overflow because weights are + // required to never be -128, so each product is at most 127^2 in absolute + // value. + local_accum00 = vmull_s8(vget_low_s8(weights00), vget_low_s8(input0)); + local_accum01 = vmull_s8(vget_low_s8(weights01), vget_low_s8(input1)); + local_accum10 = vmull_s8(vget_low_s8(weights10), vget_low_s8(input0)); + local_accum11 = vmull_s8(vget_low_s8(weights11), vget_low_s8(input1)); + local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights00), + vget_high_s8(input0)); + local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights01), + vget_high_s8(input1)); + local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights10), + vget_high_s8(input0)); + local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights11), + vget_high_s8(input1)); + // Pairwise add and accumulate into 32-bit accumulators + row_accum20 = vpadalq_s16(row_accum20, local_accum00); + row_accum21 = vpadalq_s16(row_accum21, local_accum01); + row_accum30 = vpadalq_s16(row_accum30, local_accum10); + row_accum31 = vpadalq_s16(row_accum31, local_accum11); + // Multiply-accumulate into local 16-bit accumulators. + // We can accumulate two products without overflow because weights are + // required to never be -128, so each product is at most 127^2 in absolute + // value. + local_accum00 = vmull_s8(vget_low_s8(weights02), vget_low_s8(input2)); + local_accum01 = vmull_s8(vget_low_s8(weights03), vget_low_s8(input3)); + local_accum10 = vmull_s8(vget_low_s8(weights12), vget_low_s8(input2)); + local_accum11 = vmull_s8(vget_low_s8(weights13), vget_low_s8(input3)); + local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights02), + vget_high_s8(input2)); + local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights03), + vget_high_s8(input3)); + local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights12), + vget_high_s8(input2)); + local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights13), + vget_high_s8(input3)); + // Pairwise add and accumulate into 32-bit accumulators + row_accum20 = vpadalq_s16(row_accum20, local_accum00); + row_accum21 = vpadalq_s16(row_accum21, local_accum01); + row_accum30 = vpadalq_s16(row_accum30, local_accum10); + row_accum31 = vpadalq_s16(row_accum31, local_accum11); } + + row_accum00 = vaddq_s32(row_accum00, row_accum01); + row_accum10 = vaddq_s32(row_accum10, row_accum11); + row_accum20 = vaddq_s32(row_accum20, row_accum21); + row_accum30 = vaddq_s32(row_accum30, row_accum31); // Horizontally reduce accumulators int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1, pairwise_reduced_acc_2, pairwise_reduced_acc_3; pairwise_reduced_acc_0 = - vpadd_s32(vget_low_s32(acc_0), vget_high_s32(acc_0)); + vpadd_s32(vget_low_s32(row_accum00), vget_high_s32(row_accum00)); pairwise_reduced_acc_1 = - vpadd_s32(vget_low_s32(acc_1), vget_high_s32(acc_1)); + vpadd_s32(vget_low_s32(row_accum10), vget_high_s32(row_accum10)); pairwise_reduced_acc_2 = - vpadd_s32(vget_low_s32(acc_2), vget_high_s32(acc_2)); + vpadd_s32(vget_low_s32(row_accum20), vget_high_s32(row_accum20)); pairwise_reduced_acc_3 = - vpadd_s32(vget_low_s32(acc_3), vget_high_s32(acc_3)); + vpadd_s32(vget_low_s32(row_accum30), vget_high_s32(row_accum30)); const int32x2_t reduced_lo = vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1); const int32x2_t reduced_hi = vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3); int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi); // Add bias values. - int32x4_t bias_vec = vld1q_s32(bias_ptr); - bias_ptr += 4; reduced = vaddq_s32(reduced, bias_vec); - int left_shift = accum_shift > 0 ? accum_shift : 0; - int right_shift = accum_shift > 0 ? 0 : -accum_shift; reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift)); // Multiply by the fixed-point multiplier. reduced = vqrdmulhq_n_s32(reduced, accum_multiplier); @@ -962,7 +1135,7 @@ inline void FullyConnected( #ifdef GEMMLOWP_NEON if (batches == 1 && input_offset == -128 && output_activation_min == -32768 && output_activation_max == 32767) { - if (filter_offset == -128 && !(output_depth % 4) && !(accum_depth % 16)) { + if (filter_offset == -128 && !(output_depth % 4) && !(accum_depth % 64)) { GEMVForLstmCellWithSymmetricRange(input_data, input_dims, filter_data, filter_dims, bias_data_int32, bias_dims, output_multiplier, -output_shift,