2d2842a9e8ec6c1df1af976913e8d2361dc5e611
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / BinaryOpCommon.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef LUCI_INTERPRETER_KERNELS_BINARYOPUTILS_H
19 #define LUCI_INTERPRETER_KERNELS_BINARYOPUTILS_H
20
21 #include "tensorflow/lite/kernels/internal/common.h"
22 #include "tensorflow/lite/kernels/internal/types.h"
23
24 namespace luci_interpreter
25 {
26 namespace kernels
27 {
28
29 // Derived from tensorflow/lite/kernels/internal/reference/maximum_minimum.h (v2.3.0).
30 template <typename T, typename Op, int N = 5>
31 void BinaryOpBroadcastSlow(const tflite::RuntimeShape &unextended_input1_shape,
32                            const T *input1_data,
33                            const tflite::RuntimeShape &unextended_input2_shape,
34                            const T *input2_data,
35                            const tflite::RuntimeShape &unextended_output_shape, T *output_data,
36                            Op op)
37 {
38   if (unextended_input1_shape == unextended_input2_shape)
39   {
40     const int flat_size = tflite::MatchingElementsSize(
41       unextended_input1_shape, unextended_input2_shape, unextended_output_shape);
42     for (int i = 0; i < flat_size; ++i)
43     {
44       output_data[i] = op(input1_data[i], input2_data[i]);
45     }
46   }
47   else
48   {
49     assert(unextended_input1_shape.DimensionsCount() <= N);
50     assert(unextended_input2_shape.DimensionsCount() <= N);
51     assert(unextended_output_shape.DimensionsCount() <= N);
52
53     tflite::NdArrayDesc<N> desc1{};
54     tflite::NdArrayDesc<N> desc2{};
55     tflite::NdArrayDesc<N> output_desc{};
56     tflite::NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, unextended_input2_shape,
57                                                 &desc1, &desc2);
58     tflite::CopyDimsToDesc(tflite::RuntimeShape::ExtendedShape(N, unextended_output_shape),
59                            &output_desc);
60
61     auto fn = [&](int indexes[N]) {
62       output_data[SubscriptToIndex(output_desc, indexes)] =
63         op(input1_data[SubscriptToIndex(desc1, indexes)],
64            input2_data[SubscriptToIndex(desc2, indexes)]);
65     };
66     tflite::NDOpsHelper<N>(output_desc, fn);
67   }
68 }
69
70 } // namespace kernels
71 } // namespace luci_interpreter
72
73 #endif // LUCI_INTERPRETER_KERNELS_BINARYOPUTILS_H