Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci-interpreter / src / kernels / FloorMod.cpp
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "kernels/FloorMod.h"
19 #include "kernels/Utils.h"
20
21 #include <tensorflow/lite/kernels/internal/reference/binary_function.h>
22 #include <cmath>
23
24 namespace
25 {
26
27 template <typename T> T FloorDivFunc(T input1, T input2)
28 {
29   struct FloatMod
30   {
31     float operator()(const float lhs, const float rhs) const { return std::fmod(lhs, rhs); }
32   };
33   using ModFunc =
34     typename std::conditional<std::is_integral<T>::value, std::modulus<T>, FloatMod>::type;
35   ModFunc mod_func;
36   T trunc_mod = mod_func(input1, input2);
37   return (trunc_mod != 0) && ((input2 < 0) != (trunc_mod < 0)) ? (trunc_mod + input2) : trunc_mod;
38 }
39
40 } // namespace
41
42 namespace luci_interpreter
43 {
44
45 namespace kernels
46 {
47
48 FloorMod::FloorMod(const Tensor *x, const Tensor *y, Tensor *output) : Kernel({x, y}, {output}) {}
49
50 void FloorMod::configure()
51 {
52   LUCI_INTERPRETER_CHECK(x()->element_type() == output()->element_type());
53   LUCI_INTERPRETER_CHECK(y()->element_type() == output()->element_type());
54
55   output()->resize(calculateShapeForBroadcast(x()->shape(), y()->shape()));
56 }
57
58 void FloorMod::execute() const
59 {
60   switch (x()->element_type())
61   {
62     case DataType::FLOAT32:
63       evalFloat();
64       break;
65     case DataType::S8:
66       evalInteger<int8_t>();
67       break;
68     case DataType::S16:
69       evalInteger<int16_t>();
70       break;
71     case DataType::S32:
72       evalInteger<int32_t>();
73       break;
74     case DataType::S64:
75       evalInteger<int64_t>();
76       break;
77     default:
78       throw std::runtime_error("Unsupported type.");
79   }
80 }
81
82 void FloorMod::evalFloat() const
83 {
84   const auto x_data = getTensorData<float>(x());
85   const auto y_data = getTensorData<float>(y());
86
87   if (x()->shape() != y()->shape())
88   {
89     tflite::reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>(
90       getTensorShape(x()), x_data, getTensorShape(y()), y_data, getTensorShape(output()),
91       getTensorData<float>(output()), FloorDivFunc<float>);
92   }
93   else
94   {
95     tflite::reference_ops::BinaryFunction<float, float, float>(
96       getTensorShape(x()), x_data, getTensorShape(y()), y_data, getTensorShape(output()),
97       getTensorData<float>(output()), FloorDivFunc<float>);
98   }
99 }
100
101 template <typename T> void FloorMod::evalInteger() const
102 {
103   const auto x_data = getTensorData<T>(x());
104   const auto y_data = getTensorData<T>(y());
105
106   // Check the denominator
107   const auto y_data_type = y()->element_type();
108   if (y_data_type == DataType::S8 || y_data_type == DataType::S16 || y_data_type == DataType::S32 ||
109       y_data_type == DataType::S64)
110   {
111     for (int i = 0; i < getTensorShape(y()).FlatSize(); ++i)
112     {
113       LUCI_INTERPRETER_CHECK(y_data[i] != 0);
114     }
115   }
116
117   if (x()->shape() != y()->shape())
118   {
119     tflite::reference_ops::BroadcastBinaryFunction4DSlow<T, T, T>(
120       getTensorShape(x()), x_data, getTensorShape(y()), y_data, getTensorShape(output()),
121       getTensorData<T>(output()), FloorDivFunc<T>);
122   }
123   else
124   {
125     tflite::reference_ops::BinaryFunction<T, T, T>(getTensorShape(x()), x_data, getTensorShape(y()),
126                                                    y_data, getTensorShape(output()),
127                                                    getTensorData<T>(output()), FloorDivFunc<T>);
128   }
129 }
130
131 } // namespace kernels
132 } // namespace luci_interpreter