2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef __NNFW_CKER_NEON_TENSOR_UTILS_H__
19 #define __NNFW_CKER_NEON_TENSOR_UTILS_H__
23 #include <ruy/detect_arm.h>
24 #include "cker/Types.h"
25 #include "cker/neon/neon_check.h"
26 #include "cker/ruy/RuySupport.h"
27 #include "util/logging.h"
34 #define kFloatWeightsPerNeonLane 4
44 // TODO(ahentz): Clean up.
45 using int8 = std::int8_t;
46 using uint8 = std::uint8_t;
47 using int16 = std::int16_t;
48 using uint16 = std::uint16_t;
49 using int32 = std::int32_t;
50 using uint32 = std::uint32_t;
52 // Allocates, at least, size bytes of uninitialized storage whose alignment is
53 // specified by alignment. The size parameter must be an integral multiple of
55 // Caller is responsible by freeing the allocated memory by calling free on
56 // the passed freeing_buffer pointer.
57 void *aligned_alloc(size_t alignment, size_t size, void **freeing_buffer)
59 *freeing_buffer = malloc(size + alignment);
60 const size_t offset = ((uintptr_t)*freeing_buffer) % alignment; // NOLINT
61 return offset == 0 ? *freeing_buffer : ((char *)*freeing_buffer + (alignment - offset)); // NOLINT
64 inline int32_t AccumulateNeonLane(const int32x4_t lane)
67 return vaddvq_s32(lane);
69 int64x2_t pairwiseAdded = vpaddlq_s32(lane);
70 return vgetq_lane_s64(pairwiseAdded, 0) + vgetq_lane_s64(pairwiseAdded, 1);
78 bool HasSdotInstruction()
80 static const bool has_dotprod = ruy::DetectDotprod();
84 // We interleave vector data to make the dot product logic more efficient.
85 // Suppose that vectors is:
86 // a0 a1 a2 a3 a4 a5 ...
87 // b0 b1 b2 b3 b4 b5 ...
88 // c0 c1 c2 c3 c4 c5 ...
89 // d0 d1 d2 d3 d4 d5 ...
90 // e0 e1 e2 e3 e4 e5 ...
91 // This code interleaves them like this:
92 // a0 a1 a2 a3 b0 b1 b2 b3 c0 c1 c2 c3 d0 d1 d2 d3 a4 a5 a6 a7 b4 ...
93 // e0 e1 e2 e3 f0 f1 f2 f3 ...
94 // Once the data is interleaved, each 16-byte read from the vectors pointer
95 // contains 4 bytes from each of 4 vectors.
96 const int8_t *ShuffleVectors(const int8_t *vectors, const int n_batch, const int m_cols,
97 void **shuffled_vectors_free)
99 const int kWeightsPerUint32 = 4;
101 int8 *shuffled_vectors = reinterpret_cast<int8 *>(
102 aligned_alloc(kWeightsPerUint32, n_batch * m_cols, shuffled_vectors_free));
104 for (int i = 0; i < n_batch; i += 4)
106 int8 *shuffled_vectors_ptr = shuffled_vectors + (i * m_cols);
107 const int8 *unshuffled_vec0_ptr = reinterpret_cast<const int8 *>(vectors) + (i * m_cols);
108 const int8 *unshuffled_vec1_ptr = reinterpret_cast<const int8 *>(vectors) + ((i + 1) * m_cols);
109 const int8 *unshuffled_vec2_ptr = reinterpret_cast<const int8 *>(vectors) + ((i + 2) * m_cols);
110 const int8 *unshuffled_vec3_ptr = reinterpret_cast<const int8 *>(vectors) + ((i + 3) * m_cols);
111 const int8 *const end_vec0_ptr = unshuffled_vec1_ptr;
113 while (unshuffled_vec0_ptr != end_vec0_ptr)
116 // This code path requires that (n_cols % 16) == 0 so we can safely
117 // read in 16-byte chunks from each row.
118 "ld1 {v0.16b}, [%[unshuffled_vec0_ptr]], #16\n"
119 "ld1 {v1.16b}, [%[unshuffled_vec1_ptr]], #16\n"
120 "ld1 {v2.16b}, [%[unshuffled_vec2_ptr]], #16\n"
121 "ld1 {v3.16b}, [%[unshuffled_vec3_ptr]], #16\n"
123 "st4 {v0.s, v1.s, v2.s, v3.s}[0], [%[shuffled_vectors_ptr]], #16\n"
124 "st4 {v0.s, v1.s, v2.s, v3.s}[1], [%[shuffled_vectors_ptr]], #16\n"
125 "st4 {v0.s, v1.s, v2.s, v3.s}[2], [%[shuffled_vectors_ptr]], #16\n"
126 "st4 {v0.s, v1.s, v2.s, v3.s}[3], [%[shuffled_vectors_ptr]], #16\n"
128 : [unshuffled_vec0_ptr] "+r"(unshuffled_vec0_ptr),
129 [unshuffled_vec1_ptr] "+r"(unshuffled_vec1_ptr),
130 [unshuffled_vec2_ptr] "+r"(unshuffled_vec2_ptr),
131 [unshuffled_vec3_ptr] "+r"(unshuffled_vec3_ptr),
132 [shuffled_vectors_ptr] "+r"(shuffled_vectors_ptr)
134 : "v0", "v1", "v2", "v3", "cc", "memory");
138 return reinterpret_cast<const int8_t *>(shuffled_vectors);
141 // Notes about the speed of this version vs. the baseline (from memory):
142 // - With 256K of L1, we can keep a lot of vectors in cache.
143 // I recall a reasonable speedup just by rearranging the loop to have
144 // row on the outside and batch on the inside.
145 // - I also recall getting a nice speedup from sdot.
146 // - I tried many times to do better than the current implementation, using
147 // loop unrolling and instruction reordering to avoid stalls, etc.
148 // but I was not able to do significantly better. This code is, however,
149 // much worse than what the processor spec sheet suggests is possible.
150 static void DotprodMatrixBatchFourVectorMultiplyAccumulate(const int8_t *__restrict__ matrix,
151 const int m_rows, const int m_cols,
152 const int8_t *vectors,
153 const float *scaling_factors,
154 int n_batch, float *__restrict__ result)
156 void *shuffled_vectors_free;
158 const int8_t *shuffled_vectors = ShuffleVectors(vectors, n_batch, m_cols, &shuffled_vectors_free);
160 for (int row = 0; row < m_rows; row += 2)
162 for (int batch = 0; batch < n_batch; batch += 4)
164 float *result_ptr = result + (batch * m_rows) + row;
165 const int8 *mat_ptr0 = matrix + (row * m_cols);
166 const int8 *mat_ptr1 = matrix + ((row + 1) * m_cols);
167 const int8 *mat_ptr0_end = mat_ptr1;
168 const int8 *vec_ptr = shuffled_vectors + (batch * m_cols);
169 const float *scaling_factors_ptr = scaling_factors + batch;
170 const uint64_t wide_rows = m_rows * sizeof(float);
171 const int8 *mat_ptr2 = matrix + ((row + 2) * m_cols);
172 const int8 *mat_ptr3 = matrix + ((row + 3) * m_cols);
175 // Zero out the accumulator registers.
181 "1:\n" // batch_cols_loop
183 // Read 16 more bytes from a pair of matrix rows.
184 "ld1 {v12.16b}, [%[mat_ptr0]], #16\n"
186 // Prefetch two rows ahead.
187 "prfm pldl1strm, [%[mat_ptr2]]\n"
188 "prfm pldl1strm, [%[mat_ptr3]]\n"
190 // Read from input vectors 4 times; 64 bytes total.
191 // Each 16-byte register contains parts of 4 vectors; see the
192 // shuffle logic above.
194 // From Benoit, places to look in the future:
195 // - Move load instructions further from sdot
196 // - Switch loop use-then-reload
197 // - Do partial unrolling to use register space better
198 "ld1 {v8.16b}, [%[vec_ptr]], #16\n"
199 ".word 0x4f8ce100 // sdot v0.4s, v8.16b, v12.4b[0]\n"
200 "ld1 {v9.16b}, [%[vec_ptr]], #16\n"
201 ".word 0x4face121 // sdot v1.4s, v9.16b, v12.4b[1]\n"
202 "ld1 {v10.16b}, [%[vec_ptr]], #16\n"
203 ".word 0x4f8ce940 // sdot v0.4s, v10.16b, v12.4b[2]\n"
204 "ld1 {v11.16b}, [%[vec_ptr]], #16\n"
205 ".word 0x4face961 // sdot v1.4s, v11.16b, v12.4b[3]\n"
207 // Update prefetch pointers.
208 "add %[mat_ptr2], %[mat_ptr2], #16\n"
209 "add %[mat_ptr3], %[mat_ptr3], #16\n"
211 // Re-use those vectors for the next row as well.
212 "ld1 {v13.16b}, [%[mat_ptr1]], #16\n"
213 ".word 0x4f8de102 // sdot v2.4s, v8.16b, v13.4b[0]\n"
214 ".word 0x4fade123 // sdot v3.4s, v9.16b, v13.4b[1]\n"
215 ".word 0x4f8de942 // sdot v2.4s, v10.16b, v13.4b[2]\n"
216 ".word 0x4fade963 // sdot v3.4s, v11.16b, v13.4b[3]\n"
218 // If we're not done with these rows, continue.
219 "cmp %[mat_ptr0], %[mat_ptr0_end]\n"
220 "bne 1b\n" // batch_cols_loop
222 // Done with the rows, sum the results.
223 "add v0.4s, v0.4s, v1.4s\n"
224 "add v2.4s, v2.4s, v3.4s\n"
226 // Convert the per-vector sums to floating point.
227 "scvtf v0.4s, v0.4s\n"
228 "scvtf v1.4s, v2.4s\n"
230 // Fetch scale factors.
231 "ld1 {v4.4s}, [%[scaling_factors_ptr]]\n"
233 // Multiply scale factors times sums.
234 "fmul v0.4s, v4.4s, v0.4s\n"
235 "fmul v1.4s, v4.4s, v1.4s\n"
237 // Load previous result values.
238 // The result position is:
239 // result[batch * m_rows + row]
240 // Here that is factored into:
241 // result_ptr = result + row
242 // *result_ptr = res[0]
243 // (uint8*)result_ptr += (m_rows * sizeof(float))
244 // *result_ptr = res[1]
246 // Since we're reading two rows at a time, though, we read both
247 // result[batch * m_rows + row]
249 // result[batch * m_rows + row + 1]
250 "ld2 {v9.s, v10.s}[0], [%[result_ptr]], %[wide_rows]\n"
251 "ld2 {v9.s, v10.s}[1], [%[result_ptr]], %[wide_rows]\n"
252 "ld2 {v9.s, v10.s}[2], [%[result_ptr]], %[wide_rows]\n"
253 "ld2 {v9.s, v10.s}[3], [%[result_ptr]], %[wide_rows]\n"
255 // Go back to the starting position (subtract wide_rows * 4).
256 "sub %[result_ptr], %[result_ptr], %[wide_rows], lsl #2\n"
258 // Add previous result values.
259 "fadd v9.4s, v9.4s, v0.4s\n"
260 "fadd v10.4s, v10.4s, v1.4s\n"
263 "st2 {v9.s, v10.s}[0], [%[result_ptr]], %[wide_rows]\n"
264 "st2 {v9.s, v10.s}[1], [%[result_ptr]], %[wide_rows]\n"
265 "st2 {v9.s, v10.s}[2], [%[result_ptr]], %[wide_rows]\n"
266 "st2 {v9.s, v10.s}[3], [%[result_ptr]], %[wide_rows]\n"
267 : [mat_ptr0] "+r"(mat_ptr0), [mat_ptr1] "+r"(mat_ptr1), [vec_ptr] "+r"(vec_ptr),
268 [result_ptr] "+r"(result_ptr), [mat_ptr2] "+r"(mat_ptr2), [mat_ptr3] "+r"(mat_ptr3)
269 : [mat_ptr0_end] "r"(mat_ptr0_end), [scaling_factors_ptr] "r"(scaling_factors_ptr),
270 [wide_rows] "r"(wide_rows)
271 : "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
272 "v13", "cc", "memory");
276 free(shuffled_vectors_free);
279 static void DotprodMatrixBatchFourVectorMultiplyAccumulate(
280 const int8_t *__restrict__ matrix, const int m_rows, const int m_cols, const int8_t *vectors,
281 const float *scaling_factors, int n_batch, float *__restrict__ result,
282 const float *per_channel_scale, const int32_t *input_offset, int32_t *row_sums)
284 void *shuffled_vectors_free;
285 const int8_t *shuffled_vectors = ShuffleVectors(vectors, n_batch, m_cols, &shuffled_vectors_free);
287 for (int row = 0; row < m_rows; row += 2)
289 const float *channel_scales_ptr = per_channel_scale + row;
290 int32_t *row_sums_ptr = row_sums ? row_sums + row : nullptr;
291 for (int batch = 0; batch < n_batch; batch += 4)
293 float *result_ptr = result + (batch * m_rows) + row;
294 const int8 *mat_ptr0 = matrix + (row * m_cols);
295 const int8 *mat_ptr1 = matrix + ((row + 1) * m_cols);
296 const int8 *mat_ptr0_end = mat_ptr1;
297 const int8 *vec_ptr = shuffled_vectors + (batch * m_cols);
298 const float *scaling_factors_ptr = scaling_factors + batch;
299 const uint64_t wide_rows = m_rows * sizeof(float);
300 const int32_t *batch_offsets_ptr = input_offset + batch;
301 const int32_t is_channel_scale_nullptr = per_channel_scale == nullptr;
302 const int32_t is_row_sums_nullptr = row_sums_ptr == nullptr;
303 asm volatile("dup v0.4s, wzr\n"
308 "ld1 {v7.4s}, [%[batch_offsets_ptr]]\n"
309 "ld1 {v4.4s}, [%[scaling_factors_ptr]]\n"
310 // Zero out zero point accumulators.
314 // Load per channel scales if not null.
315 "cmp %w[is_channel_scale_nullptr], #0\n"
317 "ld1r {v16.4s}, [%[channel_scales_ptr]], #4\n"
318 "ld1r {v17.4s}, [%[channel_scales_ptr]]\n"
319 "fmul v16.4s, v16.4s, v4.4s\n"
320 "fmul v17.4s, v17.4s, v4.4s\n"
323 "mov v16.16b, v4.16b\n"
324 "mov v17.16b, v4.16b\n"
326 "ld1 {v12.16b}, [%[mat_ptr0]], #16\n"
327 "ld1 {v8.16b}, [%[vec_ptr]], #16\n"
328 ".word 0x4f8ce100 // sdot v0.4s, v8.16b, v12.4b[0]\n"
329 "ld1 {v9.16b}, [%[vec_ptr]], #16\n"
330 ".word 0x4face121 // sdot v1.4s, v9.16b, v12.4b[1]\n"
331 "ld1 {v10.16b}, [%[vec_ptr]], #16\n"
332 ".word 0x4f8ce940 // sdot v0.4s, v10.16b, v12.4b[2]\n"
333 "ld1 {v11.16b}, [%[vec_ptr]], #16\n"
334 ".word 0x4face961 // sdot v1.4s, v11.16b, v12.4b[3]\n"
335 "ld1 {v13.16b}, [%[mat_ptr1]], #16\n"
336 ".word 0x4f8de102 // sdot v2.4s, v8.16b, v13.4b[0]\n"
337 ".word 0x4fade123 // sdot v3.4s, v9.16b, v13.4b[1]\n"
338 ".word 0x4f8de942 // sdot v2.4s, v10.16b, v13.4b[2]\n"
339 ".word 0x4fade963 // sdot v3.4s, v11.16b, v13.4b[3]\n"
340 "cmp %w[is_row_sums_nullptr], #1\n"
342 // Accumulate row_sums for zero point calculations.
343 "saddlp v12.8h, v12.16b\n"
344 "saddlp v13.8h, v13.16b\n"
345 "sadalp v14.4s, v12.8h\n"
346 "sadalp v15.4s, v13.8h\n"
348 "cmp %[mat_ptr0], %[mat_ptr0_end]\n"
350 "add v0.4s, v0.4s, v1.4s\n"
351 "add v2.4s, v2.4s, v3.4s\n"
353 "cmp %w[is_row_sums_nullptr], #1\n"
355 // Calculate zero point offsets.
358 "dup v14.4s, v14.s[0]\n"
359 "dup v15.4s, v15.s[0]\n"
362 "ld1r {v14.4s}, [%[row_sums_ptr]], #4\n"
363 "ld1r {v15.4s}, [%[row_sums_ptr]]\n"
366 "mul v14.4s, v14.4s, v7.4s\n"
367 "mul v15.4s, v15.4s, v7.4s\n"
368 "sub v0.4s, v0.4s, v14.4s\n"
369 "sub v2.4s, v2.4s, v15.4s\n"
371 "scvtf v0.4s, v0.4s\n"
372 "scvtf v1.4s, v2.4s\n"
375 "fmul v0.4s, v16.4s, v0.4s\n"
376 "fmul v1.4s, v17.4s, v1.4s\n"
378 "ld2 {v9.s, v10.s}[0], [%[result_ptr]], %[wide_rows]\n"
379 "ld2 {v9.s, v10.s}[1], [%[result_ptr]], %[wide_rows]\n"
380 "ld2 {v9.s, v10.s}[2], [%[result_ptr]], %[wide_rows]\n"
381 "ld2 {v9.s, v10.s}[3], [%[result_ptr]], %[wide_rows]\n"
382 "sub %[result_ptr], %[result_ptr], %[wide_rows], lsl #2\n"
383 "fadd v9.4s, v9.4s, v0.4s\n"
384 "fadd v10.4s, v10.4s, v1.4s\n"
385 "st2 {v9.s, v10.s}[0], [%[result_ptr]], %[wide_rows]\n"
386 "st2 {v9.s, v10.s}[1], [%[result_ptr]], %[wide_rows]\n"
387 "st2 {v9.s, v10.s}[2], [%[result_ptr]], %[wide_rows]\n"
388 "st2 {v9.s, v10.s}[3], [%[result_ptr]], %[wide_rows]\n"
389 : [mat_ptr0] "+r"(mat_ptr0), [mat_ptr1] "+r"(mat_ptr1), [vec_ptr] "+r"(vec_ptr),
390 [result_ptr] "+r"(result_ptr), [row_sums_ptr] "+r"(row_sums_ptr)
391 : [mat_ptr0_end] "r"(mat_ptr0_end),
392 [scaling_factors_ptr] "r"(scaling_factors_ptr), [wide_rows] "r"(wide_rows),
393 [channel_scales_ptr] "r"(channel_scales_ptr),
394 [batch_offsets_ptr] "r"(batch_offsets_ptr),
395 [is_channel_scale_nullptr] "r"(is_channel_scale_nullptr),
396 [is_row_sums_nullptr] "r"(is_row_sums_nullptr)
397 : "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
398 "v12", "v13", "v14", "v15", "v16", "v17", "w0", "w1", "cc", "memory");
402 free(shuffled_vectors_free);
405 // The DotprodMatrixBatchFourVectorMultiplyAccumulate kernel processes 4
406 // vectors in the same time as the baseline processes 1 vector. However, it
407 // requires 4 vectors of input.
409 // To take advantage of this speed difference, we add some zero-valued
410 // vectors to the batch so that n_batch is a multiple of 4. Then we execute
411 // DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate on that padded batch,
412 // then extract just the results we want at the end (ignoring the extra padding
415 // The relative cost of the padding is large when the matrix is smaller than
416 // 128x128, so we don't use this code path on small matrices. On larger
417 // matrices, the computation cost dwarfs the padding cost, making this code
420 // If we ignore the cost of padding, this kernel is:
421 // 1x the speed of NeonMatrixBatchVectorMultiplyImpl for n_batch = 1
422 // 2x the speed of NeonMatrixBatchVectorMultiplyImpl for n_batch = 2
423 // 3x the speed of NeonMatrixBatchVectorMultiplyImpl for n_batch = 3
426 // We don't use this kernel when n_batch = 1 because the baseline kernel
427 // is fine for that case.
428 void DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(
429 const int8_t *__restrict__ matrix, const int m_rows, const int m_cols, const int8_t *vectors,
430 const float *scaling_factors, int n_batch, float *__restrict__ result,
431 const float *per_channel_scale, const int32_t *input_offset, int32_t *row_sums)
433 const int kWeightsPerUint32 = 4;
435 // Round to the nearest multiple of 4.
436 int batch_round_up = n_batch;
437 if (n_batch % 4 != 0)
439 batch_round_up += (4 - n_batch % 4);
441 assert(n_batch <= batch_round_up);
443 void *padded_vectors_free;
444 const int padded_vectors_size = batch_round_up * m_cols;
445 int8_t *padded_vectors = reinterpret_cast<int8_t *>(
446 aligned_alloc(kWeightsPerUint32, padded_vectors_size, &padded_vectors_free));
447 memset(padded_vectors, 0, padded_vectors_size);
449 void *padded_result_free;
450 const int result_size = n_batch * m_rows * sizeof(float);
451 const int padded_result_size = batch_round_up * m_rows * sizeof(float);
452 float *padded_result = reinterpret_cast<float *>(
453 aligned_alloc(kWeightsPerUint32, padded_result_size, &padded_result_free));
454 memcpy(padded_result, result, result_size);
455 memset(reinterpret_cast<char *>(padded_result) + result_size, 0,
456 padded_result_size - result_size);
458 // Copy the input into the padded data structure.
459 assert(n_batch * m_cols <= padded_vectors_size);
460 memcpy(padded_vectors, vectors, n_batch * m_cols);
462 void *padded_scaling_factors_free;
463 const int padded_scaling_factors_size = batch_round_up * sizeof(float);
464 float *padded_scaling_factors = reinterpret_cast<float *>(
465 aligned_alloc(kWeightsPerUint32, padded_scaling_factors_size, &padded_scaling_factors_free));
466 assert(static_cast<int>(n_batch * sizeof(float)) <= padded_scaling_factors_size);
467 assert(static_cast<int>(batch_round_up * sizeof(float)) <= padded_scaling_factors_size);
468 memset(padded_scaling_factors, 0, batch_round_up * sizeof(float));
469 memcpy(padded_scaling_factors, scaling_factors, n_batch * sizeof(float));
471 if (input_offset != nullptr)
473 void *padded_input_offset_free;
474 const int padded_input_offset_size = batch_round_up * sizeof(int32_t);
475 int32_t *padded_input_offset = reinterpret_cast<int32_t *>(
476 aligned_alloc(kWeightsPerUint32, padded_input_offset_size, &padded_input_offset_free));
477 assert(static_cast<int>(n_batch * sizeof(int32_t)) <= padded_input_offset_size);
478 assert(static_cast<int>(batch_round_up * sizeof(int32_t)) <= padded_input_offset_size);
479 memset(padded_input_offset, 0, batch_round_up * sizeof(int32_t));
480 memcpy(padded_input_offset, input_offset, n_batch * sizeof(int32_t));
482 // Call the main kernel.
483 DotprodMatrixBatchFourVectorMultiplyAccumulate(
484 matrix, m_rows, m_cols, padded_vectors, padded_scaling_factors, batch_round_up,
485 padded_result, per_channel_scale, padded_input_offset, row_sums);
487 free(padded_input_offset_free);
491 // Call the main kernel.
492 DotprodMatrixBatchFourVectorMultiplyAccumulate(matrix, m_rows, m_cols, padded_vectors,
493 padded_scaling_factors, batch_round_up,
496 memcpy(result, padded_result, result_size);
498 free(padded_result_free);
499 free(padded_vectors_free);
500 free(padded_scaling_factors_free);
503 void DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(const int8_t *__restrict__ matrix,
504 const int m_rows, const int m_cols,
505 const int8_t *vectors,
506 const float *scaling_factors, int n_batch,
507 float *__restrict__ result)
509 DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(
510 matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
511 /*per_channel_scale=*/nullptr, /*input_offset=*/nullptr,
512 /*row_sums=*/nullptr);
514 #endif // __aarch64__
516 bool NeonIsZeroVector(const float *vector, int v_size)
518 // If v_size is not divisible by kFloatWeightsPerNeonLane, we cannot
519 // use the main vectorized loop, and we need to process sequentially.
520 // postamble_start shows the start index where this should happen.
521 const int postamble_start = v_size - (v_size & (kFloatWeightsPerNeonLane - 1));
523 const float32x4_t zero_x4_float = vmovq_n_f32(0.0f);
524 for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane)
526 const float32x4_t i_x4_float = vld1q_f32(vector + v);
527 uint32x4_t cmp_result = vceqq_f32(i_x4_float, zero_x4_float);
528 if (vgetq_lane_u32(cmp_result, 0) == 0)
530 if (vgetq_lane_u32(cmp_result, 1) == 0)
532 if (vgetq_lane_u32(cmp_result, 2) == 0)
534 if (vgetq_lane_u32(cmp_result, 3) == 0)
539 for (int v = postamble_start; v < v_size; ++v)
541 if (vector[v] != 0.0)
547 void NeonCpuBackendGemm(const int8_t *input, const int32_t *bias,
548 const int8_t *input_to_gate_weights, int32_t n_batch, int32_t n_input,
549 int32_t n_output, int32_t, int32_t *scratch, ruy::Context *ruy_context)
551 MatrixParams<int8_t> lhs_params;
552 lhs_params.order = Order::kRowMajor;
553 lhs_params.rows = n_output;
554 lhs_params.cols = n_input;
555 lhs_params.cacheable = true;
557 MatrixParams<int8_t> rhs_params;
558 rhs_params.order = Order::kColMajor;
559 rhs_params.rows = n_input;
560 rhs_params.cols = n_batch;
562 MatrixParams<int32_t> dst_params;
563 dst_params.order = Order::kColMajor;
564 dst_params.rows = n_output;
565 dst_params.cols = n_batch;
567 GemmParams<int32_t, int32_t> gemm_params;
570 gemm_params.bias = bias;
573 // Below code is from tflite::cpu_backend_gemm::detail::GemmImplUsingRuy
574 ruy::Matrix<int8_t> ruy_lhs;
575 ruy::Matrix<int8_t> ruy_rhs;
576 ruy::Matrix<int32_t> ruy_dst;
577 ruy_support::MakeRuyMatrix(lhs_params, input_to_gate_weights, &ruy_lhs);
578 ruy_support::MakeRuyMatrix(rhs_params, input, &ruy_rhs);
579 ruy_support::MakeRuyMatrix(dst_params, scratch, &ruy_dst);
581 ruy::BasicSpec<int32_t, int32_t> ruy_spec;
582 ruy_support::MakeRuySpec(gemm_params, &ruy_spec);
584 constexpr ruy::Path kRuyPath = ruy::kAllPaths;
585 ruy::Mul<kRuyPath>(ruy_lhs, ruy_rhs, ruy_spec, ruy_context, &ruy_dst);
588 void NeonSymmetricQuantizeFloats(const float *values, const int size, int8_t *quantized_values,
589 float *min, float *max, float *scaling_factor)
591 // TODO(raziel): vectorize min/max calculation.
592 auto minmax = std::minmax_element(values, values + size);
593 *min = *minmax.first;
594 *max = *minmax.second;
595 const int kScale = 127;
596 const float range = std::max(std::abs(*min), std::abs(*max));
599 memset(quantized_values, 0, size * sizeof(int8_t));
603 *scaling_factor = range / kScale;
604 const float scaling_factor_inv = kScale / range;
606 const int postamble_start = size - (size & (2 * kFloatWeightsPerNeonLane - 1));
608 // Vectorized constants.
609 const float32x4_t q_factor_f32x4 = vmovq_n_f32(scaling_factor_inv);
610 const float32x4_t point5_f32x4 = vmovq_n_f32(0.5);
611 const float32x4_t zero_f32x4 = vmovq_n_f32(0.0);
612 const int32x4_t scale_i32x4 = vmovq_n_s32(kScale);
613 const int32x4_t neg_scale_i32x4 = vmovq_n_s32(-kScale);
615 for (int i = 0; i < postamble_start; i += 2 * kFloatWeightsPerNeonLane)
617 // Implements the vectorized version of the following:
618 // const int32_t quantized_value = static_cast<int32>(
619 // std::round(*scaling_factor * values[i]));
620 // Since the vectorized round intrinsics (vrndqa_f32) is not supported
621 // on all Neon flavors, we use the following method for rounding: if (x
622 // < 0) (int)(x - 0.5) if (x >= 0) (int)(x + 0.5)
623 float32x4_t value0_f32x4 = vld1q_f32(&values[i]);
624 float32x4_t value1_f32x4 = vld1q_f32(&values[i + kFloatWeightsPerNeonLane]);
625 float32x4_t mul0_f32x4 = vmulq_f32(value0_f32x4, q_factor_f32x4);
626 float32x4_t mul1_f32x4 = vmulq_f32(value1_f32x4, q_factor_f32x4);
628 int32x4_t cmp_with_zero0_ui32x4 = (int32x4_t)vcltq_f32(mul0_f32x4, zero_f32x4); // NOLINT
629 int32x4_t cmp_with_zero1_ui32x4 = (int32x4_t)vcltq_f32(mul1_f32x4, zero_f32x4); // NOLINT
631 float32x4_t cmp_with_zero0_f32x4 = vcvtq_f32_s32(cmp_with_zero0_ui32x4);
632 float32x4_t cmp_with_zero1_f32x4 = vcvtq_f32_s32(cmp_with_zero1_ui32x4);
633 cmp_with_zero0_f32x4 = vaddq_f32(cmp_with_zero0_f32x4, point5_f32x4);
634 cmp_with_zero1_f32x4 = vaddq_f32(cmp_with_zero1_f32x4, point5_f32x4);
636 mul0_f32x4 = vaddq_f32(mul0_f32x4, cmp_with_zero0_f32x4);
637 mul1_f32x4 = vaddq_f32(mul1_f32x4, cmp_with_zero1_f32x4);
639 int32x4_t f2i0_i32x4 = vcvtq_s32_f32(mul0_f32x4);
640 int32x4_t f2i1_i32x4 = vcvtq_s32_f32(mul1_f32x4);
642 // Implements the vectorized version of the folowing block:
643 // quantized_values[i] = std::min(kScale, std::max(-kScale,
644 // quantized_value));
645 int32x4_t max0_i32x4 = vmaxq_s32(f2i0_i32x4, neg_scale_i32x4);
646 int32x4_t max1_i32x4 = vmaxq_s32(f2i1_i32x4, neg_scale_i32x4);
647 int32x4_t min0_i32x4 = vminq_s32(max0_i32x4, scale_i32x4);
648 int32x4_t min1_i32x4 = vminq_s32(max1_i32x4, scale_i32x4);
650 int16x4_t min0_16x4 = vmovn_s32(min0_i32x4);
651 int16x4_t min1_16x4 = vmovn_s32(min1_i32x4);
653 int16x8_t min_16x8 = vcombine_s16(min0_16x4, min1_16x4);
654 int8x8_t min_s8x8 = vqmovn_s16(min_16x8);
655 vst1_s8(&quantized_values[i], min_s8x8);
658 for (int i = postamble_start; i < size; ++i)
660 const int32_t quantized_value =
661 static_cast<int32_t>(std::round(scaling_factor_inv * values[i]));
662 quantized_values[i] = std::min(kScale, std::max(-kScale, quantized_value));
666 void NeonMatrixBatchVectorMultiplyAccumulate(const int8_t *__restrict__ matrix, const int m_rows,
667 const int m_cols, const int8_t *__restrict__ vectors,
668 const float *scaling_factors, int n_batch,
669 float *__restrict__ result, int result_stride)
672 if (HasSdotInstruction() && m_cols % 16 == 0 && m_rows % 2 == 0 && m_rows >= n_batch)
674 if (n_batch % 4 == 0 && result_stride == 1)
676 // Benchmarks suggest that it's always better to use the batch code
677 // when we can, even on small matrices.
678 DotprodMatrixBatchFourVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
679 scaling_factors, n_batch, result);
682 else if (result_stride == 1 && n_batch >= 2 && m_rows * m_cols >= 128 * 128)
684 DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
685 scaling_factors, n_batch, result);
689 #endif // __aarch64__
691 static const int kWeightsPerUint32 = 4;
692 static const int kWeightsPerNeonLane = 16;
693 // Assuming *matrix is kWeightsPerUint32-byte aligned,
694 // every row of the matrix is also
695 // kWeightsPerUint32-byte aligned as long as cols is
696 // a multiple of kWeightsPerUint32. The assumption
697 // is currently satisfied by TFLite's 16-byte memory
700 // Otherwise, we allocate an aligned memory block and set
701 // a flag to later copy rows from matrix to the block
702 // for aligned multiplication.
703 bool unaligned = false;
704 int8_t *aligned_row = nullptr;
705 void *aligned_row_free = nullptr;
706 if ((m_cols & (kWeightsPerUint32 - 1)) != 0)
709 aligned_row = (int8_t *)aligned_alloc(kWeightsPerUint32, m_cols, // NOLINT
712 void *aligned_vec_free = nullptr;
713 int8_t *aligned_vec = (int8_t *)aligned_alloc(kWeightsPerUint32, m_cols, // NOLINT
716 // If m_cols is not at least kWeightsPerNeonLane, we cannot use the main
717 // vectorized loop, and we need to process sequentially. postamble_half_start
718 // shows the start index where this should happen. Between postamble_start and
719 // postamble_half_start we can still process kWeightsPerNeonLane >> 1 in a
721 const int postamble_half_start = m_cols & ~(kWeightsPerNeonLane - 1);
722 const int postamble_start = m_cols & ~((kWeightsPerNeonLane >> 1) - 1);
724 for (int batch = 0; batch < n_batch; ++batch)
726 const float batch_scaling_factor = scaling_factors[batch];
727 // Copy the vector data to an aligned vector.
728 memcpy(aligned_vec, vectors + batch * m_cols, sizeof(int8_t) * m_cols);
729 // Compute dot-product for every column.
730 for (int row = 0; row < m_rows; ++row, result += result_stride)
732 // Get the address of the first element of the row.
733 int8_t *row_ptr = (int8_t *)matrix + row * m_cols; // NOLINT
736 memcpy(aligned_row, row_ptr, sizeof(int8_t) * m_cols);
737 row_ptr = aligned_row;
740 // Initialize the dot product sum for the row to 0.
741 int32x4_t dotprod_32x4 = vmovq_n_s32(0);
743 // Prefetch the row to cache.
744 __builtin_prefetch(row_ptr, 0 /* prefetch for read */, 3 /* temporal locality */);
746 // For every block of 16 8-bit elements.
748 for (; col < postamble_half_start; col += kWeightsPerNeonLane)
750 // Load 16 8-bit values from the row and vector, each, to operate on.
751 // Here the assumption is that each buffer is 4-byte aligned. Otherwise,
752 // performance may suffer significantly.
754 ((uintptr_t)(&row_ptr[col]) & (kWeightsPerUint32 - 1)) == 0);
755 const int8x16_t s1_8x16 = vld1q_s8((const int8_t *)(aligned_vec + col));
756 const int8x16_t s2_8x16 = vld1q_s8((const int8_t *)(row_ptr + col));
757 // Multiply the low bits (i.e. the lower 8 8bit numbers in the
759 int16x8_t prod_16x8 = vmull_s8(vget_low_s8(s1_8x16), vget_low_s8(s2_8x16));
760 // Multiply the high bits (i.e. the higher 8 8bit numbers in the
761 // registers), and accumulate with the result of the low bits product.
762 // The assumption here is that overflow will not happen as we quantize
763 // our values to be in the range [-127, 127]. As such the sum of the 2
764 // products is always strictly smaller than 15-bits (32767 in absolute
766 prod_16x8 = vmlal_s8(prod_16x8, vget_high_s8(s1_8x16), vget_high_s8(s2_8x16));
768 dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
771 // Half iteration dealing only 8 elements
772 // TODO(raziel): if (ABSL_PREDICT_FALSE(col < postamble_start))
773 if (col < postamble_start)
775 // Load 8 8-bit values from the row and column each to operate on.
776 // Here the assumption is that each buffer is 4-bytes aligned.
777 // Otherwise, performance may suffer significantly.
779 ((uintptr_t)(&row_ptr[col]) & (kWeightsPerUint32 - 1)) == 0);
780 const int8x8_t s1_8x8 = vld1_s8((const int8_t *)(aligned_vec + col));
781 const int8x8_t s2_8x8 = vld1_s8((const int8_t *)(row_ptr + col));
782 const int16x8_t prod_16x8 = vmull_s8(s1_8x8, s2_8x8);
783 dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
784 col += (kWeightsPerNeonLane >> 1);
786 // Add the 4 intermediate sum values to get the final dot-prod value for
788 int32_t dotprod = AccumulateNeonLane(dotprod_32x4);
790 // TODO(raziel): if (ABSL_PREDICT_FALSE(col < m_cols))
791 for (; col < m_cols; ++col)
793 dotprod += row_ptr[col] * aligned_vec[col];
796 *result += dotprod * batch_scaling_factor;
802 free(aligned_row_free);
804 free(aligned_vec_free);
807 void NeonMatrixBatchVectorMultiplyAccumulate(const float *matrix, int m_rows, int m_cols,
808 const float *vector, int n_batch, float *result,
811 // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main
812 // vectorized loop, and we need to process sequentially. postamble_start shows
813 // the start index where this should happen.
814 const int postamble_start = m_cols - (m_cols & (kFloatWeightsPerNeonLane - 1));
816 for (int b = 0; b < n_batch; b++)
818 float *result_in_batch = result + b * m_rows * result_stride;
819 const float *vector_in_batch = vector + b * m_cols;
820 const float *matrix_row = matrix;
822 // Main matrix by vector multiplication loop
823 for (int r = 0; r < m_rows; r++)
825 float32x4_t acc_32x4 = vmovq_n_f32(0.0);
826 for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane)
828 // Load 4 float values from vector and matrix row.
829 float32x4_t vector_f32x4 = vld1q_f32(vector_in_batch + c);
830 float32x4_t matrix_f32x4 = vld1q_f32(matrix_row + c);
831 // Multiply the vector and matrix row and add to accumulator.
832 acc_32x4 = vmlaq_f32(acc_32x4, matrix_f32x4, vector_f32x4);
834 // Add the 4 intermediate sum values to get the final dot-prod value for
836 *result_in_batch += (vgetq_lane_f32(acc_32x4, 0) + vgetq_lane_f32(acc_32x4, 1) +
837 vgetq_lane_f32(acc_32x4, 2) + vgetq_lane_f32(acc_32x4, 3));
838 for (int c = postamble_start; c < m_cols; c++)
840 *result_in_batch += matrix_row[c] * vector_in_batch[c];
842 matrix_row += m_cols;
843 result_in_batch += result_stride;
848 void NeonMatrixBatchVectorMultiplyAccumulate(const int8_t *__restrict__ matrix, const int m_rows,
849 const int m_cols, const int8_t *__restrict__ vectors,
850 const float *scaling_factors, int n_batch,
851 int32_t *scratch, float *__restrict__ result,
852 int result_stride, ruy::Context *ruy_context)
854 if (m_rows % 4 == 0 && result_stride == 1)
856 const int32_t *bias = static_cast<const int32_t *>(nullptr);
857 NeonCpuBackendGemm(vectors, bias, matrix, n_batch, m_cols, m_rows,
858 /*output_zp =*/0, scratch, ruy_context);
860 // Multiply by float scaling factors and write to result
861 const int total_size = n_batch * m_rows;
863 for (; i <= total_size - 8; i += 8, result += 8 * result_stride)
865 const float batch_scaling_factor0 = scaling_factors[i / m_rows];
866 const float batch_scaling_factor1 = scaling_factors[(i + 4) / m_rows];
867 const float32x4_t scaling_factor0 = vdupq_n_f32(batch_scaling_factor0);
868 const float32x4_t scaling_factor1 = vdupq_n_f32(batch_scaling_factor1);
869 const int32x4_t scratch_val0 = vld1q_s32(scratch + i);
870 const int32x4_t scratch_val1 = vld1q_s32(scratch + i + 4);
871 const float32x4_t float_val0 = vcvtq_f32_s32(scratch_val0);
872 const float32x4_t float_val1 = vcvtq_f32_s32(scratch_val1);
873 const float32x4_t result0 = vmlaq_f32(vld1q_f32(result), float_val0, scaling_factor0);
874 const float32x4_t result1 =
875 vmlaq_f32(vld1q_f32(result + 4 * result_stride), float_val1, scaling_factor1);
876 vst1q_f32(result, result0);
877 vst1q_f32(result + 4 * result_stride, result1);
880 for (; i < total_size; i++, result += result_stride)
882 const float batch_scaling_factor = scaling_factors[i / m_rows];
883 int32_t x = *(scratch++);
884 *result += x * batch_scaling_factor;
888 NeonMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors, scaling_factors, n_batch,
889 result, result_stride);
897 #endif // __NNFW_CKER_NEON_TENSOR_UTILS_H__