2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef __NNFW_CKER_BINARY_ARITHMETIC_OPS_H__
19 #define __NNFW_CKER_BINARY_ARITHMETIC_OPS_H__
22 #include "cker/operation/optimized/BinaryArithmeticOps.h"
23 #include "cker/operation/reference/BinaryArithmeticOps.h"
24 #include "cker/Shape.h"
25 #include "cker/Types.h"
26 #include "cker/Utils.h"
35 template <BinaryArithmeticOpType op_type, typename T>
36 const std::function<T(const T &, const T &)> GetBinaryArtithmeticFn()
40 case BinaryArithmeticOpType::ADD:
42 return [](const T &a, const T &b) -> T { return a + b; };
44 case BinaryArithmeticOpType::MUL:
46 return [](const T &a, const T &b) -> T { return a * b; };
48 case BinaryArithmeticOpType::SUB:
50 return [](const T &a, const T &b) -> T { return a - b; };
52 case BinaryArithmeticOpType::DIV:
54 if (std::is_floating_point<T>::value)
55 return [](const T &a, const T &b) -> T { return a / b; };
57 return [](const T &a, const T &b) -> T {
59 throw std::runtime_error("Divide by zero");
63 case BinaryArithmeticOpType::POW:
65 return [](const T &a, const T &b) -> T { return std::pow(a, b); };
76 // Consolidates dimensions in broadcast inputs, checks for five-fold pattern.
78 // For example, if sequence of dimensions of one input is
79 // ..., 1, 3, 1, 7, 9, 5,... and the other is ..., 2, 3, 1, 7, 1, 1, ...
80 // we can consolidate these as
81 // ..., 1, 3*7, 9*5, ... and 2, 3*7, 1.
83 // The category is updated in the less-frequent case of shapes that are
84 // not suited to a fivefold-loop broadcast.
86 // Falls back to generic pattern when it does not know how to process properly.
88 // Returns true iff there is some sort of broadcast, which includes five-fold
89 // patterns and falling back to generic broadcast.
90 inline bool ProcessBroadcastShapes(const Shape &shape0, const Shape &shape1,
91 BinaryArithmeticOpParam *params)
93 const int dims_count = std::max(shape0.DimensionsCount(), shape1.DimensionsCount());
95 params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
96 Shape scalar_shape(dims_count, 1);
98 auto extended_shape0 = Shape::ExtendedShape(dims_count, shape0);
99 auto extended_shape1 = Shape::ExtendedShape(dims_count, shape1);
101 // Check for "exact" match, implicitly accepting any scalar shapes.
102 if (extended_shape0 == extended_shape1)
104 params->broadcast_category = BroadcastableOpCategory::kNonBroadcast;
108 for (int i = dims_count - 1; i >= 0; --i)
110 if (extended_shape0.Dims(i) == extended_shape1.Dims(i))
114 else if (extended_shape0.Dims(i) == 1)
116 params->broadcast_category = BroadcastableOpCategory::kFirstInputBroadcastsFast;
119 else if (extended_shape1.Dims(i) == 1)
121 params->broadcast_category = BroadcastableOpCategory::kSecondInputBroadcastsFast;
126 // This case is erroneous: there is a dimension that does not match and
127 // is not a broadcast from one shape to the other.
128 params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
133 if (params->broadcast_category != BroadcastableOpCategory::kFirstInputBroadcastsFast &&
134 params->broadcast_category != BroadcastableOpCategory::kSecondInputBroadcastsFast)
139 // From this point it is assumed contractually that corresponding dimensions
140 // in shape0 and shape1 are either (a) equal or (b) one or other equals 1.
141 const bool swap_inputs =
142 params->broadcast_category == BroadcastableOpCategory::kSecondInputBroadcastsFast;
143 const Shape *shape_a = swap_inputs ? &extended_shape1 : &extended_shape0;
144 const Shape *shape_b = swap_inputs ? &extended_shape0 : &extended_shape1;
146 int i = dims_count - 1;
147 params->broadcast_shape[0] = 1;
148 params->broadcast_shape[1] = 1;
149 params->broadcast_shape[2] = 1;
150 params->broadcast_shape[3] = 1;
151 params->broadcast_shape[4] = 1;
152 // y_0 is greedy: include dims if both or neither equal 1: in other words,
153 // test for equality rather than (shape_a->Dims(i) != 1).
154 while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i))
156 params->broadcast_shape[4] *= shape_b->Dims(i);
159 // Here either input_a or input_b has dim of 1 (if i >= 0). If it is input_b
160 // that has the unit dimension, the next two loops are not entered.
161 while (i >= 0 && shape_a->Dims(i) == 1)
163 params->broadcast_shape[3] *= shape_b->Dims(i);
166 while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i))
168 params->broadcast_shape[2] *= shape_a->Dims(i);
171 // Here either input_a or input_b has dim of 1 (if i >= 0).
172 while (i >= 0 && shape_b->Dims(i) == 1)
174 params->broadcast_shape[1] *= shape_a->Dims(i);
177 while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i))
179 params->broadcast_shape[0] *= shape_b->Dims(i);
183 // Rarer case is when the broadcast dimensions cannot be handled by a fivefold
187 params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
192 template <BinaryArithmeticOpType op_type, typename T>
193 inline void BinaryArithmeticOp(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
194 const T *input1_data, const Shape &input2_shape,
195 const T *input2_data, const Shape &output_shape, T *output_data)
197 reference::BinaryArithmeticOp(params, input1_shape, input1_data, input2_shape, input2_data,
198 output_shape, output_data, GetBinaryArtithmeticFn<op_type, T>());
201 template <BinaryArithmeticOpType op_type>
202 inline void BinaryArithmeticOp(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
203 const uint8_t *input1_data, const Shape &input2_shape,
204 const uint8_t *input2_data, const Shape &output_shape,
205 uint8_t *output_data)
209 case nnfw::cker::BinaryArithmeticOpType::ADD:
210 case nnfw::cker::BinaryArithmeticOpType::SUB:
211 optimized::AddQuant8(params, input1_shape, input1_data, input2_shape, input2_data,
212 output_shape, output_data);
214 case nnfw::cker::BinaryArithmeticOpType::MUL:
215 optimized::MulQuant8(params, input1_shape, const_cast<uint8_t *>(input1_data), input2_shape,
216 const_cast<uint8_t *>(input2_data), output_shape, output_data);
218 case nnfw::cker::BinaryArithmeticOpType::DIV:
219 throw std::runtime_error{"Quant8 Asymm NYI"};
227 template <BinaryArithmeticOpType op_type>
228 inline void BinaryArithmeticOp(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
229 const float *input1_data, const Shape &input2_shape,
230 const float *input2_data, const Shape &output_shape,
233 // Supported type is only float now
236 case nnfw::cker::BinaryArithmeticOpType::ADD:
237 optimized::Add(params, input1_shape, input1_data, input2_shape, input2_data, output_shape,
240 case nnfw::cker::BinaryArithmeticOpType::MUL:
241 optimized::Mul(params, input1_shape, input1_data, input2_shape, input2_data, output_shape,
244 case nnfw::cker::BinaryArithmeticOpType::SUB:
245 optimized::Sub(params, input1_shape, input1_data, input2_shape, input2_data, output_shape,
248 case nnfw::cker::BinaryArithmeticOpType::DIV:
249 reference::BinaryArithmeticOp<float>(params, input1_shape, input1_data, input2_shape,
250 input2_data, output_shape, output_data,
251 GetBinaryArtithmeticFn<op_type, float>());
259 template <BinaryArithmeticOpType op_type, typename T>
260 inline void BroadcastBinaryArithmeticOp(BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
261 const T *input1_data, const Shape &input2_shape,
262 const T *input2_data, const Shape &output_shape,
265 reference::BroadcastBinaryArithmeticOpSlow(params, input1_shape, input1_data, input2_shape,
266 input2_data, output_shape, output_data,
267 GetBinaryArtithmeticFn<op_type, T>());
270 template <BinaryArithmeticOpType op_type>
271 inline void BroadcastBinaryArithmeticOp(BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
272 const uint8_t *input1_data, const Shape &input2_shape,
273 const uint8_t *input2_data, const Shape &output_shape,
274 uint8_t *output_data)
278 case nnfw::cker::BinaryArithmeticOpType::ADD:
279 case nnfw::cker::BinaryArithmeticOpType::SUB:
280 optimized::BroadcastAddDispatchQuant8(params, input1_shape, input1_data, input2_shape,
281 input2_data, output_shape, output_data);
283 case nnfw::cker::BinaryArithmeticOpType::MUL:
284 optimized::BroadcastMulDispatchQuant8(
285 params, input1_shape, const_cast<uint8_t *>(input1_data), input2_shape,
286 const_cast<uint8_t *>(input2_data), output_shape, output_data);
288 case nnfw::cker::BinaryArithmeticOpType::DIV:
289 case nnfw::cker::BinaryArithmeticOpType::POW:
290 throw std::runtime_error{"Quant8 Asymm NYI"};
297 template <BinaryArithmeticOpType op_type>
298 inline void BroadcastBinaryArithmeticOp(BinaryArithmeticOpParam ¶ms, const Shape &input1_shape,
299 const float *input1_data, const Shape &input2_shape,
300 const float *input2_data, const Shape &output_shape,
303 // Supported type is only float now
306 case nnfw::cker::BinaryArithmeticOpType::ADD:
307 optimized::BroadcastAddDispatch(params, input1_shape, input1_data, input2_shape, input2_data,
308 output_shape, output_data);
310 case nnfw::cker::BinaryArithmeticOpType::MUL:
311 optimized::BroadcastMulDispatch(params, input1_shape, input1_data, input2_shape, input2_data,
312 output_shape, output_data);
314 case nnfw::cker::BinaryArithmeticOpType::SUB:
315 case nnfw::cker::BinaryArithmeticOpType::DIV:
316 case nnfw::cker::BinaryArithmeticOpType::POW:
317 reference::BroadcastBinaryArithmeticOpSlow<float>(
318 params, input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data,
319 GetBinaryArtithmeticFn<op_type, float>());
330 #endif // __NNFW_CKER_BINARY_ARITHMETIC_OPS_H__