2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef __NNFW_CKER_OPTIMIZED_BINARYARITHMETICOPS_H__
19 #define __NNFW_CKER_OPTIMIZED_BINARYARITHMETICOPS_H__
22 #include "cker/neon/neon_check.h"
23 #include "cker/operation/reference/BinaryArithmeticOps.h"
24 #include "cker/Shape.h"
25 #include "cker/Types.h"
26 #include "cker/Utils.h"
27 #include "fixedpoint/fixedpoint.h"
36 template <typename ElementwiseF, typename ScalarBroadcastF, typename T>
37 inline void BinaryBroadcastFiveFold(const BinaryArithmeticOpParam ¶ms,
38 const Shape & /* unswitched_input1_shape */,
39 const T *unswitched_input1_data,
40 const Shape & /* unswitched_input2_shape */,
41 const T *unswitched_input2_data,
42 const Shape & /* output_shape */, T *output_data,
43 ElementwiseF elementwise_f, ScalarBroadcastF scalar_broadcast_f)
45 const bool use_unswitched =
46 params.broadcast_category == BroadcastableOpCategory::kFirstInputBroadcastsFast;
48 const T *input1_data = use_unswitched ? unswitched_input1_data : unswitched_input2_data;
49 const T *input2_data = use_unswitched ? unswitched_input2_data : unswitched_input1_data;
51 // Fivefold nested loops. The second input resets its position for each
52 // iteration of the second loop. The first input resets its position at the
53 // beginning of the fourth loop. The innermost loop is an elementwise add of
54 // sections of the arrays.
55 T *output_data_ptr = output_data;
56 const T *input1_data_ptr = input1_data;
57 const T *input2_data_reset = input2_data;
58 // In the fivefold pattern, y0, y2 and y4 are not broadcast, and so shared
59 // between input shapes. y3 for input 1 is always broadcast, and so the
60 // dimension there is 1, whereas optionally y1 might be broadcast for input 2.
62 // input1.shape.FlatSize = y0 * y1 * y2 * y4,
63 // input2.shape.FlatSize = y0 * y2 * y3 * y4.
64 int y0 = params.broadcast_shape[0];
65 int y1 = params.broadcast_shape[1];
66 int y2 = params.broadcast_shape[2];
67 int y3 = params.broadcast_shape[3];
68 int y4 = params.broadcast_shape[4];
71 // General fivefold pattern, with y4 > 1 so there is a non-broadcast inner
73 for (int i0 = 0; i0 < y0; ++i0)
75 const T *input2_data_ptr = nullptr;
76 for (int i1 = 0; i1 < y1; ++i1)
78 input2_data_ptr = input2_data_reset;
79 for (int i2 = 0; i2 < y2; ++i2)
81 for (int i3 = 0; i3 < y3; ++i3)
83 elementwise_f(y4, params, input1_data_ptr, input2_data_ptr, output_data_ptr);
84 input2_data_ptr += y4;
85 output_data_ptr += y4;
87 // We have broadcast y4 of input1 data y3 times, and now move on.
88 input1_data_ptr += y4;
91 // We have broadcast y2*y3*y4 of input2 data y1 times, and now move on.
92 input2_data_reset = input2_data_ptr;
97 // Special case of y4 == 1, in which the innermost loop is a single element
98 // and can be combined with the next (y3) as an inner broadcast.
100 // Note that this handles the case of pure scalar broadcast when
101 // y0 == y1 == y2 == 1. With low overhead it handles cases such as scalar
102 // broadcast with batch (as y2 > 1).
104 // NOTE The process is the same as the above general case except simplified
105 // for y4 == 1 and the loop over y3 is contained within the
106 // AddScalarBroadcast function.
107 for (int i0 = 0; i0 < y0; ++i0)
109 const T *input2_data_ptr = nullptr;
110 for (int i1 = 0; i1 < y1; ++i1)
112 input2_data_ptr = input2_data_reset;
113 for (int i2 = 0; i2 < y2; ++i2)
115 scalar_broadcast_f(y3, params, *input1_data_ptr, input2_data_ptr, output_data_ptr);
116 input2_data_ptr += y3;
117 output_data_ptr += y3;
118 input1_data_ptr += 1;
121 input2_data_reset = input2_data_ptr;
126 inline int32_t quant8_sum(const BinaryArithmeticOpParam ¶ms, const uint8_t input1_data,
127 const uint8_t input2_data)
129 const int32_t input1_val = params.input1_offset + input1_data;
130 const int32_t input2_val = params.input2_offset + input2_data;
131 const int32_t shifted_input1_val = input1_val * (1 << params.left_shift);
132 const int32_t shifted_input2_val = input2_val * (1 << params.left_shift);
133 const int32_t scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
134 shifted_input1_val, params.input1_multiplier, params.input1_shift);
135 const int32_t scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
136 shifted_input2_val, params.input2_multiplier, params.input2_shift);
137 const int32_t raw_sum = scaled_input1_val + scaled_input2_val;
138 const int32_t raw_output = MultiplyByQuantizedMultiplierSmallerThanOneExp(
139 raw_sum, params.output_multiplier, params.output_shift) +
140 params.output_offset;
141 const int32_t clamped_output = std::min(params.quantized_activation_max,
142 std::max(params.quantized_activation_min, raw_output));
143 return clamped_output;
146 inline void AddElementwiseQuant8(int size, const BinaryArithmeticOpParam ¶ms,
147 const uint8_t *input1_data, const uint8_t *input2_data,
148 uint8_t *output_data)
151 for (; i < size; ++i)
153 int32_t clamped_output = quant8_sum(params, input1_data[i], input2_data[i]);
154 output_data[i] = static_cast<uint8_t>(clamped_output);
158 inline void AddElementwise(int size, const BinaryArithmeticOpParam ¶ms,
159 const float *input1_data, const float *input2_data, float *output_data)
164 const auto activation_min = vdupq_n_f32(params.float_activation_min);
165 const auto activation_max = vdupq_n_f32(params.float_activation_max);
166 for (; i <= size - 16; i += 16)
168 auto a10 = vld1q_f32(input1_data + i);
169 auto a11 = vld1q_f32(input1_data + i + 4);
170 auto a12 = vld1q_f32(input1_data + i + 8);
171 auto a13 = vld1q_f32(input1_data + i + 12);
172 auto a20 = vld1q_f32(input2_data + i);
173 auto a21 = vld1q_f32(input2_data + i + 4);
174 auto a22 = vld1q_f32(input2_data + i + 8);
175 auto a23 = vld1q_f32(input2_data + i + 12);
176 auto x0 = vaddq_f32(a10, a20);
177 auto x1 = vaddq_f32(a11, a21);
178 auto x2 = vaddq_f32(a12, a22);
179 auto x3 = vaddq_f32(a13, a23);
180 x0 = vmaxq_f32(activation_min, x0);
181 x1 = vmaxq_f32(activation_min, x1);
182 x2 = vmaxq_f32(activation_min, x2);
183 x3 = vmaxq_f32(activation_min, x3);
184 x0 = vminq_f32(activation_max, x0);
185 x1 = vminq_f32(activation_max, x1);
186 x2 = vminq_f32(activation_max, x2);
187 x3 = vminq_f32(activation_max, x3);
188 vst1q_f32(output_data + i, x0);
189 vst1q_f32(output_data + i + 4, x1);
190 vst1q_f32(output_data + i + 8, x2);
191 vst1q_f32(output_data + i + 12, x3);
193 for (; i <= size - 4; i += 4)
195 auto a1 = vld1q_f32(input1_data + i);
196 auto a2 = vld1q_f32(input2_data + i);
197 auto x = vaddq_f32(a1, a2);
198 x = vmaxq_f32(activation_min, x);
199 x = vminq_f32(activation_max, x);
200 vst1q_f32(output_data + i, x);
203 for (; i < size; i++)
205 auto x = input1_data[i] + input2_data[i];
206 output_data[i] = ActivationFunctionWithMinMax<float>(x, params.float_activation_min,
207 params.float_activation_max);
211 inline void AddQuant8(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
212 const uint8_t *input1_data, const Shape &input2_shape,
213 const uint8_t *input2_data, const Shape &output_shape, uint8_t *output_data)
215 const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
216 AddElementwiseQuant8(flat_size, params, input1_data, input2_data, output_data);
219 inline void Add(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
220 const float *input1_data, const Shape &input2_shape, const float *input2_data,
221 const Shape &output_shape, float *output_data)
223 const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
224 AddElementwise(flat_size, params, input1_data, input2_data, output_data);
227 // Scalar-broadcast add that can be used for inner loop of more general
228 // broadcast add, so that, for example, scalar-broadcast with batch will still
230 inline void AddScalarBroadcastQuant8(int size, const BinaryArithmeticOpParam ¶ms,
231 uint8_t broadcast_value, const uint8_t *input2_data,
232 uint8_t *output_data)
235 int32_t clamped_output;
236 for (; i < size; ++i)
238 clamped_output = quant8_sum(params, broadcast_value, input2_data[i]);
239 output_data[i] = static_cast<uint8_t>(clamped_output);
243 inline void AddScalarBroadcast(int size, const BinaryArithmeticOpParam ¶ms,
244 float broadcast_value, const float *input2_data, float *output_data)
248 const float32x4_t output_activation_min_vector = vdupq_n_f32(params.float_activation_min);
249 const float32x4_t output_activation_max_vector = vdupq_n_f32(params.float_activation_max);
250 const float32x4_t broadcast_value_dup = vdupq_n_f32(broadcast_value);
251 for (; i <= size - 4; i += 4)
253 const float32x4_t input2_val_original = vld1q_f32(input2_data + i);
255 const float32x4_t output = vaddq_f32(input2_val_original, broadcast_value_dup);
257 const float32x4_t clamped =
258 vmaxq_f32(output_activation_min_vector, vminq_f32(output_activation_max_vector, output));
259 vst1q_f32(output_data + i, clamped);
262 for (; i < size; ++i)
264 auto x = broadcast_value + input2_data[i];
265 output_data[i] = ActivationFunctionWithMinMax<float>(x, params.float_activation_min,
266 params.float_activation_max);
270 inline void BroadcastAddDispatchQuant8(const BinaryArithmeticOpParam ¶ms,
271 const Shape &input1_shape, const uint8_t *input1_data,
272 const Shape &input2_shape, const uint8_t *input2_data,
273 const Shape &output_shape, uint8_t *output_data)
275 if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast)
277 const std::function<uint8_t(const BinaryArithmeticOpParam &, const uint8_t &, const uint8_t &)>
278 fn = [](const BinaryArithmeticOpParam ¶ms, const uint8_t &a,
279 const uint8_t &b) -> uint8_t {
280 return static_cast<uint8_t>(quant8_sum(params, a, b));
282 reference::BroadcastBinaryArithmeticOpSlowQuant8(params, input1_shape, input1_data,
283 input2_shape, input2_data, output_shape,
288 BinaryBroadcastFiveFold(
289 params, input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data,
290 static_cast<void (*)(int, const BinaryArithmeticOpParam &, const uint8_t *, const uint8_t *,
291 uint8_t *)>(AddElementwiseQuant8),
292 static_cast<void (*)(int, const BinaryArithmeticOpParam &, uint8_t, const uint8_t *,
293 uint8_t *)>(AddScalarBroadcastQuant8));
297 inline void BroadcastAddDispatch(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
298 const float *input1_data, const Shape &input2_shape,
299 const float *input2_data, const Shape &output_shape,
302 if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast)
304 const std::function<float(const float &, const float &)> fn =
305 [](const float &a, const float &b) -> float { return a + b; };
306 reference::BroadcastBinaryArithmeticOpSlow(params, input1_shape, input1_data, input2_shape,
307 input2_data, output_shape, output_data, fn);
311 BinaryBroadcastFiveFold(
312 params, input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data,
313 static_cast<void (*)(int, const BinaryArithmeticOpParam &, const float *, const float *,
314 float *)>(AddElementwise),
315 static_cast<void (*)(int, const BinaryArithmeticOpParam &, float, const float *, float *)>(
316 AddScalarBroadcast));
320 inline void Sub(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
321 const float *input1_data, const Shape &input2_shape, const float *input2_data,
322 const Shape &output_shape, float *output_data)
325 const int size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
327 const auto activation_min = vdupq_n_f32(params.float_activation_min);
328 const auto activation_max = vdupq_n_f32(params.float_activation_max);
329 for (; i <= size - 16; i += 16)
331 auto a10 = vld1q_f32(input1_data + i);
332 auto a11 = vld1q_f32(input1_data + i + 4);
333 auto a12 = vld1q_f32(input1_data + i + 8);
334 auto a13 = vld1q_f32(input1_data + i + 12);
335 auto a20 = vld1q_f32(input2_data + i);
336 auto a21 = vld1q_f32(input2_data + i + 4);
337 auto a22 = vld1q_f32(input2_data + i + 8);
338 auto a23 = vld1q_f32(input2_data + i + 12);
339 auto x0 = vsubq_f32(a10, a20);
340 auto x1 = vsubq_f32(a11, a21);
341 auto x2 = vsubq_f32(a12, a22);
342 auto x3 = vsubq_f32(a13, a23);
343 x0 = vmaxq_f32(activation_min, x0);
344 x1 = vmaxq_f32(activation_min, x1);
345 x2 = vmaxq_f32(activation_min, x2);
346 x3 = vmaxq_f32(activation_min, x3);
347 x0 = vminq_f32(activation_max, x0);
348 x1 = vminq_f32(activation_max, x1);
349 x2 = vminq_f32(activation_max, x2);
350 x3 = vminq_f32(activation_max, x3);
351 vst1q_f32(output_data + i, x0);
352 vst1q_f32(output_data + i + 4, x1);
353 vst1q_f32(output_data + i + 8, x2);
354 vst1q_f32(output_data + i + 12, x3);
356 for (; i <= size - 4; i += 4)
358 auto a1 = vld1q_f32(input1_data + i);
359 auto a2 = vld1q_f32(input2_data + i);
360 auto x = vsubq_f32(a1, a2);
361 x = vmaxq_f32(activation_min, x);
362 x = vminq_f32(activation_max, x);
363 vst1q_f32(output_data + i, x);
367 for (; i < size; i++)
369 auto x = input1_data[i] - input2_data[i];
371 ActivationFunctionWithMinMax(x, params.float_activation_min, params.float_activation_max);
375 inline int32_t quant8_mul(const BinaryArithmeticOpParam ¶ms, const uint8_t input1_data,
376 const uint8_t input2_data)
378 const int32_t input1_val = params.input1_offset + input1_data;
379 const int32_t input2_val = params.input2_offset + input2_data;
380 const int32_t unclamped_result =
381 params.output_offset + MultiplyByQuantizedMultiplier(input1_val * input2_val,
382 params.output_multiplier,
383 params.output_shift);
384 const int32_t clamped_output = std::min(
385 params.quantized_activation_max, std::max(params.quantized_activation_min, unclamped_result));
387 return clamped_output;
390 inline void MulElementwiseQuant8(int size, const BinaryArithmeticOpParam ¶ms,
391 const uint8_t *input1_data, const uint8_t *input2_data,
392 uint8_t *output_data)
395 int32_t clamped_output;
396 for (; i < size; i++)
398 clamped_output = quant8_mul(params, input1_data[i], input2_data[i]);
399 output_data[i] = static_cast<uint8_t>(clamped_output);
403 inline void MulElementwise(int size, const BinaryArithmeticOpParam ¶ms,
404 const float *input1_data, const float *input2_data, float *output_data)
409 const auto activation_min = vdupq_n_f32(params.float_activation_min);
410 const auto activation_max = vdupq_n_f32(params.float_activation_max);
411 for (; i <= size - 16; i += 16)
413 auto a10 = vld1q_f32(input1_data + i);
414 auto a11 = vld1q_f32(input1_data + i + 4);
415 auto a12 = vld1q_f32(input1_data + i + 8);
416 auto a13 = vld1q_f32(input1_data + i + 12);
417 auto a20 = vld1q_f32(input2_data + i);
418 auto a21 = vld1q_f32(input2_data + i + 4);
419 auto a22 = vld1q_f32(input2_data + i + 8);
420 auto a23 = vld1q_f32(input2_data + i + 12);
421 auto x0 = vmulq_f32(a10, a20);
422 auto x1 = vmulq_f32(a11, a21);
423 auto x2 = vmulq_f32(a12, a22);
424 auto x3 = vmulq_f32(a13, a23);
425 x0 = vmaxq_f32(activation_min, x0);
426 x1 = vmaxq_f32(activation_min, x1);
427 x2 = vmaxq_f32(activation_min, x2);
428 x3 = vmaxq_f32(activation_min, x3);
429 x0 = vminq_f32(activation_max, x0);
430 x1 = vminq_f32(activation_max, x1);
431 x2 = vminq_f32(activation_max, x2);
432 x3 = vminq_f32(activation_max, x3);
433 vst1q_f32(output_data + i, x0);
434 vst1q_f32(output_data + i + 4, x1);
435 vst1q_f32(output_data + i + 8, x2);
436 vst1q_f32(output_data + i + 12, x3);
438 for (; i <= size - 4; i += 4)
440 auto a1 = vld1q_f32(input1_data + i);
441 auto a2 = vld1q_f32(input2_data + i);
442 auto x = vmulq_f32(a1, a2);
443 x = vmaxq_f32(activation_min, x);
444 x = vminq_f32(activation_max, x);
445 vst1q_f32(output_data + i, x);
449 for (; i < size; i++)
451 auto x = input1_data[i] * input2_data[i];
453 ActivationFunctionWithMinMax(x, params.float_activation_min, params.float_activation_max);
457 inline void MulQuant8(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
458 const uint8_t *input1_data, const Shape &input2_shape,
459 const uint8_t *input2_data, const Shape &output_shape, uint8_t *output_data)
461 const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
462 MulElementwiseQuant8(flat_size, params, input1_data, input2_data, output_data);
465 inline void Mul(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
466 const float *input1_data, const Shape &input2_shape, const float *input2_data,
467 const Shape &output_shape, float *output_data)
469 const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
470 MulElementwise(flat_size, params, input1_data, input2_data, output_data);
473 inline void MulSimpleBroadcastQuant8(int size, const BinaryArithmeticOpParam ¶ms,
474 const uint8_t broadcast_value, const uint8_t *input2_data,
475 uint8_t *output_data)
478 int32_t clamped_output;
479 for (; i < size; ++i)
481 clamped_output = quant8_mul(params, broadcast_value, input2_data[i]);
482 output_data[i] = static_cast<uint8_t>(clamped_output);
486 // Broadcast mul that can often be used for inner loop of broadcast Mul.
487 // This function will handle scalar_value (LHS) * vector_values (RHS).
488 // Since it's a float function, input params does not matter here.
489 inline void MulSimpleBroadcast(int size, const BinaryArithmeticOpParam ¶ms,
490 const float broadcast_value, const float *input2_data,
495 const float32x4_t output_activation_min_vector = vdupq_n_f32(params.float_activation_min);
496 const float32x4_t output_activation_max_vector = vdupq_n_f32(params.float_activation_max);
497 const float32x4_t broadcast_value_dup = vdupq_n_f32(broadcast_value);
498 for (; i <= size - 4; i += 4)
500 const float32x4_t input2_val_original = vld1q_f32(input2_data + i);
502 const float32x4_t output = vmulq_f32(input2_val_original, broadcast_value_dup);
504 const float32x4_t clamped =
505 vmaxq_f32(output_activation_min_vector, vminq_f32(output_activation_max_vector, output));
506 vst1q_f32(output_data + i, clamped);
510 for (; i < size; ++i)
512 float x = broadcast_value * input2_data[i];
514 ActivationFunctionWithMinMax(x, params.float_activation_min, params.float_activation_max);
518 inline void BroadcastMulDispatchQuant8(const BinaryArithmeticOpParam ¶ms,
519 const Shape &input1_shape, const uint8_t *input1_data,
520 const Shape &input2_shape, const uint8_t *input2_data,
521 const Shape &output_shape, uint8_t *output_data)
523 if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast)
525 const std::function<uint8_t(const BinaryArithmeticOpParam &, const uint8_t &, const uint8_t &)>
526 fn = [](const BinaryArithmeticOpParam ¶ms, const uint8_t &a,
527 const uint8_t &b) -> uint8_t {
528 return static_cast<uint8_t>(quant8_mul(params, a, b));
530 reference::BroadcastBinaryArithmeticOpSlowQuant8(params, input1_shape, input1_data,
531 input2_shape, input2_data, output_shape,
535 BinaryBroadcastFiveFold(
536 params, input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data,
537 static_cast<void (*)(int, const BinaryArithmeticOpParam &, const uint8_t *, const uint8_t *,
538 uint8_t *)>(MulElementwiseQuant8),
539 static_cast<void (*)(int, const BinaryArithmeticOpParam &, uint8_t, const uint8_t *,
540 uint8_t *)>(MulSimpleBroadcastQuant8));
543 inline void BroadcastMulDispatch(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
544 const float *input1_data, const Shape &input2_shape,
545 const float *input2_data, const Shape &output_shape,
548 if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast)
550 // TODO: Use GetBinaryArithmeticFn
551 const std::function<float(const float &, const float &)> fn =
552 [](const float &a, const float &b) -> float { return a * b; };
553 reference::BroadcastBinaryArithmeticOpSlow(params, input1_shape, input1_data, input2_shape,
554 input2_data, output_shape, output_data, fn);
557 BinaryBroadcastFiveFold(
558 params, input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data,
559 static_cast<void (*)(int, const BinaryArithmeticOpParam &, const float *, const float *,
560 float *)>(MulElementwise),
561 static_cast<void (*)(int, const BinaryArithmeticOpParam &, float, const float *, float *)>(
562 MulSimpleBroadcast));
565 } // namespace optimized
569 #endif // __NNFW_CKER_OPTIMIZED_BINARYARITHMETICOPS_H__