// exercises it). We just guard our assumptions about size evenness with
// the following assertions.
TFLITE_DCHECK(!(output_size % 4));
- TFLITE_DCHECK(!(input_size % 8));
+ TFLITE_DCHECK(!(input_size % 64));
const int32* bias_ptr = bias_data;
int16* output_ptr = output_data;
const uint8x16_t signbit = vdupq_n_u8(0x80);
for (int in = 0; in < input_size; in += 32) {
optimized_ops_preload_l1_keep(input_data + in);
}
+ const int left_shift = accum_shift > 0 ? accum_shift : 0;
+ const int right_shift = accum_shift > 0 ? 0 : -accum_shift;
for (int out = 0; out < output_size; out += 4) {
- const uint8* weights_ptr_0 = weights_data + out * input_size;
- const uint8* weights_ptr_1 = weights_ptr_0 + 1 * input_size;
- const uint8* weights_ptr_2 = weights_ptr_0 + 2 * input_size;
- const uint8* weights_ptr_3 = weights_ptr_0 + 3 * input_size;
+ // Load the bias values
+ int32x4_t bias_vec = vld1q_s32(bias_ptr);
+ bias_ptr += 4;
- int32x4_t acc_0 = vdupq_n_s32(0);
- int32x4_t acc_1 = vdupq_n_s32(0);
- int32x4_t acc_2 = vdupq_n_s32(0);
- int32x4_t acc_3 = vdupq_n_s32(0);
- int in = 0;
- const int kReadAhead = 256;
- // Handle 16 levels of depth at a time.
- for (; in < input_size; in += 16) {
- int8x16_t weights_val_0 =
- vreinterpretq_s8_u8(veorq_u8(signbit, vld1q_u8(weights_ptr_0)));
- int8x16_t weights_val_1 =
- vreinterpretq_s8_u8(veorq_u8(signbit, vld1q_u8(weights_ptr_1)));
- int8x16_t weights_val_2 =
- vreinterpretq_s8_u8(veorq_u8(signbit, vld1q_u8(weights_ptr_2)));
- int8x16_t weights_val_3 =
- vreinterpretq_s8_u8(veorq_u8(signbit, vld1q_u8(weights_ptr_3)));
- int8x16_t input_val =
- vreinterpretq_s8_u8(veorq_u8(signbit, vld1q_u8(input_data + in)));
- int16x8_t acc16_0 =
- vmull_s8(vget_low_s8(weights_val_0), vget_low_s8(input_val));
- int16x8_t acc16_1 =
- vmull_s8(vget_low_s8(weights_val_1), vget_low_s8(input_val));
- int16x8_t acc16_2 =
- vmull_s8(vget_low_s8(weights_val_2), vget_low_s8(input_val));
- int16x8_t acc16_3 =
- vmull_s8(vget_low_s8(weights_val_3), vget_low_s8(input_val));
- acc16_0 = vmlal_s8(acc16_0, vget_high_s8(weights_val_0),
- vget_high_s8(input_val));
- acc16_1 = vmlal_s8(acc16_1, vget_high_s8(weights_val_1),
- vget_high_s8(input_val));
- acc16_2 = vmlal_s8(acc16_2, vget_high_s8(weights_val_2),
- vget_high_s8(input_val));
- acc16_3 = vmlal_s8(acc16_3, vget_high_s8(weights_val_3),
- vget_high_s8(input_val));
- acc_0 = vpadalq_s16(acc_0, acc16_0);
- acc_1 = vpadalq_s16(acc_1, acc16_1);
- acc_2 = vpadalq_s16(acc_2, acc16_2);
- acc_3 = vpadalq_s16(acc_3, acc16_3);
- weights_ptr_0 += 16;
- weights_ptr_1 += 16;
- weights_ptr_2 += 16;
- weights_ptr_3 += 16;
- optimized_ops_preload_l1_stream(weights_ptr_0 + kReadAhead);
- optimized_ops_preload_l1_stream(weights_ptr_1 + kReadAhead);
- optimized_ops_preload_l1_stream(weights_ptr_2 + kReadAhead);
- optimized_ops_preload_l1_stream(weights_ptr_3 + kReadAhead);
+ // Clear accumulators. We use 2 accumulator registers per row,
+ // for 4 rows. row_accumRN is the N-th accumulator for row R.
+ int32x4_t row_accum00 = vdupq_n_s32(0);
+ int32x4_t row_accum01 = vdupq_n_s32(0);
+ int32x4_t row_accum10 = vdupq_n_s32(0);
+ int32x4_t row_accum11 = vdupq_n_s32(0);
+ int32x4_t row_accum20 = vdupq_n_s32(0);
+ int32x4_t row_accum21 = vdupq_n_s32(0);
+ int32x4_t row_accum30 = vdupq_n_s32(0);
+ int32x4_t row_accum31 = vdupq_n_s32(0);
+
+ // kReadAhead parametrizes how far ahead we prefetch weights into L1 cache.
+ const int kReadAhead = 512;
+ // Prefetch the first weights values.
+ for (int k = 0; k < kReadAhead; k += 64) {
+ optimized_ops_preload_l1_stream(weights_data + (out + 0) * input_size +
+ k);
+ optimized_ops_preload_l1_stream(weights_data + (out + 1) * input_size +
+ k);
+ optimized_ops_preload_l1_stream(weights_data + (out + 2) * input_size +
+ k);
+ optimized_ops_preload_l1_stream(weights_data + (out + 3) * input_size +
+ k);
+ }
+ // Loop along the rows, handling 64 bytes per iteration because that's
+ // cache line size on most current ARM-architecture CPUs.
+ for (int in = 0; in < input_size; in += 64) {
+ // Prefetch some future weights values.
+ optimized_ops_preload_l1_stream(weights_data + (out + 0) * input_size +
+ in + kReadAhead);
+ optimized_ops_preload_l1_stream(weights_data + (out + 1) * input_size +
+ in + kReadAhead);
+ optimized_ops_preload_l1_stream(weights_data + (out + 2) * input_size +
+ in + kReadAhead);
+ optimized_ops_preload_l1_stream(weights_data + (out + 3) * input_size +
+ in + kReadAhead);
+
+ // We will use 2 local 16-bit accumulators per row, for 2 rows.
+ // See below (*) for the rationale of processing only 2 rows at a time.
+ // local_accumRN is the N-th local accumulator for row R.
+ int16x8_t local_accum00;
+ int16x8_t local_accum01;
+ int16x8_t local_accum10;
+ int16x8_t local_accum11;
+
+ // Load 64 bytes of input activations values. Convert to signed int8
+ // by flipping the sign bit (i.e. subtracting 128, the required
+ // zero_point value).
+ int8x16_t input0 = vreinterpretq_s8_u8(
+ veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 0)));
+ int8x16_t input1 = vreinterpretq_s8_u8(
+ veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 1)));
+ int8x16_t input2 = vreinterpretq_s8_u8(
+ veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 2)));
+ int8x16_t input3 = vreinterpretq_s8_u8(
+ veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 3)));
+
+ // Beginning of the core accumulation. Notice how while we have 4
+ // rows to process, this code is taking care of only 2 rows at a time,
+ // thus being divided into two parts looking similar ("Rows 0 and 1" and
+ // "Rows 2 and 3").
+ //
+ // (*) The rationale for handling only 2 rows at a time is to avoid
+ // cache aliasing issues on 4-way set-associative L1-cache CPUs, such
+ // as Cortex-A53. With sufficiently large, power-of-two matrix dimensions,
+ // we may find ourselves in a situation where rows alias each other in
+ // the L1 cache, and moreover may also mutually alias with the input
+ // activations. If we try to load 4 rows at a time, together with the
+ // input activations, that may be 5 mutually-aliasing vectors, resulting
+ // in constant mutual eviction from L1 cache. Handling 2 rows at a time
+ // here largely mitigates these issues, and seems at least to be very
+ // effective on Cortex-A53:
+ // Before After
+ // big (Cortex-A73) 2.85 ms 2.85 ms
+ // little (Cortex-A53) 11.0 ms 5.16 ms
+
+ // Rows 0 and 1:
+ // Load 64 bytes of weights values from each row. Convert to signed int8
+ // by flipping the sign bit (i.e. subtracting 128, the required
+ // zero_point value).
+ int8x16_t weights00 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 0)));
+ int8x16_t weights01 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 1)));
+ int8x16_t weights02 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 2)));
+ int8x16_t weights03 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 3)));
+ int8x16_t weights10 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 0)));
+ int8x16_t weights11 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 1)));
+ int8x16_t weights12 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 2)));
+ int8x16_t weights13 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 3)));
+ // Multiply-accumulate into local 16-bit accumulators.
+ // We can accumulate two products without overflow because weights are
+ // required to never be -128, so each product is at most 127^2 in absolute
+ // value.
+ local_accum00 = vmull_s8(vget_low_s8(weights00), vget_low_s8(input0));
+ local_accum01 = vmull_s8(vget_low_s8(weights01), vget_low_s8(input1));
+ local_accum10 = vmull_s8(vget_low_s8(weights10), vget_low_s8(input0));
+ local_accum11 = vmull_s8(vget_low_s8(weights11), vget_low_s8(input1));
+ local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights00),
+ vget_high_s8(input0));
+ local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights01),
+ vget_high_s8(input1));
+ local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights10),
+ vget_high_s8(input0));
+ local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights11),
+ vget_high_s8(input1));
+ // Pairwise add and accumulate into 32-bit accumulators
+ row_accum00 = vpadalq_s16(row_accum00, local_accum00);
+ row_accum01 = vpadalq_s16(row_accum01, local_accum01);
+ row_accum10 = vpadalq_s16(row_accum10, local_accum10);
+ row_accum11 = vpadalq_s16(row_accum11, local_accum11);
+ // Multiply-accumulate into local 16-bit accumulators.
+ // We can accumulate two products without overflow because weights are
+ // required to never be -128, so each product is at most 127^2 in absolute
+ // value.
+ local_accum00 = vmull_s8(vget_low_s8(weights02), vget_low_s8(input2));
+ local_accum01 = vmull_s8(vget_low_s8(weights03), vget_low_s8(input3));
+ local_accum10 = vmull_s8(vget_low_s8(weights12), vget_low_s8(input2));
+ local_accum11 = vmull_s8(vget_low_s8(weights13), vget_low_s8(input3));
+ local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights02),
+ vget_high_s8(input2));
+ local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights03),
+ vget_high_s8(input3));
+ local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights12),
+ vget_high_s8(input2));
+ local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights13),
+ vget_high_s8(input3));
+ // Pairwise add and accumulate into 32-bit accumulators
+ row_accum00 = vpadalq_s16(row_accum00, local_accum00);
+ row_accum01 = vpadalq_s16(row_accum01, local_accum01);
+ row_accum10 = vpadalq_s16(row_accum10, local_accum10);
+ row_accum11 = vpadalq_s16(row_accum11, local_accum11);
+
+ // Rows 2 and 3:
+ // Load 64 bytes of weights values from each row. Convert to signed int8
+ // by flipping the sign bit (i.e. subtracting 128, the required
+ // zero_point value).
+ weights00 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 0)));
+ weights01 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 1)));
+ weights02 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 2)));
+ weights03 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 3)));
+ weights10 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 0)));
+ weights11 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 1)));
+ weights12 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 2)));
+ weights13 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 3)));
+ // Multiply-accumulate into local 16-bit accumulators.
+ // We can accumulate two products without overflow because weights are
+ // required to never be -128, so each product is at most 127^2 in absolute
+ // value.
+ local_accum00 = vmull_s8(vget_low_s8(weights00), vget_low_s8(input0));
+ local_accum01 = vmull_s8(vget_low_s8(weights01), vget_low_s8(input1));
+ local_accum10 = vmull_s8(vget_low_s8(weights10), vget_low_s8(input0));
+ local_accum11 = vmull_s8(vget_low_s8(weights11), vget_low_s8(input1));
+ local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights00),
+ vget_high_s8(input0));
+ local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights01),
+ vget_high_s8(input1));
+ local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights10),
+ vget_high_s8(input0));
+ local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights11),
+ vget_high_s8(input1));
+ // Pairwise add and accumulate into 32-bit accumulators
+ row_accum20 = vpadalq_s16(row_accum20, local_accum00);
+ row_accum21 = vpadalq_s16(row_accum21, local_accum01);
+ row_accum30 = vpadalq_s16(row_accum30, local_accum10);
+ row_accum31 = vpadalq_s16(row_accum31, local_accum11);
+ // Multiply-accumulate into local 16-bit accumulators.
+ // We can accumulate two products without overflow because weights are
+ // required to never be -128, so each product is at most 127^2 in absolute
+ // value.
+ local_accum00 = vmull_s8(vget_low_s8(weights02), vget_low_s8(input2));
+ local_accum01 = vmull_s8(vget_low_s8(weights03), vget_low_s8(input3));
+ local_accum10 = vmull_s8(vget_low_s8(weights12), vget_low_s8(input2));
+ local_accum11 = vmull_s8(vget_low_s8(weights13), vget_low_s8(input3));
+ local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights02),
+ vget_high_s8(input2));
+ local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights03),
+ vget_high_s8(input3));
+ local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights12),
+ vget_high_s8(input2));
+ local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights13),
+ vget_high_s8(input3));
+ // Pairwise add and accumulate into 32-bit accumulators
+ row_accum20 = vpadalq_s16(row_accum20, local_accum00);
+ row_accum21 = vpadalq_s16(row_accum21, local_accum01);
+ row_accum30 = vpadalq_s16(row_accum30, local_accum10);
+ row_accum31 = vpadalq_s16(row_accum31, local_accum11);
}
+
+ row_accum00 = vaddq_s32(row_accum00, row_accum01);
+ row_accum10 = vaddq_s32(row_accum10, row_accum11);
+ row_accum20 = vaddq_s32(row_accum20, row_accum21);
+ row_accum30 = vaddq_s32(row_accum30, row_accum31);
// Horizontally reduce accumulators
int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,
pairwise_reduced_acc_2, pairwise_reduced_acc_3;
pairwise_reduced_acc_0 =
- vpadd_s32(vget_low_s32(acc_0), vget_high_s32(acc_0));
+ vpadd_s32(vget_low_s32(row_accum00), vget_high_s32(row_accum00));
pairwise_reduced_acc_1 =
- vpadd_s32(vget_low_s32(acc_1), vget_high_s32(acc_1));
+ vpadd_s32(vget_low_s32(row_accum10), vget_high_s32(row_accum10));
pairwise_reduced_acc_2 =
- vpadd_s32(vget_low_s32(acc_2), vget_high_s32(acc_2));
+ vpadd_s32(vget_low_s32(row_accum20), vget_high_s32(row_accum20));
pairwise_reduced_acc_3 =
- vpadd_s32(vget_low_s32(acc_3), vget_high_s32(acc_3));
+ vpadd_s32(vget_low_s32(row_accum30), vget_high_s32(row_accum30));
const int32x2_t reduced_lo =
vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
const int32x2_t reduced_hi =
vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
// Add bias values.
- int32x4_t bias_vec = vld1q_s32(bias_ptr);
- bias_ptr += 4;
reduced = vaddq_s32(reduced, bias_vec);
- int left_shift = accum_shift > 0 ? accum_shift : 0;
- int right_shift = accum_shift > 0 ? 0 : -accum_shift;
reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));
// Multiply by the fixed-point multiplier.
reduced = vqrdmulhq_n_s32(reduced, accum_multiplier);
#ifdef GEMMLOWP_NEON
if (batches == 1 && input_offset == -128 && output_activation_min == -32768 &&
output_activation_max == 32767) {
- if (filter_offset == -128 && !(output_depth % 4) && !(accum_depth % 16)) {
+ if (filter_offset == -128 && !(output_depth % 4) && !(accum_depth % 64)) {
GEMVForLstmCellWithSymmetricRange(input_data, input_dims, filter_data,
filter_dims, bias_data_int32, bias_dims,
output_multiplier, -output_shift,