2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef __NNFW_CKER_OPTIMIZED_BINARYARITHMETICOPS_H__
19 #define __NNFW_CKER_OPTIMIZED_BINARYARITHMETICOPS_H__
24 #include "cker/neon/neon_check.h"
25 #include "cker/operation/reference/BinaryArithmeticOps.h"
26 #include "cker/Shape.h"
27 #include "cker/Types.h"
28 #include "cker/Utils.h"
29 #include "fixedpoint/fixedpoint.h"
38 template <typename ElementwiseF, typename ScalarBroadcastF, typename T>
39 inline void BinaryBroadcastFiveFold(const BinaryArithmeticOpParam ¶ms, bool switch_inputs,
40 const Shape & /* unswitched_input1_shape */,
41 const T *unswitched_input1_data,
42 const Shape & /* unswitched_input2_shape */,
43 const T *unswitched_input2_data,
44 const Shape & /* output_shape */, T *output_data,
45 ElementwiseF elementwise_f, ScalarBroadcastF scalar_broadcast_f)
47 const T *input1_data = switch_inputs ? unswitched_input2_data : unswitched_input1_data;
48 const T *input2_data = switch_inputs ? unswitched_input1_data : unswitched_input2_data;
50 // Fivefold nested loops. The second input resets its position for each
51 // iteration of the second loop. The first input resets its position at the
52 // beginning of the fourth loop. The innermost loop is an elementwise add of
53 // sections of the arrays.
54 T *output_data_ptr = output_data;
55 const T *input1_data_ptr = input1_data;
56 const T *input2_data_reset = input2_data;
57 // In the fivefold pattern, y0, y2 and y4 are not broadcast, and so shared
58 // between input shapes. y3 for input 1 is always broadcast, and so the
59 // dimension there is 1, whereas optionally y1 might be broadcast for input 2.
61 // input1.shape.FlatSize = y0 * y1 * y2 * y4,
62 // input2.shape.FlatSize = y0 * y2 * y3 * y4.
63 int y0 = params.broadcast_shape[0];
64 int y1 = params.broadcast_shape[1];
65 int y2 = params.broadcast_shape[2];
66 int y3 = params.broadcast_shape[3];
67 int y4 = params.broadcast_shape[4];
70 // General fivefold pattern, with y4 > 1 so there is a non-broadcast inner
72 for (int i0 = 0; i0 < y0; ++i0)
74 const T *input2_data_ptr = nullptr;
75 for (int i1 = 0; i1 < y1; ++i1)
77 input2_data_ptr = input2_data_reset;
78 for (int i2 = 0; i2 < y2; ++i2)
80 for (int i3 = 0; i3 < y3; ++i3)
82 elementwise_f(y4, params, input1_data_ptr, input2_data_ptr, output_data_ptr);
83 input2_data_ptr += y4;
84 output_data_ptr += y4;
86 // We have broadcast y4 of input1 data y3 times, and now move on.
87 input1_data_ptr += y4;
90 // We have broadcast y2*y3*y4 of input2 data y1 times, and now move on.
91 input2_data_reset = input2_data_ptr;
96 // Special case of y4 == 1, in which the innermost loop is a single element
97 // and can be combined with the next (y3) as an inner broadcast.
99 // Note that this handles the case of pure scalar broadcast when
100 // y0 == y1 == y2 == 1. With low overhead it handles cases such as scalar
101 // broadcast with batch (as y2 > 1).
103 // NOTE The process is the same as the above general case except simplified
104 // for y4 == 1 and the loop over y3 is contained within the
105 // AddScalarBroadcast function.
106 for (int i0 = 0; i0 < y0; ++i0)
108 const T *input2_data_ptr = nullptr;
109 for (int i1 = 0; i1 < y1; ++i1)
111 input2_data_ptr = input2_data_reset;
112 for (int i2 = 0; i2 < y2; ++i2)
114 scalar_broadcast_f(y3, params, *input1_data_ptr, input2_data_ptr, output_data_ptr);
115 input2_data_ptr += y3;
116 output_data_ptr += y3;
117 input1_data_ptr += 1;
120 input2_data_reset = input2_data_ptr;
125 inline int32_t quant8_sum(const BinaryArithmeticOpParam ¶ms, const uint8_t input1_data,
126 const uint8_t input2_data)
128 const int32_t input1_val = params.input1_offset + input1_data;
129 const int32_t input2_val = params.input2_offset + input2_data;
130 const int32_t shifted_input1_val = input1_val * (1 << params.left_shift);
131 const int32_t shifted_input2_val = input2_val * (1 << params.left_shift);
132 const int32_t scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
133 shifted_input1_val, params.input1_multiplier, params.input1_shift);
134 const int32_t scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
135 shifted_input2_val, params.input2_multiplier, params.input2_shift);
136 const int32_t raw_sum = scaled_input1_val + scaled_input2_val;
137 const int32_t raw_output = MultiplyByQuantizedMultiplierSmallerThanOneExp(
138 raw_sum, params.output_multiplier, params.output_shift) +
139 params.output_offset;
140 const int32_t clamped_output = std::min(params.quantized_activation_max,
141 std::max(params.quantized_activation_min, raw_output));
142 return clamped_output;
145 inline void AddElementwiseQuant8(int size, const BinaryArithmeticOpParam ¶ms,
146 const uint8_t *input1_data, const uint8_t *input2_data,
147 uint8_t *output_data)
152 const uint8x8_t output_activation_min_vector = vdup_n_u8(params.quantized_activation_min);
153 const uint8x8_t output_activation_max_vector = vdup_n_u8(params.quantized_activation_max);
154 for (; i <= size - 8; i += 8)
156 const uint8x8_t input1_val_original = vld1_u8(input1_data + i);
157 const uint8x8_t input2_val_original = vld1_u8(input2_data + i);
158 const int16x8_t input1_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
159 const int16x8_t input2_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
160 const int16x8_t input1_val = vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset));
161 const int16x8_t input2_val = vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset));
162 const int16x4_t input1_val_high = vget_high_s16(input1_val);
163 const int16x4_t input1_val_low = vget_low_s16(input1_val);
164 const int16x4_t input2_val_high = vget_high_s16(input2_val);
165 const int16x4_t input2_val_low = vget_low_s16(input2_val);
166 int32x4_t x11 = vmovl_s16(input1_val_low);
167 int32x4_t x12 = vmovl_s16(input1_val_high);
168 int32x4_t x21 = vmovl_s16(input2_val_low);
169 int32x4_t x22 = vmovl_s16(input2_val_high);
170 const int32x4_t left_shift_dup = vdupq_n_s32(params.left_shift);
171 x11 = vshlq_s32(x11, left_shift_dup);
172 x12 = vshlq_s32(x12, left_shift_dup);
173 x21 = vshlq_s32(x21, left_shift_dup);
174 x22 = vshlq_s32(x22, left_shift_dup);
175 x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier);
176 x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier);
177 x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier);
178 x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier);
179 const int32x4_t input1_shift_dup = vdupq_n_s32(params.input1_shift);
180 const int32x4_t input2_shift_dup = vdupq_n_s32(params.input2_shift);
181 x11 = vshlq_s32(x11, input1_shift_dup);
182 x12 = vshlq_s32(x12, input1_shift_dup);
183 x21 = vshlq_s32(x21, input2_shift_dup);
184 x22 = vshlq_s32(x22, input2_shift_dup);
185 int32x4_t s1 = vaddq_s32(x11, x21);
186 int32x4_t s2 = vaddq_s32(x12, x22);
187 s1 = vqrdmulhq_n_s32(s1, params.output_multiplier);
188 s2 = vqrdmulhq_n_s32(s2, params.output_multiplier);
189 using gemmlowp::RoundingDivideByPOT;
190 s1 = RoundingDivideByPOT(s1, -params.output_shift);
191 s2 = RoundingDivideByPOT(s2, -params.output_shift);
192 const int16x4_t s1_narrowed = vmovn_s32(s1);
193 const int16x4_t s2_narrowed = vmovn_s32(s2);
195 vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed), vdupq_n_s16(params.output_offset));
196 const uint8x8_t clamped = vmax_u8(output_activation_min_vector,
197 vmin_u8(output_activation_max_vector, vqmovun_s16(s)));
198 vst1_u8(output_data + i, clamped);
201 for (; i < size; ++i)
203 const int32_t input1_val = params.input1_offset + input1_data[i];
204 const int32_t input2_val = params.input2_offset + input2_data[i];
205 const int32_t shifted_input1_val = input1_val * (1 << params.left_shift);
206 const int32_t shifted_input2_val = input2_val * (1 << params.left_shift);
207 const int32_t scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
208 shifted_input1_val, params.input1_multiplier, params.input1_shift);
209 const int32_t scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
210 shifted_input2_val, params.input2_multiplier, params.input2_shift);
211 const int32_t raw_sum = scaled_input1_val + scaled_input2_val;
212 const int32_t raw_output = MultiplyByQuantizedMultiplierSmallerThanOneExp(
213 raw_sum, params.output_multiplier, params.output_shift) +
214 params.output_offset;
215 const int32_t clamped_output = std::min(params.quantized_activation_max,
216 std::max(params.quantized_activation_min, raw_output));
217 output_data[i] = static_cast<uint8_t>(clamped_output);
221 struct BinaryOpFuncAddFloat
224 static inline float32x4_t calculate(const float32x4_t &a, const float32x4_t &b)
226 return vaddq_f32(a, b);
229 static inline float calculate(const float a, const float b) { return a + b; }
232 struct BinaryOpFuncSubFloat
235 static inline float32x4_t calculate(const float32x4_t &a, const float32x4_t &b)
237 return vsubq_f32(a, b);
240 static inline float calculate(const float a, const float b) { return a - b; }
243 struct BinaryOpFuncMulFloat
246 static inline float32x4_t calculate(const float32x4_t &a, const float32x4_t &b)
248 return vmulq_f32(a, b);
251 static inline float calculate(const float a, const float b) { return a * b; }
254 struct BinaryOpFuncDivFloat
258 static inline float32x4_t calculate(const float32x4_t &a, const float32x4_t &b)
260 return vdivq_f32(a, b);
262 #endif // __aarch64__
264 static inline float calculate(const float a, const float b) { return a / b; }
267 template <class BASEOPERATOR> struct BinaryOpFuncSwapArgs
269 template <typename T> static inline T calculate(const T &a, const T &b)
271 return BASEOPERATOR::calculate(b, a);
275 struct BinaryOpActivationFloatNone
278 static inline float32x4_t applyCeiling(const float32x4_t &value, const float32x4_t &ceilingParam)
280 (void)ceilingParam; // suppress unused argument warning
283 static inline float32x4_t applyFloor(const float32x4_t &value, const float32x4_t &floorParam)
289 static inline float applyCeiling(const float value, const float ceilingParam)
294 static inline float applyFloor(const float value, const float floorParam)
301 struct BinaryOpActivationFloatMax
304 static inline float32x4_t applyCeiling(const float32x4_t &value, const float32x4_t &ceilingParam)
306 (void)ceilingParam; // suppress unused argument warning
309 static inline float32x4_t applyFloor(const float32x4_t &value, const float32x4_t &floorParam)
311 return vmaxq_f32(value, floorParam);
314 static inline float applyCeiling(const float value, const float ceilingParam)
319 static inline float applyFloor(const float value, const float floorParam)
321 return std::max(value, floorParam);
325 struct BinaryOpActivationFloatMinMax
328 static inline float32x4_t applyCeiling(const float32x4_t &value, const float32x4_t &ceilingParam)
330 return vminq_f32(value, ceilingParam);
332 static inline float32x4_t applyFloor(const float32x4_t &value, const float32x4_t &floorParam)
334 return vmaxq_f32(value, floorParam);
337 static inline float applyCeiling(const float value, const float ceilingParam)
339 return std::min(value, ceilingParam);
341 static inline float applyFloor(const float value, const float floorParam)
343 return std::max(value, floorParam);
347 template <class OPERATOR, class ACTIVATION>
348 inline void BinaryOpElementwise(int size, const BinaryArithmeticOpParam ¶ms,
349 const float *input1_data, const float *input2_data,
355 const auto activation_min = vdupq_n_f32(params.float_activation_min);
356 const auto activation_max = vdupq_n_f32(params.float_activation_max);
357 for (; i <= size - 16; i += 16)
359 auto a10 = vld1q_f32(input1_data + i);
360 auto a11 = vld1q_f32(input1_data + i + 4);
361 auto a12 = vld1q_f32(input1_data + i + 8);
362 auto a13 = vld1q_f32(input1_data + i + 12);
363 auto a20 = vld1q_f32(input2_data + i);
364 auto a21 = vld1q_f32(input2_data + i + 4);
365 auto a22 = vld1q_f32(input2_data + i + 8);
366 auto a23 = vld1q_f32(input2_data + i + 12);
367 auto x0 = OPERATOR::calculate(a10, a20);
368 auto x1 = OPERATOR::calculate(a11, a21);
369 auto x2 = OPERATOR::calculate(a12, a22);
370 auto x3 = OPERATOR::calculate(a13, a23);
371 x0 = ACTIVATION::applyFloor(x0, activation_min);
372 x1 = ACTIVATION::applyFloor(x1, activation_min);
373 x2 = ACTIVATION::applyFloor(x2, activation_min);
374 x3 = ACTIVATION::applyFloor(x3, activation_min);
375 x0 = ACTIVATION::applyCeiling(x0, activation_max);
376 x1 = ACTIVATION::applyCeiling(x1, activation_max);
377 x2 = ACTIVATION::applyCeiling(x2, activation_max);
378 x3 = ACTIVATION::applyCeiling(x3, activation_max);
379 vst1q_f32(output_data + i, x0);
380 vst1q_f32(output_data + i + 4, x1);
381 vst1q_f32(output_data + i + 8, x2);
382 vst1q_f32(output_data + i + 12, x3);
384 for (; i <= size - 4; i += 4)
386 auto a1 = vld1q_f32(input1_data + i);
387 auto a2 = vld1q_f32(input2_data + i);
388 auto x = OPERATOR::calculate(a1, a2); // vaddq
390 ACTIVATION::applyCeiling(ACTIVATION::applyFloor(x, activation_min), activation_max);
391 vst1q_f32(output_data + i, x_clamped);
394 for (; i < size; i++)
396 auto x = OPERATOR::calculate(input1_data[i], input2_data[i]);
397 output_data[i] = ACTIVATION::applyCeiling(
398 ACTIVATION::applyFloor(x, params.float_activation_min), params.float_activation_max);
402 // Broadcast binary op template that can often be used for inner loop
403 // This function will handle scalar_value (LHS) and vector_values (RHS).
404 // Since it's a float function, input params does not matter here.
405 template <class OPERATOR, class ACTIVATION>
406 inline void BinaryOpScalarBroadcast(int size, const BinaryArithmeticOpParam ¶ms,
407 const float broadcast_value, const float *input2_data,
413 const auto activation_min = vdupq_n_f32(params.float_activation_min);
414 const auto activation_max = vdupq_n_f32(params.float_activation_max);
415 const auto broadcast_value_dup = vdupq_n_f32(broadcast_value);
416 for (; i <= size - 16; i += 16)
418 auto a20 = vld1q_f32(input2_data + i);
419 auto a21 = vld1q_f32(input2_data + i + 4);
420 auto a22 = vld1q_f32(input2_data + i + 8);
421 auto a23 = vld1q_f32(input2_data + i + 12);
422 auto x0 = OPERATOR::calculate(broadcast_value_dup, a20);
423 auto x1 = OPERATOR::calculate(broadcast_value_dup, a21);
424 auto x2 = OPERATOR::calculate(broadcast_value_dup, a22);
425 auto x3 = OPERATOR::calculate(broadcast_value_dup, a23);
426 x0 = ACTIVATION::applyFloor(x0, activation_min);
427 x1 = ACTIVATION::applyFloor(x1, activation_min);
428 x2 = ACTIVATION::applyFloor(x2, activation_min);
429 x3 = ACTIVATION::applyFloor(x3, activation_min);
430 x0 = ACTIVATION::applyCeiling(x0, activation_max);
431 x1 = ACTIVATION::applyCeiling(x1, activation_max);
432 x2 = ACTIVATION::applyCeiling(x2, activation_max);
433 x3 = ACTIVATION::applyCeiling(x3, activation_max);
434 vst1q_f32(output_data + i, x0);
435 vst1q_f32(output_data + i + 4, x1);
436 vst1q_f32(output_data + i + 8, x2);
437 vst1q_f32(output_data + i + 12, x3);
439 for (; i <= size - 4; i += 4)
441 auto a2 = vld1q_f32(input2_data + i);
442 auto x = OPERATOR::calculate(broadcast_value_dup, a2);
444 ACTIVATION::applyCeiling(ACTIVATION::applyFloor(x, activation_min), activation_max);
445 vst1q_f32(output_data + i, x_clamped);
448 for (; i < size; i++)
450 auto x = OPERATOR::calculate(broadcast_value, input2_data[i]);
451 output_data[i] = ACTIVATION::applyCeiling(
452 ACTIVATION::applyFloor(x, params.float_activation_min), params.float_activation_max);
456 using BinaryOpImplFloatFuncs =
457 std::pair<void (*)(int, const BinaryArithmeticOpParam &, const float *, const float *, float *),
458 void (*)(int, const BinaryArithmeticOpParam &, const float, const float *, float *)>;
460 template <class FUNC>
461 inline BinaryOpImplFloatFuncs
462 getBinaryOpWithActivationImplFloat(const BinaryArithmeticOpParam ¶ms)
464 if (params.float_activation_max == std::numeric_limits<float>::max())
465 if (params.float_activation_min == std::numeric_limits<float>::lowest())
466 return BinaryOpImplFloatFuncs(BinaryOpElementwise<FUNC, BinaryOpActivationFloatNone>,
467 BinaryOpScalarBroadcast<FUNC, BinaryOpActivationFloatNone>);
469 return BinaryOpImplFloatFuncs(BinaryOpElementwise<FUNC, BinaryOpActivationFloatMax>,
470 BinaryOpScalarBroadcast<FUNC, BinaryOpActivationFloatMax>);
472 return BinaryOpImplFloatFuncs(BinaryOpElementwise<FUNC, BinaryOpActivationFloatMinMax>,
473 BinaryOpScalarBroadcast<FUNC, BinaryOpActivationFloatMinMax>);
476 inline void AddQuant8(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
477 const uint8_t *input1_data, const Shape &input2_shape,
478 const uint8_t *input2_data, const Shape &output_shape, uint8_t *output_data)
480 const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
481 AddElementwiseQuant8(flat_size, params, input1_data, input2_data, output_data);
484 inline void Add(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
485 const float *input1_data, const Shape &input2_shape, const float *input2_data,
486 const Shape &output_shape, float *output_data)
488 const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
489 auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncAddFloat>(params);
490 (*implFuncs.first)(flat_size, params, input1_data, input2_data, output_data);
493 // Scalar-broadcast add that can be used for inner loop of more general
494 // broadcast add, so that, for example, scalar-broadcast with batch will still
496 inline void AddScalarBroadcastQuant8(int size, const BinaryArithmeticOpParam ¶ms,
497 uint8_t broadcast_value, const uint8_t *input2_data,
498 uint8_t *output_data)
501 int32_t clamped_output;
502 for (; i < size; ++i)
504 clamped_output = quant8_sum(params, broadcast_value, input2_data[i]);
505 output_data[i] = static_cast<uint8_t>(clamped_output);
509 inline void BroadcastAddDispatchQuant8(const BinaryArithmeticOpParam ¶ms,
510 const Shape &input1_shape, const uint8_t *input1_data,
511 const Shape &input2_shape, const uint8_t *input2_data,
512 const Shape &output_shape, uint8_t *output_data)
514 if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast)
516 const std::function<uint8_t(const BinaryArithmeticOpParam &, const uint8_t &, const uint8_t &)>
517 fn = [](const BinaryArithmeticOpParam ¶ms, const uint8_t &a,
518 const uint8_t &b) -> uint8_t {
519 return static_cast<uint8_t>(quant8_sum(params, a, b));
521 reference::BroadcastBinaryArithmeticOpSlowQuant8(params, input1_shape, input1_data,
522 input2_shape, input2_data, output_shape,
527 BinaryBroadcastFiveFold(
528 params, params.broadcast_category == BroadcastableOpCategory::kSecondInputBroadcastsFast,
529 input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data,
530 static_cast<void (*)(int, const BinaryArithmeticOpParam &, const uint8_t *, const uint8_t *,
531 uint8_t *)>(AddElementwiseQuant8),
532 static_cast<void (*)(int, const BinaryArithmeticOpParam &, uint8_t, const uint8_t *,
533 uint8_t *)>(AddScalarBroadcastQuant8));
537 inline void BroadcastAddDispatch(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
538 const float *input1_data, const Shape &input2_shape,
539 const float *input2_data, const Shape &output_shape,
542 if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast)
544 const std::function<float(const float &, const float &)> fn =
545 [](const float &a, const float &b) -> float { return a + b; };
546 reference::BroadcastBinaryArithmeticOpSlow(params, input1_shape, input1_data, input2_shape,
547 input2_data, output_shape, output_data, fn);
551 auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncAddFloat>(params);
553 BinaryBroadcastFiveFold(params, params.broadcast_category ==
554 BroadcastableOpCategory::kSecondInputBroadcastsFast,
555 input1_shape, input1_data, input2_shape, input2_data, output_shape,
556 output_data, implFuncs.first, implFuncs.second);
560 inline void Sub(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
561 const float *input1_data, const Shape &input2_shape, const float *input2_data,
562 const Shape &output_shape, float *output_data)
564 const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
565 auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncSubFloat>(params);
566 (*implFuncs.first)(flat_size, params, input1_data, input2_data, output_data);
569 inline void BroadcastSubDispatch(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
570 const float *input1_data, const Shape &input2_shape,
571 const float *input2_data, const Shape &output_shape,
574 if (params.broadcast_category == BroadcastableOpCategory::kFirstInputBroadcastsFast)
576 auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncSubFloat>(params);
577 BinaryBroadcastFiveFold(params, false, input1_shape, input1_data, input2_shape, input2_data,
578 output_shape, output_data, implFuncs.first, implFuncs.second);
580 else if (params.broadcast_category == BroadcastableOpCategory::kSecondInputBroadcastsFast)
583 getBinaryOpWithActivationImplFloat<BinaryOpFuncSwapArgs<BinaryOpFuncSubFloat>>(params);
584 BinaryBroadcastFiveFold(params, true, input1_shape, input1_data, input2_shape, input2_data,
585 output_shape, output_data, implFuncs.first, implFuncs.second);
589 const std::function<float(const float &, const float &)> fn =
590 [](const float &a, const float &b) -> float { return a - b; };
591 reference::BroadcastBinaryArithmeticOpSlow(params, input1_shape, input1_data, input2_shape,
592 input2_data, output_shape, output_data, fn);
596 inline int32_t quant8_mul(const BinaryArithmeticOpParam ¶ms, const uint8_t input1_data,
597 const uint8_t input2_data)
599 const int32_t input1_val = params.input1_offset + input1_data;
600 const int32_t input2_val = params.input2_offset + input2_data;
601 const int32_t unclamped_result =
602 params.output_offset + MultiplyByQuantizedMultiplier(input1_val * input2_val,
603 params.output_multiplier,
604 params.output_shift);
605 const int32_t clamped_output = std::min(
606 params.quantized_activation_max, std::max(params.quantized_activation_min, unclamped_result));
608 return clamped_output;
611 inline void MulElementwiseQuant8(int size, const BinaryArithmeticOpParam ¶ms,
612 const uint8_t *input1_data, const uint8_t *input2_data,
613 uint8_t *output_data)
618 const auto input1_offset_vector = vdupq_n_s16(params.input1_offset);
619 const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
620 const auto output_offset_vector = vdupq_n_s16(params.output_offset);
621 const auto output_activation_min_vector = vdup_n_u8(params.quantized_activation_min);
622 const auto output_activation_max_vector = vdup_n_u8(params.quantized_activation_max);
623 const int left_shift = std::max(0, params.output_shift);
624 const int right_shift = std::max(0, -params.output_shift);
625 const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
626 for (; i <= size - 8; i += 8)
628 // We load / store 8 at a time, multiplying as two sets of 4 int32s.
629 const auto input1_val_original = vld1_u8(input1_data + i);
630 const auto input2_val_original = vld1_u8(input2_data + i);
631 const auto input1_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
632 const auto input2_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
633 const auto input1_val = vaddq_s16(input1_val_s16, input1_offset_vector);
634 const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
636 const auto input1_val_low = vget_low_s16(input1_val);
637 const auto input1_val_high = vget_high_s16(input1_val);
638 const auto input2_val_low = vget_low_s16(input2_val);
639 const auto input2_val_high = vget_high_s16(input2_val);
641 auto p1 = vmull_s16(input2_val_low, input1_val_low);
642 auto p2 = vmull_s16(input2_val_high, input1_val_high);
644 p1 = vshlq_s32(p1, left_shift_vec);
645 p2 = vshlq_s32(p2, left_shift_vec);
646 p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
647 p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
648 using gemmlowp::RoundingDivideByPOT;
649 p1 = RoundingDivideByPOT(p1, right_shift);
650 p2 = RoundingDivideByPOT(p2, right_shift);
652 const auto p1_narrowed = vqmovn_s32(p1);
653 const auto p2_narrowed = vqmovn_s32(p2);
654 const auto p = vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
655 const auto clamped = vmax_u8(output_activation_min_vector,
656 vmin_u8(output_activation_max_vector, vqmovun_s16(p)));
657 vst1_u8(output_data + i, clamped);
661 for (; i < size; ++i)
663 const int32_t input1_val = params.input1_offset + input1_data[i];
664 const int32_t input2_val = params.input2_offset + input2_data[i];
665 const int32_t unclamped_result =
666 params.output_offset + MultiplyByQuantizedMultiplier(input1_val * input2_val,
667 params.output_multiplier,
668 params.output_shift);
669 const int32_t clamped_output =
670 std::min(params.quantized_activation_max,
671 std::max(params.quantized_activation_min, unclamped_result));
672 output_data[i] = static_cast<uint8_t>(clamped_output);
676 inline void MulQuant8(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
677 const uint8_t *input1_data, const Shape &input2_shape,
678 const uint8_t *input2_data, const Shape &output_shape, uint8_t *output_data)
680 const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
681 MulElementwiseQuant8(flat_size, params, input1_data, input2_data, output_data);
684 inline void Mul(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
685 const float *input1_data, const Shape &input2_shape, const float *input2_data,
686 const Shape &output_shape, float *output_data)
688 const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
689 auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncMulFloat>(params);
690 (*implFuncs.first)(flat_size, params, input1_data, input2_data, output_data);
693 inline void MulSimpleBroadcastQuant8(int size, const BinaryArithmeticOpParam ¶ms,
694 const uint8_t broadcast_value, const uint8_t *input2_data,
695 uint8_t *output_data)
698 int32_t clamped_output;
699 for (; i < size; ++i)
701 clamped_output = quant8_mul(params, broadcast_value, input2_data[i]);
702 output_data[i] = static_cast<uint8_t>(clamped_output);
706 inline void BroadcastMulDispatchQuant8(const BinaryArithmeticOpParam ¶ms,
707 const Shape &input1_shape, const uint8_t *input1_data,
708 const Shape &input2_shape, const uint8_t *input2_data,
709 const Shape &output_shape, uint8_t *output_data)
711 if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast)
713 const std::function<uint8_t(const BinaryArithmeticOpParam &, const uint8_t &, const uint8_t &)>
714 fn = [](const BinaryArithmeticOpParam ¶ms, const uint8_t &a,
715 const uint8_t &b) -> uint8_t {
716 return static_cast<uint8_t>(quant8_mul(params, a, b));
718 reference::BroadcastBinaryArithmeticOpSlowQuant8(params, input1_shape, input1_data,
719 input2_shape, input2_data, output_shape,
723 BinaryBroadcastFiveFold(
724 params, params.broadcast_category == BroadcastableOpCategory::kSecondInputBroadcastsFast,
725 input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data,
726 static_cast<void (*)(int, const BinaryArithmeticOpParam &, const uint8_t *, const uint8_t *,
727 uint8_t *)>(MulElementwiseQuant8),
728 static_cast<void (*)(int, const BinaryArithmeticOpParam &, uint8_t, const uint8_t *,
729 uint8_t *)>(MulSimpleBroadcastQuant8));
732 inline void BroadcastMulDispatch(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
733 const float *input1_data, const Shape &input2_shape,
734 const float *input2_data, const Shape &output_shape,
737 if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast)
739 // TODO: Use GetBinaryArithmeticFn
740 const std::function<float(const float &, const float &)> fn =
741 [](const float &a, const float &b) -> float { return a * b; };
742 reference::BroadcastBinaryArithmeticOpSlow(params, input1_shape, input1_data, input2_shape,
743 input2_data, output_shape, output_data, fn);
746 auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncMulFloat>(params);
747 BinaryBroadcastFiveFold(params, params.broadcast_category ==
748 BroadcastableOpCategory::kSecondInputBroadcastsFast,
749 input1_shape, input1_data, input2_shape, input2_data, output_shape,
750 output_data, implFuncs.first, implFuncs.second);
753 inline void Div(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
754 const float *input1_data, const Shape &input2_shape, const float *input2_data,
755 const Shape &output_shape, float *output_data)
758 const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
759 auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncDivFloat>(params);
760 (*implFuncs.first)(flat_size, params, input1_data, input2_data, output_data);
762 const std::function<float(const float &, const float &)> fn =
763 [](const float &a, const float &b) -> float { return a / b; };
764 reference::BinaryArithmeticOp(params, input1_shape, input1_data, input2_shape, input2_data,
765 output_shape, output_data, fn);
766 #endif // __aarch64__
769 inline void BroadcastDivDispatch(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
770 const float *input1_data, const Shape &input2_shape,
771 const float *input2_data, const Shape &output_shape,
775 if (params.broadcast_category == BroadcastableOpCategory::kFirstInputBroadcastsFast)
777 auto implFuncs = getBinaryOpWithActivationImplFloat<BinaryOpFuncDivFloat>(params);
778 BinaryBroadcastFiveFold(params, false, input1_shape, input1_data, input2_shape, input2_data,
779 output_shape, output_data, implFuncs.first, implFuncs.second);
781 else if (params.broadcast_category == BroadcastableOpCategory::kSecondInputBroadcastsFast)
784 getBinaryOpWithActivationImplFloat<BinaryOpFuncSwapArgs<BinaryOpFuncDivFloat>>(params);
785 BinaryBroadcastFiveFold(params, true, input1_shape, input1_data, input2_shape, input2_data,
786 output_shape, output_data, implFuncs.first, implFuncs.second);
789 #endif // __aarch64__
791 const std::function<float(const float &, const float &)> fn =
792 [](const float &a, const float &b) -> float { return a / b; };
793 reference::BroadcastBinaryArithmeticOpSlow(params, input1_shape, input1_data, input2_shape,
794 input2_data, output_shape, output_data, fn);
798 } // namespace optimized
802 #endif // __NNFW_CKER_OPTIMIZED_BINARYARITHMETICOPS_H__