Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / pal / common / PALComparisons.h
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef LUCI_INTERPRETER_PAL_COMPARISONS_H
19 #define LUCI_INTERPRETER_PAL_COMPARISONS_H
20
21 #include "Params.h"
22 #include "ProcessBroadcastShapes.h"
23 #include "PALUtils.h"
24
25 namespace luci_interpreter_pal
26 {
27 namespace
28 {
29
30 struct BroadcastComparison4DSlowCommon
31 {
32   const luci_interpreter::RuntimeShape output_shape;
33   NdArrayDesc<4> desc1;
34   NdArrayDesc<4> desc2;
35 };
36
37 inline BroadcastComparison4DSlowCommon
38 BroadcastComparison4DSlowPreprocess(const luci_interpreter::RuntimeShape &unextended_input1_shape,
39                                     const luci_interpreter::RuntimeShape &unextended_input2_shape,
40                                     const luci_interpreter::RuntimeShape &unextended_output_shape)
41 {
42   NdArrayDesc<4> desc1;
43   NdArrayDesc<4> desc2;
44   NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, unextended_input2_shape, &desc1,
45                                       &desc2);
46   return {luci_interpreter::RuntimeShape::extendedShape(4, unextended_output_shape), desc1, desc2};
47 }
48
49 } // namespace
50
51 template <typename T> inline bool LessFn(T lhs, T rhs) { return lhs < rhs; }
52 template <typename T> inline bool LessEqualFn(T lhs, T rhs) { return lhs <= rhs; }
53 template <typename T> inline bool EqualFn(T lhs, T rhs) { return lhs == rhs; }
54 template <typename T> inline bool GreaterFn(T lhs, T rhs) { return lhs > rhs; }
55 template <typename T> inline bool GreaterEqualFn(T lhs, T rhs) { return lhs >= rhs; }
56 template <typename T> inline bool NotEqualFn(T lhs, T rhs) { return lhs != rhs; }
57
58 template <typename T>
59 inline void ComparisonNoScaling(const int64_t flat_size, const T *input1_data, const T *input2_data,
60                                 bool *output_data, bool F(T, T))
61 {
62   for (int64_t i = 0; i < flat_size; ++i)
63   {
64     output_data[i] = F(input1_data[i], input2_data[i]);
65   }
66 }
67
68 template <typename T>
69 inline void BroadcastComparison4DSlowWithScaling(
70   const ComparisonParams &op_params, const luci_interpreter::RuntimeShape &unextended_input1_shape,
71   const T *input1_data, const luci_interpreter::RuntimeShape &unextended_input2_shape,
72   const T *input2_data, const luci_interpreter::RuntimeShape &unextended_output_shape,
73   bool *output_data, bool F(T, T))
74 {
75   const BroadcastComparison4DSlowCommon dims = BroadcastComparison4DSlowPreprocess(
76     unextended_input1_shape, unextended_input2_shape, unextended_output_shape);
77
78   int left_shift = op_params.left_shift;
79   int32_t input1_offset = op_params.input1_offset;
80   int32_t input1_multiplier = op_params.input1_multiplier;
81   int input1_shift = op_params.input1_shift;
82   int32_t input2_offset = op_params.input2_offset;
83   int32_t input2_multiplier = op_params.input2_multiplier;
84   int input2_shift = op_params.input2_shift;
85
86   for (int b = 0; b < dims.output_shape.dims(0); ++b)
87   {
88     for (int y = 0; y < dims.output_shape.dims(1); ++y)
89     {
90       for (int x = 0; x < dims.output_shape.dims(2); ++x)
91       {
92         for (int c = 0; c < dims.output_shape.dims(3); ++c)
93         {
94           const int32_t input1_val =
95             input1_offset + input1_data[subscriptToIndex(dims.desc1, b, y, x, c)];
96           const int32_t input2_val =
97             input2_offset + input2_data[subscriptToIndex(dims.desc2, b, y, x, c)];
98           const int32_t shifted_input1_val = input1_val * (1 << left_shift);
99           const int32_t shifted_input2_val = input2_val * (1 << left_shift);
100           const int32_t scaled_input1_val = multiplyByQuantizedMultiplierSmallerThanOneExp(
101             shifted_input1_val, input1_multiplier, input1_shift);
102           const int32_t scaled_input2_val = multiplyByQuantizedMultiplierSmallerThanOneExp(
103             shifted_input2_val, input2_multiplier, input2_shift);
104
105           const int output_data_offset =
106             ((b * dims.output_shape.dims(1) + y) * dims.output_shape.dims(2) + x) *
107               dims.output_shape.dims(3) +
108             c;
109           output_data[output_data_offset] = F(scaled_input1_val, scaled_input2_val);
110         }
111       }
112     }
113   }
114 }
115
116 template <typename T>
117 inline void ComparisonWithScaling(const ComparisonParams &op_params, const int64_t flat_size,
118                                   const T *input1_data, const T *input2_data, bool *output_data,
119                                   bool F(T, T))
120 {
121   int left_shift = op_params.left_shift;
122   int32_t input1_offset = op_params.input1_offset;
123   int32_t input1_multiplier = op_params.input1_multiplier;
124   int input1_shift = op_params.input1_shift;
125   int32_t input2_offset = op_params.input2_offset;
126   int32_t input2_multiplier = op_params.input2_multiplier;
127   int input2_shift = op_params.input2_shift;
128
129   for (int64_t i = 0; i < flat_size; ++i)
130   {
131     const int32_t input1_val = input1_offset + input1_data[i];
132     const int32_t input2_val = input2_offset + input2_data[i];
133     const int32_t shifted_input1_val = input1_val * (1 << left_shift);
134     const int32_t shifted_input2_val = input2_val * (1 << left_shift);
135     const int32_t scaled_input1_val = multiplyByQuantizedMultiplierSmallerThanOneExp(
136       shifted_input1_val, input1_multiplier, input1_shift);
137     const int32_t scaled_input2_val = multiplyByQuantizedMultiplierSmallerThanOneExp(
138       shifted_input2_val, input2_multiplier, input2_shift);
139     output_data[i] = F(scaled_input1_val, scaled_input2_val);
140   }
141 }
142
143 template <typename T>
144 inline void BroadcastComparison4DSlowNoScaling(
145   const ComparisonParams &op_params, const luci_interpreter::RuntimeShape &unextended_input1_shape,
146   const T *input1_data, const luci_interpreter::RuntimeShape &unextended_input2_shape,
147   const T *input2_data, const luci_interpreter::RuntimeShape &unextended_output_shape,
148   bool *output_data, bool F(T, T))
149 {
150   const BroadcastComparison4DSlowCommon dims = BroadcastComparison4DSlowPreprocess(
151     unextended_input1_shape, unextended_input2_shape, unextended_output_shape);
152
153   for (int b = 0; b < dims.output_shape.dims(0); ++b)
154   {
155     for (int y = 0; y < dims.output_shape.dims(1); ++y)
156     {
157       for (int x = 0; x < dims.output_shape.dims(2); ++x)
158       {
159         for (int c = 0; c < dims.output_shape.dims(3); ++c)
160         {
161           const int output_data_offset =
162             ((b * dims.output_shape.dims(1) + y) * dims.output_shape.dims(2) + x) *
163               dims.output_shape.dims(3) +
164             c;
165           output_data[output_data_offset] =
166             F(input1_data[subscriptToIndex(dims.desc1, b, y, x, c)],
167               input2_data[subscriptToIndex(dims.desc2, b, y, x, c)]);
168         }
169       }
170     }
171   }
172 }
173
174 } // namespace luci_interpreter_pal
175
176 #endif // LUCI_INTERPRETER_PAL_COMPARISONS_H