Imported Upstream version 1.9.0
[platform/core/ml/nnfw.git] / compute / cker / include / cker / operation / optimized / BinaryArithmeticOps.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_CKER_OPTIMIZED_BINARYARITHMETICOPS_H__
19 #define __NNFW_CKER_OPTIMIZED_BINARYARITHMETICOPS_H__
20
21 #include <functional>
22 #include "cker/neon/neon_check.h"
23 #include "cker/operation/reference/BinaryArithmeticOps.h"
24 #include "cker/Shape.h"
25 #include "cker/Types.h"
26 #include "cker/Utils.h"
27 #include "fixedpoint/fixedpoint.h"
28
29 namespace nnfw
30 {
31 namespace cker
32 {
33 namespace optimized
34 {
35
36 template <typename ElementwiseF, typename ScalarBroadcastF, typename T>
37 inline void BinaryBroadcastFiveFold(const BinaryArithmeticOpParam &params,
38                                     const Shape & /* unswitched_input1_shape */,
39                                     const T *unswitched_input1_data,
40                                     const Shape & /* unswitched_input2_shape */,
41                                     const T *unswitched_input2_data,
42                                     const Shape & /* output_shape */, T *output_data,
43                                     ElementwiseF elementwise_f, ScalarBroadcastF scalar_broadcast_f)
44 {
45   const bool use_unswitched =
46       params.broadcast_category == BroadcastableOpCategory::kFirstInputBroadcastsFast;
47
48   const T *input1_data = use_unswitched ? unswitched_input1_data : unswitched_input2_data;
49   const T *input2_data = use_unswitched ? unswitched_input2_data : unswitched_input1_data;
50
51   // Fivefold nested loops. The second input resets its position for each
52   // iteration of the second loop. The first input resets its position at the
53   // beginning of the fourth loop. The innermost loop is an elementwise add of
54   // sections of the arrays.
55   T *output_data_ptr = output_data;
56   const T *input1_data_ptr = input1_data;
57   const T *input2_data_reset = input2_data;
58   // In the fivefold pattern, y0, y2 and y4 are not broadcast, and so shared
59   // between input shapes. y3 for input 1 is always broadcast, and so the
60   // dimension there is 1, whereas optionally y1 might be broadcast for input 2.
61   // Put another way,
62   // input1.shape.FlatSize = y0 * y1 * y2 * y4,
63   // input2.shape.FlatSize = y0 * y2 * y3 * y4.
64   int y0 = params.broadcast_shape[0];
65   int y1 = params.broadcast_shape[1];
66   int y2 = params.broadcast_shape[2];
67   int y3 = params.broadcast_shape[3];
68   int y4 = params.broadcast_shape[4];
69   if (y4 > 1)
70   {
71     // General fivefold pattern, with y4 > 1 so there is a non-broadcast inner
72     // dimension.
73     for (int i0 = 0; i0 < y0; ++i0)
74     {
75       const T *input2_data_ptr = nullptr;
76       for (int i1 = 0; i1 < y1; ++i1)
77       {
78         input2_data_ptr = input2_data_reset;
79         for (int i2 = 0; i2 < y2; ++i2)
80         {
81           for (int i3 = 0; i3 < y3; ++i3)
82           {
83             elementwise_f(y4, params, input1_data_ptr, input2_data_ptr, output_data_ptr);
84             input2_data_ptr += y4;
85             output_data_ptr += y4;
86           }
87           // We have broadcast y4 of input1 data y3 times, and now move on.
88           input1_data_ptr += y4;
89         }
90       }
91       // We have broadcast y2*y3*y4 of input2 data y1 times, and now move on.
92       input2_data_reset = input2_data_ptr;
93     }
94   }
95   else
96   {
97     // Special case of y4 == 1, in which the innermost loop is a single element
98     // and can be combined with the next (y3) as an inner broadcast.
99     //
100     // Note that this handles the case of pure scalar broadcast when
101     // y0 == y1 == y2 == 1. With low overhead it handles cases such as scalar
102     // broadcast with batch (as y2 > 1).
103     //
104     // NOTE The process is the same as the above general case except simplified
105     // for y4 == 1 and the loop over y3 is contained within the
106     // AddScalarBroadcast function.
107     for (int i0 = 0; i0 < y0; ++i0)
108     {
109       const T *input2_data_ptr = nullptr;
110       for (int i1 = 0; i1 < y1; ++i1)
111       {
112         input2_data_ptr = input2_data_reset;
113         for (int i2 = 0; i2 < y2; ++i2)
114         {
115           scalar_broadcast_f(y3, params, *input1_data_ptr, input2_data_ptr, output_data_ptr);
116           input2_data_ptr += y3;
117           output_data_ptr += y3;
118           input1_data_ptr += 1;
119         }
120       }
121       input2_data_reset = input2_data_ptr;
122     }
123   }
124 }
125
126 inline int32_t quant8_sum(const BinaryArithmeticOpParam &params, const uint8_t input1_data,
127                           const uint8_t input2_data)
128 {
129   const int32_t input1_val = params.input1_offset + input1_data;
130   const int32_t input2_val = params.input2_offset + input2_data;
131   const int32_t shifted_input1_val = input1_val * (1 << params.left_shift);
132   const int32_t shifted_input2_val = input2_val * (1 << params.left_shift);
133   const int32_t scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
134       shifted_input1_val, params.input1_multiplier, params.input1_shift);
135   const int32_t scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
136       shifted_input2_val, params.input2_multiplier, params.input2_shift);
137   const int32_t raw_sum = scaled_input1_val + scaled_input2_val;
138   const int32_t raw_output = MultiplyByQuantizedMultiplierSmallerThanOneExp(
139                                  raw_sum, params.output_multiplier, params.output_shift) +
140                              params.output_offset;
141   const int32_t clamped_output = std::min(params.quantized_activation_max,
142                                           std::max(params.quantized_activation_min, raw_output));
143   return clamped_output;
144 }
145
146 inline void AddElementwiseQuant8(int size, const BinaryArithmeticOpParam &params,
147                                  const uint8_t *input1_data, const uint8_t *input2_data,
148                                  uint8_t *output_data)
149 {
150   int i = 0;
151
152 #ifdef USE_NEON
153   const uint8x8_t output_activation_min_vector = vdup_n_u8(params.quantized_activation_min);
154   const uint8x8_t output_activation_max_vector = vdup_n_u8(params.quantized_activation_max);
155   for (; i <= size - 8; i += 8)
156   {
157     const uint8x8_t input1_val_original = vld1_u8(input1_data + i);
158     const uint8x8_t input2_val_original = vld1_u8(input2_data + i);
159     const int16x8_t input1_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
160     const int16x8_t input2_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
161     const int16x8_t input1_val = vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset));
162     const int16x8_t input2_val = vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset));
163     const int16x4_t input1_val_high = vget_high_s16(input1_val);
164     const int16x4_t input1_val_low = vget_low_s16(input1_val);
165     const int16x4_t input2_val_high = vget_high_s16(input2_val);
166     const int16x4_t input2_val_low = vget_low_s16(input2_val);
167     int32x4_t x11 = vmovl_s16(input1_val_low);
168     int32x4_t x12 = vmovl_s16(input1_val_high);
169     int32x4_t x21 = vmovl_s16(input2_val_low);
170     int32x4_t x22 = vmovl_s16(input2_val_high);
171     const int32x4_t left_shift_dup = vdupq_n_s32(params.left_shift);
172     x11 = vshlq_s32(x11, left_shift_dup);
173     x12 = vshlq_s32(x12, left_shift_dup);
174     x21 = vshlq_s32(x21, left_shift_dup);
175     x22 = vshlq_s32(x22, left_shift_dup);
176     x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier);
177     x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier);
178     x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier);
179     x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier);
180     const int32x4_t input1_shift_dup = vdupq_n_s32(params.input1_shift);
181     const int32x4_t input2_shift_dup = vdupq_n_s32(params.input2_shift);
182     x11 = vshlq_s32(x11, input1_shift_dup);
183     x12 = vshlq_s32(x12, input1_shift_dup);
184     x21 = vshlq_s32(x21, input2_shift_dup);
185     x22 = vshlq_s32(x22, input2_shift_dup);
186     int32x4_t s1 = vaddq_s32(x11, x21);
187     int32x4_t s2 = vaddq_s32(x12, x22);
188     s1 = vqrdmulhq_n_s32(s1, params.output_multiplier);
189     s2 = vqrdmulhq_n_s32(s2, params.output_multiplier);
190     using gemmlowp::RoundingDivideByPOT;
191     s1 = RoundingDivideByPOT(s1, -params.output_shift);
192     s2 = RoundingDivideByPOT(s2, -params.output_shift);
193     const int16x4_t s1_narrowed = vmovn_s32(s1);
194     const int16x4_t s2_narrowed = vmovn_s32(s2);
195     const int16x8_t s =
196         vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed), vdupq_n_s16(params.output_offset));
197     const uint8x8_t clamped = vmax_u8(output_activation_min_vector,
198                                       vmin_u8(output_activation_max_vector, vqmovun_s16(s)));
199     vst1_u8(output_data + i, clamped);
200   }
201 #endif // NEON
202   for (; i < size; ++i)
203   {
204     const int32_t input1_val = params.input1_offset + input1_data[i];
205     const int32_t input2_val = params.input2_offset + input2_data[i];
206     const int32_t shifted_input1_val = input1_val * (1 << params.left_shift);
207     const int32_t shifted_input2_val = input2_val * (1 << params.left_shift);
208     const int32_t scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
209         shifted_input1_val, params.input1_multiplier, params.input1_shift);
210     const int32_t scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
211         shifted_input2_val, params.input2_multiplier, params.input2_shift);
212     const int32_t raw_sum = scaled_input1_val + scaled_input2_val;
213     const int32_t raw_output = MultiplyByQuantizedMultiplierSmallerThanOneExp(
214                                    raw_sum, params.output_multiplier, params.output_shift) +
215                                params.output_offset;
216     const int32_t clamped_output = std::min(params.quantized_activation_max,
217                                             std::max(params.quantized_activation_min, raw_output));
218     output_data[i] = static_cast<uint8_t>(clamped_output);
219   }
220 }
221
222 inline void AddElementwise(int size, const BinaryArithmeticOpParam &params,
223                            const float *input1_data, const float *input2_data, float *output_data)
224 {
225   int i = 0;
226
227 #ifdef USE_NEON
228   const auto activation_min = vdupq_n_f32(params.float_activation_min);
229   const auto activation_max = vdupq_n_f32(params.float_activation_max);
230   for (; i <= size - 16; i += 16)
231   {
232     auto a10 = vld1q_f32(input1_data + i);
233     auto a11 = vld1q_f32(input1_data + i + 4);
234     auto a12 = vld1q_f32(input1_data + i + 8);
235     auto a13 = vld1q_f32(input1_data + i + 12);
236     auto a20 = vld1q_f32(input2_data + i);
237     auto a21 = vld1q_f32(input2_data + i + 4);
238     auto a22 = vld1q_f32(input2_data + i + 8);
239     auto a23 = vld1q_f32(input2_data + i + 12);
240     auto x0 = vaddq_f32(a10, a20);
241     auto x1 = vaddq_f32(a11, a21);
242     auto x2 = vaddq_f32(a12, a22);
243     auto x3 = vaddq_f32(a13, a23);
244     x0 = vmaxq_f32(activation_min, x0);
245     x1 = vmaxq_f32(activation_min, x1);
246     x2 = vmaxq_f32(activation_min, x2);
247     x3 = vmaxq_f32(activation_min, x3);
248     x0 = vminq_f32(activation_max, x0);
249     x1 = vminq_f32(activation_max, x1);
250     x2 = vminq_f32(activation_max, x2);
251     x3 = vminq_f32(activation_max, x3);
252     vst1q_f32(output_data + i, x0);
253     vst1q_f32(output_data + i + 4, x1);
254     vst1q_f32(output_data + i + 8, x2);
255     vst1q_f32(output_data + i + 12, x3);
256   }
257   for (; i <= size - 4; i += 4)
258   {
259     auto a1 = vld1q_f32(input1_data + i);
260     auto a2 = vld1q_f32(input2_data + i);
261     auto x = vaddq_f32(a1, a2);
262     x = vmaxq_f32(activation_min, x);
263     x = vminq_f32(activation_max, x);
264     vst1q_f32(output_data + i, x);
265   }
266 #endif // NEON
267   for (; i < size; i++)
268   {
269     auto x = input1_data[i] + input2_data[i];
270     output_data[i] = ActivationFunctionWithMinMax<float>(x, params.float_activation_min,
271                                                          params.float_activation_max);
272   }
273 }
274
275 inline void AddQuant8(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
276                       const uint8_t *input1_data, const Shape &input2_shape,
277                       const uint8_t *input2_data, const Shape &output_shape, uint8_t *output_data)
278 {
279   const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
280   AddElementwiseQuant8(flat_size, params, input1_data, input2_data, output_data);
281 }
282
283 inline void Add(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
284                 const float *input1_data, const Shape &input2_shape, const float *input2_data,
285                 const Shape &output_shape, float *output_data)
286 {
287   const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
288   AddElementwise(flat_size, params, input1_data, input2_data, output_data);
289 }
290
291 // Scalar-broadcast add that can be used for inner loop of more general
292 // broadcast add, so that, for example, scalar-broadcast with batch will still
293 // be fast.
294 inline void AddScalarBroadcastQuant8(int size, const BinaryArithmeticOpParam &params,
295                                      uint8_t broadcast_value, const uint8_t *input2_data,
296                                      uint8_t *output_data)
297 {
298   int i = 0;
299   int32_t clamped_output;
300   for (; i < size; ++i)
301   {
302     clamped_output = quant8_sum(params, broadcast_value, input2_data[i]);
303     output_data[i] = static_cast<uint8_t>(clamped_output);
304   }
305 }
306
307 inline void AddScalarBroadcast(int size, const BinaryArithmeticOpParam &params,
308                                float broadcast_value, const float *input2_data, float *output_data)
309 {
310   int i = 0;
311 #ifdef USE_NEON
312   const float32x4_t output_activation_min_vector = vdupq_n_f32(params.float_activation_min);
313   const float32x4_t output_activation_max_vector = vdupq_n_f32(params.float_activation_max);
314   const float32x4_t broadcast_value_dup = vdupq_n_f32(broadcast_value);
315   for (; i <= size - 4; i += 4)
316   {
317     const float32x4_t input2_val_original = vld1q_f32(input2_data + i);
318
319     const float32x4_t output = vaddq_f32(input2_val_original, broadcast_value_dup);
320
321     const float32x4_t clamped =
322         vmaxq_f32(output_activation_min_vector, vminq_f32(output_activation_max_vector, output));
323     vst1q_f32(output_data + i, clamped);
324   }
325 #endif // NEON
326   for (; i < size; ++i)
327   {
328     auto x = broadcast_value + input2_data[i];
329     output_data[i] = ActivationFunctionWithMinMax<float>(x, params.float_activation_min,
330                                                          params.float_activation_max);
331   }
332 }
333
334 inline void BroadcastAddDispatchQuant8(const BinaryArithmeticOpParam &params,
335                                        const Shape &input1_shape, const uint8_t *input1_data,
336                                        const Shape &input2_shape, const uint8_t *input2_data,
337                                        const Shape &output_shape, uint8_t *output_data)
338 {
339   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast)
340   {
341     const std::function<uint8_t(const BinaryArithmeticOpParam &, const uint8_t &, const uint8_t &)>
342         fn = [](const BinaryArithmeticOpParam &params, const uint8_t &a,
343                 const uint8_t &b) -> uint8_t {
344       return static_cast<uint8_t>(quant8_sum(params, a, b));
345     };
346     reference::BroadcastBinaryArithmeticOpSlowQuant8(params, input1_shape, input1_data,
347                                                      input2_shape, input2_data, output_shape,
348                                                      output_data, fn);
349   }
350   else
351   {
352     BinaryBroadcastFiveFold(
353         params, input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data,
354         static_cast<void (*)(int, const BinaryArithmeticOpParam &, const uint8_t *, const uint8_t *,
355                              uint8_t *)>(AddElementwiseQuant8),
356         static_cast<void (*)(int, const BinaryArithmeticOpParam &, uint8_t, const uint8_t *,
357                              uint8_t *)>(AddScalarBroadcastQuant8));
358   }
359 }
360
361 inline void BroadcastAddDispatch(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
362                                  const float *input1_data, const Shape &input2_shape,
363                                  const float *input2_data, const Shape &output_shape,
364                                  float *output_data)
365 {
366   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast)
367   {
368     const std::function<float(const float &, const float &)> fn =
369         [](const float &a, const float &b) -> float { return a + b; };
370     reference::BroadcastBinaryArithmeticOpSlow(params, input1_shape, input1_data, input2_shape,
371                                                input2_data, output_shape, output_data, fn);
372   }
373   else
374   {
375     BinaryBroadcastFiveFold(
376         params, input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data,
377         static_cast<void (*)(int, const BinaryArithmeticOpParam &, const float *, const float *,
378                              float *)>(AddElementwise),
379         static_cast<void (*)(int, const BinaryArithmeticOpParam &, float, const float *, float *)>(
380             AddScalarBroadcast));
381   }
382 }
383
384 inline void Sub(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
385                 const float *input1_data, const Shape &input2_shape, const float *input2_data,
386                 const Shape &output_shape, float *output_data)
387 {
388   int i = 0;
389   const int size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
390 #ifdef USE_NEON
391   const auto activation_min = vdupq_n_f32(params.float_activation_min);
392   const auto activation_max = vdupq_n_f32(params.float_activation_max);
393   for (; i <= size - 16; i += 16)
394   {
395     auto a10 = vld1q_f32(input1_data + i);
396     auto a11 = vld1q_f32(input1_data + i + 4);
397     auto a12 = vld1q_f32(input1_data + i + 8);
398     auto a13 = vld1q_f32(input1_data + i + 12);
399     auto a20 = vld1q_f32(input2_data + i);
400     auto a21 = vld1q_f32(input2_data + i + 4);
401     auto a22 = vld1q_f32(input2_data + i + 8);
402     auto a23 = vld1q_f32(input2_data + i + 12);
403     auto x0 = vsubq_f32(a10, a20);
404     auto x1 = vsubq_f32(a11, a21);
405     auto x2 = vsubq_f32(a12, a22);
406     auto x3 = vsubq_f32(a13, a23);
407     x0 = vmaxq_f32(activation_min, x0);
408     x1 = vmaxq_f32(activation_min, x1);
409     x2 = vmaxq_f32(activation_min, x2);
410     x3 = vmaxq_f32(activation_min, x3);
411     x0 = vminq_f32(activation_max, x0);
412     x1 = vminq_f32(activation_max, x1);
413     x2 = vminq_f32(activation_max, x2);
414     x3 = vminq_f32(activation_max, x3);
415     vst1q_f32(output_data + i, x0);
416     vst1q_f32(output_data + i + 4, x1);
417     vst1q_f32(output_data + i + 8, x2);
418     vst1q_f32(output_data + i + 12, x3);
419   }
420   for (; i <= size - 4; i += 4)
421   {
422     auto a1 = vld1q_f32(input1_data + i);
423     auto a2 = vld1q_f32(input2_data + i);
424     auto x = vsubq_f32(a1, a2);
425     x = vmaxq_f32(activation_min, x);
426     x = vminq_f32(activation_max, x);
427     vst1q_f32(output_data + i, x);
428   }
429 #endif // NEON
430
431   for (; i < size; i++)
432   {
433     auto x = input1_data[i] - input2_data[i];
434     output_data[i] =
435         ActivationFunctionWithMinMax(x, params.float_activation_min, params.float_activation_max);
436   }
437 }
438
439 inline int32_t quant8_mul(const BinaryArithmeticOpParam &params, const uint8_t input1_data,
440                           const uint8_t input2_data)
441 {
442   const int32_t input1_val = params.input1_offset + input1_data;
443   const int32_t input2_val = params.input2_offset + input2_data;
444   const int32_t unclamped_result =
445       params.output_offset + MultiplyByQuantizedMultiplier(input1_val * input2_val,
446                                                            params.output_multiplier,
447                                                            params.output_shift);
448   const int32_t clamped_output = std::min(
449       params.quantized_activation_max, std::max(params.quantized_activation_min, unclamped_result));
450
451   return clamped_output;
452 }
453
454 inline void MulElementwiseQuant8(int size, const BinaryArithmeticOpParam &params,
455                                  const uint8_t *input1_data, const uint8_t *input2_data,
456                                  uint8_t *output_data)
457 {
458   int i = 0;
459
460 #ifdef USE_NEON
461   const auto input1_offset_vector = vdupq_n_s16(params.input1_offset);
462   const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
463   const auto output_offset_vector = vdupq_n_s16(params.output_offset);
464   const auto output_activation_min_vector = vdup_n_u8(params.quantized_activation_min);
465   const auto output_activation_max_vector = vdup_n_u8(params.quantized_activation_max);
466   const int left_shift = std::max(0, params.output_shift);
467   const int right_shift = std::max(0, -params.output_shift);
468   const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
469   for (; i <= size - 8; i += 8)
470   {
471     // We load / store 8 at a time, multiplying as two sets of 4 int32s.
472     const auto input1_val_original = vld1_u8(input1_data + i);
473     const auto input2_val_original = vld1_u8(input2_data + i);
474     const auto input1_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
475     const auto input2_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
476     const auto input1_val = vaddq_s16(input1_val_s16, input1_offset_vector);
477     const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
478
479     const auto input1_val_low = vget_low_s16(input1_val);
480     const auto input1_val_high = vget_high_s16(input1_val);
481     const auto input2_val_low = vget_low_s16(input2_val);
482     const auto input2_val_high = vget_high_s16(input2_val);
483
484     auto p1 = vmull_s16(input2_val_low, input1_val_low);
485     auto p2 = vmull_s16(input2_val_high, input1_val_high);
486
487     p1 = vshlq_s32(p1, left_shift_vec);
488     p2 = vshlq_s32(p2, left_shift_vec);
489     p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
490     p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
491     using gemmlowp::RoundingDivideByPOT;
492     p1 = RoundingDivideByPOT(p1, right_shift);
493     p2 = RoundingDivideByPOT(p2, right_shift);
494
495     const auto p1_narrowed = vqmovn_s32(p1);
496     const auto p2_narrowed = vqmovn_s32(p2);
497     const auto p = vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
498     const auto clamped = vmax_u8(output_activation_min_vector,
499                                  vmin_u8(output_activation_max_vector, vqmovun_s16(p)));
500     vst1_u8(output_data + i, clamped);
501   }
502 #endif // NEON
503
504   for (; i < size; ++i)
505   {
506     const int32_t input1_val = params.input1_offset + input1_data[i];
507     const int32_t input2_val = params.input2_offset + input2_data[i];
508     const int32_t unclamped_result =
509         params.output_offset + MultiplyByQuantizedMultiplier(input1_val * input2_val,
510                                                              params.output_multiplier,
511                                                              params.output_shift);
512     const int32_t clamped_output =
513         std::min(params.quantized_activation_max,
514                  std::max(params.quantized_activation_min, unclamped_result));
515     output_data[i] = static_cast<uint8_t>(clamped_output);
516   }
517 }
518
519 inline void MulElementwise(int size, const BinaryArithmeticOpParam &params,
520                            const float *input1_data, const float *input2_data, float *output_data)
521 {
522   int i = 0;
523
524 #ifdef USE_NEON
525   const auto activation_min = vdupq_n_f32(params.float_activation_min);
526   const auto activation_max = vdupq_n_f32(params.float_activation_max);
527   for (; i <= size - 16; i += 16)
528   {
529     auto a10 = vld1q_f32(input1_data + i);
530     auto a11 = vld1q_f32(input1_data + i + 4);
531     auto a12 = vld1q_f32(input1_data + i + 8);
532     auto a13 = vld1q_f32(input1_data + i + 12);
533     auto a20 = vld1q_f32(input2_data + i);
534     auto a21 = vld1q_f32(input2_data + i + 4);
535     auto a22 = vld1q_f32(input2_data + i + 8);
536     auto a23 = vld1q_f32(input2_data + i + 12);
537     auto x0 = vmulq_f32(a10, a20);
538     auto x1 = vmulq_f32(a11, a21);
539     auto x2 = vmulq_f32(a12, a22);
540     auto x3 = vmulq_f32(a13, a23);
541     x0 = vmaxq_f32(activation_min, x0);
542     x1 = vmaxq_f32(activation_min, x1);
543     x2 = vmaxq_f32(activation_min, x2);
544     x3 = vmaxq_f32(activation_min, x3);
545     x0 = vminq_f32(activation_max, x0);
546     x1 = vminq_f32(activation_max, x1);
547     x2 = vminq_f32(activation_max, x2);
548     x3 = vminq_f32(activation_max, x3);
549     vst1q_f32(output_data + i, x0);
550     vst1q_f32(output_data + i + 4, x1);
551     vst1q_f32(output_data + i + 8, x2);
552     vst1q_f32(output_data + i + 12, x3);
553   }
554   for (; i <= size - 4; i += 4)
555   {
556     auto a1 = vld1q_f32(input1_data + i);
557     auto a2 = vld1q_f32(input2_data + i);
558     auto x = vmulq_f32(a1, a2);
559     x = vmaxq_f32(activation_min, x);
560     x = vminq_f32(activation_max, x);
561     vst1q_f32(output_data + i, x);
562   }
563 #endif // NEON
564
565   for (; i < size; i++)
566   {
567     auto x = input1_data[i] * input2_data[i];
568     output_data[i] =
569         ActivationFunctionWithMinMax(x, params.float_activation_min, params.float_activation_max);
570   }
571 }
572
573 inline void MulQuant8(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
574                       const uint8_t *input1_data, const Shape &input2_shape,
575                       const uint8_t *input2_data, const Shape &output_shape, uint8_t *output_data)
576 {
577   const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
578   MulElementwiseQuant8(flat_size, params, input1_data, input2_data, output_data);
579 }
580
581 inline void Mul(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
582                 const float *input1_data, const Shape &input2_shape, const float *input2_data,
583                 const Shape &output_shape, float *output_data)
584 {
585   const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
586   MulElementwise(flat_size, params, input1_data, input2_data, output_data);
587 }
588
589 inline void MulSimpleBroadcastQuant8(int size, const BinaryArithmeticOpParam &params,
590                                      const uint8_t broadcast_value, const uint8_t *input2_data,
591                                      uint8_t *output_data)
592 {
593   int i = 0;
594   int32_t clamped_output;
595   for (; i < size; ++i)
596   {
597     clamped_output = quant8_mul(params, broadcast_value, input2_data[i]);
598     output_data[i] = static_cast<uint8_t>(clamped_output);
599   }
600 }
601
602 // Broadcast mul that can often be used for inner loop of broadcast Mul.
603 // This function will handle scalar_value (LHS) * vector_values (RHS).
604 // Since it's a float function, input params does not matter here.
605 inline void MulSimpleBroadcast(int size, const BinaryArithmeticOpParam &params,
606                                const float broadcast_value, const float *input2_data,
607                                float *output_data)
608 {
609   int i = 0;
610 #ifdef USE_NEON
611   const float32x4_t output_activation_min_vector = vdupq_n_f32(params.float_activation_min);
612   const float32x4_t output_activation_max_vector = vdupq_n_f32(params.float_activation_max);
613   const float32x4_t broadcast_value_dup = vdupq_n_f32(broadcast_value);
614   for (; i <= size - 4; i += 4)
615   {
616     const float32x4_t input2_val_original = vld1q_f32(input2_data + i);
617
618     const float32x4_t output = vmulq_f32(input2_val_original, broadcast_value_dup);
619
620     const float32x4_t clamped =
621         vmaxq_f32(output_activation_min_vector, vminq_f32(output_activation_max_vector, output));
622     vst1q_f32(output_data + i, clamped);
623   }
624 #endif // NEON
625
626   for (; i < size; ++i)
627   {
628     float x = broadcast_value * input2_data[i];
629     output_data[i] =
630         ActivationFunctionWithMinMax(x, params.float_activation_min, params.float_activation_max);
631   }
632 }
633
634 inline void BroadcastMulDispatchQuant8(const BinaryArithmeticOpParam &params,
635                                        const Shape &input1_shape, const uint8_t *input1_data,
636                                        const Shape &input2_shape, const uint8_t *input2_data,
637                                        const Shape &output_shape, uint8_t *output_data)
638 {
639   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast)
640   {
641     const std::function<uint8_t(const BinaryArithmeticOpParam &, const uint8_t &, const uint8_t &)>
642         fn = [](const BinaryArithmeticOpParam &params, const uint8_t &a,
643                 const uint8_t &b) -> uint8_t {
644       return static_cast<uint8_t>(quant8_mul(params, a, b));
645     };
646     reference::BroadcastBinaryArithmeticOpSlowQuant8(params, input1_shape, input1_data,
647                                                      input2_shape, input2_data, output_shape,
648                                                      output_data, fn);
649     return;
650   }
651   BinaryBroadcastFiveFold(
652       params, input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data,
653       static_cast<void (*)(int, const BinaryArithmeticOpParam &, const uint8_t *, const uint8_t *,
654                            uint8_t *)>(MulElementwiseQuant8),
655       static_cast<void (*)(int, const BinaryArithmeticOpParam &, uint8_t, const uint8_t *,
656                            uint8_t *)>(MulSimpleBroadcastQuant8));
657 }
658
659 inline void BroadcastMulDispatch(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
660                                  const float *input1_data, const Shape &input2_shape,
661                                  const float *input2_data, const Shape &output_shape,
662                                  float *output_data)
663 {
664   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast)
665   {
666     // TODO: Use GetBinaryArithmeticFn
667     const std::function<float(const float &, const float &)> fn =
668         [](const float &a, const float &b) -> float { return a * b; };
669     reference::BroadcastBinaryArithmeticOpSlow(params, input1_shape, input1_data, input2_shape,
670                                                input2_data, output_shape, output_data, fn);
671     return;
672   }
673   BinaryBroadcastFiveFold(
674       params, input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data,
675       static_cast<void (*)(int, const BinaryArithmeticOpParam &, const float *, const float *,
676                            float *)>(MulElementwise),
677       static_cast<void (*)(int, const BinaryArithmeticOpParam &, float, const float *, float *)>(
678           MulSimpleBroadcast));
679 }
680
681 } // namespace optimized
682 } // namespace cker
683 } // namespace nnfw
684
685 #endif // __NNFW_CKER_OPTIMIZED_BINARYARITHMETICOPS_H__