Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / compute / cker / include / cker / NeonTensorUtils.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_CKER_NEON_TENSOR_UTILS_H__
19 #define __NNFW_CKER_NEON_TENSOR_UTILS_H__
20
21 #include <ruy/path.h>
22 #include <ruy/ruy.h>
23 #include <ruy/detect_arm.h>
24 #include "cker/Types.h"
25 #include "cker/neon/neon_check.h"
26 #include "cker/ruy/RuySupport.h"
27 #include "util/logging.h"
28
29 #include <cassert>
30 #include <cmath>
31
32 #ifdef USE_NEON
33
34 #define kFloatWeightsPerNeonLane 4
35
36 namespace nnfw
37 {
38 namespace cker
39 {
40
41 namespace
42 {
43
44 // TODO(ahentz): Clean up.
45 using int8 = std::int8_t;
46 using uint8 = std::uint8_t;
47 using int16 = std::int16_t;
48 using uint16 = std::uint16_t;
49 using int32 = std::int32_t;
50 using uint32 = std::uint32_t;
51
52 // Allocates, at least, size bytes of uninitialized storage whose alignment is
53 // specified by alignment. The size parameter must be an integral multiple of
54 // alignment.
55 // Caller is responsible by freeing the allocated memory by calling free on
56 // the passed freeing_buffer pointer.
57 void *aligned_alloc(size_t alignment, size_t size, void **freeing_buffer)
58 {
59   *freeing_buffer = malloc(size + alignment);
60   const size_t offset = ((uintptr_t)*freeing_buffer) % alignment;                          // NOLINT
61   return offset == 0 ? *freeing_buffer : ((char *)*freeing_buffer + (alignment - offset)); // NOLINT
62 }
63
64 inline int32_t AccumulateNeonLane(const int32x4_t lane)
65 {
66 #ifdef __aarch64__
67   return vaddvq_s32(lane);
68 #else
69   int64x2_t pairwiseAdded = vpaddlq_s32(lane);
70   return vgetq_lane_s64(pairwiseAdded, 0) + vgetq_lane_s64(pairwiseAdded, 1);
71 #endif
72 }
73
74 } // namespace
75
76 #ifdef __aarch64__
77
78 bool HasSdotInstruction()
79 {
80   static const bool has_dotprod = ruy::DetectDotprod();
81   return has_dotprod;
82 }
83
84 // We interleave vector data to make the dot product logic more efficient.
85 // Suppose that vectors is:
86 //     a0 a1 a2 a3 a4 a5 ...
87 //     b0 b1 b2 b3 b4 b5 ...
88 //     c0 c1 c2 c3 c4 c5 ...
89 //     d0 d1 d2 d3 d4 d5 ...
90 //     e0 e1 e2 e3 e4 e5 ...
91 // This code interleaves them like this:
92 //     a0 a1 a2 a3 b0 b1 b2 b3 c0 c1 c2 c3 d0 d1 d2 d3 a4 a5 a6 a7 b4 ...
93 //     e0 e1 e2 e3 f0 f1 f2 f3 ...
94 // Once the data is interleaved, each 16-byte read from the vectors pointer
95 // contains 4 bytes from each of 4 vectors.
96 const int8_t *ShuffleVectors(const int8_t *vectors, const int n_batch, const int m_cols,
97                              void **shuffled_vectors_free)
98 {
99   const int kWeightsPerUint32 = 4;
100
101   int8 *shuffled_vectors = reinterpret_cast<int8 *>(
102       aligned_alloc(kWeightsPerUint32, n_batch * m_cols, shuffled_vectors_free));
103
104   for (int i = 0; i < n_batch; i += 4)
105   {
106     int8 *shuffled_vectors_ptr = shuffled_vectors + (i * m_cols);
107     const int8 *unshuffled_vec0_ptr = reinterpret_cast<const int8 *>(vectors) + (i * m_cols);
108     const int8 *unshuffled_vec1_ptr = reinterpret_cast<const int8 *>(vectors) + ((i + 1) * m_cols);
109     const int8 *unshuffled_vec2_ptr = reinterpret_cast<const int8 *>(vectors) + ((i + 2) * m_cols);
110     const int8 *unshuffled_vec3_ptr = reinterpret_cast<const int8 *>(vectors) + ((i + 3) * m_cols);
111     const int8 *const end_vec0_ptr = unshuffled_vec1_ptr;
112
113     while (unshuffled_vec0_ptr != end_vec0_ptr)
114     {
115       asm volatile(
116           // This code path requires that (n_cols % 16) == 0 so we can safely
117           // read in 16-byte chunks from each row.
118           "ld1 {v0.16b}, [%[unshuffled_vec0_ptr]], #16\n"
119           "ld1 {v1.16b}, [%[unshuffled_vec1_ptr]], #16\n"
120           "ld1 {v2.16b}, [%[unshuffled_vec2_ptr]], #16\n"
121           "ld1 {v3.16b}, [%[unshuffled_vec3_ptr]], #16\n"
122
123           "st4 {v0.s, v1.s, v2.s, v3.s}[0], [%[shuffled_vectors_ptr]], #16\n"
124           "st4 {v0.s, v1.s, v2.s, v3.s}[1], [%[shuffled_vectors_ptr]], #16\n"
125           "st4 {v0.s, v1.s, v2.s, v3.s}[2], [%[shuffled_vectors_ptr]], #16\n"
126           "st4 {v0.s, v1.s, v2.s, v3.s}[3], [%[shuffled_vectors_ptr]], #16\n"
127
128           : [unshuffled_vec0_ptr] "+r"(unshuffled_vec0_ptr),
129             [unshuffled_vec1_ptr] "+r"(unshuffled_vec1_ptr),
130             [unshuffled_vec2_ptr] "+r"(unshuffled_vec2_ptr),
131             [unshuffled_vec3_ptr] "+r"(unshuffled_vec3_ptr),
132             [shuffled_vectors_ptr] "+r"(shuffled_vectors_ptr)
133           :
134           : "v0", "v1", "v2", "v3", "cc", "memory");
135     }
136   }
137
138   return reinterpret_cast<const int8_t *>(shuffled_vectors);
139 }
140
141 // Notes about the speed of this version vs. the baseline (from memory):
142 // - With 256K of L1, we can keep a lot of vectors in cache.
143 //   I recall a reasonable speedup just by rearranging the loop to have
144 //   row on the outside and batch on the inside.
145 // - I also recall getting a nice speedup from sdot.
146 // - I tried many times to do better than the current implementation, using
147 //   loop unrolling and instruction reordering to avoid stalls, etc.
148 //   but I was not able to do significantly better. This code is, however,
149 //   much worse than what the processor spec sheet suggests is possible.
150 static void DotprodMatrixBatchFourVectorMultiplyAccumulate(const int8_t *__restrict__ matrix,
151                                                            const int m_rows, const int m_cols,
152                                                            const int8_t *vectors,
153                                                            const float *scaling_factors,
154                                                            int n_batch, float *__restrict__ result)
155 {
156   void *shuffled_vectors_free;
157
158   const int8_t *shuffled_vectors = ShuffleVectors(vectors, n_batch, m_cols, &shuffled_vectors_free);
159
160   for (int row = 0; row < m_rows; row += 2)
161   {
162     for (int batch = 0; batch < n_batch; batch += 4)
163     {
164       float *result_ptr = result + (batch * m_rows) + row;
165       const int8 *mat_ptr0 = matrix + (row * m_cols);
166       const int8 *mat_ptr1 = matrix + ((row + 1) * m_cols);
167       const int8 *mat_ptr0_end = mat_ptr1;
168       const int8 *vec_ptr = shuffled_vectors + (batch * m_cols);
169       const float *scaling_factors_ptr = scaling_factors + batch;
170       const uint64_t wide_rows = m_rows * sizeof(float);
171       const int8 *mat_ptr2 = matrix + ((row + 2) * m_cols);
172       const int8 *mat_ptr3 = matrix + ((row + 3) * m_cols);
173
174       asm volatile(
175           // Zero out the accumulator registers.
176           "dup v0.4s, wzr\n"
177           "dup v1.4s, wzr\n"
178           "dup v2.4s, wzr\n"
179           "dup v3.4s, wzr\n"
180
181           "1:\n" // batch_cols_loop
182
183           // Read 16 more bytes from a pair of matrix rows.
184           "ld1 {v12.16b}, [%[mat_ptr0]], #16\n"
185
186           // Prefetch two rows ahead.
187           "prfm pldl1strm, [%[mat_ptr2]]\n"
188           "prfm pldl1strm, [%[mat_ptr3]]\n"
189
190           // Read from input vectors 4 times; 64 bytes total.
191           // Each 16-byte register contains parts of 4 vectors; see the
192           // shuffle logic above.
193
194           // From Benoit, places to look in the future:
195           // - Move load instructions further from sdot
196           // - Switch loop use-then-reload
197           // - Do partial unrolling to use register space better
198           "ld1 {v8.16b}, [%[vec_ptr]], #16\n"
199           ".word 0x4f8ce100  // sdot v0.4s, v8.16b, v12.4b[0]\n"
200           "ld1 {v9.16b}, [%[vec_ptr]], #16\n"
201           ".word 0x4face121  // sdot v1.4s, v9.16b, v12.4b[1]\n"
202           "ld1 {v10.16b}, [%[vec_ptr]], #16\n"
203           ".word 0x4f8ce940  // sdot v0.4s, v10.16b, v12.4b[2]\n"
204           "ld1 {v11.16b}, [%[vec_ptr]], #16\n"
205           ".word 0x4face961  // sdot v1.4s, v11.16b, v12.4b[3]\n"
206
207           // Update prefetch pointers.
208           "add %[mat_ptr2], %[mat_ptr2], #16\n"
209           "add %[mat_ptr3], %[mat_ptr3], #16\n"
210
211           // Re-use those vectors for the next row as well.
212           "ld1 {v13.16b}, [%[mat_ptr1]], #16\n"
213           ".word 0x4f8de102  // sdot v2.4s, v8.16b, v13.4b[0]\n"
214           ".word 0x4fade123  // sdot v3.4s, v9.16b, v13.4b[1]\n"
215           ".word 0x4f8de942  // sdot v2.4s, v10.16b, v13.4b[2]\n"
216           ".word 0x4fade963  // sdot v3.4s, v11.16b, v13.4b[3]\n"
217
218           // If we're not done with these rows, continue.
219           "cmp %[mat_ptr0], %[mat_ptr0_end]\n"
220           "bne 1b\n" // batch_cols_loop
221
222           // Done with the rows, sum the results.
223           "add v0.4s, v0.4s, v1.4s\n"
224           "add v2.4s, v2.4s, v3.4s\n"
225
226           // Convert the per-vector sums to floating point.
227           "scvtf v0.4s, v0.4s\n"
228           "scvtf v1.4s, v2.4s\n"
229
230           // Fetch scale factors.
231           "ld1 {v4.4s}, [%[scaling_factors_ptr]]\n"
232
233           // Multiply scale factors times sums.
234           "fmul v0.4s, v4.4s, v0.4s\n"
235           "fmul v1.4s, v4.4s, v1.4s\n"
236
237           // Load previous result values.
238           // The result position is:
239           //   result[batch * m_rows + row]
240           // Here that is factored into:
241           //   result_ptr = result + row
242           //   *result_ptr = res[0]
243           //   (uint8*)result_ptr += (m_rows * sizeof(float))
244           //   *result_ptr = res[1]
245           //   ...
246           // Since we're reading two rows at a time, though, we read both
247           //   result[batch * m_rows + row]
248           // and
249           //   result[batch * m_rows + row + 1]
250           "ld2 {v9.s, v10.s}[0], [%[result_ptr]], %[wide_rows]\n"
251           "ld2 {v9.s, v10.s}[1], [%[result_ptr]], %[wide_rows]\n"
252           "ld2 {v9.s, v10.s}[2], [%[result_ptr]], %[wide_rows]\n"
253           "ld2 {v9.s, v10.s}[3], [%[result_ptr]], %[wide_rows]\n"
254
255           // Go back to the starting position (subtract wide_rows * 4).
256           "sub %[result_ptr], %[result_ptr], %[wide_rows], lsl #2\n"
257
258           // Add previous result values.
259           "fadd v9.4s, v9.4s, v0.4s\n"
260           "fadd v10.4s, v10.4s, v1.4s\n"
261
262           // Store results.
263           "st2 {v9.s, v10.s}[0], [%[result_ptr]], %[wide_rows]\n"
264           "st2 {v9.s, v10.s}[1], [%[result_ptr]], %[wide_rows]\n"
265           "st2 {v9.s, v10.s}[2], [%[result_ptr]], %[wide_rows]\n"
266           "st2 {v9.s, v10.s}[3], [%[result_ptr]], %[wide_rows]\n"
267           : [mat_ptr0] "+r"(mat_ptr0), [mat_ptr1] "+r"(mat_ptr1), [vec_ptr] "+r"(vec_ptr),
268             [result_ptr] "+r"(result_ptr), [mat_ptr2] "+r"(mat_ptr2), [mat_ptr3] "+r"(mat_ptr3)
269           : [mat_ptr0_end] "r"(mat_ptr0_end), [scaling_factors_ptr] "r"(scaling_factors_ptr),
270             [wide_rows] "r"(wide_rows)
271           : "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
272             "v13", "cc", "memory");
273     }
274   }
275
276   free(shuffled_vectors_free);
277 }
278
279 static void DotprodMatrixBatchFourVectorMultiplyAccumulate(
280     const int8_t *__restrict__ matrix, const int m_rows, const int m_cols, const int8_t *vectors,
281     const float *scaling_factors, int n_batch, float *__restrict__ result,
282     const float *per_channel_scale, const int32_t *input_offset, int32_t *row_sums)
283 {
284   void *shuffled_vectors_free;
285   const int8_t *shuffled_vectors = ShuffleVectors(vectors, n_batch, m_cols, &shuffled_vectors_free);
286
287   for (int row = 0; row < m_rows; row += 2)
288   {
289     const float *channel_scales_ptr = per_channel_scale + row;
290     int32_t *row_sums_ptr = row_sums ? row_sums + row : nullptr;
291     for (int batch = 0; batch < n_batch; batch += 4)
292     {
293       float *result_ptr = result + (batch * m_rows) + row;
294       const int8 *mat_ptr0 = matrix + (row * m_cols);
295       const int8 *mat_ptr1 = matrix + ((row + 1) * m_cols);
296       const int8 *mat_ptr0_end = mat_ptr1;
297       const int8 *vec_ptr = shuffled_vectors + (batch * m_cols);
298       const float *scaling_factors_ptr = scaling_factors + batch;
299       const uint64_t wide_rows = m_rows * sizeof(float);
300       const int32_t *batch_offsets_ptr = input_offset + batch;
301       const int32_t is_channel_scale_nullptr = per_channel_scale == nullptr;
302       const int32_t is_row_sums_nullptr = row_sums_ptr == nullptr;
303       asm volatile("dup v0.4s, wzr\n"
304                    "dup v1.4s, wzr\n"
305                    "dup v2.4s, wzr\n"
306                    "dup v3.4s, wzr\n"
307                    // Load zero points.
308                    "ld1 {v7.4s}, [%[batch_offsets_ptr]]\n"
309                    "ld1 {v4.4s}, [%[scaling_factors_ptr]]\n"
310                    // Zero out zero point accumulators.
311                    "dup v14.4s, wzr\n"
312                    "dup v15.4s, wzr\n"
313
314                    // Load per channel scales if not null.
315                    "cmp %w[is_channel_scale_nullptr], #0\n"
316                    "bne 1f\n"
317                    "ld1r {v16.4s}, [%[channel_scales_ptr]], #4\n"
318                    "ld1r {v17.4s}, [%[channel_scales_ptr]]\n"
319                    "fmul v16.4s, v16.4s, v4.4s\n"
320                    "fmul v17.4s, v17.4s, v4.4s\n"
321                    "b 2f\n"
322                    "1:\n"
323                    "mov v16.16b, v4.16b\n"
324                    "mov v17.16b, v4.16b\n"
325                    "2:\n"
326                    "ld1 {v12.16b}, [%[mat_ptr0]], #16\n"
327                    "ld1 {v8.16b}, [%[vec_ptr]], #16\n"
328                    ".word 0x4f8ce100  // sdot v0.4s, v8.16b, v12.4b[0]\n"
329                    "ld1 {v9.16b}, [%[vec_ptr]], #16\n"
330                    ".word 0x4face121  // sdot v1.4s, v9.16b, v12.4b[1]\n"
331                    "ld1 {v10.16b}, [%[vec_ptr]], #16\n"
332                    ".word 0x4f8ce940  // sdot v0.4s, v10.16b, v12.4b[2]\n"
333                    "ld1 {v11.16b}, [%[vec_ptr]], #16\n"
334                    ".word 0x4face961  // sdot v1.4s, v11.16b, v12.4b[3]\n"
335                    "ld1 {v13.16b}, [%[mat_ptr1]], #16\n"
336                    ".word 0x4f8de102  // sdot v2.4s, v8.16b, v13.4b[0]\n"
337                    ".word 0x4fade123  // sdot v3.4s, v9.16b, v13.4b[1]\n"
338                    ".word 0x4f8de942  // sdot v2.4s, v10.16b, v13.4b[2]\n"
339                    ".word 0x4fade963  // sdot v3.4s, v11.16b, v13.4b[3]\n"
340                    "cmp %w[is_row_sums_nullptr], #1\n"
341                    "bne 3f\n"
342                    // Accumulate row_sums for zero point calculations.
343                    "saddlp v12.8h, v12.16b\n"
344                    "saddlp v13.8h, v13.16b\n"
345                    "sadalp v14.4s, v12.8h\n"
346                    "sadalp v15.4s, v13.8h\n"
347                    "3:\n"
348                    "cmp %[mat_ptr0], %[mat_ptr0_end]\n"
349                    "bne 2b\n"
350                    "add v0.4s, v0.4s, v1.4s\n"
351                    "add v2.4s, v2.4s, v3.4s\n"
352
353                    "cmp %w[is_row_sums_nullptr], #1\n"
354                    "bne 4f\n"
355                    // Calculate zero point offsets.
356                    "addv s14, v14.4s\n"
357                    "addv s15, v15.4s\n"
358                    "dup v14.4s, v14.s[0]\n"
359                    "dup v15.4s, v15.s[0]\n"
360                    "b 5f\n"
361                    "4:\n"
362                    "ld1r {v14.4s}, [%[row_sums_ptr]], #4\n"
363                    "ld1r {v15.4s}, [%[row_sums_ptr]]\n"
364                    "5:\n"
365
366                    "mul v14.4s, v14.4s, v7.4s\n"
367                    "mul v15.4s, v15.4s, v7.4s\n"
368                    "sub v0.4s, v0.4s, v14.4s\n"
369                    "sub v2.4s, v2.4s, v15.4s\n"
370
371                    "scvtf v0.4s, v0.4s\n"
372                    "scvtf v1.4s, v2.4s\n"
373
374                    // Multiply scale.
375                    "fmul v0.4s, v16.4s, v0.4s\n"
376                    "fmul v1.4s, v17.4s, v1.4s\n"
377
378                    "ld2 {v9.s, v10.s}[0], [%[result_ptr]], %[wide_rows]\n"
379                    "ld2 {v9.s, v10.s}[1], [%[result_ptr]], %[wide_rows]\n"
380                    "ld2 {v9.s, v10.s}[2], [%[result_ptr]], %[wide_rows]\n"
381                    "ld2 {v9.s, v10.s}[3], [%[result_ptr]], %[wide_rows]\n"
382                    "sub %[result_ptr], %[result_ptr], %[wide_rows], lsl #2\n"
383                    "fadd v9.4s, v9.4s, v0.4s\n"
384                    "fadd v10.4s, v10.4s, v1.4s\n"
385                    "st2 {v9.s, v10.s}[0], [%[result_ptr]], %[wide_rows]\n"
386                    "st2 {v9.s, v10.s}[1], [%[result_ptr]], %[wide_rows]\n"
387                    "st2 {v9.s, v10.s}[2], [%[result_ptr]], %[wide_rows]\n"
388                    "st2 {v9.s, v10.s}[3], [%[result_ptr]], %[wide_rows]\n"
389                    : [mat_ptr0] "+r"(mat_ptr0), [mat_ptr1] "+r"(mat_ptr1), [vec_ptr] "+r"(vec_ptr),
390                      [result_ptr] "+r"(result_ptr), [row_sums_ptr] "+r"(row_sums_ptr)
391                    : [mat_ptr0_end] "r"(mat_ptr0_end),
392                      [scaling_factors_ptr] "r"(scaling_factors_ptr), [wide_rows] "r"(wide_rows),
393                      [channel_scales_ptr] "r"(channel_scales_ptr),
394                      [batch_offsets_ptr] "r"(batch_offsets_ptr),
395                      [is_channel_scale_nullptr] "r"(is_channel_scale_nullptr),
396                      [is_row_sums_nullptr] "r"(is_row_sums_nullptr)
397                    : "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
398                      "v12", "v13", "v14", "v15", "v16", "v17", "w0", "w1", "cc", "memory");
399     }
400   }
401
402   free(shuffled_vectors_free);
403 }
404
405 // The DotprodMatrixBatchFourVectorMultiplyAccumulate kernel processes 4
406 // vectors in the same time as the baseline processes 1 vector. However, it
407 // requires 4 vectors of input.
408 //
409 // To take advantage of this speed difference, we add some zero-valued
410 // vectors to the batch so that n_batch is a multiple of 4. Then we execute
411 // DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate on that padded batch,
412 // then extract just the results we want at the end (ignoring the extra padding
413 // outputs).
414 //
415 // The relative cost of the padding is large when the matrix is smaller than
416 // 128x128, so we don't use this code path on small matrices. On larger
417 // matrices, the computation cost dwarfs the padding cost, making this code
418 // viable.
419 //
420 // If we ignore the cost of padding, this kernel is:
421 //    1x the speed of NeonMatrixBatchVectorMultiplyImpl for n_batch = 1
422 //    2x the speed of NeonMatrixBatchVectorMultiplyImpl for n_batch = 2
423 //    3x the speed of NeonMatrixBatchVectorMultiplyImpl for n_batch = 3
424 //    ...
425 //
426 // We don't use this kernel when n_batch = 1 because the baseline kernel
427 // is fine for that case.
428 void DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(
429     const int8_t *__restrict__ matrix, const int m_rows, const int m_cols, const int8_t *vectors,
430     const float *scaling_factors, int n_batch, float *__restrict__ result,
431     const float *per_channel_scale, const int32_t *input_offset, int32_t *row_sums)
432 {
433   const int kWeightsPerUint32 = 4;
434
435   // Round to the nearest multiple of 4.
436   int batch_round_up = n_batch;
437   if (n_batch % 4 != 0)
438   {
439     batch_round_up += (4 - n_batch % 4);
440   }
441   assert(n_batch <= batch_round_up);
442
443   void *padded_vectors_free;
444   const int padded_vectors_size = batch_round_up * m_cols;
445   int8_t *padded_vectors = reinterpret_cast<int8_t *>(
446       aligned_alloc(kWeightsPerUint32, padded_vectors_size, &padded_vectors_free));
447   memset(padded_vectors, 0, padded_vectors_size);
448
449   void *padded_result_free;
450   const int result_size = n_batch * m_rows * sizeof(float);
451   const int padded_result_size = batch_round_up * m_rows * sizeof(float);
452   float *padded_result = reinterpret_cast<float *>(
453       aligned_alloc(kWeightsPerUint32, padded_result_size, &padded_result_free));
454   memcpy(padded_result, result, result_size);
455   memset(reinterpret_cast<char *>(padded_result) + result_size, 0,
456          padded_result_size - result_size);
457
458   // Copy the input into the padded data structure.
459   assert(n_batch * m_cols <= padded_vectors_size);
460   memcpy(padded_vectors, vectors, n_batch * m_cols);
461
462   void *padded_scaling_factors_free;
463   const int padded_scaling_factors_size = batch_round_up * sizeof(float);
464   float *padded_scaling_factors = reinterpret_cast<float *>(
465       aligned_alloc(kWeightsPerUint32, padded_scaling_factors_size, &padded_scaling_factors_free));
466   assert(static_cast<int>(n_batch * sizeof(float)) <= padded_scaling_factors_size);
467   assert(static_cast<int>(batch_round_up * sizeof(float)) <= padded_scaling_factors_size);
468   memset(padded_scaling_factors, 0, batch_round_up * sizeof(float));
469   memcpy(padded_scaling_factors, scaling_factors, n_batch * sizeof(float));
470
471   if (input_offset != nullptr)
472   {
473     void *padded_input_offset_free;
474     const int padded_input_offset_size = batch_round_up * sizeof(int32_t);
475     int32_t *padded_input_offset = reinterpret_cast<int32_t *>(
476         aligned_alloc(kWeightsPerUint32, padded_input_offset_size, &padded_input_offset_free));
477     assert(static_cast<int>(n_batch * sizeof(int32_t)) <= padded_input_offset_size);
478     assert(static_cast<int>(batch_round_up * sizeof(int32_t)) <= padded_input_offset_size);
479     memset(padded_input_offset, 0, batch_round_up * sizeof(int32_t));
480     memcpy(padded_input_offset, input_offset, n_batch * sizeof(int32_t));
481
482     // Call the main kernel.
483     DotprodMatrixBatchFourVectorMultiplyAccumulate(
484         matrix, m_rows, m_cols, padded_vectors, padded_scaling_factors, batch_round_up,
485         padded_result, per_channel_scale, padded_input_offset, row_sums);
486
487     free(padded_input_offset_free);
488   }
489   else
490   {
491     // Call the main kernel.
492     DotprodMatrixBatchFourVectorMultiplyAccumulate(matrix, m_rows, m_cols, padded_vectors,
493                                                    padded_scaling_factors, batch_round_up,
494                                                    padded_result);
495   }
496   memcpy(result, padded_result, result_size);
497
498   free(padded_result_free);
499   free(padded_vectors_free);
500   free(padded_scaling_factors_free);
501 }
502
503 void DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(const int8_t *__restrict__ matrix,
504                                                           const int m_rows, const int m_cols,
505                                                           const int8_t *vectors,
506                                                           const float *scaling_factors, int n_batch,
507                                                           float *__restrict__ result)
508 {
509   DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(
510       matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
511       /*per_channel_scale=*/nullptr, /*input_offset=*/nullptr,
512       /*row_sums=*/nullptr);
513 }
514 #endif // __aarch64__
515
516 bool NeonIsZeroVector(const float *vector, int v_size)
517 {
518   // If v_size is not divisible by kFloatWeightsPerNeonLane, we cannot
519   // use the main vectorized loop, and we need to process sequentially.
520   // postamble_start shows the start index where this should happen.
521   const int postamble_start = v_size - (v_size & (kFloatWeightsPerNeonLane - 1));
522
523   const float32x4_t zero_x4_float = vmovq_n_f32(0.0f);
524   for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane)
525   {
526     const float32x4_t i_x4_float = vld1q_f32(vector + v);
527     uint32x4_t cmp_result = vceqq_f32(i_x4_float, zero_x4_float);
528     if (vgetq_lane_u32(cmp_result, 0) == 0)
529       return false;
530     if (vgetq_lane_u32(cmp_result, 1) == 0)
531       return false;
532     if (vgetq_lane_u32(cmp_result, 2) == 0)
533       return false;
534     if (vgetq_lane_u32(cmp_result, 3) == 0)
535       return false;
536   }
537
538   // Postamble loop
539   for (int v = postamble_start; v < v_size; ++v)
540   {
541     if (vector[v] != 0.0)
542       return false;
543   }
544   return true;
545 }
546
547 void NeonCpuBackendGemm(const int8_t *input, const int32_t *bias,
548                         const int8_t *input_to_gate_weights, int32_t n_batch, int32_t n_input,
549                         int32_t n_output, int32_t, int32_t *scratch, ruy::Context *ruy_context)
550 {
551   MatrixParams<int8_t> lhs_params;
552   lhs_params.order = Order::kRowMajor;
553   lhs_params.rows = n_output;
554   lhs_params.cols = n_input;
555   lhs_params.cacheable = true;
556
557   MatrixParams<int8_t> rhs_params;
558   rhs_params.order = Order::kColMajor;
559   rhs_params.rows = n_input;
560   rhs_params.cols = n_batch;
561
562   MatrixParams<int32_t> dst_params;
563   dst_params.order = Order::kColMajor;
564   dst_params.rows = n_output;
565   dst_params.cols = n_batch;
566
567   GemmParams<int32_t, int32_t> gemm_params;
568   if (bias)
569   {
570     gemm_params.bias = bias;
571   }
572
573   // Below code is from tflite::cpu_backend_gemm::detail::GemmImplUsingRuy
574   ruy::Matrix<int8_t> ruy_lhs;
575   ruy::Matrix<int8_t> ruy_rhs;
576   ruy::Matrix<int32_t> ruy_dst;
577   ruy_support::MakeRuyMatrix(lhs_params, input_to_gate_weights, &ruy_lhs);
578   ruy_support::MakeRuyMatrix(rhs_params, input, &ruy_rhs);
579   ruy_support::MakeRuyMatrix(dst_params, scratch, &ruy_dst);
580
581   ruy::BasicSpec<int32_t, int32_t> ruy_spec;
582   ruy_support::MakeRuySpec(gemm_params, &ruy_spec);
583
584   constexpr ruy::Path kRuyPath = ruy::kAllPaths;
585   ruy::Mul<kRuyPath>(ruy_lhs, ruy_rhs, ruy_spec, ruy_context, &ruy_dst);
586 }
587
588 void NeonSymmetricQuantizeFloats(const float *values, const int size, int8_t *quantized_values,
589                                  float *min, float *max, float *scaling_factor)
590 {
591   // TODO(raziel): vectorize min/max calculation.
592   auto minmax = std::minmax_element(values, values + size);
593   *min = *minmax.first;
594   *max = *minmax.second;
595   const int kScale = 127;
596   const float range = std::max(std::abs(*min), std::abs(*max));
597   if (range == 0)
598   {
599     memset(quantized_values, 0, size * sizeof(int8_t));
600     *scaling_factor = 1;
601     return;
602   }
603   *scaling_factor = range / kScale;
604   const float scaling_factor_inv = kScale / range;
605
606   const int postamble_start = size - (size & (2 * kFloatWeightsPerNeonLane - 1));
607
608   // Vectorized constants.
609   const float32x4_t q_factor_f32x4 = vmovq_n_f32(scaling_factor_inv);
610   const float32x4_t point5_f32x4 = vmovq_n_f32(0.5);
611   const float32x4_t zero_f32x4 = vmovq_n_f32(0.0);
612   const int32x4_t scale_i32x4 = vmovq_n_s32(kScale);
613   const int32x4_t neg_scale_i32x4 = vmovq_n_s32(-kScale);
614
615   for (int i = 0; i < postamble_start; i += 2 * kFloatWeightsPerNeonLane)
616   {
617     // Implements the vectorized version of the following:
618     // const int32_t quantized_value = static_cast<int32>(
619     //    std::round(*scaling_factor * values[i]));
620     // Since the vectorized round intrinsics (vrndqa_f32) is not supported
621     // on all Neon flavors, we use the following method for rounding: if (x
622     // < 0) (int)(x - 0.5) if (x >= 0) (int)(x + 0.5)
623     float32x4_t value0_f32x4 = vld1q_f32(&values[i]);
624     float32x4_t value1_f32x4 = vld1q_f32(&values[i + kFloatWeightsPerNeonLane]);
625     float32x4_t mul0_f32x4 = vmulq_f32(value0_f32x4, q_factor_f32x4);
626     float32x4_t mul1_f32x4 = vmulq_f32(value1_f32x4, q_factor_f32x4);
627
628     int32x4_t cmp_with_zero0_ui32x4 = (int32x4_t)vcltq_f32(mul0_f32x4, zero_f32x4); // NOLINT
629     int32x4_t cmp_with_zero1_ui32x4 = (int32x4_t)vcltq_f32(mul1_f32x4, zero_f32x4); // NOLINT
630
631     float32x4_t cmp_with_zero0_f32x4 = vcvtq_f32_s32(cmp_with_zero0_ui32x4);
632     float32x4_t cmp_with_zero1_f32x4 = vcvtq_f32_s32(cmp_with_zero1_ui32x4);
633     cmp_with_zero0_f32x4 = vaddq_f32(cmp_with_zero0_f32x4, point5_f32x4);
634     cmp_with_zero1_f32x4 = vaddq_f32(cmp_with_zero1_f32x4, point5_f32x4);
635
636     mul0_f32x4 = vaddq_f32(mul0_f32x4, cmp_with_zero0_f32x4);
637     mul1_f32x4 = vaddq_f32(mul1_f32x4, cmp_with_zero1_f32x4);
638
639     int32x4_t f2i0_i32x4 = vcvtq_s32_f32(mul0_f32x4);
640     int32x4_t f2i1_i32x4 = vcvtq_s32_f32(mul1_f32x4);
641
642     // Implements the vectorized version of the folowing block:
643     //  quantized_values[i] = std::min(kScale, std::max(-kScale,
644     //  quantized_value));
645     int32x4_t max0_i32x4 = vmaxq_s32(f2i0_i32x4, neg_scale_i32x4);
646     int32x4_t max1_i32x4 = vmaxq_s32(f2i1_i32x4, neg_scale_i32x4);
647     int32x4_t min0_i32x4 = vminq_s32(max0_i32x4, scale_i32x4);
648     int32x4_t min1_i32x4 = vminq_s32(max1_i32x4, scale_i32x4);
649
650     int16x4_t min0_16x4 = vmovn_s32(min0_i32x4);
651     int16x4_t min1_16x4 = vmovn_s32(min1_i32x4);
652
653     int16x8_t min_16x8 = vcombine_s16(min0_16x4, min1_16x4);
654     int8x8_t min_s8x8 = vqmovn_s16(min_16x8);
655     vst1_s8(&quantized_values[i], min_s8x8);
656   }
657
658   for (int i = postamble_start; i < size; ++i)
659   {
660     const int32_t quantized_value =
661         static_cast<int32_t>(std::round(scaling_factor_inv * values[i]));
662     quantized_values[i] = std::min(kScale, std::max(-kScale, quantized_value));
663   }
664 }
665
666 void NeonMatrixBatchVectorMultiplyAccumulate(const int8_t *__restrict__ matrix, const int m_rows,
667                                              const int m_cols, const int8_t *__restrict__ vectors,
668                                              const float *scaling_factors, int n_batch,
669                                              float *__restrict__ result, int result_stride)
670 {
671 #ifdef __aarch64__
672   if (HasSdotInstruction() && m_cols % 16 == 0 && m_rows % 2 == 0 && m_rows >= n_batch)
673   {
674     if (n_batch % 4 == 0 && result_stride == 1)
675     {
676       // Benchmarks suggest that it's always better to use the batch code
677       // when we can, even on small matrices.
678       DotprodMatrixBatchFourVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
679                                                      scaling_factors, n_batch, result);
680       return;
681     }
682     else if (result_stride == 1 && n_batch >= 2 && m_rows * m_cols >= 128 * 128)
683     {
684       DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
685                                                            scaling_factors, n_batch, result);
686       return;
687     }
688   }
689 #endif // __aarch64__
690
691   static const int kWeightsPerUint32 = 4;
692   static const int kWeightsPerNeonLane = 16;
693   // Assuming *matrix is kWeightsPerUint32-byte aligned,
694   // every row of the matrix is also
695   // kWeightsPerUint32-byte aligned as long as cols is
696   // a multiple of kWeightsPerUint32. The assumption
697   // is currently satisfied by TFLite's 16-byte memory
698   // alignment scheme.
699   //
700   // Otherwise, we allocate an aligned memory block and set
701   // a flag to later copy rows from matrix to the block
702   // for aligned multiplication.
703   bool unaligned = false;
704   int8_t *aligned_row = nullptr;
705   void *aligned_row_free = nullptr;
706   if ((m_cols & (kWeightsPerUint32 - 1)) != 0)
707   {
708     unaligned = true;
709     aligned_row = (int8_t *)aligned_alloc(kWeightsPerUint32, m_cols, // NOLINT
710                                           &aligned_row_free);
711   }
712   void *aligned_vec_free = nullptr;
713   int8_t *aligned_vec = (int8_t *)aligned_alloc(kWeightsPerUint32, m_cols, // NOLINT
714                                                 &aligned_vec_free);
715
716   // If m_cols is not at least kWeightsPerNeonLane, we cannot use the main
717   // vectorized loop, and we need to process sequentially. postamble_half_start
718   // shows the start index where this should happen. Between postamble_start and
719   // postamble_half_start we can still process kWeightsPerNeonLane >> 1 in a
720   // vectorized form.
721   const int postamble_half_start = m_cols & ~(kWeightsPerNeonLane - 1);
722   const int postamble_start = m_cols & ~((kWeightsPerNeonLane >> 1) - 1);
723
724   for (int batch = 0; batch < n_batch; ++batch)
725   {
726     const float batch_scaling_factor = scaling_factors[batch];
727     // Copy the vector data to an aligned vector.
728     memcpy(aligned_vec, vectors + batch * m_cols, sizeof(int8_t) * m_cols);
729     // Compute dot-product for every column.
730     for (int row = 0; row < m_rows; ++row, result += result_stride)
731     {
732       // Get the address of the first element of the row.
733       int8_t *row_ptr = (int8_t *)matrix + row * m_cols; // NOLINT
734       if (unaligned)
735       {
736         memcpy(aligned_row, row_ptr, sizeof(int8_t) * m_cols);
737         row_ptr = aligned_row;
738       }
739
740       // Initialize the dot product sum for the row to 0.
741       int32x4_t dotprod_32x4 = vmovq_n_s32(0);
742
743       // Prefetch the row to cache.
744       __builtin_prefetch(row_ptr, 0 /* prefetch for read */, 3 /* temporal locality */);
745
746       // For every block of 16 8-bit elements.
747       int col = 0;
748       for (; col < postamble_half_start; col += kWeightsPerNeonLane)
749       {
750         // Load 16 8-bit values from the row and vector, each, to operate on.
751         // Here the assumption is that each buffer is 4-byte aligned. Otherwise,
752         // performance may suffer significantly.
753         assert( // NOLINT
754             ((uintptr_t)(&row_ptr[col]) & (kWeightsPerUint32 - 1)) == 0);
755         const int8x16_t s1_8x16 = vld1q_s8((const int8_t *)(aligned_vec + col));
756         const int8x16_t s2_8x16 = vld1q_s8((const int8_t *)(row_ptr + col));
757         // Multiply the low bits (i.e. the lower 8 8bit numbers in the
758         // registers).
759         int16x8_t prod_16x8 = vmull_s8(vget_low_s8(s1_8x16), vget_low_s8(s2_8x16));
760         // Multiply the high bits (i.e. the higher 8 8bit numbers in the
761         // registers), and accumulate with the result of the low bits product.
762         // The assumption here is that overflow will not happen as we quantize
763         // our values to be in the range [-127, 127]. As such the sum of the 2
764         // products is always strictly smaller than 15-bits (32767 in absolute
765         // value).
766         prod_16x8 = vmlal_s8(prod_16x8, vget_high_s8(s1_8x16), vget_high_s8(s2_8x16));
767
768         dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
769       } // for col
770
771       // Half iteration dealing only 8 elements
772       // TODO(raziel): if (ABSL_PREDICT_FALSE(col < postamble_start))
773       if (col < postamble_start)
774       {
775         // Load 8 8-bit values from the row and column each to operate on.
776         // Here the assumption is that each buffer is 4-bytes aligned.
777         // Otherwise, performance may suffer significantly.
778         assert( // NOLINT
779             ((uintptr_t)(&row_ptr[col]) & (kWeightsPerUint32 - 1)) == 0);
780         const int8x8_t s1_8x8 = vld1_s8((const int8_t *)(aligned_vec + col));
781         const int8x8_t s2_8x8 = vld1_s8((const int8_t *)(row_ptr + col));
782         const int16x8_t prod_16x8 = vmull_s8(s1_8x8, s2_8x8);
783         dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
784         col += (kWeightsPerNeonLane >> 1);
785       }
786       // Add the 4 intermediate sum values to get the final dot-prod value for
787       // this row.
788       int32_t dotprod = AccumulateNeonLane(dotprod_32x4);
789       // Postamble loop.
790       // TODO(raziel): if (ABSL_PREDICT_FALSE(col < m_cols))
791       for (; col < m_cols; ++col)
792       {
793         dotprod += row_ptr[col] * aligned_vec[col];
794       } // for col
795
796       *result += dotprod * batch_scaling_factor;
797     } // for row
798   }   // for batch
799
800   if (unaligned)
801   {
802     free(aligned_row_free);
803   }
804   free(aligned_vec_free);
805 }
806
807 void NeonMatrixBatchVectorMultiplyAccumulate(const float *matrix, int m_rows, int m_cols,
808                                              const float *vector, int n_batch, float *result,
809                                              int result_stride)
810 {
811   // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main
812   // vectorized loop, and we need to process sequentially. postamble_start shows
813   // the start index where this should happen.
814   const int postamble_start = m_cols - (m_cols & (kFloatWeightsPerNeonLane - 1));
815
816   for (int b = 0; b < n_batch; b++)
817   {
818     float *result_in_batch = result + b * m_rows * result_stride;
819     const float *vector_in_batch = vector + b * m_cols;
820     const float *matrix_row = matrix;
821
822     // Main matrix by vector multiplication loop
823     for (int r = 0; r < m_rows; r++)
824     {
825       float32x4_t acc_32x4 = vmovq_n_f32(0.0);
826       for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane)
827       {
828         // Load 4 float values from vector and matrix row.
829         float32x4_t vector_f32x4 = vld1q_f32(vector_in_batch + c);
830         float32x4_t matrix_f32x4 = vld1q_f32(matrix_row + c);
831         // Multiply the vector and matrix row and add to accumulator.
832         acc_32x4 = vmlaq_f32(acc_32x4, matrix_f32x4, vector_f32x4);
833       }
834       // Add the 4 intermediate sum values to get the final dot-prod value for
835       // this column.
836       *result_in_batch += (vgetq_lane_f32(acc_32x4, 0) + vgetq_lane_f32(acc_32x4, 1) +
837                            vgetq_lane_f32(acc_32x4, 2) + vgetq_lane_f32(acc_32x4, 3));
838       for (int c = postamble_start; c < m_cols; c++)
839       {
840         *result_in_batch += matrix_row[c] * vector_in_batch[c];
841       }
842       matrix_row += m_cols;
843       result_in_batch += result_stride;
844     }
845   }
846 }
847
848 void NeonMatrixBatchVectorMultiplyAccumulate(const int8_t *__restrict__ matrix, const int m_rows,
849                                              const int m_cols, const int8_t *__restrict__ vectors,
850                                              const float *scaling_factors, int n_batch,
851                                              int32_t *scratch, float *__restrict__ result,
852                                              int result_stride, ruy::Context *ruy_context)
853 {
854   if (m_rows % 4 == 0 && result_stride == 1)
855   {
856     const int32_t *bias = static_cast<const int32_t *>(nullptr);
857     NeonCpuBackendGemm(vectors, bias, matrix, n_batch, m_cols, m_rows,
858                        /*output_zp =*/0, scratch, ruy_context);
859
860     // Multiply by float scaling factors and write to result
861     const int total_size = n_batch * m_rows;
862     int i = 0;
863     for (; i <= total_size - 8; i += 8, result += 8 * result_stride)
864     {
865       const float batch_scaling_factor0 = scaling_factors[i / m_rows];
866       const float batch_scaling_factor1 = scaling_factors[(i + 4) / m_rows];
867       const float32x4_t scaling_factor0 = vdupq_n_f32(batch_scaling_factor0);
868       const float32x4_t scaling_factor1 = vdupq_n_f32(batch_scaling_factor1);
869       const int32x4_t scratch_val0 = vld1q_s32(scratch + i);
870       const int32x4_t scratch_val1 = vld1q_s32(scratch + i + 4);
871       const float32x4_t float_val0 = vcvtq_f32_s32(scratch_val0);
872       const float32x4_t float_val1 = vcvtq_f32_s32(scratch_val1);
873       const float32x4_t result0 = vmlaq_f32(vld1q_f32(result), float_val0, scaling_factor0);
874       const float32x4_t result1 =
875           vmlaq_f32(vld1q_f32(result + 4 * result_stride), float_val1, scaling_factor1);
876       vst1q_f32(result, result0);
877       vst1q_f32(result + 4 * result_stride, result1);
878     }
879     scratch += i;
880     for (; i < total_size; i++, result += result_stride)
881     {
882       const float batch_scaling_factor = scaling_factors[i / m_rows];
883       int32_t x = *(scratch++);
884       *result += x * batch_scaling_factor;
885     }
886     return;
887   }
888   NeonMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors, scaling_factors, n_batch,
889                                           result, result_stride);
890 }
891
892 } // namespace cker
893 } // namespace nnfw
894
895 #endif // USE_NEON
896
897 #endif // __NNFW_CKER_NEON_TENSOR_UTILS_H__