e08040632083bb4b1edae51d97f447d0bbc3fa79
[platform/core/ml/nnfw.git] / compute / cker / include / cker / NeonTensorUtils.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_CKER_NEON_TENSOR_UTILS_H__
19 #define __NNFW_CKER_NEON_TENSOR_UTILS_H__
20
21 #include <ruy/path.h>
22 #include <ruy/ruy.h>
23 #include "cker/Types.h"
24 #include "cker/neon/neon_check.h"
25 #include "cker/ruy/RuySupport.h"
26 #include "util/logging.h"
27 #if defined __linux__ && defined __aarch64__
28 #include <sys/auxv.h>
29 #endif
30
31 #include <cassert>
32 #include <cmath>
33
34 #ifdef USE_NEON
35
36 #define kFloatWeightsPerNeonLane 4
37
38 namespace nnfw
39 {
40 namespace cker
41 {
42
43 namespace
44 {
45
46 constexpr int kFloatValuesPerNeonVector = 4;
47
48 // TODO(ahentz): Clean up.
49 using int8 = std::int8_t;
50 using uint8 = std::uint8_t;
51 using int16 = std::int16_t;
52 using uint16 = std::uint16_t;
53 using int32 = std::int32_t;
54 using uint32 = std::uint32_t;
55
56 template <int PerNeonSize> inline int RoundDownVectors(int size)
57 {
58   return size & ~(PerNeonSize - 1);
59 }
60
61 // Allocates, at least, size bytes of uninitialized storage whose alignment is
62 // specified by alignment. The size parameter must be an integral multiple of
63 // alignment.
64 // Caller is responsible by freeing the allocated memory by calling free on
65 // the passed freeing_buffer pointer.
66 void *aligned_alloc(size_t alignment, size_t size, void **freeing_buffer)
67 {
68   *freeing_buffer = malloc(size + alignment);
69   const size_t offset = ((uintptr_t)*freeing_buffer) % alignment;                          // NOLINT
70   return offset == 0 ? *freeing_buffer : ((char *)*freeing_buffer + (alignment - offset)); // NOLINT
71 }
72
73 inline int32_t AccumulateNeonLane(const int32x4_t lane)
74 {
75 #ifdef __aarch64__
76   return vaddvq_s32(lane);
77 #else
78   int64x2_t pairwiseAdded = vpaddlq_s32(lane);
79   return vgetq_lane_s64(pairwiseAdded, 0) + vgetq_lane_s64(pairwiseAdded, 1);
80 #endif
81 }
82
83 } // namespace
84
85 // The implementation of dotprod detection is copied from ruy's internal
86 // function DetectDotprod().
87 // At the moment it's only implemented on Linux ARM64. Consider syncing again
88 // with ruy in the future to share improvements.
89 #if defined __linux__ && defined __aarch64__
90 inline bool DetectDotprodByLinuxAuxvMethod()
91 {
92   // This is the value of HWCAP_ASIMDDP in sufficiently recent Linux headers,
93   // however we need to support building against older headers for the time
94   // being.
95   const int kLocalHwcapAsimddp = 1 << 20;
96   return getauxval(AT_HWCAP) & kLocalHwcapAsimddp;
97 }
98 #endif
99
100 inline bool DetectArmNeonDotprod()
101 {
102 #if defined __linux__ && defined __aarch64__
103   return DetectDotprodByLinuxAuxvMethod();
104 #endif
105
106   return false;
107 }
108
109 inline bool HasSdotInstruction()
110 {
111   static const bool has_dotprod = DetectArmNeonDotprod();
112   return has_dotprod;
113 }
114
115 #ifdef __aarch64__
116 // We interleave vector data to make the dot product logic more efficient.
117 // Suppose that vectors is:
118 //     a0 a1 a2 a3 a4 a5 ...
119 //     b0 b1 b2 b3 b4 b5 ...
120 //     c0 c1 c2 c3 c4 c5 ...
121 //     d0 d1 d2 d3 d4 d5 ...
122 //     e0 e1 e2 e3 e4 e5 ...
123 // This code interleaves them like this:
124 //     a0 a1 a2 a3 b0 b1 b2 b3 c0 c1 c2 c3 d0 d1 d2 d3 a4 a5 a6 a7 b4 ...
125 //     e0 e1 e2 e3 f0 f1 f2 f3 ...
126 // Once the data is interleaved, each 16-byte read from the vectors pointer
127 // contains 4 bytes from each of 4 vectors.
128 inline const int8_t *ShuffleVectors(const int8_t *vectors, const int n_batch, const int m_cols,
129                                     void **shuffled_vectors_free)
130 {
131   const int kWeightsPerUint32 = 4;
132
133   int8 *shuffled_vectors = reinterpret_cast<int8 *>(
134       aligned_alloc(kWeightsPerUint32, n_batch * m_cols, shuffled_vectors_free));
135
136   for (int i = 0; i < n_batch; i += 4)
137   {
138     int8 *shuffled_vectors_ptr = shuffled_vectors + (i * m_cols);
139     const int8 *unshuffled_vec0_ptr = reinterpret_cast<const int8 *>(vectors) + (i * m_cols);
140     const int8 *unshuffled_vec1_ptr = reinterpret_cast<const int8 *>(vectors) + ((i + 1) * m_cols);
141     const int8 *unshuffled_vec2_ptr = reinterpret_cast<const int8 *>(vectors) + ((i + 2) * m_cols);
142     const int8 *unshuffled_vec3_ptr = reinterpret_cast<const int8 *>(vectors) + ((i + 3) * m_cols);
143     const int8 *const end_vec0_ptr = unshuffled_vec1_ptr;
144
145     while (unshuffled_vec0_ptr != end_vec0_ptr)
146     {
147       asm volatile(
148           // This code path requires that (n_cols % 16) == 0 so we can safely
149           // read in 16-byte chunks from each row.
150           "ld1 {v0.16b}, [%[unshuffled_vec0_ptr]], #16\n"
151           "ld1 {v1.16b}, [%[unshuffled_vec1_ptr]], #16\n"
152           "ld1 {v2.16b}, [%[unshuffled_vec2_ptr]], #16\n"
153           "ld1 {v3.16b}, [%[unshuffled_vec3_ptr]], #16\n"
154
155           "st4 {v0.s, v1.s, v2.s, v3.s}[0], [%[shuffled_vectors_ptr]], #16\n"
156           "st4 {v0.s, v1.s, v2.s, v3.s}[1], [%[shuffled_vectors_ptr]], #16\n"
157           "st4 {v0.s, v1.s, v2.s, v3.s}[2], [%[shuffled_vectors_ptr]], #16\n"
158           "st4 {v0.s, v1.s, v2.s, v3.s}[3], [%[shuffled_vectors_ptr]], #16\n"
159
160           : [unshuffled_vec0_ptr] "+r"(unshuffled_vec0_ptr),
161             [unshuffled_vec1_ptr] "+r"(unshuffled_vec1_ptr),
162             [unshuffled_vec2_ptr] "+r"(unshuffled_vec2_ptr),
163             [unshuffled_vec3_ptr] "+r"(unshuffled_vec3_ptr),
164             [shuffled_vectors_ptr] "+r"(shuffled_vectors_ptr)
165           :
166           : "v0", "v1", "v2", "v3", "cc", "memory");
167     }
168   }
169
170   return reinterpret_cast<const int8_t *>(shuffled_vectors);
171 }
172
173 // Notes about the speed of this version vs. the baseline (from memory):
174 // - With 256K of L1, we can keep a lot of vectors in cache.
175 //   I recall a reasonable speedup just by rearranging the loop to have
176 //   row on the outside and batch on the inside.
177 // - I also recall getting a nice speedup from sdot.
178 // - I tried many times to do better than the current implementation, using
179 //   loop unrolling and instruction reordering to avoid stalls, etc.
180 //   but I was not able to do significantly better. This code is, however,
181 //   much worse than what the processor spec sheet suggests is possible.
182 static void DotprodMatrixBatchFourVectorMultiplyAccumulate(const int8_t *__restrict__ matrix,
183                                                            const int m_rows, const int m_cols,
184                                                            const int8_t *vectors,
185                                                            const float *scaling_factors,
186                                                            int n_batch, float *__restrict__ result)
187 {
188   void *shuffled_vectors_free;
189
190   const int8_t *shuffled_vectors = ShuffleVectors(vectors, n_batch, m_cols, &shuffled_vectors_free);
191
192   for (int row = 0; row < m_rows; row += 2)
193   {
194     for (int batch = 0; batch < n_batch; batch += 4)
195     {
196       float *result_ptr = result + (batch * m_rows) + row;
197       const int8 *mat_ptr0 = matrix + (row * m_cols);
198       const int8 *mat_ptr1 = matrix + ((row + 1) * m_cols);
199       const int8 *mat_ptr0_end = mat_ptr1;
200       const int8 *vec_ptr = shuffled_vectors + (batch * m_cols);
201       const float *scaling_factors_ptr = scaling_factors + batch;
202       const uint64_t wide_rows = m_rows * sizeof(float);
203       const int8 *mat_ptr2 = matrix + ((row + 2) * m_cols);
204       const int8 *mat_ptr3 = matrix + ((row + 3) * m_cols);
205
206       asm volatile(
207           // Zero out the accumulator registers.
208           "dup v0.4s, wzr\n"
209           "dup v1.4s, wzr\n"
210           "dup v2.4s, wzr\n"
211           "dup v3.4s, wzr\n"
212
213           "1:\n" // batch_cols_loop
214
215           // Read 16 more bytes from a pair of matrix rows.
216           "ld1 {v12.16b}, [%[mat_ptr0]], #16\n"
217
218           // Prefetch two rows ahead.
219           "prfm pldl1strm, [%[mat_ptr2]]\n"
220           "prfm pldl1strm, [%[mat_ptr3]]\n"
221
222           // Read from input vectors 4 times; 64 bytes total.
223           // Each 16-byte register contains parts of 4 vectors; see the
224           // shuffle logic above.
225
226           // From Benoit, places to look in the future:
227           // - Move load instructions further from sdot
228           // - Switch loop use-then-reload
229           // - Do partial unrolling to use register space better
230           "ld1 {v8.16b}, [%[vec_ptr]], #16\n"
231           ".word 0x4f8ce100  // sdot v0.4s, v8.16b, v12.4b[0]\n"
232           "ld1 {v9.16b}, [%[vec_ptr]], #16\n"
233           ".word 0x4face121  // sdot v1.4s, v9.16b, v12.4b[1]\n"
234           "ld1 {v10.16b}, [%[vec_ptr]], #16\n"
235           ".word 0x4f8ce940  // sdot v0.4s, v10.16b, v12.4b[2]\n"
236           "ld1 {v11.16b}, [%[vec_ptr]], #16\n"
237           ".word 0x4face961  // sdot v1.4s, v11.16b, v12.4b[3]\n"
238
239           // Update prefetch pointers.
240           "add %[mat_ptr2], %[mat_ptr2], #16\n"
241           "add %[mat_ptr3], %[mat_ptr3], #16\n"
242
243           // Re-use those vectors for the next row as well.
244           "ld1 {v13.16b}, [%[mat_ptr1]], #16\n"
245           ".word 0x4f8de102  // sdot v2.4s, v8.16b, v13.4b[0]\n"
246           ".word 0x4fade123  // sdot v3.4s, v9.16b, v13.4b[1]\n"
247           ".word 0x4f8de942  // sdot v2.4s, v10.16b, v13.4b[2]\n"
248           ".word 0x4fade963  // sdot v3.4s, v11.16b, v13.4b[3]\n"
249
250           // If we're not done with these rows, continue.
251           "cmp %[mat_ptr0], %[mat_ptr0_end]\n"
252           "bne 1b\n" // batch_cols_loop
253
254           // Done with the rows, sum the results.
255           "add v0.4s, v0.4s, v1.4s\n"
256           "add v2.4s, v2.4s, v3.4s\n"
257
258           // Convert the per-vector sums to floating point.
259           "scvtf v0.4s, v0.4s\n"
260           "scvtf v1.4s, v2.4s\n"
261
262           // Fetch scale factors.
263           "ld1 {v4.4s}, [%[scaling_factors_ptr]]\n"
264
265           // Multiply scale factors times sums.
266           "fmul v0.4s, v4.4s, v0.4s\n"
267           "fmul v1.4s, v4.4s, v1.4s\n"
268
269           // Load previous result values.
270           // The result position is:
271           //   result[batch * m_rows + row]
272           // Here that is factored into:
273           //   result_ptr = result + row
274           //   *result_ptr = res[0]
275           //   (uint8*)result_ptr += (m_rows * sizeof(float))
276           //   *result_ptr = res[1]
277           //   ...
278           // Since we're reading two rows at a time, though, we read both
279           //   result[batch * m_rows + row]
280           // and
281           //   result[batch * m_rows + row + 1]
282           "ld2 {v9.s, v10.s}[0], [%[result_ptr]], %[wide_rows]\n"
283           "ld2 {v9.s, v10.s}[1], [%[result_ptr]], %[wide_rows]\n"
284           "ld2 {v9.s, v10.s}[2], [%[result_ptr]], %[wide_rows]\n"
285           "ld2 {v9.s, v10.s}[3], [%[result_ptr]], %[wide_rows]\n"
286
287           // Go back to the starting position (subtract wide_rows * 4).
288           "sub %[result_ptr], %[result_ptr], %[wide_rows], lsl #2\n"
289
290           // Add previous result values.
291           "fadd v9.4s, v9.4s, v0.4s\n"
292           "fadd v10.4s, v10.4s, v1.4s\n"
293
294           // Store results.
295           "st2 {v9.s, v10.s}[0], [%[result_ptr]], %[wide_rows]\n"
296           "st2 {v9.s, v10.s}[1], [%[result_ptr]], %[wide_rows]\n"
297           "st2 {v9.s, v10.s}[2], [%[result_ptr]], %[wide_rows]\n"
298           "st2 {v9.s, v10.s}[3], [%[result_ptr]], %[wide_rows]\n"
299           : [mat_ptr0] "+r"(mat_ptr0), [mat_ptr1] "+r"(mat_ptr1), [vec_ptr] "+r"(vec_ptr),
300             [result_ptr] "+r"(result_ptr), [mat_ptr2] "+r"(mat_ptr2), [mat_ptr3] "+r"(mat_ptr3)
301           : [mat_ptr0_end] "r"(mat_ptr0_end), [scaling_factors_ptr] "r"(scaling_factors_ptr),
302             [wide_rows] "r"(wide_rows)
303           : "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
304             "v13", "cc", "memory");
305     }
306   }
307
308   free(shuffled_vectors_free);
309 }
310
311 static void DotprodMatrixBatchFourVectorMultiplyAccumulate(
312     const int8_t *__restrict__ matrix, const int m_rows, const int m_cols, const int8_t *vectors,
313     const float *scaling_factors, int n_batch, float *__restrict__ result,
314     const float *per_channel_scale, const int32_t *input_offset, int32_t *row_sums)
315 {
316   void *shuffled_vectors_free;
317   const int8_t *shuffled_vectors = ShuffleVectors(vectors, n_batch, m_cols, &shuffled_vectors_free);
318
319   for (int row = 0; row < m_rows; row += 2)
320   {
321     const float *channel_scales_ptr = per_channel_scale + row;
322     int32_t *row_sums_ptr = row_sums ? row_sums + row : nullptr;
323     for (int batch = 0; batch < n_batch; batch += 4)
324     {
325       float *result_ptr = result + (batch * m_rows) + row;
326       const int8 *mat_ptr0 = matrix + (row * m_cols);
327       const int8 *mat_ptr1 = matrix + ((row + 1) * m_cols);
328       const int8 *mat_ptr0_end = mat_ptr1;
329       const int8 *vec_ptr = shuffled_vectors + (batch * m_cols);
330       const float *scaling_factors_ptr = scaling_factors + batch;
331       const uint64_t wide_rows = m_rows * sizeof(float);
332       const int32_t *batch_offsets_ptr = input_offset + batch;
333       const int32_t is_channel_scale_nullptr = per_channel_scale == nullptr;
334       const int32_t is_row_sums_nullptr = row_sums_ptr == nullptr;
335       asm volatile("dup v0.4s, wzr\n"
336                    "dup v1.4s, wzr\n"
337                    "dup v2.4s, wzr\n"
338                    "dup v3.4s, wzr\n"
339                    // Load zero points.
340                    "ld1 {v7.4s}, [%[batch_offsets_ptr]]\n"
341                    "ld1 {v4.4s}, [%[scaling_factors_ptr]]\n"
342                    // Zero out zero point accumulators.
343                    "dup v14.4s, wzr\n"
344                    "dup v15.4s, wzr\n"
345
346                    // Load per channel scales if not null.
347                    "cmp %w[is_channel_scale_nullptr], #0\n"
348                    "bne 1f\n"
349                    "ld1r {v16.4s}, [%[channel_scales_ptr]], #4\n"
350                    "ld1r {v17.4s}, [%[channel_scales_ptr]]\n"
351                    "fmul v16.4s, v16.4s, v4.4s\n"
352                    "fmul v17.4s, v17.4s, v4.4s\n"
353                    "b 2f\n"
354                    "1:\n"
355                    "mov v16.16b, v4.16b\n"
356                    "mov v17.16b, v4.16b\n"
357                    "2:\n"
358                    "ld1 {v12.16b}, [%[mat_ptr0]], #16\n"
359                    "ld1 {v8.16b}, [%[vec_ptr]], #16\n"
360                    ".word 0x4f8ce100  // sdot v0.4s, v8.16b, v12.4b[0]\n"
361                    "ld1 {v9.16b}, [%[vec_ptr]], #16\n"
362                    ".word 0x4face121  // sdot v1.4s, v9.16b, v12.4b[1]\n"
363                    "ld1 {v10.16b}, [%[vec_ptr]], #16\n"
364                    ".word 0x4f8ce940  // sdot v0.4s, v10.16b, v12.4b[2]\n"
365                    "ld1 {v11.16b}, [%[vec_ptr]], #16\n"
366                    ".word 0x4face961  // sdot v1.4s, v11.16b, v12.4b[3]\n"
367                    "ld1 {v13.16b}, [%[mat_ptr1]], #16\n"
368                    ".word 0x4f8de102  // sdot v2.4s, v8.16b, v13.4b[0]\n"
369                    ".word 0x4fade123  // sdot v3.4s, v9.16b, v13.4b[1]\n"
370                    ".word 0x4f8de942  // sdot v2.4s, v10.16b, v13.4b[2]\n"
371                    ".word 0x4fade963  // sdot v3.4s, v11.16b, v13.4b[3]\n"
372                    "cmp %w[is_row_sums_nullptr], #1\n"
373                    "bne 3f\n"
374                    // Accumulate row_sums for zero point calculations.
375                    "saddlp v12.8h, v12.16b\n"
376                    "saddlp v13.8h, v13.16b\n"
377                    "sadalp v14.4s, v12.8h\n"
378                    "sadalp v15.4s, v13.8h\n"
379                    "3:\n"
380                    "cmp %[mat_ptr0], %[mat_ptr0_end]\n"
381                    "bne 2b\n"
382                    "add v0.4s, v0.4s, v1.4s\n"
383                    "add v2.4s, v2.4s, v3.4s\n"
384
385                    "cmp %w[is_row_sums_nullptr], #1\n"
386                    "bne 4f\n"
387                    // Calculate zero point offsets.
388                    "addv s14, v14.4s\n"
389                    "addv s15, v15.4s\n"
390                    "dup v14.4s, v14.s[0]\n"
391                    "dup v15.4s, v15.s[0]\n"
392                    "b 5f\n"
393                    "4:\n"
394                    "ld1r {v14.4s}, [%[row_sums_ptr]], #4\n"
395                    "ld1r {v15.4s}, [%[row_sums_ptr]]\n"
396                    "5:\n"
397
398                    "mul v14.4s, v14.4s, v7.4s\n"
399                    "mul v15.4s, v15.4s, v7.4s\n"
400                    "sub v0.4s, v0.4s, v14.4s\n"
401                    "sub v2.4s, v2.4s, v15.4s\n"
402
403                    "scvtf v0.4s, v0.4s\n"
404                    "scvtf v1.4s, v2.4s\n"
405
406                    // Multiply scale.
407                    "fmul v0.4s, v16.4s, v0.4s\n"
408                    "fmul v1.4s, v17.4s, v1.4s\n"
409
410                    "ld2 {v9.s, v10.s}[0], [%[result_ptr]], %[wide_rows]\n"
411                    "ld2 {v9.s, v10.s}[1], [%[result_ptr]], %[wide_rows]\n"
412                    "ld2 {v9.s, v10.s}[2], [%[result_ptr]], %[wide_rows]\n"
413                    "ld2 {v9.s, v10.s}[3], [%[result_ptr]], %[wide_rows]\n"
414                    "sub %[result_ptr], %[result_ptr], %[wide_rows], lsl #2\n"
415                    "fadd v9.4s, v9.4s, v0.4s\n"
416                    "fadd v10.4s, v10.4s, v1.4s\n"
417                    "st2 {v9.s, v10.s}[0], [%[result_ptr]], %[wide_rows]\n"
418                    "st2 {v9.s, v10.s}[1], [%[result_ptr]], %[wide_rows]\n"
419                    "st2 {v9.s, v10.s}[2], [%[result_ptr]], %[wide_rows]\n"
420                    "st2 {v9.s, v10.s}[3], [%[result_ptr]], %[wide_rows]\n"
421                    : [mat_ptr0] "+r"(mat_ptr0), [mat_ptr1] "+r"(mat_ptr1), [vec_ptr] "+r"(vec_ptr),
422                      [result_ptr] "+r"(result_ptr), [row_sums_ptr] "+r"(row_sums_ptr)
423                    : [mat_ptr0_end] "r"(mat_ptr0_end),
424                      [scaling_factors_ptr] "r"(scaling_factors_ptr), [wide_rows] "r"(wide_rows),
425                      [channel_scales_ptr] "r"(channel_scales_ptr),
426                      [batch_offsets_ptr] "r"(batch_offsets_ptr),
427                      [is_channel_scale_nullptr] "r"(is_channel_scale_nullptr),
428                      [is_row_sums_nullptr] "r"(is_row_sums_nullptr)
429                    : "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
430                      "v12", "v13", "v14", "v15", "v16", "v17", "w0", "w1", "cc", "memory");
431     }
432   }
433
434   free(shuffled_vectors_free);
435 }
436
437 // The DotprodMatrixBatchFourVectorMultiplyAccumulate kernel processes 4
438 // vectors in the same time as the baseline processes 1 vector. However, it
439 // requires 4 vectors of input.
440 //
441 // To take advantage of this speed difference, we add some zero-valued
442 // vectors to the batch so that n_batch is a multiple of 4. Then we execute
443 // DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate on that padded batch,
444 // then extract just the results we want at the end (ignoring the extra padding
445 // outputs).
446 //
447 // The relative cost of the padding is large when the matrix is smaller than
448 // 128x128, so we don't use this code path on small matrices. On larger
449 // matrices, the computation cost dwarfs the padding cost, making this code
450 // viable.
451 //
452 // If we ignore the cost of padding, this kernel is:
453 //    1x the speed of NeonMatrixBatchVectorMultiplyImpl for n_batch = 1
454 //    2x the speed of NeonMatrixBatchVectorMultiplyImpl for n_batch = 2
455 //    3x the speed of NeonMatrixBatchVectorMultiplyImpl for n_batch = 3
456 //    ...
457 //
458 // We don't use this kernel when n_batch = 1 because the baseline kernel
459 // is fine for that case.
460 inline void DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(
461     const int8_t *__restrict__ matrix, const int m_rows, const int m_cols, const int8_t *vectors,
462     const float *scaling_factors, int n_batch, float *__restrict__ result,
463     const float *per_channel_scale, const int32_t *input_offset, int32_t *row_sums)
464 {
465   const int kWeightsPerUint32 = 4;
466
467   // Round to the nearest multiple of 4.
468   int batch_round_up = n_batch;
469   if (n_batch % 4 != 0)
470   {
471     batch_round_up += (4 - n_batch % 4);
472   }
473   assert(n_batch <= batch_round_up);
474
475   void *padded_vectors_free;
476   const int padded_vectors_size = batch_round_up * m_cols;
477   int8_t *padded_vectors = reinterpret_cast<int8_t *>(
478       aligned_alloc(kWeightsPerUint32, padded_vectors_size, &padded_vectors_free));
479   memset(padded_vectors, 0, padded_vectors_size);
480
481   void *padded_result_free;
482   const int result_size = n_batch * m_rows * sizeof(float);
483   const int padded_result_size = batch_round_up * m_rows * sizeof(float);
484   float *padded_result = reinterpret_cast<float *>(
485       aligned_alloc(kWeightsPerUint32, padded_result_size, &padded_result_free));
486   memcpy(padded_result, result, result_size);
487   memset(reinterpret_cast<char *>(padded_result) + result_size, 0,
488          padded_result_size - result_size);
489
490   // Copy the input into the padded data structure.
491   assert(n_batch * m_cols <= padded_vectors_size);
492   memcpy(padded_vectors, vectors, n_batch * m_cols);
493
494   void *padded_scaling_factors_free;
495   const int padded_scaling_factors_size = batch_round_up * sizeof(float);
496   float *padded_scaling_factors = reinterpret_cast<float *>(
497       aligned_alloc(kWeightsPerUint32, padded_scaling_factors_size, &padded_scaling_factors_free));
498   assert(static_cast<int>(n_batch * sizeof(float)) <= padded_scaling_factors_size);
499   assert(static_cast<int>(batch_round_up * sizeof(float)) <= padded_scaling_factors_size);
500   memset(padded_scaling_factors, 0, batch_round_up * sizeof(float));
501   memcpy(padded_scaling_factors, scaling_factors, n_batch * sizeof(float));
502
503   if (input_offset != nullptr)
504   {
505     void *padded_input_offset_free;
506     const int padded_input_offset_size = batch_round_up * sizeof(int32_t);
507     int32_t *padded_input_offset = reinterpret_cast<int32_t *>(
508         aligned_alloc(kWeightsPerUint32, padded_input_offset_size, &padded_input_offset_free));
509     assert(static_cast<int>(n_batch * sizeof(int32_t)) <= padded_input_offset_size);
510     assert(static_cast<int>(batch_round_up * sizeof(int32_t)) <= padded_input_offset_size);
511     memset(padded_input_offset, 0, batch_round_up * sizeof(int32_t));
512     memcpy(padded_input_offset, input_offset, n_batch * sizeof(int32_t));
513
514     // Call the main kernel.
515     DotprodMatrixBatchFourVectorMultiplyAccumulate(
516         matrix, m_rows, m_cols, padded_vectors, padded_scaling_factors, batch_round_up,
517         padded_result, per_channel_scale, padded_input_offset, row_sums);
518
519     free(padded_input_offset_free);
520   }
521   else
522   {
523     // Call the main kernel.
524     DotprodMatrixBatchFourVectorMultiplyAccumulate(matrix, m_rows, m_cols, padded_vectors,
525                                                    padded_scaling_factors, batch_round_up,
526                                                    padded_result);
527   }
528   memcpy(result, padded_result, result_size);
529
530   free(padded_result_free);
531   free(padded_vectors_free);
532   free(padded_scaling_factors_free);
533 }
534
535 inline void DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(
536     const int8_t *__restrict__ matrix, const int m_rows, const int m_cols, const int8_t *vectors,
537     const float *scaling_factors, int n_batch, float *__restrict__ result)
538 {
539   DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(
540       matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
541       /*per_channel_scale=*/nullptr, /*input_offset=*/nullptr,
542       /*row_sums=*/nullptr);
543 }
544 #endif // __aarch64__
545
546 inline void NeonCwiseClipping(float *vector, const int v_size, const float clipping_value)
547 {
548   const float32x4_t clipping_value_f32x4 = vmovq_n_f32(clipping_value);
549   const float32x4_t neg_clipping_value_f32x4 = vmovq_n_f32(-clipping_value);
550
551   int i = 0;
552   for (; i <= v_size - kFloatValuesPerNeonVector; i += kFloatValuesPerNeonVector)
553   {
554     // Load from memory to vector.
555     float32x4_t v_f32x4 = vld1q_f32(vector + i);
556     // Clip between clipping_value and -clipping_value.
557     v_f32x4 = vminq_f32(clipping_value_f32x4, v_f32x4);
558     v_f32x4 = vmaxq_f32(neg_clipping_value_f32x4, v_f32x4);
559     // Save to output.
560     vst1q_f32(vector + i, v_f32x4);
561   }
562   for (; i < v_size; i++)
563   {
564     vector[i] = std::max(std::min(clipping_value, vector[i]), -clipping_value);
565   }
566 }
567
568 inline bool NeonIsZeroVector(const float *vector, int v_size)
569 {
570   // If v_size is not divisible by kFloatWeightsPerNeonLane, we cannot
571   // use the main vectorized loop, and we need to process sequentially.
572   // postamble_start shows the start index where this should happen.
573   const int postamble_start = v_size - (v_size & (kFloatWeightsPerNeonLane - 1));
574
575   const float32x4_t zero_x4_float = vmovq_n_f32(0.0f);
576   for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane)
577   {
578     const float32x4_t i_x4_float = vld1q_f32(vector + v);
579     uint32x4_t cmp_result = vceqq_f32(i_x4_float, zero_x4_float);
580     if (vgetq_lane_u32(cmp_result, 0) == 0)
581       return false;
582     if (vgetq_lane_u32(cmp_result, 1) == 0)
583       return false;
584     if (vgetq_lane_u32(cmp_result, 2) == 0)
585       return false;
586     if (vgetq_lane_u32(cmp_result, 3) == 0)
587       return false;
588   }
589
590   // Postamble loop
591   for (int v = postamble_start; v < v_size; ++v)
592   {
593     if (vector[v] != 0.0)
594       return false;
595   }
596   return true;
597 }
598
599 inline void NeonCpuBackendGemm(const int8_t *input, const int32_t *bias,
600                                const int8_t *input_to_gate_weights, int32_t n_batch,
601                                int32_t n_input, int32_t n_output, int32_t, int32_t *scratch,
602                                ruy::Context *ruy_context)
603 {
604   MatrixParams<int8_t> lhs_params;
605   lhs_params.order = Order::kRowMajor;
606   lhs_params.rows = n_output;
607   lhs_params.cols = n_input;
608   lhs_params.cache_policy = CachePolicy::kAlwaysCache;
609
610   MatrixParams<int8_t> rhs_params;
611   rhs_params.order = Order::kColMajor;
612   rhs_params.rows = n_input;
613   rhs_params.cols = n_batch;
614
615   MatrixParams<int32_t> dst_params;
616   dst_params.order = Order::kColMajor;
617   dst_params.rows = n_output;
618   dst_params.cols = n_batch;
619
620   GemmParams<int32_t, int32_t> gemm_params;
621   if (bias)
622   {
623     gemm_params.bias = bias;
624   }
625
626   // Below code is from tflite::cpu_backend_gemm::detail::GemmImplUsingRuy
627   ruy::Matrix<int8_t> ruy_lhs;
628   ruy::Matrix<int8_t> ruy_rhs;
629   ruy::Matrix<int32_t> ruy_dst;
630   // Note that cache is always enabled for input and weight tensors
631   ruy_support::MakeRuyMatrix(lhs_params, input_to_gate_weights, &ruy_lhs, true);
632   ruy_support::MakeRuyMatrix(rhs_params, input, &ruy_rhs, true);
633   ruy_support::MakeRuyMatrix(dst_params, scratch, &ruy_dst);
634
635   ruy::BasicSpec<int32_t, int32_t> ruy_mul_params;
636   ruy_support::MakeRuyMulParams(gemm_params, &ruy_mul_params);
637
638   ruy::Mul(ruy_lhs, ruy_rhs, ruy_mul_params, ruy_context, &ruy_dst);
639 }
640
641 inline void NeonSub1Vector(const float *vector, int v_size, float *result)
642 {
643   // If v_size is not divisible by the vector size, then we need to process the
644   // final few elements sequentially. postamble_start shows the start index
645   // where this should happen.
646   const int postamble_start = RoundDownVectors<kFloatValuesPerNeonVector>(v_size);
647
648   float32x4_t one_f32x4 = vmovq_n_f32(1.0);
649   int v = 0;
650   for (; v < postamble_start; v += kFloatValuesPerNeonVector)
651   {
652     // Load 4 float values from the current pointers of the input column and
653     // subtract from 1.
654     float32x4_t v_f32x4 = vld1q_f32(vector + v);
655     float32x4_t result_f32x4 = vsubq_f32(one_f32x4, v_f32x4);
656     // Save to output.
657     vst1q_f32(result + v, result_f32x4);
658   }
659   for (; v < v_size; v++)
660   {
661     result[v] = 1.0f - vector[v];
662   }
663 }
664
665 inline void NeonSymmetricQuantizeFloats(const float *values, const int size,
666                                         int8_t *quantized_values, float *min, float *max,
667                                         float *scaling_factor)
668 {
669   // TODO(raziel): vectorize min/max calculation.
670   auto minmax = std::minmax_element(values, values + size);
671   *min = *minmax.first;
672   *max = *minmax.second;
673   const int kScale = 127;
674   const float range = std::max(std::abs(*min), std::abs(*max));
675   if (range == 0)
676   {
677     memset(quantized_values, 0, size * sizeof(int8_t));
678     *scaling_factor = 1;
679     return;
680   }
681   *scaling_factor = range / kScale;
682   const float scaling_factor_inv = kScale / range;
683
684   const int postamble_start = size - (size & (2 * kFloatWeightsPerNeonLane - 1));
685
686   // Vectorized constants.
687   const float32x4_t q_factor_f32x4 = vmovq_n_f32(scaling_factor_inv);
688   const float32x4_t point5_f32x4 = vmovq_n_f32(0.5);
689   const float32x4_t zero_f32x4 = vmovq_n_f32(0.0);
690   const int32x4_t scale_i32x4 = vmovq_n_s32(kScale);
691   const int32x4_t neg_scale_i32x4 = vmovq_n_s32(-kScale);
692
693   for (int i = 0; i < postamble_start; i += 2 * kFloatWeightsPerNeonLane)
694   {
695     // Implements the vectorized version of the following:
696     // const int32_t quantized_value = static_cast<int32>(
697     //    std::round(*scaling_factor * values[i]));
698     // Since the vectorized round intrinsics (vrndqa_f32) is not supported
699     // on all Neon flavors, we use the following method for rounding: if (x
700     // < 0) (int)(x - 0.5) if (x >= 0) (int)(x + 0.5)
701     float32x4_t value0_f32x4 = vld1q_f32(&values[i]);
702     float32x4_t value1_f32x4 = vld1q_f32(&values[i + kFloatWeightsPerNeonLane]);
703     float32x4_t mul0_f32x4 = vmulq_f32(value0_f32x4, q_factor_f32x4);
704     float32x4_t mul1_f32x4 = vmulq_f32(value1_f32x4, q_factor_f32x4);
705
706     int32x4_t cmp_with_zero0_ui32x4 = (int32x4_t)vcltq_f32(mul0_f32x4, zero_f32x4); // NOLINT
707     int32x4_t cmp_with_zero1_ui32x4 = (int32x4_t)vcltq_f32(mul1_f32x4, zero_f32x4); // NOLINT
708
709     float32x4_t cmp_with_zero0_f32x4 = vcvtq_f32_s32(cmp_with_zero0_ui32x4);
710     float32x4_t cmp_with_zero1_f32x4 = vcvtq_f32_s32(cmp_with_zero1_ui32x4);
711     cmp_with_zero0_f32x4 = vaddq_f32(cmp_with_zero0_f32x4, point5_f32x4);
712     cmp_with_zero1_f32x4 = vaddq_f32(cmp_with_zero1_f32x4, point5_f32x4);
713
714     mul0_f32x4 = vaddq_f32(mul0_f32x4, cmp_with_zero0_f32x4);
715     mul1_f32x4 = vaddq_f32(mul1_f32x4, cmp_with_zero1_f32x4);
716
717     int32x4_t f2i0_i32x4 = vcvtq_s32_f32(mul0_f32x4);
718     int32x4_t f2i1_i32x4 = vcvtq_s32_f32(mul1_f32x4);
719
720     // Implements the vectorized version of the folowing block:
721     //  quantized_values[i] = std::min(kScale, std::max(-kScale,
722     //  quantized_value));
723     int32x4_t max0_i32x4 = vmaxq_s32(f2i0_i32x4, neg_scale_i32x4);
724     int32x4_t max1_i32x4 = vmaxq_s32(f2i1_i32x4, neg_scale_i32x4);
725     int32x4_t min0_i32x4 = vminq_s32(max0_i32x4, scale_i32x4);
726     int32x4_t min1_i32x4 = vminq_s32(max1_i32x4, scale_i32x4);
727
728     int16x4_t min0_16x4 = vmovn_s32(min0_i32x4);
729     int16x4_t min1_16x4 = vmovn_s32(min1_i32x4);
730
731     int16x8_t min_16x8 = vcombine_s16(min0_16x4, min1_16x4);
732     int8x8_t min_s8x8 = vqmovn_s16(min_16x8);
733     vst1_s8(&quantized_values[i], min_s8x8);
734   }
735
736   for (int i = postamble_start; i < size; ++i)
737   {
738     const int32_t quantized_value =
739         static_cast<int32_t>(std::round(scaling_factor_inv * values[i]));
740     quantized_values[i] = std::min(kScale, std::max(-kScale, quantized_value));
741   }
742 }
743
744 inline void NeonMatrixBatchVectorMultiplyAccumulate(const int8_t *__restrict__ matrix,
745                                                     const int m_rows, const int m_cols,
746                                                     const int8_t *__restrict__ vectors,
747                                                     const float *scaling_factors, int n_batch,
748                                                     float *__restrict__ result, int result_stride)
749 {
750 #ifdef __aarch64__
751   if (HasSdotInstruction() && m_cols % 16 == 0 && m_rows % 2 == 0 && m_rows >= n_batch)
752   {
753     if (n_batch % 4 == 0 && result_stride == 1)
754     {
755       // Benchmarks suggest that it's always better to use the batch code
756       // when we can, even on small matrices.
757       DotprodMatrixBatchFourVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
758                                                      scaling_factors, n_batch, result);
759       return;
760     }
761     else if (result_stride == 1 && n_batch >= 2 && m_rows * m_cols >= 128 * 128)
762     {
763       DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
764                                                            scaling_factors, n_batch, result);
765       return;
766     }
767   }
768 #endif // __aarch64__
769
770   static const int kWeightsPerUint32 = 4;
771   static const int kWeightsPerNeonLane = 16;
772   // Assuming *matrix is kWeightsPerUint32-byte aligned,
773   // every row of the matrix is also
774   // kWeightsPerUint32-byte aligned as long as cols is
775   // a multiple of kWeightsPerUint32. The assumption
776   // is currently satisfied by TFLite's 16-byte memory
777   // alignment scheme.
778   //
779   // Otherwise, we allocate an aligned memory block and set
780   // a flag to later copy rows from matrix to the block
781   // for aligned multiplication.
782   bool unaligned = false;
783   int8_t *aligned_row = nullptr;
784   void *aligned_row_free = nullptr;
785   if ((m_cols & (kWeightsPerUint32 - 1)) != 0)
786   {
787     unaligned = true;
788     aligned_row = (int8_t *)aligned_alloc(kWeightsPerUint32, m_cols, // NOLINT
789                                           &aligned_row_free);
790   }
791   void *aligned_vec_free = nullptr;
792   int8_t *aligned_vec = (int8_t *)aligned_alloc(kWeightsPerUint32, m_cols, // NOLINT
793                                                 &aligned_vec_free);
794
795   // If m_cols is not at least kWeightsPerNeonLane, we cannot use the main
796   // vectorized loop, and we need to process sequentially. postamble_half_start
797   // shows the start index where this should happen. Between postamble_start and
798   // postamble_half_start we can still process kWeightsPerNeonLane >> 1 in a
799   // vectorized form.
800   const int postamble_half_start = m_cols & ~(kWeightsPerNeonLane - 1);
801   const int postamble_start = m_cols & ~((kWeightsPerNeonLane >> 1) - 1);
802
803   for (int batch = 0; batch < n_batch; ++batch)
804   {
805     const float batch_scaling_factor = scaling_factors[batch];
806     // Copy the vector data to an aligned vector.
807     memcpy(aligned_vec, vectors + batch * m_cols, sizeof(int8_t) * m_cols);
808     // Compute dot-product for every column.
809     for (int row = 0; row < m_rows; ++row, result += result_stride)
810     {
811       // Get the address of the first element of the row.
812       int8_t *row_ptr = (int8_t *)matrix + row * m_cols; // NOLINT
813       if (unaligned)
814       {
815         memcpy(aligned_row, row_ptr, sizeof(int8_t) * m_cols);
816         row_ptr = aligned_row;
817       }
818
819       // Initialize the dot product sum for the row to 0.
820       int32x4_t dotprod_32x4 = vmovq_n_s32(0);
821
822       // Prefetch the row to cache.
823       __builtin_prefetch(row_ptr, 0 /* prefetch for read */, 3 /* temporal locality */);
824
825       // For every block of 16 8-bit elements.
826       int col = 0;
827       for (; col < postamble_half_start; col += kWeightsPerNeonLane)
828       {
829         // Load 16 8-bit values from the row and vector, each, to operate on.
830         // Here the assumption is that each buffer is 4-byte aligned. Otherwise,
831         // performance may suffer significantly.
832         assert( // NOLINT
833             ((uintptr_t)(&row_ptr[col]) & (kWeightsPerUint32 - 1)) == 0);
834         const int8x16_t s1_8x16 = vld1q_s8((const int8_t *)(aligned_vec + col));
835         const int8x16_t s2_8x16 = vld1q_s8((const int8_t *)(row_ptr + col));
836         // Multiply the low bits (i.e. the lower 8 8bit numbers in the
837         // registers).
838         int16x8_t prod_16x8 = vmull_s8(vget_low_s8(s1_8x16), vget_low_s8(s2_8x16));
839         // Multiply the high bits (i.e. the higher 8 8bit numbers in the
840         // registers), and accumulate with the result of the low bits product.
841         // The assumption here is that overflow will not happen as we quantize
842         // our values to be in the range [-127, 127]. As such the sum of the 2
843         // products is always strictly smaller than 15-bits (32767 in absolute
844         // value).
845         prod_16x8 = vmlal_s8(prod_16x8, vget_high_s8(s1_8x16), vget_high_s8(s2_8x16));
846
847         dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
848       } // for col
849
850       // Half iteration dealing only 8 elements
851       // TODO(raziel): if (ABSL_PREDICT_FALSE(col < postamble_start))
852       if (col < postamble_start)
853       {
854         // Load 8 8-bit values from the row and column each to operate on.
855         // Here the assumption is that each buffer is 4-bytes aligned.
856         // Otherwise, performance may suffer significantly.
857         assert( // NOLINT
858             ((uintptr_t)(&row_ptr[col]) & (kWeightsPerUint32 - 1)) == 0);
859         const int8x8_t s1_8x8 = vld1_s8((const int8_t *)(aligned_vec + col));
860         const int8x8_t s2_8x8 = vld1_s8((const int8_t *)(row_ptr + col));
861         const int16x8_t prod_16x8 = vmull_s8(s1_8x8, s2_8x8);
862         dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);
863         col += (kWeightsPerNeonLane >> 1);
864       }
865       // Add the 4 intermediate sum values to get the final dot-prod value for
866       // this row.
867       int32_t dotprod = AccumulateNeonLane(dotprod_32x4);
868       // Postamble loop.
869       // TODO(raziel): if (ABSL_PREDICT_FALSE(col < m_cols))
870       for (; col < m_cols; ++col)
871       {
872         dotprod += row_ptr[col] * aligned_vec[col];
873       } // for col
874
875       *result += dotprod * batch_scaling_factor;
876     } // for row
877   }   // for batch
878
879   if (unaligned)
880   {
881     free(aligned_row_free);
882   }
883   free(aligned_vec_free);
884 }
885
886 inline void NeonMatrixBatchVectorMultiplyAccumulate(const float *matrix, int m_rows, int m_cols,
887                                                     const float *vector, int n_batch, float *result,
888                                                     int result_stride)
889 {
890   // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main
891   // vectorized loop, and we need to process sequentially. postamble_start shows
892   // the start index where this should happen.
893   const int postamble_start = m_cols - (m_cols & (kFloatWeightsPerNeonLane - 1));
894
895   for (int b = 0; b < n_batch; b++)
896   {
897     float *result_in_batch = result + b * m_rows * result_stride;
898     const float *vector_in_batch = vector + b * m_cols;
899     const float *matrix_row = matrix;
900
901     // Main matrix by vector multiplication loop
902     for (int r = 0; r < m_rows; r++)
903     {
904       float32x4_t acc_32x4 = vmovq_n_f32(0.0);
905       for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane)
906       {
907         // Load 4 float values from vector and matrix row.
908         float32x4_t vector_f32x4 = vld1q_f32(vector_in_batch + c);
909         float32x4_t matrix_f32x4 = vld1q_f32(matrix_row + c);
910         // Multiply the vector and matrix row and add to accumulator.
911         acc_32x4 = vmlaq_f32(acc_32x4, matrix_f32x4, vector_f32x4);
912       }
913       // Add the 4 intermediate sum values to get the final dot-prod value for
914       // this column.
915       *result_in_batch += (vgetq_lane_f32(acc_32x4, 0) + vgetq_lane_f32(acc_32x4, 1) +
916                            vgetq_lane_f32(acc_32x4, 2) + vgetq_lane_f32(acc_32x4, 3));
917       for (int c = postamble_start; c < m_cols; c++)
918       {
919         *result_in_batch += matrix_row[c] * vector_in_batch[c];
920       }
921       matrix_row += m_cols;
922       result_in_batch += result_stride;
923     }
924   }
925 }
926
927 inline void NeonMatrixBatchVectorMultiplyAccumulate(const int8_t *__restrict__ matrix,
928                                                     const int m_rows, const int m_cols,
929                                                     const int8_t *__restrict__ vectors,
930                                                     const float *scaling_factors, int n_batch,
931                                                     int32_t *scratch, float *__restrict__ result,
932                                                     int result_stride, ruy::Context *ruy_context)
933 {
934   if (m_rows % 4 == 0 && result_stride == 1)
935   {
936     const int32_t *bias = static_cast<const int32_t *>(nullptr);
937     NeonCpuBackendGemm(vectors, bias, matrix, n_batch, m_cols, m_rows,
938                        /*output_zp =*/0, scratch, ruy_context);
939
940     // Multiply by float scaling factors and write to result
941     const int total_size = n_batch * m_rows;
942     int i = 0;
943     for (; i <= total_size - 8; i += 8, result += 8 * result_stride)
944     {
945       const float batch_scaling_factor0 = scaling_factors[i / m_rows];
946       const float batch_scaling_factor1 = scaling_factors[(i + 4) / m_rows];
947       const float32x4_t scaling_factor0 = vdupq_n_f32(batch_scaling_factor0);
948       const float32x4_t scaling_factor1 = vdupq_n_f32(batch_scaling_factor1);
949       const int32x4_t scratch_val0 = vld1q_s32(scratch + i);
950       const int32x4_t scratch_val1 = vld1q_s32(scratch + i + 4);
951       const float32x4_t float_val0 = vcvtq_f32_s32(scratch_val0);
952       const float32x4_t float_val1 = vcvtq_f32_s32(scratch_val1);
953       const float32x4_t result0 = vmlaq_f32(vld1q_f32(result), float_val0, scaling_factor0);
954       const float32x4_t result1 =
955           vmlaq_f32(vld1q_f32(result + 4 * result_stride), float_val1, scaling_factor1);
956       vst1q_f32(result, result0);
957       vst1q_f32(result + 4 * result_stride, result1);
958     }
959     scratch += i;
960     for (; i < total_size; i++, result += result_stride)
961     {
962       const float batch_scaling_factor = scaling_factors[i / m_rows];
963       int32_t x = *(scratch++);
964       *result += x * batch_scaling_factor;
965     }
966     return;
967   }
968   NeonMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors, scaling_factors, n_batch,
969                                           result, result_stride);
970 }
971
972 } // namespace cker
973 } // namespace nnfw
974
975 #endif // USE_NEON
976
977 #endif // __NNFW_CKER_NEON_TENSOR_UTILS_H__