912b01a64c1aaf1b6b195e5b9a3b93e5a5612060
[platform/core/ml/nnfw.git] / compute / cker / include / cker / operation / optimized / BinaryArithmeticOps.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_CKER_OPTIMIZED_BINARYARITHMETICOPS_H__
19 #define __NNFW_CKER_OPTIMIZED_BINARYARITHMETICOPS_H__
20
21 #include <functional>
22 #include <limits>
23 #include <utility>
24 #include "cker/neon/neon_check.h"
25 #include "cker/operation/reference/BinaryArithmeticOps.h"
26 #include "cker/Shape.h"
27 #include "cker/Types.h"
28 #include "cker/Utils.h"
29 #include "fixedpoint/fixedpoint.h"
30
31 namespace nnfw
32 {
33 namespace cker
34 {
35 namespace optimized
36 {
37
38 template <typename ElementwiseF, typename ScalarBroadcastF, typename T>
39 inline void BinaryBroadcastFiveFold(const BinaryArithmeticOpParam &params, bool switch_inputs,
40                                     const Shape & /* unswitched_input1_shape */,
41                                     const T *unswitched_input1_data,
42                                     const Shape & /* unswitched_input2_shape */,
43                                     const T *unswitched_input2_data,
44                                     const Shape & /* output_shape */, T *output_data,
45                                     ElementwiseF elementwise_f, ScalarBroadcastF scalar_broadcast_f)
46 {
47   const T *input1_data = switch_inputs ? unswitched_input2_data : unswitched_input1_data;
48   const T *input2_data = switch_inputs ? unswitched_input1_data : unswitched_input2_data;
49
50   // Fivefold nested loops. The second input resets its position for each
51   // iteration of the second loop. The first input resets its position at the
52   // beginning of the fourth loop. The innermost loop is an elementwise add of
53   // sections of the arrays.
54   T *output_data_ptr = output_data;
55   const T *input1_data_ptr = input1_data;
56   const T *input2_data_reset = input2_data;
57   // In the fivefold pattern, y0, y2 and y4 are not broadcast, and so shared
58   // between input shapes. y3 for input 1 is always broadcast, and so the
59   // dimension there is 1, whereas optionally y1 might be broadcast for input 2.
60   // Put another way,
61   // input1.shape.FlatSize = y0 * y1 * y2 * y4,
62   // input2.shape.FlatSize = y0 * y2 * y3 * y4.
63   int y0 = params.broadcast_shape[0];
64   int y1 = params.broadcast_shape[1];
65   int y2 = params.broadcast_shape[2];
66   int y3 = params.broadcast_shape[3];
67   int y4 = params.broadcast_shape[4];
68   if (y4 > 1)
69   {
70     // General fivefold pattern, with y4 > 1 so there is a non-broadcast inner
71     // dimension.
72     for (int i0 = 0; i0 < y0; ++i0)
73     {
74       const T *input2_data_ptr = nullptr;
75       for (int i1 = 0; i1 < y1; ++i1)
76       {
77         input2_data_ptr = input2_data_reset;
78         for (int i2 = 0; i2 < y2; ++i2)
79         {
80           for (int i3 = 0; i3 < y3; ++i3)
81           {
82             elementwise_f(y4, params, input1_data_ptr, input2_data_ptr, output_data_ptr);
83             input2_data_ptr += y4;
84             output_data_ptr += y4;
85           }
86           // We have broadcast y4 of input1 data y3 times, and now move on.
87           input1_data_ptr += y4;
88         }
89       }
90       // We have broadcast y2*y3*y4 of input2 data y1 times, and now move on.
91       input2_data_reset = input2_data_ptr;
92     }
93   }
94   else
95   {
96     // Special case of y4 == 1, in which the innermost loop is a single element
97     // and can be combined with the next (y3) as an inner broadcast.
98     //
99     // Note that this handles the case of pure scalar broadcast when
100     // y0 == y1 == y2 == 1. With low overhead it handles cases such as scalar
101     // broadcast with batch (as y2 > 1).
102     //
103     // NOTE The process is the same as the above general case except simplified
104     // for y4 == 1 and the loop over y3 is contained within the
105     // AddScalarBroadcast function.
106     for (int i0 = 0; i0 < y0; ++i0)
107     {
108       const T *input2_data_ptr = nullptr;
109       for (int i1 = 0; i1 < y1; ++i1)
110       {
111         input2_data_ptr = input2_data_reset;
112         for (int i2 = 0; i2 < y2; ++i2)
113         {
114           scalar_broadcast_f(y3, params, *input1_data_ptr, input2_data_ptr, output_data_ptr);
115           input2_data_ptr += y3;
116           output_data_ptr += y3;
117           input1_data_ptr += 1;
118         }
119       }
120       input2_data_reset = input2_data_ptr;
121     }
122   }
123 }
124
125 inline int32_t quant8_sum(const BinaryArithmeticOpParam &params, const uint8_t input1_data,
126                           const uint8_t input2_data)
127 {
128   const int32_t input1_val = params.input1_offset + input1_data;
129   const int32_t input2_val = params.input2_offset + input2_data;
130   const int32_t shifted_input1_val = input1_val * (1 << params.left_shift);
131   const int32_t shifted_input2_val = input2_val * (1 << params.left_shift);
132   const int32_t scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
133       shifted_input1_val, params.input1_multiplier, params.input1_shift);
134   const int32_t scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
135       shifted_input2_val, params.input2_multiplier, params.input2_shift);
136   const int32_t raw_sum = scaled_input1_val + scaled_input2_val;
137   const int32_t raw_output = MultiplyByQuantizedMultiplierSmallerThanOneExp(
138                                  raw_sum, params.output_multiplier, params.output_shift) +
139                              params.output_offset;
140   const int32_t clamped_output = std::min(params.quantized_activation_max,
141                                           std::max(params.quantized_activation_min, raw_output));
142   return clamped_output;
143 }
144
145 inline void AddElementwiseQuant8(int size, const BinaryArithmeticOpParam &params,
146                                  const uint8_t *input1_data, const uint8_t *input2_data,
147                                  uint8_t *output_data)
148 {
149   int i = 0;
150
151 #ifdef USE_NEON
152   const uint8x8_t output_activation_min_vector = vdup_n_u8(params.quantized_activation_min);
153   const uint8x8_t output_activation_max_vector = vdup_n_u8(params.quantized_activation_max);
154   for (; i <= size - 8; i += 8)
155   {
156     const uint8x8_t input1_val_original = vld1_u8(input1_data + i);
157     const uint8x8_t input2_val_original = vld1_u8(input2_data + i);
158     const int16x8_t input1_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
159     const int16x8_t input2_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
160     const int16x8_t input1_val = vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset));
161     const int16x8_t input2_val = vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset));
162     const int16x4_t input1_val_high = vget_high_s16(input1_val);
163     const int16x4_t input1_val_low = vget_low_s16(input1_val);
164     const int16x4_t input2_val_high = vget_high_s16(input2_val);
165     const int16x4_t input2_val_low = vget_low_s16(input2_val);
166     int32x4_t x11 = vmovl_s16(input1_val_low);
167     int32x4_t x12 = vmovl_s16(input1_val_high);
168     int32x4_t x21 = vmovl_s16(input2_val_low);
169     int32x4_t x22 = vmovl_s16(input2_val_high);
170     const int32x4_t left_shift_dup = vdupq_n_s32(params.left_shift);
171     x11 = vshlq_s32(x11, left_shift_dup);
172     x12 = vshlq_s32(x12, left_shift_dup);
173     x21 = vshlq_s32(x21, left_shift_dup);
174     x22 = vshlq_s32(x22, left_shift_dup);
175     x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier);
176     x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier);
177     x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier);
178     x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier);
179     const int32x4_t input1_shift_dup = vdupq_n_s32(params.input1_shift);
180     const int32x4_t input2_shift_dup = vdupq_n_s32(params.input2_shift);
181     x11 = vshlq_s32(x11, input1_shift_dup);
182     x12 = vshlq_s32(x12, input1_shift_dup);
183     x21 = vshlq_s32(x21, input2_shift_dup);
184     x22 = vshlq_s32(x22, input2_shift_dup);
185     int32x4_t s1 = vaddq_s32(x11, x21);
186     int32x4_t s2 = vaddq_s32(x12, x22);
187     s1 = vqrdmulhq_n_s32(s1, params.output_multiplier);
188     s2 = vqrdmulhq_n_s32(s2, params.output_multiplier);
189     using gemmlowp::RoundingDivideByPOT;
190     s1 = RoundingDivideByPOT(s1, -params.output_shift);
191     s2 = RoundingDivideByPOT(s2, -params.output_shift);
192     const int16x4_t s1_narrowed = vmovn_s32(s1);
193     const int16x4_t s2_narrowed = vmovn_s32(s2);
194     const int16x8_t s =
195         vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed), vdupq_n_s16(params.output_offset));
196     const uint8x8_t clamped = vmax_u8(output_activation_min_vector,
197                                       vmin_u8(output_activation_max_vector, vqmovun_s16(s)));
198     vst1_u8(output_data + i, clamped);
199   }
200 #endif // NEON
201   for (; i < size; ++i)
202   {
203     const int32_t input1_val = params.input1_offset + input1_data[i];
204     const int32_t input2_val = params.input2_offset + input2_data[i];
205     const int32_t shifted_input1_val = input1_val * (1 << params.left_shift);
206     const int32_t shifted_input2_val = input2_val * (1 << params.left_shift);
207     const int32_t scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
208         shifted_input1_val, params.input1_multiplier, params.input1_shift);
209     const int32_t scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
210         shifted_input2_val, params.input2_multiplier, params.input2_shift);
211     const int32_t raw_sum = scaled_input1_val + scaled_input2_val;
212     const int32_t raw_output = MultiplyByQuantizedMultiplierSmallerThanOneExp(
213                                    raw_sum, params.output_multiplier, params.output_shift) +
214                                params.output_offset;
215     const int32_t clamped_output = std::min(params.quantized_activation_max,
216                                             std::max(params.quantized_activation_min, raw_output));
217     output_data[i] = static_cast<uint8_t>(clamped_output);
218   }
219 }
220
221 struct BinaryOpFuncAddFloat
222 {
223 #ifdef USE_NEON
224   static inline float32x4_t calculate(const float32x4_t &a, const float32x4_t &b)
225   {
226     return vaddq_f32(a, b);
227   }
228 #endif // USE_NEON
229   static inline float calculate(const float a, const float b) { return a + b; }
230 };
231
232 struct BinaryOpFuncSubFloat
233 {
234 #ifdef USE_NEON
235   static inline float32x4_t calculate(const float32x4_t &a, const float32x4_t &b)
236   {
237     return vsubq_f32(a, b);
238   }
239 #endif // USE_NEON
240   static inline float calculate(const float a, const float b) { return a - b; }
241 };
242
243 struct BinaryOpFuncMulFloat
244 {
245 #ifdef USE_NEON
246   static inline float32x4_t calculate(const float32x4_t &a, const float32x4_t &b)
247   {
248     return vmulq_f32(a, b);
249   }
250 #endif // USE_NEON
251   static inline float calculate(const float a, const float b) { return a * b; }
252 };
253
254 struct BinaryOpFuncDivFloat
255 {
256 #ifdef USE_NEON
257 #ifdef __aarch64__
258   static inline float32x4_t calculate(const float32x4_t &a, const float32x4_t &b)
259   {
260     return vdivq_f32(a, b);
261   }
262 #endif // __aarch64__
263 #endif // USE_NEON
264   static inline float calculate(const float a, const float b) { return a / b; }
265 };
266
267 template <class BASEOPERATOR> struct BinaryOpFuncSwapArgs
268 {
269   template <typename T> static inline T calculate(const T &a, const T &b)
270   {
271     return BASEOPERATOR::calculate(b, a);
272   }
273 };
274
275 struct BinaryOpActivationFloatNone
276 {
277 #ifdef USE_NEON
278   static inline float32x4_t applyCeiling(const float32x4_t &value, const float32x4_t &ceilingParam)
279   {
280     (void)ceilingParam; // suppress unused argument warning
281     return value;
282   }
283   static inline float32x4_t applyFloor(const float32x4_t &value, const float32x4_t &floorParam)
284   {
285     (void)floorParam;
286     return value;
287   }
288 #endif // USE_NEON
289   static inline float applyCeiling(const float value, const float ceilingParam)
290   {
291     (void)ceilingParam;
292     return value;
293   }
294   static inline float applyFloor(const float value, const float floorParam)
295   {
296     (void)floorParam;
297     return value;
298   }
299 };
300
301 struct BinaryOpActivationFloatMax
302 {
303 #ifdef USE_NEON
304   static inline float32x4_t applyCeiling(const float32x4_t &value, const float32x4_t &ceilingParam)
305   {
306     (void)ceilingParam; // suppress unused argument warning
307     return value;
308   }
309   static inline float32x4_t applyFloor(const float32x4_t &value, const float32x4_t &floorParam)
310   {
311     return vmaxq_f32(value, floorParam);
312   }
313 #endif // USE_NEON
314   static inline float applyCeiling(const float value, const float ceilingParam)
315   {
316     (void)ceilingParam;
317     return value;
318   }
319   static inline float applyFloor(const float value, const float floorParam)
320   {
321     return std::max(value, floorParam);
322   }
323 };
324
325 struct BinaryOpActivationFloatMinMax
326 {
327 #ifdef USE_NEON
328   static inline float32x4_t applyCeiling(const float32x4_t &value, const float32x4_t &ceilingParam)
329   {
330     return vminq_f32(value, ceilingParam);
331   }
332   static inline float32x4_t applyFloor(const float32x4_t &value, const float32x4_t &floorParam)
333   {
334     return vmaxq_f32(value, floorParam);
335   }
336 #endif // USE_NEON
337   static inline float applyCeiling(const float value, const float ceilingParam)
338   {
339     return std::min(value, ceilingParam);
340   }
341   static inline float applyFloor(const float value, const float floorParam)
342   {
343     return std::max(value, floorParam);
344   }
345 };
346
347 template <class OPERATOR, class ACTIVATION>
348 inline void BinaryOpElementwise(int size, const BinaryArithmeticOpParam &params,
349                                 const float *input1_data, const float *input2_data,
350                                 float *output_data)
351 {
352   int i = 0;
353
354 #ifdef USE_NEON
355   const auto activation_min = vdupq_n_f32(params.float_activation_min);
356   const auto activation_max = vdupq_n_f32(params.float_activation_max);
357   for (; i <= size - 16; i += 16)
358   {
359     auto a10 = vld1q_f32(input1_data + i);
360     auto a11 = vld1q_f32(input1_data + i + 4);
361     auto a12 = vld1q_f32(input1_data + i + 8);
362     auto a13 = vld1q_f32(input1_data + i + 12);
363     auto a20 = vld1q_f32(input2_data + i);
364     auto a21 = vld1q_f32(input2_data + i + 4);
365     auto a22 = vld1q_f32(input2_data + i + 8);
366     auto a23 = vld1q_f32(input2_data + i + 12);
367     auto x0 = OPERATOR::calculate(a10, a20);
368     auto x1 = OPERATOR::calculate(a11, a21);
369     auto x2 = OPERATOR::calculate(a12, a22);
370     auto x3 = OPERATOR::calculate(a13, a23);
371     x0 = ACTIVATION::applyFloor(x0, activation_min);
372     x1 = ACTIVATION::applyFloor(x1, activation_min);
373     x2 = ACTIVATION::applyFloor(x2, activation_min);
374     x3 = ACTIVATION::applyFloor(x3, activation_min);
375     x0 = ACTIVATION::applyCeiling(x0, activation_max);
376     x1 = ACTIVATION::applyCeiling(x1, activation_max);
377     x2 = ACTIVATION::applyCeiling(x2, activation_max);
378     x3 = ACTIVATION::applyCeiling(x3, activation_max);
379     vst1q_f32(output_data + i, x0);
380     vst1q_f32(output_data + i + 4, x1);
381     vst1q_f32(output_data + i + 8, x2);
382     vst1q_f32(output_data + i + 12, x3);
383   }
384   for (; i <= size - 4; i += 4)
385   {
386     auto a1 = vld1q_f32(input1_data + i);
387     auto a2 = vld1q_f32(input2_data + i);
388     auto x = OPERATOR::calculate(a1, a2); // vaddq
389     auto x_clamped =
390         ACTIVATION::applyCeiling(ACTIVATION::applyFloor(x, activation_min), activation_max);
391     vst1q_f32(output_data + i, x_clamped);
392   }
393 #endif // USE_NEON
394   for (; i < size; i++)
395   {
396     auto x = OPERATOR::calculate(input1_data[i], input2_data[i]);
397     output_data[i] = ACTIVATION::applyCeiling(
398         ACTIVATION::applyFloor(x, params.float_activation_min), params.float_activation_max);
399   }
400 }
401
402 // Broadcast binary op template that can often be used for inner loop
403 // This function will handle scalar_value (LHS) and vector_values (RHS).
404 // Since it's a float function, input params does not matter here.
405 template <class OPERATOR, class ACTIVATION>
406 inline void BinaryOpScalarBroadcast(int size, const BinaryArithmeticOpParam &params,
407                                     const float broadcast_value, const float *input2_data,
408                                     float *output_data)
409 {
410   int i = 0;
411
412 #ifdef USE_NEON
413   const auto activation_min = vdupq_n_f32(params.float_activation_min);
414   const auto activation_max = vdupq_n_f32(params.float_activation_max);
415   const auto broadcast_value_dup = vdupq_n_f32(broadcast_value);
416   for (; i <= size - 16; i += 16)
417   {
418     auto a20 = vld1q_f32(input2_data + i);
419     auto a21 = vld1q_f32(input2_data + i + 4);
420     auto a22 = vld1q_f32(input2_data + i + 8);
421     auto a23 = vld1q_f32(input2_data + i + 12);
422     auto x0 = OPERATOR::calculate(broadcast_value_dup, a20);
423     auto x1 = OPERATOR::calculate(broadcast_value_dup, a21);
424     auto x2 = OPERATOR::calculate(broadcast_value_dup, a22);
425     auto x3 = OPERATOR::calculate(broadcast_value_dup, a23);
426     x0 = ACTIVATION::applyFloor(x0, activation_min);
427     x1 = ACTIVATION::applyFloor(x1, activation_min);
428     x2 = ACTIVATION::applyFloor(x2, activation_min);
429     x3 = ACTIVATION::applyFloor(x3, activation_min);
430     x0 = ACTIVATION::applyCeiling(x0, activation_max);
431     x1 = ACTIVATION::applyCeiling(x1, activation_max);
432     x2 = ACTIVATION::applyCeiling(x2, activation_max);
433     x3 = ACTIVATION::applyCeiling(x3, activation_max);
434     vst1q_f32(output_data + i, x0);
435     vst1q_f32(output_data + i + 4, x1);
436     vst1q_f32(output_data + i + 8, x2);
437     vst1q_f32(output_data + i + 12, x3);
438   }
439   for (; i <= size - 4; i += 4)
440   {
441     auto a2 = vld1q_f32(input2_data + i);
442     auto x = OPERATOR::calculate(broadcast_value_dup, a2);
443     auto x_clamped =
444         ACTIVATION::applyCeiling(ACTIVATION::applyFloor(x, activation_min), activation_max);
445     vst1q_f32(output_data + i, x_clamped);
446   }
447 #endif // USE_NEON
448   for (; i < size; i++)
449   {
450     auto x = OPERATOR::calculate(broadcast_value, input2_data[i]);
451     output_data[i] = ACTIVATION::applyCeiling(
452         ACTIVATION::applyFloor(x, params.float_activation_min), params.float_activation_max);
453   }
454 }
455
456 using BinaryOpImplFloatFuncs =
457     std::pair<void (*)(int, const BinaryArithmeticOpParam &, const float *, const float *, float *),
458               void (*)(int, const BinaryArithmeticOpParam &, const float, const float *, float *)>;
459
460 template <class FUNC>
461 inline BinaryOpImplFloatFuncs
462 getBinaryOpWithActivationImplFloat(const BinaryArithmeticOpParam &params)
463 {
464   if (params.float_activation_max == std::numeric_limits<float>::max())
465     if (params.float_activation_min == std::numeric_limits<float>::lowest())
466       return BinaryOpImplFloatFuncs(BinaryOpElementwise<FUNC, BinaryOpActivationFloatNone>,
467                                     BinaryOpScalarBroadcast<FUNC, BinaryOpActivationFloatNone>);
468     else
469       return BinaryOpImplFloatFuncs(BinaryOpElementwise<FUNC, BinaryOpActivationFloatMax>,
470                                     BinaryOpScalarBroadcast<FUNC, BinaryOpActivationFloatMax>);
471   else
472     return BinaryOpImplFloatFuncs(BinaryOpElementwise<FUNC, BinaryOpActivationFloatMinMax>,
473                                   BinaryOpScalarBroadcast<FUNC, BinaryOpActivationFloatMinMax>);
474 }
475
476 inline void AddQuant8(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
477                       const uint8_t *input1_data, const Shape &input2_shape,
478                       const uint8_t *input2_data, const Shape &output_shape, uint8_t *output_data)
479 {
480   const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
481   AddElementwiseQuant8(flat_size, params, input1_data, input2_data, output_data);
482 }
483
484 inline void Add(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
485                 const float *input1_data, const Shape &input2_shape, const float *input2_data,
486                 const Shape &output_shape, float *output_data)
487 {
488   const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
489   auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncAddFloat>(params);
490   (*implFuncs.first)(flat_size, params, input1_data, input2_data, output_data);
491 }
492
493 // Scalar-broadcast add that can be used for inner loop of more general
494 // broadcast add, so that, for example, scalar-broadcast with batch will still
495 // be fast.
496 inline void AddScalarBroadcastQuant8(int size, const BinaryArithmeticOpParam &params,
497                                      uint8_t broadcast_value, const uint8_t *input2_data,
498                                      uint8_t *output_data)
499 {
500   int i = 0;
501   int32_t clamped_output;
502   for (; i < size; ++i)
503   {
504     clamped_output = quant8_sum(params, broadcast_value, input2_data[i]);
505     output_data[i] = static_cast<uint8_t>(clamped_output);
506   }
507 }
508
509 inline void BroadcastAddDispatchQuant8(const BinaryArithmeticOpParam &params,
510                                        const Shape &input1_shape, const uint8_t *input1_data,
511                                        const Shape &input2_shape, const uint8_t *input2_data,
512                                        const Shape &output_shape, uint8_t *output_data)
513 {
514   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast)
515   {
516     const std::function<uint8_t(const BinaryArithmeticOpParam &, const uint8_t &, const uint8_t &)>
517         fn = [](const BinaryArithmeticOpParam &params, const uint8_t &a,
518                 const uint8_t &b) -> uint8_t {
519       return static_cast<uint8_t>(quant8_sum(params, a, b));
520     };
521     reference::BroadcastBinaryArithmeticOpSlowQuant8(params, input1_shape, input1_data,
522                                                      input2_shape, input2_data, output_shape,
523                                                      output_data, fn);
524   }
525   else
526   {
527     BinaryBroadcastFiveFold(
528         params, params.broadcast_category == BroadcastableOpCategory::kSecondInputBroadcastsFast,
529         input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data,
530         static_cast<void (*)(int, const BinaryArithmeticOpParam &, const uint8_t *, const uint8_t *,
531                              uint8_t *)>(AddElementwiseQuant8),
532         static_cast<void (*)(int, const BinaryArithmeticOpParam &, uint8_t, const uint8_t *,
533                              uint8_t *)>(AddScalarBroadcastQuant8));
534   }
535 }
536
537 inline void BroadcastAddDispatch(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
538                                  const float *input1_data, const Shape &input2_shape,
539                                  const float *input2_data, const Shape &output_shape,
540                                  float *output_data)
541 {
542   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast)
543   {
544     const std::function<float(const float &, const float &)> fn =
545         [](const float &a, const float &b) -> float { return a + b; };
546     reference::BroadcastBinaryArithmeticOpSlow(params, input1_shape, input1_data, input2_shape,
547                                                input2_data, output_shape, output_data, fn);
548   }
549   else
550   {
551     auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncAddFloat>(params);
552
553     BinaryBroadcastFiveFold(params, params.broadcast_category ==
554                                         BroadcastableOpCategory::kSecondInputBroadcastsFast,
555                             input1_shape, input1_data, input2_shape, input2_data, output_shape,
556                             output_data, implFuncs.first, implFuncs.second);
557   }
558 }
559
560 inline void Sub(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
561                 const float *input1_data, const Shape &input2_shape, const float *input2_data,
562                 const Shape &output_shape, float *output_data)
563 {
564   const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
565   auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncSubFloat>(params);
566   (*implFuncs.first)(flat_size, params, input1_data, input2_data, output_data);
567 }
568
569 inline void BroadcastSubDispatch(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
570                                  const float *input1_data, const Shape &input2_shape,
571                                  const float *input2_data, const Shape &output_shape,
572                                  float *output_data)
573 {
574   if (params.broadcast_category == BroadcastableOpCategory::kFirstInputBroadcastsFast)
575   {
576     auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncSubFloat>(params);
577     BinaryBroadcastFiveFold(params, false, input1_shape, input1_data, input2_shape, input2_data,
578                             output_shape, output_data, implFuncs.first, implFuncs.second);
579   }
580   else if (params.broadcast_category == BroadcastableOpCategory::kSecondInputBroadcastsFast)
581   {
582     auto implFuncs =
583         getBinaryOpWithActivationImplFloat<BinaryOpFuncSwapArgs<BinaryOpFuncSubFloat>>(params);
584     BinaryBroadcastFiveFold(params, true, input1_shape, input1_data, input2_shape, input2_data,
585                             output_shape, output_data, implFuncs.first, implFuncs.second);
586   }
587   else
588   {
589     const std::function<float(const float &, const float &)> fn =
590         [](const float &a, const float &b) -> float { return a - b; };
591     reference::BroadcastBinaryArithmeticOpSlow(params, input1_shape, input1_data, input2_shape,
592                                                input2_data, output_shape, output_data, fn);
593   }
594 }
595
596 inline int32_t quant8_mul(const BinaryArithmeticOpParam &params, const uint8_t input1_data,
597                           const uint8_t input2_data)
598 {
599   const int32_t input1_val = params.input1_offset + input1_data;
600   const int32_t input2_val = params.input2_offset + input2_data;
601   const int32_t unclamped_result =
602       params.output_offset + MultiplyByQuantizedMultiplier(input1_val * input2_val,
603                                                            params.output_multiplier,
604                                                            params.output_shift);
605   const int32_t clamped_output = std::min(
606       params.quantized_activation_max, std::max(params.quantized_activation_min, unclamped_result));
607
608   return clamped_output;
609 }
610
611 inline void MulElementwiseQuant8(int size, const BinaryArithmeticOpParam &params,
612                                  const uint8_t *input1_data, const uint8_t *input2_data,
613                                  uint8_t *output_data)
614 {
615   int i = 0;
616
617 #ifdef USE_NEON
618   const auto input1_offset_vector = vdupq_n_s16(params.input1_offset);
619   const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
620   const auto output_offset_vector = vdupq_n_s16(params.output_offset);
621   const auto output_activation_min_vector = vdup_n_u8(params.quantized_activation_min);
622   const auto output_activation_max_vector = vdup_n_u8(params.quantized_activation_max);
623   const int left_shift = std::max(0, params.output_shift);
624   const int right_shift = std::max(0, -params.output_shift);
625   const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
626   for (; i <= size - 8; i += 8)
627   {
628     // We load / store 8 at a time, multiplying as two sets of 4 int32s.
629     const auto input1_val_original = vld1_u8(input1_data + i);
630     const auto input2_val_original = vld1_u8(input2_data + i);
631     const auto input1_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
632     const auto input2_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
633     const auto input1_val = vaddq_s16(input1_val_s16, input1_offset_vector);
634     const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
635
636     const auto input1_val_low = vget_low_s16(input1_val);
637     const auto input1_val_high = vget_high_s16(input1_val);
638     const auto input2_val_low = vget_low_s16(input2_val);
639     const auto input2_val_high = vget_high_s16(input2_val);
640
641     auto p1 = vmull_s16(input2_val_low, input1_val_low);
642     auto p2 = vmull_s16(input2_val_high, input1_val_high);
643
644     p1 = vshlq_s32(p1, left_shift_vec);
645     p2 = vshlq_s32(p2, left_shift_vec);
646     p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
647     p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
648     using gemmlowp::RoundingDivideByPOT;
649     p1 = RoundingDivideByPOT(p1, right_shift);
650     p2 = RoundingDivideByPOT(p2, right_shift);
651
652     const auto p1_narrowed = vqmovn_s32(p1);
653     const auto p2_narrowed = vqmovn_s32(p2);
654     const auto p = vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
655     const auto clamped = vmax_u8(output_activation_min_vector,
656                                  vmin_u8(output_activation_max_vector, vqmovun_s16(p)));
657     vst1_u8(output_data + i, clamped);
658   }
659 #endif // NEON
660
661   for (; i < size; ++i)
662   {
663     const int32_t input1_val = params.input1_offset + input1_data[i];
664     const int32_t input2_val = params.input2_offset + input2_data[i];
665     const int32_t unclamped_result =
666         params.output_offset + MultiplyByQuantizedMultiplier(input1_val * input2_val,
667                                                              params.output_multiplier,
668                                                              params.output_shift);
669     const int32_t clamped_output =
670         std::min(params.quantized_activation_max,
671                  std::max(params.quantized_activation_min, unclamped_result));
672     output_data[i] = static_cast<uint8_t>(clamped_output);
673   }
674 }
675
676 inline void MulQuant8(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
677                       const uint8_t *input1_data, const Shape &input2_shape,
678                       const uint8_t *input2_data, const Shape &output_shape, uint8_t *output_data)
679 {
680   const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
681   MulElementwiseQuant8(flat_size, params, input1_data, input2_data, output_data);
682 }
683
684 inline void Mul(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
685                 const float *input1_data, const Shape &input2_shape, const float *input2_data,
686                 const Shape &output_shape, float *output_data)
687 {
688   const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
689   auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncMulFloat>(params);
690   (*implFuncs.first)(flat_size, params, input1_data, input2_data, output_data);
691 }
692
693 inline void MulSimpleBroadcastQuant8(int size, const BinaryArithmeticOpParam &params,
694                                      const uint8_t broadcast_value, const uint8_t *input2_data,
695                                      uint8_t *output_data)
696 {
697   int i = 0;
698   int32_t clamped_output;
699   for (; i < size; ++i)
700   {
701     clamped_output = quant8_mul(params, broadcast_value, input2_data[i]);
702     output_data[i] = static_cast<uint8_t>(clamped_output);
703   }
704 }
705
706 inline void BroadcastMulDispatchQuant8(const BinaryArithmeticOpParam &params,
707                                        const Shape &input1_shape, const uint8_t *input1_data,
708                                        const Shape &input2_shape, const uint8_t *input2_data,
709                                        const Shape &output_shape, uint8_t *output_data)
710 {
711   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast)
712   {
713     const std::function<uint8_t(const BinaryArithmeticOpParam &, const uint8_t &, const uint8_t &)>
714         fn = [](const BinaryArithmeticOpParam &params, const uint8_t &a,
715                 const uint8_t &b) -> uint8_t {
716       return static_cast<uint8_t>(quant8_mul(params, a, b));
717     };
718     reference::BroadcastBinaryArithmeticOpSlowQuant8(params, input1_shape, input1_data,
719                                                      input2_shape, input2_data, output_shape,
720                                                      output_data, fn);
721     return;
722   }
723   BinaryBroadcastFiveFold(
724       params, params.broadcast_category == BroadcastableOpCategory::kSecondInputBroadcastsFast,
725       input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data,
726       static_cast<void (*)(int, const BinaryArithmeticOpParam &, const uint8_t *, const uint8_t *,
727                            uint8_t *)>(MulElementwiseQuant8),
728       static_cast<void (*)(int, const BinaryArithmeticOpParam &, uint8_t, const uint8_t *,
729                            uint8_t *)>(MulSimpleBroadcastQuant8));
730 }
731
732 inline void BroadcastMulDispatch(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
733                                  const float *input1_data, const Shape &input2_shape,
734                                  const float *input2_data, const Shape &output_shape,
735                                  float *output_data)
736 {
737   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast)
738   {
739     // TODO: Use GetBinaryArithmeticFn
740     const std::function<float(const float &, const float &)> fn =
741         [](const float &a, const float &b) -> float { return a * b; };
742     reference::BroadcastBinaryArithmeticOpSlow(params, input1_shape, input1_data, input2_shape,
743                                                input2_data, output_shape, output_data, fn);
744     return;
745   }
746   auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncMulFloat>(params);
747   BinaryBroadcastFiveFold(params, params.broadcast_category ==
748                                       BroadcastableOpCategory::kSecondInputBroadcastsFast,
749                           input1_shape, input1_data, input2_shape, input2_data, output_shape,
750                           output_data, implFuncs.first, implFuncs.second);
751 }
752
753 inline void Div(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
754                 const float *input1_data, const Shape &input2_shape, const float *input2_data,
755                 const Shape &output_shape, float *output_data)
756 {
757 #ifdef __aarch64__
758   const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
759   auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncDivFloat>(params);
760   (*implFuncs.first)(flat_size, params, input1_data, input2_data, output_data);
761 #else
762   const std::function<float(const float &, const float &)> fn =
763       [](const float &a, const float &b) -> float { return a / b; };
764   reference::BinaryArithmeticOp(params, input1_shape, input1_data, input2_shape, input2_data,
765                                 output_shape, output_data, fn);
766 #endif // __aarch64__
767 }
768
769 inline void BroadcastDivDispatch(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
770                                  const float *input1_data, const Shape &input2_shape,
771                                  const float *input2_data, const Shape &output_shape,
772                                  float *output_data)
773 {
774 #ifdef __aarch64__
775   if (params.broadcast_category == BroadcastableOpCategory::kFirstInputBroadcastsFast)
776   {
777     auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncDivFloat>(params);
778     BinaryBroadcastFiveFold(params, false, input1_shape, input1_data, input2_shape, input2_data,
779                             output_shape, output_data, implFuncs.first, implFuncs.second);
780   }
781   else if (params.broadcast_category == BroadcastableOpCategory::kSecondInputBroadcastsFast)
782   {
783     auto implFuncs =
784         getBinaryOpWithActivationImplFloat<BinaryOpFuncSwapArgs<BinaryOpFuncDivFloat>>(params);
785     BinaryBroadcastFiveFold(params, true, input1_shape, input1_data, input2_shape, input2_data,
786                             output_shape, output_data, implFuncs.first, implFuncs.second);
787   }
788   else
789 #endif // __aarch64__
790   {
791     const std::function<float(const float &, const float &)> fn =
792         [](const float &a, const float &b) -> float { return a / b; };
793     reference::BroadcastBinaryArithmeticOpSlow(params, input1_shape, input1_data, input2_shape,
794                                                input2_data, output_shape, output_data, fn);
795   }
796 }
797
798 } // namespace optimized
799 } // namespace cker
800 } // namespace nnfw
801
802 #endif // __NNFW_CKER_OPTIMIZED_BINARYARITHMETICOPS_H__