Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compute / cker / include / cker / operation / optimized / BinaryArithmeticOps.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_CKER_OPTIMIZED_BINARYARITHMETICOPS_H__
19 #define __NNFW_CKER_OPTIMIZED_BINARYARITHMETICOPS_H__
20
21 #include <functional>
22 #include "cker/neon/neon_check.h"
23 #include "cker/operation/reference/BinaryArithmeticOps.h"
24 #include "cker/Shape.h"
25 #include "cker/Types.h"
26 #include "cker/Utils.h"
27 #include "fixedpoint/fixedpoint.h"
28
29 namespace nnfw
30 {
31 namespace cker
32 {
33 namespace optimized
34 {
35
36 template <typename ElementwiseF, typename ScalarBroadcastF, typename T>
37 inline void BinaryBroadcastFiveFold(const BinaryArithmeticOpParam &params,
38                                     const Shape & /* unswitched_input1_shape */,
39                                     const T *unswitched_input1_data,
40                                     const Shape & /* unswitched_input2_shape */,
41                                     const T *unswitched_input2_data,
42                                     const Shape & /* output_shape */, T *output_data,
43                                     ElementwiseF elementwise_f, ScalarBroadcastF scalar_broadcast_f)
44 {
45   const bool use_unswitched =
46       params.broadcast_category == BroadcastableOpCategory::kFirstInputBroadcastsFast;
47
48   const T *input1_data = use_unswitched ? unswitched_input1_data : unswitched_input2_data;
49   const T *input2_data = use_unswitched ? unswitched_input2_data : unswitched_input1_data;
50
51   // Fivefold nested loops. The second input resets its position for each
52   // iteration of the second loop. The first input resets its position at the
53   // beginning of the fourth loop. The innermost loop is an elementwise add of
54   // sections of the arrays.
55   T *output_data_ptr = output_data;
56   const T *input1_data_ptr = input1_data;
57   const T *input2_data_reset = input2_data;
58   // In the fivefold pattern, y0, y2 and y4 are not broadcast, and so shared
59   // between input shapes. y3 for input 1 is always broadcast, and so the
60   // dimension there is 1, whereas optionally y1 might be broadcast for input 2.
61   // Put another way,
62   // input1.shape.FlatSize = y0 * y1 * y2 * y4,
63   // input2.shape.FlatSize = y0 * y2 * y3 * y4.
64   int y0 = params.broadcast_shape[0];
65   int y1 = params.broadcast_shape[1];
66   int y2 = params.broadcast_shape[2];
67   int y3 = params.broadcast_shape[3];
68   int y4 = params.broadcast_shape[4];
69   if (y4 > 1)
70   {
71     // General fivefold pattern, with y4 > 1 so there is a non-broadcast inner
72     // dimension.
73     for (int i0 = 0; i0 < y0; ++i0)
74     {
75       const T *input2_data_ptr = nullptr;
76       for (int i1 = 0; i1 < y1; ++i1)
77       {
78         input2_data_ptr = input2_data_reset;
79         for (int i2 = 0; i2 < y2; ++i2)
80         {
81           for (int i3 = 0; i3 < y3; ++i3)
82           {
83             elementwise_f(y4, params, input1_data_ptr, input2_data_ptr, output_data_ptr);
84             input2_data_ptr += y4;
85             output_data_ptr += y4;
86           }
87           // We have broadcast y4 of input1 data y3 times, and now move on.
88           input1_data_ptr += y4;
89         }
90       }
91       // We have broadcast y2*y3*y4 of input2 data y1 times, and now move on.
92       input2_data_reset = input2_data_ptr;
93     }
94   }
95   else
96   {
97     // Special case of y4 == 1, in which the innermost loop is a single element
98     // and can be combined with the next (y3) as an inner broadcast.
99     //
100     // Note that this handles the case of pure scalar broadcast when
101     // y0 == y1 == y2 == 1. With low overhead it handles cases such as scalar
102     // broadcast with batch (as y2 > 1).
103     //
104     // NOTE The process is the same as the above general case except simplified
105     // for y4 == 1 and the loop over y3 is contained within the
106     // AddScalarBroadcast function.
107     for (int i0 = 0; i0 < y0; ++i0)
108     {
109       const T *input2_data_ptr = nullptr;
110       for (int i1 = 0; i1 < y1; ++i1)
111       {
112         input2_data_ptr = input2_data_reset;
113         for (int i2 = 0; i2 < y2; ++i2)
114         {
115           scalar_broadcast_f(y3, params, *input1_data_ptr, input2_data_ptr, output_data_ptr);
116           input2_data_ptr += y3;
117           output_data_ptr += y3;
118           input1_data_ptr += 1;
119         }
120       }
121       input2_data_reset = input2_data_ptr;
122     }
123   }
124 }
125
126 inline int32_t quant8_sum(const BinaryArithmeticOpParam &params, const uint8_t input1_data,
127                           const uint8_t input2_data)
128 {
129   const int32_t input1_val = params.input1_offset + input1_data;
130   const int32_t input2_val = params.input2_offset + input2_data;
131   const int32_t shifted_input1_val = input1_val * (1 << params.left_shift);
132   const int32_t shifted_input2_val = input2_val * (1 << params.left_shift);
133   const int32_t scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
134       shifted_input1_val, params.input1_multiplier, params.input1_shift);
135   const int32_t scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
136       shifted_input2_val, params.input2_multiplier, params.input2_shift);
137   const int32_t raw_sum = scaled_input1_val + scaled_input2_val;
138   const int32_t raw_output = MultiplyByQuantizedMultiplierSmallerThanOneExp(
139                                  raw_sum, params.output_multiplier, params.output_shift) +
140                              params.output_offset;
141   const int32_t clamped_output = std::min(params.quantized_activation_max,
142                                           std::max(params.quantized_activation_min, raw_output));
143   return clamped_output;
144 }
145
146 inline void AddElementwiseQuant8(int size, const BinaryArithmeticOpParam &params,
147                                  const uint8_t *input1_data, const uint8_t *input2_data,
148                                  uint8_t *output_data)
149 {
150   int i = 0;
151   for (; i < size; ++i)
152   {
153     int32_t clamped_output = quant8_sum(params, input1_data[i], input2_data[i]);
154     output_data[i] = static_cast<uint8_t>(clamped_output);
155   }
156 }
157
158 inline void AddElementwise(int size, const BinaryArithmeticOpParam &params,
159                            const float *input1_data, const float *input2_data, float *output_data)
160 {
161   int i = 0;
162
163 #ifdef USE_NEON
164   const auto activation_min = vdupq_n_f32(params.float_activation_min);
165   const auto activation_max = vdupq_n_f32(params.float_activation_max);
166   for (; i <= size - 16; i += 16)
167   {
168     auto a10 = vld1q_f32(input1_data + i);
169     auto a11 = vld1q_f32(input1_data + i + 4);
170     auto a12 = vld1q_f32(input1_data + i + 8);
171     auto a13 = vld1q_f32(input1_data + i + 12);
172     auto a20 = vld1q_f32(input2_data + i);
173     auto a21 = vld1q_f32(input2_data + i + 4);
174     auto a22 = vld1q_f32(input2_data + i + 8);
175     auto a23 = vld1q_f32(input2_data + i + 12);
176     auto x0 = vaddq_f32(a10, a20);
177     auto x1 = vaddq_f32(a11, a21);
178     auto x2 = vaddq_f32(a12, a22);
179     auto x3 = vaddq_f32(a13, a23);
180     x0 = vmaxq_f32(activation_min, x0);
181     x1 = vmaxq_f32(activation_min, x1);
182     x2 = vmaxq_f32(activation_min, x2);
183     x3 = vmaxq_f32(activation_min, x3);
184     x0 = vminq_f32(activation_max, x0);
185     x1 = vminq_f32(activation_max, x1);
186     x2 = vminq_f32(activation_max, x2);
187     x3 = vminq_f32(activation_max, x3);
188     vst1q_f32(output_data + i, x0);
189     vst1q_f32(output_data + i + 4, x1);
190     vst1q_f32(output_data + i + 8, x2);
191     vst1q_f32(output_data + i + 12, x3);
192   }
193   for (; i <= size - 4; i += 4)
194   {
195     auto a1 = vld1q_f32(input1_data + i);
196     auto a2 = vld1q_f32(input2_data + i);
197     auto x = vaddq_f32(a1, a2);
198     x = vmaxq_f32(activation_min, x);
199     x = vminq_f32(activation_max, x);
200     vst1q_f32(output_data + i, x);
201   }
202 #endif // NEON
203   for (; i < size; i++)
204   {
205     auto x = input1_data[i] + input2_data[i];
206     output_data[i] = ActivationFunctionWithMinMax<float>(x, params.float_activation_min,
207                                                          params.float_activation_max);
208   }
209 }
210
211 inline void AddQuant8(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
212                       const uint8_t *input1_data, const Shape &input2_shape,
213                       const uint8_t *input2_data, const Shape &output_shape, uint8_t *output_data)
214 {
215   const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
216   AddElementwiseQuant8(flat_size, params, input1_data, input2_data, output_data);
217 }
218
219 inline void Add(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
220                 const float *input1_data, const Shape &input2_shape, const float *input2_data,
221                 const Shape &output_shape, float *output_data)
222 {
223   const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
224   AddElementwise(flat_size, params, input1_data, input2_data, output_data);
225 }
226
227 // Scalar-broadcast add that can be used for inner loop of more general
228 // broadcast add, so that, for example, scalar-broadcast with batch will still
229 // be fast.
230 inline void AddScalarBroadcastQuant8(int size, const BinaryArithmeticOpParam &params,
231                                      uint8_t broadcast_value, const uint8_t *input2_data,
232                                      uint8_t *output_data)
233 {
234   int i = 0;
235   int32_t clamped_output;
236   for (; i < size; ++i)
237   {
238     clamped_output = quant8_sum(params, broadcast_value, input2_data[i]);
239     output_data[i] = static_cast<uint8_t>(clamped_output);
240   }
241 }
242
243 inline void AddScalarBroadcast(int size, const BinaryArithmeticOpParam &params,
244                                float broadcast_value, const float *input2_data, float *output_data)
245 {
246   int i = 0;
247 #ifdef USE_NEON
248   const float32x4_t output_activation_min_vector = vdupq_n_f32(params.float_activation_min);
249   const float32x4_t output_activation_max_vector = vdupq_n_f32(params.float_activation_max);
250   const float32x4_t broadcast_value_dup = vdupq_n_f32(broadcast_value);
251   for (; i <= size - 4; i += 4)
252   {
253     const float32x4_t input2_val_original = vld1q_f32(input2_data + i);
254
255     const float32x4_t output = vaddq_f32(input2_val_original, broadcast_value_dup);
256
257     const float32x4_t clamped =
258         vmaxq_f32(output_activation_min_vector, vminq_f32(output_activation_max_vector, output));
259     vst1q_f32(output_data + i, clamped);
260   }
261 #endif // NEON
262   for (; i < size; ++i)
263   {
264     auto x = broadcast_value + input2_data[i];
265     output_data[i] = ActivationFunctionWithMinMax<float>(x, params.float_activation_min,
266                                                          params.float_activation_max);
267   }
268 }
269
270 inline void BroadcastAddDispatchQuant8(const BinaryArithmeticOpParam &params,
271                                        const Shape &input1_shape, const uint8_t *input1_data,
272                                        const Shape &input2_shape, const uint8_t *input2_data,
273                                        const Shape &output_shape, uint8_t *output_data)
274 {
275   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast)
276   {
277     const std::function<uint8_t(const BinaryArithmeticOpParam &, const uint8_t &, const uint8_t &)>
278         fn = [](const BinaryArithmeticOpParam &params, const uint8_t &a,
279                 const uint8_t &b) -> uint8_t {
280       return static_cast<uint8_t>(quant8_sum(params, a, b));
281     };
282     reference::BroadcastBinaryArithmeticOpSlowQuant8(params, input1_shape, input1_data,
283                                                      input2_shape, input2_data, output_shape,
284                                                      output_data, fn);
285   }
286   else
287   {
288     BinaryBroadcastFiveFold(
289         params, input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data,
290         static_cast<void (*)(int, const BinaryArithmeticOpParam &, const uint8_t *, const uint8_t *,
291                              uint8_t *)>(AddElementwiseQuant8),
292         static_cast<void (*)(int, const BinaryArithmeticOpParam &, uint8_t, const uint8_t *,
293                              uint8_t *)>(AddScalarBroadcastQuant8));
294   }
295 }
296
297 inline void BroadcastAddDispatch(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
298                                  const float *input1_data, const Shape &input2_shape,
299                                  const float *input2_data, const Shape &output_shape,
300                                  float *output_data)
301 {
302   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast)
303   {
304     const std::function<float(const float &, const float &)> fn =
305         [](const float &a, const float &b) -> float { return a + b; };
306     reference::BroadcastBinaryArithmeticOpSlow(params, input1_shape, input1_data, input2_shape,
307                                                input2_data, output_shape, output_data, fn);
308   }
309   else
310   {
311     BinaryBroadcastFiveFold(
312         params, input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data,
313         static_cast<void (*)(int, const BinaryArithmeticOpParam &, const float *, const float *,
314                              float *)>(AddElementwise),
315         static_cast<void (*)(int, const BinaryArithmeticOpParam &, float, const float *, float *)>(
316             AddScalarBroadcast));
317   }
318 }
319
320 inline void Sub(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
321                 const float *input1_data, const Shape &input2_shape, const float *input2_data,
322                 const Shape &output_shape, float *output_data)
323 {
324   int i = 0;
325   const int size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
326 #ifdef USE_NEON
327   const auto activation_min = vdupq_n_f32(params.float_activation_min);
328   const auto activation_max = vdupq_n_f32(params.float_activation_max);
329   for (; i <= size - 16; i += 16)
330   {
331     auto a10 = vld1q_f32(input1_data + i);
332     auto a11 = vld1q_f32(input1_data + i + 4);
333     auto a12 = vld1q_f32(input1_data + i + 8);
334     auto a13 = vld1q_f32(input1_data + i + 12);
335     auto a20 = vld1q_f32(input2_data + i);
336     auto a21 = vld1q_f32(input2_data + i + 4);
337     auto a22 = vld1q_f32(input2_data + i + 8);
338     auto a23 = vld1q_f32(input2_data + i + 12);
339     auto x0 = vsubq_f32(a10, a20);
340     auto x1 = vsubq_f32(a11, a21);
341     auto x2 = vsubq_f32(a12, a22);
342     auto x3 = vsubq_f32(a13, a23);
343     x0 = vmaxq_f32(activation_min, x0);
344     x1 = vmaxq_f32(activation_min, x1);
345     x2 = vmaxq_f32(activation_min, x2);
346     x3 = vmaxq_f32(activation_min, x3);
347     x0 = vminq_f32(activation_max, x0);
348     x1 = vminq_f32(activation_max, x1);
349     x2 = vminq_f32(activation_max, x2);
350     x3 = vminq_f32(activation_max, x3);
351     vst1q_f32(output_data + i, x0);
352     vst1q_f32(output_data + i + 4, x1);
353     vst1q_f32(output_data + i + 8, x2);
354     vst1q_f32(output_data + i + 12, x3);
355   }
356   for (; i <= size - 4; i += 4)
357   {
358     auto a1 = vld1q_f32(input1_data + i);
359     auto a2 = vld1q_f32(input2_data + i);
360     auto x = vsubq_f32(a1, a2);
361     x = vmaxq_f32(activation_min, x);
362     x = vminq_f32(activation_max, x);
363     vst1q_f32(output_data + i, x);
364   }
365 #endif // NEON
366
367   for (; i < size; i++)
368   {
369     auto x = input1_data[i] - input2_data[i];
370     output_data[i] =
371         ActivationFunctionWithMinMax(x, params.float_activation_min, params.float_activation_max);
372   }
373 }
374
375 inline int32_t quant8_mul(const BinaryArithmeticOpParam &params, const uint8_t input1_data,
376                           const uint8_t input2_data)
377 {
378   const int32_t input1_val = params.input1_offset + input1_data;
379   const int32_t input2_val = params.input2_offset + input2_data;
380   const int32_t unclamped_result =
381       params.output_offset + MultiplyByQuantizedMultiplier(input1_val * input2_val,
382                                                            params.output_multiplier,
383                                                            params.output_shift);
384   const int32_t clamped_output = std::min(
385       params.quantized_activation_max, std::max(params.quantized_activation_min, unclamped_result));
386
387   return clamped_output;
388 }
389
390 inline void MulElementwiseQuant8(int size, const BinaryArithmeticOpParam &params,
391                                  const uint8_t *input1_data, const uint8_t *input2_data,
392                                  uint8_t *output_data)
393 {
394   int i = 0;
395   int32_t clamped_output;
396   for (; i < size; i++)
397   {
398     clamped_output = quant8_mul(params, input1_data[i], input2_data[i]);
399     output_data[i] = static_cast<uint8_t>(clamped_output);
400   }
401 }
402
403 inline void MulElementwise(int size, const BinaryArithmeticOpParam &params,
404                            const float *input1_data, const float *input2_data, float *output_data)
405 {
406   int i = 0;
407
408 #ifdef USE_NEON
409   const auto activation_min = vdupq_n_f32(params.float_activation_min);
410   const auto activation_max = vdupq_n_f32(params.float_activation_max);
411   for (; i <= size - 16; i += 16)
412   {
413     auto a10 = vld1q_f32(input1_data + i);
414     auto a11 = vld1q_f32(input1_data + i + 4);
415     auto a12 = vld1q_f32(input1_data + i + 8);
416     auto a13 = vld1q_f32(input1_data + i + 12);
417     auto a20 = vld1q_f32(input2_data + i);
418     auto a21 = vld1q_f32(input2_data + i + 4);
419     auto a22 = vld1q_f32(input2_data + i + 8);
420     auto a23 = vld1q_f32(input2_data + i + 12);
421     auto x0 = vmulq_f32(a10, a20);
422     auto x1 = vmulq_f32(a11, a21);
423     auto x2 = vmulq_f32(a12, a22);
424     auto x3 = vmulq_f32(a13, a23);
425     x0 = vmaxq_f32(activation_min, x0);
426     x1 = vmaxq_f32(activation_min, x1);
427     x2 = vmaxq_f32(activation_min, x2);
428     x3 = vmaxq_f32(activation_min, x3);
429     x0 = vminq_f32(activation_max, x0);
430     x1 = vminq_f32(activation_max, x1);
431     x2 = vminq_f32(activation_max, x2);
432     x3 = vminq_f32(activation_max, x3);
433     vst1q_f32(output_data + i, x0);
434     vst1q_f32(output_data + i + 4, x1);
435     vst1q_f32(output_data + i + 8, x2);
436     vst1q_f32(output_data + i + 12, x3);
437   }
438   for (; i <= size - 4; i += 4)
439   {
440     auto a1 = vld1q_f32(input1_data + i);
441     auto a2 = vld1q_f32(input2_data + i);
442     auto x = vmulq_f32(a1, a2);
443     x = vmaxq_f32(activation_min, x);
444     x = vminq_f32(activation_max, x);
445     vst1q_f32(output_data + i, x);
446   }
447 #endif // NEON
448
449   for (; i < size; i++)
450   {
451     auto x = input1_data[i] * input2_data[i];
452     output_data[i] =
453         ActivationFunctionWithMinMax(x, params.float_activation_min, params.float_activation_max);
454   }
455 }
456
457 inline void MulQuant8(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
458                       const uint8_t *input1_data, const Shape &input2_shape,
459                       const uint8_t *input2_data, const Shape &output_shape, uint8_t *output_data)
460 {
461   const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
462   MulElementwiseQuant8(flat_size, params, input1_data, input2_data, output_data);
463 }
464
465 inline void Mul(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
466                 const float *input1_data, const Shape &input2_shape, const float *input2_data,
467                 const Shape &output_shape, float *output_data)
468 {
469   const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
470   MulElementwise(flat_size, params, input1_data, input2_data, output_data);
471 }
472
473 inline void MulSimpleBroadcastQuant8(int size, const BinaryArithmeticOpParam &params,
474                                      const uint8_t broadcast_value, const uint8_t *input2_data,
475                                      uint8_t *output_data)
476 {
477   int i = 0;
478   int32_t clamped_output;
479   for (; i < size; ++i)
480   {
481     clamped_output = quant8_mul(params, broadcast_value, input2_data[i]);
482     output_data[i] = static_cast<uint8_t>(clamped_output);
483   }
484 }
485
486 // Broadcast mul that can often be used for inner loop of broadcast Mul.
487 // This function will handle scalar_value (LHS) * vector_values (RHS).
488 // Since it's a float function, input params does not matter here.
489 inline void MulSimpleBroadcast(int size, const BinaryArithmeticOpParam &params,
490                                const float broadcast_value, const float *input2_data,
491                                float *output_data)
492 {
493   int i = 0;
494 #ifdef USE_NEON
495   const float32x4_t output_activation_min_vector = vdupq_n_f32(params.float_activation_min);
496   const float32x4_t output_activation_max_vector = vdupq_n_f32(params.float_activation_max);
497   const float32x4_t broadcast_value_dup = vdupq_n_f32(broadcast_value);
498   for (; i <= size - 4; i += 4)
499   {
500     const float32x4_t input2_val_original = vld1q_f32(input2_data + i);
501
502     const float32x4_t output = vmulq_f32(input2_val_original, broadcast_value_dup);
503
504     const float32x4_t clamped =
505         vmaxq_f32(output_activation_min_vector, vminq_f32(output_activation_max_vector, output));
506     vst1q_f32(output_data + i, clamped);
507   }
508 #endif // NEON
509
510   for (; i < size; ++i)
511   {
512     float x = broadcast_value * input2_data[i];
513     output_data[i] =
514         ActivationFunctionWithMinMax(x, params.float_activation_min, params.float_activation_max);
515   }
516 }
517
518 inline void BroadcastMulDispatchQuant8(const BinaryArithmeticOpParam &params,
519                                        const Shape &input1_shape, const uint8_t *input1_data,
520                                        const Shape &input2_shape, const uint8_t *input2_data,
521                                        const Shape &output_shape, uint8_t *output_data)
522 {
523   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast)
524   {
525     const std::function<uint8_t(const BinaryArithmeticOpParam &, const uint8_t &, const uint8_t &)>
526         fn = [](const BinaryArithmeticOpParam &params, const uint8_t &a,
527                 const uint8_t &b) -> uint8_t {
528       return static_cast<uint8_t>(quant8_mul(params, a, b));
529     };
530     reference::BroadcastBinaryArithmeticOpSlowQuant8(params, input1_shape, input1_data,
531                                                      input2_shape, input2_data, output_shape,
532                                                      output_data, fn);
533     return;
534   }
535   BinaryBroadcastFiveFold(
536       params, input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data,
537       static_cast<void (*)(int, const BinaryArithmeticOpParam &, const uint8_t *, const uint8_t *,
538                            uint8_t *)>(MulElementwiseQuant8),
539       static_cast<void (*)(int, const BinaryArithmeticOpParam &, uint8_t, const uint8_t *,
540                            uint8_t *)>(MulSimpleBroadcastQuant8));
541 }
542
543 inline void BroadcastMulDispatch(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
544                                  const float *input1_data, const Shape &input2_shape,
545                                  const float *input2_data, const Shape &output_shape,
546                                  float *output_data)
547 {
548   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast)
549   {
550     // TODO: Use GetBinaryArithmeticFn
551     const std::function<float(const float &, const float &)> fn =
552         [](const float &a, const float &b) -> float { return a * b; };
553     reference::BroadcastBinaryArithmeticOpSlow(params, input1_shape, input1_data, input2_shape,
554                                                input2_data, output_shape, output_data, fn);
555     return;
556   }
557   BinaryBroadcastFiveFold(
558       params, input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data,
559       static_cast<void (*)(int, const BinaryArithmeticOpParam &, const float *, const float *,
560                            float *)>(MulElementwise),
561       static_cast<void (*)(int, const BinaryArithmeticOpParam &, float, const float *, float *)>(
562           MulSimpleBroadcast));
563 }
564
565 } // namespace optimized
566 } // namespace cker
567 } // namespace nnfw
568
569 #endif // __NNFW_CKER_OPTIMIZED_BINARYARITHMETICOPS_H__