2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef LUCI_INTERPRETER_KERNELS_BINARYOPUTILS_H
19 #define LUCI_INTERPRETER_KERNELS_BINARYOPUTILS_H
21 #include "tensorflow/lite/kernels/internal/common.h"
22 #include "tensorflow/lite/kernels/internal/types.h"
24 namespace luci_interpreter
29 // Derived from tensorflow/lite/kernels/internal/reference/maximum_minimum.h (v2.3.0).
30 template <typename T, typename Op, int N = 5>
31 void BinaryOpBroadcastSlow(const tflite::RuntimeShape &unextended_input1_shape,
33 const tflite::RuntimeShape &unextended_input2_shape,
35 const tflite::RuntimeShape &unextended_output_shape, T *output_data,
38 if (unextended_input1_shape == unextended_input2_shape)
40 const int flat_size = tflite::MatchingElementsSize(
41 unextended_input1_shape, unextended_input2_shape, unextended_output_shape);
42 for (int i = 0; i < flat_size; ++i)
44 output_data[i] = op(input1_data[i], input2_data[i]);
49 assert(unextended_input1_shape.DimensionsCount() <= N);
50 assert(unextended_input2_shape.DimensionsCount() <= N);
51 assert(unextended_output_shape.DimensionsCount() <= N);
53 tflite::NdArrayDesc<N> desc1{};
54 tflite::NdArrayDesc<N> desc2{};
55 tflite::NdArrayDesc<N> output_desc{};
56 tflite::NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, unextended_input2_shape,
58 tflite::CopyDimsToDesc(tflite::RuntimeShape::ExtendedShape(N, unextended_output_shape),
61 auto fn = [&](int indexes[N]) {
62 output_data[SubscriptToIndex(output_desc, indexes)] =
63 op(input1_data[SubscriptToIndex(desc1, indexes)],
64 input2_data[SubscriptToIndex(desc2, indexes)]);
66 tflite::NDOpsHelper<N>(output_desc, fn);
70 } // namespace kernels
71 } // namespace luci_interpreter
73 #endif // LUCI_INTERPRETER_KERNELS_BINARYOPUTILS_H