Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / compute / cker / include / cker / operation / BinaryArithmeticOps.h
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_CKER_BINARY_ARITHMETIC_OPS_H__
19 #define __NNFW_CKER_BINARY_ARITHMETIC_OPS_H__
20
21 #include <functional>
22 #include "cker/operation/optimized/BinaryArithmeticOps.h"
23 #include "cker/operation/reference/BinaryArithmeticOps.h"
24 #include "cker/Shape.h"
25 #include "cker/Types.h"
26 #include "cker/Utils.h"
27
28 namespace nnfw
29 {
30 namespace cker
31 {
32
33 namespace
34 {
35 template <BinaryArithmeticOpType op_type, typename T>
36 const std::function<T(const T &, const T &)> GetBinaryArtithmeticFn()
37 {
38   switch (op_type)
39   {
40     case BinaryArithmeticOpType::ADD:
41     {
42       return [](const T &a, const T &b) -> T { return a + b; };
43     }
44     case BinaryArithmeticOpType::MUL:
45     {
46       return [](const T &a, const T &b) -> T { return a * b; };
47     }
48     case BinaryArithmeticOpType::SUB:
49     {
50       return [](const T &a, const T &b) -> T { return a - b; };
51     }
52     case BinaryArithmeticOpType::DIV:
53     {
54       if (std::is_floating_point<T>::value)
55         return [](const T &a, const T &b) -> T { return a / b; };
56       else
57         return [](const T &a, const T &b) -> T {
58           if (b == 0)
59             throw std::runtime_error("Divide by zero");
60           return a / b;
61         };
62     }
63     case BinaryArithmeticOpType::POW:
64     {
65       return [](const T &a, const T &b) -> T { return std::pow(a, b); };
66     }
67     default:
68     {
69       assert(false);
70       return nullptr;
71     }
72   }
73 }
74 } // namespace
75
76 // Consolidates dimensions in broadcast inputs, checks for five-fold pattern.
77 //
78 // For example, if sequence of dimensions of one input is
79 // ..., 1, 3, 1, 7, 9, 5,... and the other is ..., 2, 3, 1, 7, 1, 1, ...
80 // we can consolidate these as
81 // ..., 1, 3*7, 9*5, ... and 2, 3*7, 1.
82 //
83 // The category is updated in the less-frequent case of shapes that are
84 // not suited to a fivefold-loop broadcast.
85 //
86 // Falls back to generic pattern when it does not know how to process properly.
87 //
88 // Returns true iff there is some sort of broadcast, which includes five-fold
89 // patterns and falling back to generic broadcast.
90 inline bool ProcessBroadcastShapes(const Shape &shape0, const Shape &shape1,
91                                    BinaryArithmeticOpParam *params)
92 {
93   const int dims_count = std::max(shape0.DimensionsCount(), shape1.DimensionsCount());
94
95   params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
96   Shape scalar_shape(dims_count, 1);
97
98   auto extended_shape0 = Shape::ExtendedShape(dims_count, shape0);
99   auto extended_shape1 = Shape::ExtendedShape(dims_count, shape1);
100
101   // Check for "exact" match, implicitly accepting any scalar shapes.
102   if (extended_shape0 == extended_shape1)
103   {
104     params->broadcast_category = BroadcastableOpCategory::kNonBroadcast;
105     return false;
106   }
107
108   for (int i = dims_count - 1; i >= 0; --i)
109   {
110     if (extended_shape0.Dims(i) == extended_shape1.Dims(i))
111     {
112       continue;
113     }
114     else if (extended_shape0.Dims(i) == 1)
115     {
116       params->broadcast_category = BroadcastableOpCategory::kFirstInputBroadcastsFast;
117       break;
118     }
119     else if (extended_shape1.Dims(i) == 1)
120     {
121       params->broadcast_category = BroadcastableOpCategory::kSecondInputBroadcastsFast;
122       break;
123     }
124     else
125     {
126       // This case is erroneous: there is a dimension that does not match and
127       // is not a broadcast from one shape to the other.
128       params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
129       return true;
130     }
131   }
132
133   if (params->broadcast_category != BroadcastableOpCategory::kFirstInputBroadcastsFast &&
134       params->broadcast_category != BroadcastableOpCategory::kSecondInputBroadcastsFast)
135   {
136     return false;
137   }
138
139   // From this point it is assumed contractually that corresponding dimensions
140   // in shape0 and shape1 are either (a) equal or (b) one or other equals 1.
141   const bool swap_inputs =
142     params->broadcast_category == BroadcastableOpCategory::kSecondInputBroadcastsFast;
143   const Shape *shape_a = swap_inputs ? &extended_shape1 : &extended_shape0;
144   const Shape *shape_b = swap_inputs ? &extended_shape0 : &extended_shape1;
145
146   int i = dims_count - 1;
147   params->broadcast_shape[0] = 1;
148   params->broadcast_shape[1] = 1;
149   params->broadcast_shape[2] = 1;
150   params->broadcast_shape[3] = 1;
151   params->broadcast_shape[4] = 1;
152   // y_0 is greedy: include dims if both or neither equal 1: in other words,
153   // test for equality rather than (shape_a->Dims(i) != 1).
154   while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i))
155   {
156     params->broadcast_shape[4] *= shape_b->Dims(i);
157     --i;
158   }
159   // Here either input_a or input_b has dim of 1 (if i >= 0).  If it is input_b
160   // that has the unit dimension, the next two loops are not entered.
161   while (i >= 0 && shape_a->Dims(i) == 1)
162   {
163     params->broadcast_shape[3] *= shape_b->Dims(i);
164     --i;
165   }
166   while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i))
167   {
168     params->broadcast_shape[2] *= shape_a->Dims(i);
169     --i;
170   }
171   // Here either input_a or input_b has dim of 1 (if i >= 0).
172   while (i >= 0 && shape_b->Dims(i) == 1)
173   {
174     params->broadcast_shape[1] *= shape_a->Dims(i);
175     --i;
176   }
177   while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i))
178   {
179     params->broadcast_shape[0] *= shape_b->Dims(i);
180     --i;
181   }
182
183   // Rarer case is when the broadcast dimensions cannot be handled by a fivefold
184   // loop.
185   if (i >= 0)
186   {
187     params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
188   }
189   return true;
190 }
191
192 template <BinaryArithmeticOpType op_type, typename T>
193 inline void BinaryArithmeticOp(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
194                                const T *input1_data, const Shape &input2_shape,
195                                const T *input2_data, const Shape &output_shape, T *output_data)
196 {
197   reference::BinaryArithmeticOp(params, input1_shape, input1_data, input2_shape, input2_data,
198                                 output_shape, output_data, GetBinaryArtithmeticFn<op_type, T>());
199 }
200
201 template <BinaryArithmeticOpType op_type>
202 inline void BinaryArithmeticOp(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
203                                const uint8_t *input1_data, const Shape &input2_shape,
204                                const uint8_t *input2_data, const Shape &output_shape,
205                                uint8_t *output_data)
206 {
207   switch (op_type)
208   {
209     case nnfw::cker::BinaryArithmeticOpType::ADD:
210     case nnfw::cker::BinaryArithmeticOpType::SUB:
211       optimized::AddQuant8(params, input1_shape, input1_data, input2_shape, input2_data,
212                            output_shape, output_data);
213       break;
214     case nnfw::cker::BinaryArithmeticOpType::MUL:
215       optimized::MulQuant8(params, input1_shape, const_cast<uint8_t *>(input1_data), input2_shape,
216                            const_cast<uint8_t *>(input2_data), output_shape, output_data);
217       break;
218     case nnfw::cker::BinaryArithmeticOpType::DIV:
219       throw std::runtime_error{"Quant8 Asymm NYI"};
220
221     default:
222       assert(false);
223       break;
224   }
225 }
226
227 template <BinaryArithmeticOpType op_type>
228 inline void BinaryArithmeticOp(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
229                                const float *input1_data, const Shape &input2_shape,
230                                const float *input2_data, const Shape &output_shape,
231                                float *output_data)
232 {
233   // Supported type is only float now
234   switch (op_type)
235   {
236     case nnfw::cker::BinaryArithmeticOpType::ADD:
237       optimized::Add(params, input1_shape, input1_data, input2_shape, input2_data, output_shape,
238                      output_data);
239       break;
240     case nnfw::cker::BinaryArithmeticOpType::MUL:
241       optimized::Mul(params, input1_shape, input1_data, input2_shape, input2_data, output_shape,
242                      output_data);
243       break;
244     case nnfw::cker::BinaryArithmeticOpType::SUB:
245       optimized::Sub(params, input1_shape, input1_data, input2_shape, input2_data, output_shape,
246                      output_data);
247       break;
248     case nnfw::cker::BinaryArithmeticOpType::DIV:
249       optimized::Div(params, input1_shape, input1_data, input2_shape, input2_data, output_shape,
250                      output_data);
251       break;
252     default:
253       assert(false);
254       break;
255   }
256 }
257
258 template <BinaryArithmeticOpType op_type, typename T>
259 inline void BroadcastBinaryArithmeticOp(BinaryArithmeticOpParam &params, const Shape &input1_shape,
260                                         const T *input1_data, const Shape &input2_shape,
261                                         const T *input2_data, const Shape &output_shape,
262                                         T *output_data)
263 {
264   reference::BroadcastBinaryArithmeticOpSlow(params, input1_shape, input1_data, input2_shape,
265                                              input2_data, output_shape, output_data,
266                                              GetBinaryArtithmeticFn<op_type, T>());
267 }
268
269 template <BinaryArithmeticOpType op_type>
270 inline void BroadcastBinaryArithmeticOp(BinaryArithmeticOpParam &params, const Shape &input1_shape,
271                                         const uint8_t *input1_data, const Shape &input2_shape,
272                                         const uint8_t *input2_data, const Shape &output_shape,
273                                         uint8_t *output_data)
274 {
275   switch (op_type)
276   {
277     case nnfw::cker::BinaryArithmeticOpType::ADD:
278     case nnfw::cker::BinaryArithmeticOpType::SUB:
279       optimized::BroadcastAddDispatchQuant8(params, input1_shape, input1_data, input2_shape,
280                                             input2_data, output_shape, output_data);
281       break;
282     case nnfw::cker::BinaryArithmeticOpType::MUL:
283       optimized::BroadcastMulDispatchQuant8(
284         params, input1_shape, const_cast<uint8_t *>(input1_data), input2_shape,
285         const_cast<uint8_t *>(input2_data), output_shape, output_data);
286       break;
287     case nnfw::cker::BinaryArithmeticOpType::DIV:
288     case nnfw::cker::BinaryArithmeticOpType::POW:
289       throw std::runtime_error{"Quant8 Asymm NYI"};
290     default:
291       assert(false);
292       break;
293   }
294 }
295
296 template <BinaryArithmeticOpType op_type>
297 inline void BroadcastBinaryArithmeticOp(BinaryArithmeticOpParam &params, const Shape &input1_shape,
298                                         const float *input1_data, const Shape &input2_shape,
299                                         const float *input2_data, const Shape &output_shape,
300                                         float *output_data)
301 {
302   // Supported type is only float now
303   switch (op_type)
304   {
305     case nnfw::cker::BinaryArithmeticOpType::ADD:
306       optimized::BroadcastAddDispatch(params, input1_shape, input1_data, input2_shape, input2_data,
307                                       output_shape, output_data);
308       break;
309     case nnfw::cker::BinaryArithmeticOpType::MUL:
310       optimized::BroadcastMulDispatch(params, input1_shape, input1_data, input2_shape, input2_data,
311                                       output_shape, output_data);
312       break;
313     case nnfw::cker::BinaryArithmeticOpType::SUB:
314       optimized::BroadcastSubDispatch(params, input1_shape, input1_data, input2_shape, input2_data,
315                                       output_shape, output_data);
316       break;
317     case nnfw::cker::BinaryArithmeticOpType::DIV:
318       optimized::BroadcastDivDispatch(params, input1_shape, input1_data, input2_shape, input2_data,
319                                       output_shape, output_data);
320       break;
321     case nnfw::cker::BinaryArithmeticOpType::POW:
322       reference::BroadcastBinaryArithmeticOpSlow<float>(
323         params, input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data,
324         GetBinaryArtithmeticFn<op_type, float>());
325       break;
326     default:
327       assert(false);
328       break;
329   }
330 }
331
332 } // namespace cker
333 } // namespace nnfw
334
335 #endif // __NNFW_CKER_BINARY_ARITHMETIC_OPS_H__