2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef __NNFW_CKER_OPTIMIZED_BINARYARITHMETICOPS_H__
19 #define __NNFW_CKER_OPTIMIZED_BINARYARITHMETICOPS_H__
22 #include "cker/neon/neon_check.h"
23 #include "cker/operation/reference/BinaryArithmeticOps.h"
24 #include "cker/Shape.h"
25 #include "cker/Types.h"
26 #include "cker/Utils.h"
27 #include "fixedpoint/fixedpoint.h"
36 template <typename ElementwiseF, typename ScalarBroadcastF, typename T>
37 inline void BinaryBroadcastFiveFold(const BinaryArithmeticOpParam ¶ms,
38 const Shape & /* unswitched_input1_shape */,
39 const T *unswitched_input1_data,
40 const Shape & /* unswitched_input2_shape */,
41 const T *unswitched_input2_data,
42 const Shape & /* output_shape */, T *output_data,
43 ElementwiseF elementwise_f, ScalarBroadcastF scalar_broadcast_f)
45 const bool use_unswitched =
46 params.broadcast_category == BroadcastableOpCategory::kFirstInputBroadcastsFast;
48 const T *input1_data = use_unswitched ? unswitched_input1_data : unswitched_input2_data;
49 const T *input2_data = use_unswitched ? unswitched_input2_data : unswitched_input1_data;
51 // Fivefold nested loops. The second input resets its position for each
52 // iteration of the second loop. The first input resets its position at the
53 // beginning of the fourth loop. The innermost loop is an elementwise add of
54 // sections of the arrays.
55 T *output_data_ptr = output_data;
56 const T *input1_data_ptr = input1_data;
57 const T *input2_data_reset = input2_data;
58 // In the fivefold pattern, y0, y2 and y4 are not broadcast, and so shared
59 // between input shapes. y3 for input 1 is always broadcast, and so the
60 // dimension there is 1, whereas optionally y1 might be broadcast for input 2.
62 // input1.shape.FlatSize = y0 * y1 * y2 * y4,
63 // input2.shape.FlatSize = y0 * y2 * y3 * y4.
64 int y0 = params.broadcast_shape[0];
65 int y1 = params.broadcast_shape[1];
66 int y2 = params.broadcast_shape[2];
67 int y3 = params.broadcast_shape[3];
68 int y4 = params.broadcast_shape[4];
71 // General fivefold pattern, with y4 > 1 so there is a non-broadcast inner
73 for (int i0 = 0; i0 < y0; ++i0)
75 const T *input2_data_ptr = nullptr;
76 for (int i1 = 0; i1 < y1; ++i1)
78 input2_data_ptr = input2_data_reset;
79 for (int i2 = 0; i2 < y2; ++i2)
81 for (int i3 = 0; i3 < y3; ++i3)
83 elementwise_f(y4, params, input1_data_ptr, input2_data_ptr, output_data_ptr);
84 input2_data_ptr += y4;
85 output_data_ptr += y4;
87 // We have broadcast y4 of input1 data y3 times, and now move on.
88 input1_data_ptr += y4;
91 // We have broadcast y2*y3*y4 of input2 data y1 times, and now move on.
92 input2_data_reset = input2_data_ptr;
97 // Special case of y4 == 1, in which the innermost loop is a single element
98 // and can be combined with the next (y3) as an inner broadcast.
100 // Note that this handles the case of pure scalar broadcast when
101 // y0 == y1 == y2 == 1. With low overhead it handles cases such as scalar
102 // broadcast with batch (as y2 > 1).
104 // NOTE The process is the same as the above general case except simplified
105 // for y4 == 1 and the loop over y3 is contained within the
106 // AddScalarBroadcast function.
107 for (int i0 = 0; i0 < y0; ++i0)
109 const T *input2_data_ptr = nullptr;
110 for (int i1 = 0; i1 < y1; ++i1)
112 input2_data_ptr = input2_data_reset;
113 for (int i2 = 0; i2 < y2; ++i2)
115 scalar_broadcast_f(y3, params, *input1_data_ptr, input2_data_ptr, output_data_ptr);
116 input2_data_ptr += y3;
117 output_data_ptr += y3;
118 input1_data_ptr += 1;
121 input2_data_reset = input2_data_ptr;
126 inline int32_t quant8_sum(const BinaryArithmeticOpParam ¶ms, const uint8_t input1_data,
127 const uint8_t input2_data)
129 const int32_t input1_val = params.input1_offset + input1_data;
130 const int32_t input2_val = params.input2_offset + input2_data;
131 const int32_t shifted_input1_val = input1_val * (1 << params.left_shift);
132 const int32_t shifted_input2_val = input2_val * (1 << params.left_shift);
133 const int32_t scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
134 shifted_input1_val, params.input1_multiplier, params.input1_shift);
135 const int32_t scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
136 shifted_input2_val, params.input2_multiplier, params.input2_shift);
137 const int32_t raw_sum = scaled_input1_val + scaled_input2_val;
138 const int32_t raw_output = MultiplyByQuantizedMultiplierSmallerThanOneExp(
139 raw_sum, params.output_multiplier, params.output_shift) +
140 params.output_offset;
141 const int32_t clamped_output = std::min(params.quantized_activation_max,
142 std::max(params.quantized_activation_min, raw_output));
143 return clamped_output;
146 inline void AddElementwiseQuant8(int size, const BinaryArithmeticOpParam ¶ms,
147 const uint8_t *input1_data, const uint8_t *input2_data,
148 uint8_t *output_data)
153 const uint8x8_t output_activation_min_vector = vdup_n_u8(params.quantized_activation_min);
154 const uint8x8_t output_activation_max_vector = vdup_n_u8(params.quantized_activation_max);
155 for (; i <= size - 8; i += 8)
157 const uint8x8_t input1_val_original = vld1_u8(input1_data + i);
158 const uint8x8_t input2_val_original = vld1_u8(input2_data + i);
159 const int16x8_t input1_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
160 const int16x8_t input2_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
161 const int16x8_t input1_val = vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset));
162 const int16x8_t input2_val = vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset));
163 const int16x4_t input1_val_high = vget_high_s16(input1_val);
164 const int16x4_t input1_val_low = vget_low_s16(input1_val);
165 const int16x4_t input2_val_high = vget_high_s16(input2_val);
166 const int16x4_t input2_val_low = vget_low_s16(input2_val);
167 int32x4_t x11 = vmovl_s16(input1_val_low);
168 int32x4_t x12 = vmovl_s16(input1_val_high);
169 int32x4_t x21 = vmovl_s16(input2_val_low);
170 int32x4_t x22 = vmovl_s16(input2_val_high);
171 const int32x4_t left_shift_dup = vdupq_n_s32(params.left_shift);
172 x11 = vshlq_s32(x11, left_shift_dup);
173 x12 = vshlq_s32(x12, left_shift_dup);
174 x21 = vshlq_s32(x21, left_shift_dup);
175 x22 = vshlq_s32(x22, left_shift_dup);
176 x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier);
177 x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier);
178 x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier);
179 x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier);
180 const int32x4_t input1_shift_dup = vdupq_n_s32(params.input1_shift);
181 const int32x4_t input2_shift_dup = vdupq_n_s32(params.input2_shift);
182 x11 = vshlq_s32(x11, input1_shift_dup);
183 x12 = vshlq_s32(x12, input1_shift_dup);
184 x21 = vshlq_s32(x21, input2_shift_dup);
185 x22 = vshlq_s32(x22, input2_shift_dup);
186 int32x4_t s1 = vaddq_s32(x11, x21);
187 int32x4_t s2 = vaddq_s32(x12, x22);
188 s1 = vqrdmulhq_n_s32(s1, params.output_multiplier);
189 s2 = vqrdmulhq_n_s32(s2, params.output_multiplier);
190 using gemmlowp::RoundingDivideByPOT;
191 s1 = RoundingDivideByPOT(s1, -params.output_shift);
192 s2 = RoundingDivideByPOT(s2, -params.output_shift);
193 const int16x4_t s1_narrowed = vmovn_s32(s1);
194 const int16x4_t s2_narrowed = vmovn_s32(s2);
196 vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed), vdupq_n_s16(params.output_offset));
197 const uint8x8_t clamped = vmax_u8(output_activation_min_vector,
198 vmin_u8(output_activation_max_vector, vqmovun_s16(s)));
199 vst1_u8(output_data + i, clamped);
202 for (; i < size; ++i)
204 const int32_t input1_val = params.input1_offset + input1_data[i];
205 const int32_t input2_val = params.input2_offset + input2_data[i];
206 const int32_t shifted_input1_val = input1_val * (1 << params.left_shift);
207 const int32_t shifted_input2_val = input2_val * (1 << params.left_shift);
208 const int32_t scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
209 shifted_input1_val, params.input1_multiplier, params.input1_shift);
210 const int32_t scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
211 shifted_input2_val, params.input2_multiplier, params.input2_shift);
212 const int32_t raw_sum = scaled_input1_val + scaled_input2_val;
213 const int32_t raw_output = MultiplyByQuantizedMultiplierSmallerThanOneExp(
214 raw_sum, params.output_multiplier, params.output_shift) +
215 params.output_offset;
216 const int32_t clamped_output = std::min(params.quantized_activation_max,
217 std::max(params.quantized_activation_min, raw_output));
218 output_data[i] = static_cast<uint8_t>(clamped_output);
222 inline void AddElementwise(int size, const BinaryArithmeticOpParam ¶ms,
223 const float *input1_data, const float *input2_data, float *output_data)
228 const auto activation_min = vdupq_n_f32(params.float_activation_min);
229 const auto activation_max = vdupq_n_f32(params.float_activation_max);
230 for (; i <= size - 16; i += 16)
232 auto a10 = vld1q_f32(input1_data + i);
233 auto a11 = vld1q_f32(input1_data + i + 4);
234 auto a12 = vld1q_f32(input1_data + i + 8);
235 auto a13 = vld1q_f32(input1_data + i + 12);
236 auto a20 = vld1q_f32(input2_data + i);
237 auto a21 = vld1q_f32(input2_data + i + 4);
238 auto a22 = vld1q_f32(input2_data + i + 8);
239 auto a23 = vld1q_f32(input2_data + i + 12);
240 auto x0 = vaddq_f32(a10, a20);
241 auto x1 = vaddq_f32(a11, a21);
242 auto x2 = vaddq_f32(a12, a22);
243 auto x3 = vaddq_f32(a13, a23);
244 x0 = vmaxq_f32(activation_min, x0);
245 x1 = vmaxq_f32(activation_min, x1);
246 x2 = vmaxq_f32(activation_min, x2);
247 x3 = vmaxq_f32(activation_min, x3);
248 x0 = vminq_f32(activation_max, x0);
249 x1 = vminq_f32(activation_max, x1);
250 x2 = vminq_f32(activation_max, x2);
251 x3 = vminq_f32(activation_max, x3);
252 vst1q_f32(output_data + i, x0);
253 vst1q_f32(output_data + i + 4, x1);
254 vst1q_f32(output_data + i + 8, x2);
255 vst1q_f32(output_data + i + 12, x3);
257 for (; i <= size - 4; i += 4)
259 auto a1 = vld1q_f32(input1_data + i);
260 auto a2 = vld1q_f32(input2_data + i);
261 auto x = vaddq_f32(a1, a2);
262 x = vmaxq_f32(activation_min, x);
263 x = vminq_f32(activation_max, x);
264 vst1q_f32(output_data + i, x);
267 for (; i < size; i++)
269 auto x = input1_data[i] + input2_data[i];
270 output_data[i] = ActivationFunctionWithMinMax<float>(x, params.float_activation_min,
271 params.float_activation_max);
275 inline void AddQuant8(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
276 const uint8_t *input1_data, const Shape &input2_shape,
277 const uint8_t *input2_data, const Shape &output_shape, uint8_t *output_data)
279 const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
280 AddElementwiseQuant8(flat_size, params, input1_data, input2_data, output_data);
283 inline void Add(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
284 const float *input1_data, const Shape &input2_shape, const float *input2_data,
285 const Shape &output_shape, float *output_data)
287 const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
288 AddElementwise(flat_size, params, input1_data, input2_data, output_data);
291 // Scalar-broadcast add that can be used for inner loop of more general
292 // broadcast add, so that, for example, scalar-broadcast with batch will still
294 inline void AddScalarBroadcastQuant8(int size, const BinaryArithmeticOpParam ¶ms,
295 uint8_t broadcast_value, const uint8_t *input2_data,
296 uint8_t *output_data)
299 int32_t clamped_output;
300 for (; i < size; ++i)
302 clamped_output = quant8_sum(params, broadcast_value, input2_data[i]);
303 output_data[i] = static_cast<uint8_t>(clamped_output);
307 inline void AddScalarBroadcast(int size, const BinaryArithmeticOpParam ¶ms,
308 float broadcast_value, const float *input2_data, float *output_data)
312 const float32x4_t output_activation_min_vector = vdupq_n_f32(params.float_activation_min);
313 const float32x4_t output_activation_max_vector = vdupq_n_f32(params.float_activation_max);
314 const float32x4_t broadcast_value_dup = vdupq_n_f32(broadcast_value);
315 for (; i <= size - 4; i += 4)
317 const float32x4_t input2_val_original = vld1q_f32(input2_data + i);
319 const float32x4_t output = vaddq_f32(input2_val_original, broadcast_value_dup);
321 const float32x4_t clamped =
322 vmaxq_f32(output_activation_min_vector, vminq_f32(output_activation_max_vector, output));
323 vst1q_f32(output_data + i, clamped);
326 for (; i < size; ++i)
328 auto x = broadcast_value + input2_data[i];
329 output_data[i] = ActivationFunctionWithMinMax<float>(x, params.float_activation_min,
330 params.float_activation_max);
334 inline void BroadcastAddDispatchQuant8(const BinaryArithmeticOpParam ¶ms,
335 const Shape &input1_shape, const uint8_t *input1_data,
336 const Shape &input2_shape, const uint8_t *input2_data,
337 const Shape &output_shape, uint8_t *output_data)
339 if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast)
341 const std::function<uint8_t(const BinaryArithmeticOpParam &, const uint8_t &, const uint8_t &)>
342 fn = [](const BinaryArithmeticOpParam ¶ms, const uint8_t &a,
343 const uint8_t &b) -> uint8_t {
344 return static_cast<uint8_t>(quant8_sum(params, a, b));
346 reference::BroadcastBinaryArithmeticOpSlowQuant8(params, input1_shape, input1_data,
347 input2_shape, input2_data, output_shape,
352 BinaryBroadcastFiveFold(
353 params, input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data,
354 static_cast<void (*)(int, const BinaryArithmeticOpParam &, const uint8_t *, const uint8_t *,
355 uint8_t *)>(AddElementwiseQuant8),
356 static_cast<void (*)(int, const BinaryArithmeticOpParam &, uint8_t, const uint8_t *,
357 uint8_t *)>(AddScalarBroadcastQuant8));
361 inline void BroadcastAddDispatch(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
362 const float *input1_data, const Shape &input2_shape,
363 const float *input2_data, const Shape &output_shape,
366 if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast)
368 const std::function<float(const float &, const float &)> fn =
369 [](const float &a, const float &b) -> float { return a + b; };
370 reference::BroadcastBinaryArithmeticOpSlow(params, input1_shape, input1_data, input2_shape,
371 input2_data, output_shape, output_data, fn);
375 BinaryBroadcastFiveFold(
376 params, input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data,
377 static_cast<void (*)(int, const BinaryArithmeticOpParam &, const float *, const float *,
378 float *)>(AddElementwise),
379 static_cast<void (*)(int, const BinaryArithmeticOpParam &, float, const float *, float *)>(
380 AddScalarBroadcast));
384 inline void Sub(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
385 const float *input1_data, const Shape &input2_shape, const float *input2_data,
386 const Shape &output_shape, float *output_data)
389 const int size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
391 const auto activation_min = vdupq_n_f32(params.float_activation_min);
392 const auto activation_max = vdupq_n_f32(params.float_activation_max);
393 for (; i <= size - 16; i += 16)
395 auto a10 = vld1q_f32(input1_data + i);
396 auto a11 = vld1q_f32(input1_data + i + 4);
397 auto a12 = vld1q_f32(input1_data + i + 8);
398 auto a13 = vld1q_f32(input1_data + i + 12);
399 auto a20 = vld1q_f32(input2_data + i);
400 auto a21 = vld1q_f32(input2_data + i + 4);
401 auto a22 = vld1q_f32(input2_data + i + 8);
402 auto a23 = vld1q_f32(input2_data + i + 12);
403 auto x0 = vsubq_f32(a10, a20);
404 auto x1 = vsubq_f32(a11, a21);
405 auto x2 = vsubq_f32(a12, a22);
406 auto x3 = vsubq_f32(a13, a23);
407 x0 = vmaxq_f32(activation_min, x0);
408 x1 = vmaxq_f32(activation_min, x1);
409 x2 = vmaxq_f32(activation_min, x2);
410 x3 = vmaxq_f32(activation_min, x3);
411 x0 = vminq_f32(activation_max, x0);
412 x1 = vminq_f32(activation_max, x1);
413 x2 = vminq_f32(activation_max, x2);
414 x3 = vminq_f32(activation_max, x3);
415 vst1q_f32(output_data + i, x0);
416 vst1q_f32(output_data + i + 4, x1);
417 vst1q_f32(output_data + i + 8, x2);
418 vst1q_f32(output_data + i + 12, x3);
420 for (; i <= size - 4; i += 4)
422 auto a1 = vld1q_f32(input1_data + i);
423 auto a2 = vld1q_f32(input2_data + i);
424 auto x = vsubq_f32(a1, a2);
425 x = vmaxq_f32(activation_min, x);
426 x = vminq_f32(activation_max, x);
427 vst1q_f32(output_data + i, x);
431 for (; i < size; i++)
433 auto x = input1_data[i] - input2_data[i];
435 ActivationFunctionWithMinMax(x, params.float_activation_min, params.float_activation_max);
439 inline int32_t quant8_mul(const BinaryArithmeticOpParam ¶ms, const uint8_t input1_data,
440 const uint8_t input2_data)
442 const int32_t input1_val = params.input1_offset + input1_data;
443 const int32_t input2_val = params.input2_offset + input2_data;
444 const int32_t unclamped_result =
445 params.output_offset + MultiplyByQuantizedMultiplier(input1_val * input2_val,
446 params.output_multiplier,
447 params.output_shift);
448 const int32_t clamped_output = std::min(
449 params.quantized_activation_max, std::max(params.quantized_activation_min, unclamped_result));
451 return clamped_output;
454 inline void MulElementwiseQuant8(int size, const BinaryArithmeticOpParam ¶ms,
455 const uint8_t *input1_data, const uint8_t *input2_data,
456 uint8_t *output_data)
461 const auto input1_offset_vector = vdupq_n_s16(params.input1_offset);
462 const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
463 const auto output_offset_vector = vdupq_n_s16(params.output_offset);
464 const auto output_activation_min_vector = vdup_n_u8(params.quantized_activation_min);
465 const auto output_activation_max_vector = vdup_n_u8(params.quantized_activation_max);
466 const int left_shift = std::max(0, params.output_shift);
467 const int right_shift = std::max(0, -params.output_shift);
468 const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
469 for (; i <= size - 8; i += 8)
471 // We load / store 8 at a time, multiplying as two sets of 4 int32s.
472 const auto input1_val_original = vld1_u8(input1_data + i);
473 const auto input2_val_original = vld1_u8(input2_data + i);
474 const auto input1_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
475 const auto input2_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
476 const auto input1_val = vaddq_s16(input1_val_s16, input1_offset_vector);
477 const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
479 const auto input1_val_low = vget_low_s16(input1_val);
480 const auto input1_val_high = vget_high_s16(input1_val);
481 const auto input2_val_low = vget_low_s16(input2_val);
482 const auto input2_val_high = vget_high_s16(input2_val);
484 auto p1 = vmull_s16(input2_val_low, input1_val_low);
485 auto p2 = vmull_s16(input2_val_high, input1_val_high);
487 p1 = vshlq_s32(p1, left_shift_vec);
488 p2 = vshlq_s32(p2, left_shift_vec);
489 p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
490 p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
491 using gemmlowp::RoundingDivideByPOT;
492 p1 = RoundingDivideByPOT(p1, right_shift);
493 p2 = RoundingDivideByPOT(p2, right_shift);
495 const auto p1_narrowed = vqmovn_s32(p1);
496 const auto p2_narrowed = vqmovn_s32(p2);
497 const auto p = vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
498 const auto clamped = vmax_u8(output_activation_min_vector,
499 vmin_u8(output_activation_max_vector, vqmovun_s16(p)));
500 vst1_u8(output_data + i, clamped);
504 for (; i < size; ++i)
506 const int32_t input1_val = params.input1_offset + input1_data[i];
507 const int32_t input2_val = params.input2_offset + input2_data[i];
508 const int32_t unclamped_result =
509 params.output_offset + MultiplyByQuantizedMultiplier(input1_val * input2_val,
510 params.output_multiplier,
511 params.output_shift);
512 const int32_t clamped_output =
513 std::min(params.quantized_activation_max,
514 std::max(params.quantized_activation_min, unclamped_result));
515 output_data[i] = static_cast<uint8_t>(clamped_output);
519 inline void MulElementwise(int size, const BinaryArithmeticOpParam ¶ms,
520 const float *input1_data, const float *input2_data, float *output_data)
525 const auto activation_min = vdupq_n_f32(params.float_activation_min);
526 const auto activation_max = vdupq_n_f32(params.float_activation_max);
527 for (; i <= size - 16; i += 16)
529 auto a10 = vld1q_f32(input1_data + i);
530 auto a11 = vld1q_f32(input1_data + i + 4);
531 auto a12 = vld1q_f32(input1_data + i + 8);
532 auto a13 = vld1q_f32(input1_data + i + 12);
533 auto a20 = vld1q_f32(input2_data + i);
534 auto a21 = vld1q_f32(input2_data + i + 4);
535 auto a22 = vld1q_f32(input2_data + i + 8);
536 auto a23 = vld1q_f32(input2_data + i + 12);
537 auto x0 = vmulq_f32(a10, a20);
538 auto x1 = vmulq_f32(a11, a21);
539 auto x2 = vmulq_f32(a12, a22);
540 auto x3 = vmulq_f32(a13, a23);
541 x0 = vmaxq_f32(activation_min, x0);
542 x1 = vmaxq_f32(activation_min, x1);
543 x2 = vmaxq_f32(activation_min, x2);
544 x3 = vmaxq_f32(activation_min, x3);
545 x0 = vminq_f32(activation_max, x0);
546 x1 = vminq_f32(activation_max, x1);
547 x2 = vminq_f32(activation_max, x2);
548 x3 = vminq_f32(activation_max, x3);
549 vst1q_f32(output_data + i, x0);
550 vst1q_f32(output_data + i + 4, x1);
551 vst1q_f32(output_data + i + 8, x2);
552 vst1q_f32(output_data + i + 12, x3);
554 for (; i <= size - 4; i += 4)
556 auto a1 = vld1q_f32(input1_data + i);
557 auto a2 = vld1q_f32(input2_data + i);
558 auto x = vmulq_f32(a1, a2);
559 x = vmaxq_f32(activation_min, x);
560 x = vminq_f32(activation_max, x);
561 vst1q_f32(output_data + i, x);
565 for (; i < size; i++)
567 auto x = input1_data[i] * input2_data[i];
569 ActivationFunctionWithMinMax(x, params.float_activation_min, params.float_activation_max);
573 inline void MulQuant8(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
574 const uint8_t *input1_data, const Shape &input2_shape,
575 const uint8_t *input2_data, const Shape &output_shape, uint8_t *output_data)
577 const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
578 MulElementwiseQuant8(flat_size, params, input1_data, input2_data, output_data);
581 inline void Mul(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
582 const float *input1_data, const Shape &input2_shape, const float *input2_data,
583 const Shape &output_shape, float *output_data)
585 const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
586 MulElementwise(flat_size, params, input1_data, input2_data, output_data);
589 inline void MulSimpleBroadcastQuant8(int size, const BinaryArithmeticOpParam ¶ms,
590 const uint8_t broadcast_value, const uint8_t *input2_data,
591 uint8_t *output_data)
594 int32_t clamped_output;
595 for (; i < size; ++i)
597 clamped_output = quant8_mul(params, broadcast_value, input2_data[i]);
598 output_data[i] = static_cast<uint8_t>(clamped_output);
602 // Broadcast mul that can often be used for inner loop of broadcast Mul.
603 // This function will handle scalar_value (LHS) * vector_values (RHS).
604 // Since it's a float function, input params does not matter here.
605 inline void MulSimpleBroadcast(int size, const BinaryArithmeticOpParam ¶ms,
606 const float broadcast_value, const float *input2_data,
611 const float32x4_t output_activation_min_vector = vdupq_n_f32(params.float_activation_min);
612 const float32x4_t output_activation_max_vector = vdupq_n_f32(params.float_activation_max);
613 const float32x4_t broadcast_value_dup = vdupq_n_f32(broadcast_value);
614 for (; i <= size - 4; i += 4)
616 const float32x4_t input2_val_original = vld1q_f32(input2_data + i);
618 const float32x4_t output = vmulq_f32(input2_val_original, broadcast_value_dup);
620 const float32x4_t clamped =
621 vmaxq_f32(output_activation_min_vector, vminq_f32(output_activation_max_vector, output));
622 vst1q_f32(output_data + i, clamped);
626 for (; i < size; ++i)
628 float x = broadcast_value * input2_data[i];
630 ActivationFunctionWithMinMax(x, params.float_activation_min, params.float_activation_max);
634 inline void BroadcastMulDispatchQuant8(const BinaryArithmeticOpParam ¶ms,
635 const Shape &input1_shape, const uint8_t *input1_data,
636 const Shape &input2_shape, const uint8_t *input2_data,
637 const Shape &output_shape, uint8_t *output_data)
639 if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast)
641 const std::function<uint8_t(const BinaryArithmeticOpParam &, const uint8_t &, const uint8_t &)>
642 fn = [](const BinaryArithmeticOpParam ¶ms, const uint8_t &a,
643 const uint8_t &b) -> uint8_t {
644 return static_cast<uint8_t>(quant8_mul(params, a, b));
646 reference::BroadcastBinaryArithmeticOpSlowQuant8(params, input1_shape, input1_data,
647 input2_shape, input2_data, output_shape,
651 BinaryBroadcastFiveFold(
652 params, input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data,
653 static_cast<void (*)(int, const BinaryArithmeticOpParam &, const uint8_t *, const uint8_t *,
654 uint8_t *)>(MulElementwiseQuant8),
655 static_cast<void (*)(int, const BinaryArithmeticOpParam &, uint8_t, const uint8_t *,
656 uint8_t *)>(MulSimpleBroadcastQuant8));
659 inline void BroadcastMulDispatch(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
660 const float *input1_data, const Shape &input2_shape,
661 const float *input2_data, const Shape &output_shape,
664 if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast)
666 // TODO: Use GetBinaryArithmeticFn
667 const std::function<float(const float &, const float &)> fn =
668 [](const float &a, const float &b) -> float { return a * b; };
669 reference::BroadcastBinaryArithmeticOpSlow(params, input1_shape, input1_data, input2_shape,
670 input2_data, output_shape, output_data, fn);
673 BinaryBroadcastFiveFold(
674 params, input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data,
675 static_cast<void (*)(int, const BinaryArithmeticOpParam &, const float *, const float *,
676 float *)>(MulElementwise),
677 static_cast<void (*)(int, const BinaryArithmeticOpParam &, float, const float *, float *)>(
678 MulSimpleBroadcast));
681 } // namespace optimized
685 #endif // __NNFW_CKER_OPTIMIZED_BINARYARITHMETICOPS_H__