47eb6034cc5fa3ba41f53b84ee461f79841594c0
[platform/core/ml/nnfw.git] / compute / cker / include / cker / operation / Comparison.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_CKER_COMPARISON_H__
19 #define __NNFW_CKER_COMPARISON_H__
20
21 #include "cker/Shape.h"
22 #include "cker/Types.h"
23 #include "cker/Utils.h"
24
25 namespace nnfw
26 {
27 namespace cker
28 {
29
30 template <typename T> inline bool EqualFn(T lhs, T rhs) { return lhs == rhs; }
31 template <typename T> inline bool NotEqualFn(T lhs, T rhs) { return lhs != rhs; }
32 template <typename T> inline bool GreaterFn(T lhs, T rhs) { return lhs > rhs; }
33 template <typename T> inline bool GreaterEqualFn(T lhs, T rhs) { return lhs >= rhs; }
34 template <typename T> inline bool LessFn(T lhs, T rhs) { return lhs < rhs; }
35 template <typename T> inline bool LessEqualFn(T lhs, T rhs) { return lhs <= rhs; }
36
37 template <typename T> using ComparisonFn = bool (*)(T, T);
38
39 template <typename T, ComparisonFn<T> F>
40 inline void ComparisonImpl(const Shape &input1_shape, const T *input1_data,
41                            const Shape &input2_shape, const T *input2_data,
42                            const Shape &output_shape, bool *output_data)
43 {
44   const int64_t flatsize = // number of data....
45       MatchingFlatSize(input1_shape, input2_shape, output_shape);
46   for (int64_t i = 0; i < flatsize; ++i)
47   {
48     output_data[i] = F(input1_data[i], input2_data[i]);
49   }
50 }
51
52 template <ComparisonFn<float> F>
53 inline void Comparison(const Shape &input1_shape, const float *input1_data,
54                        const Shape &input2_shape, const float *input2_data,
55                        const Shape &output_shape, bool *output_data)
56 {
57   ComparisonImpl<float, F>(input1_shape, input1_data, input2_shape, input2_data, output_shape,
58                            output_data);
59 }
60
61 template <typename T, ComparisonFn<int32_t> F>
62 inline void ComparisonWithScaling(ComparisonParams &params, const Shape &input1_shape,
63                                   const T *input1_data, const Shape &input2_shape,
64                                   const T *input2_data, const Shape &output_shape,
65                                   bool *output_data)
66 {
67   int left_shift = params.left_shift;
68   int32_t input1_offset = params.input1_offset;
69   int32_t input1_multiplier = params.input1_multiplier;
70   int input1_shift = params.input1_shift;
71   int32_t input2_offset = params.input2_offset;
72   int32_t input2_multiplier = params.input2_multiplier;
73   int input2_shift = params.input2_shift;
74   const int64_t flatsize = MatchingFlatSize(input1_shape, input2_shape, output_shape);
75   for (int64_t i = 0; i < flatsize; ++i)
76   {
77     const int32_t input1_val = input1_offset + input1_data[i];
78     const int32_t input2_val = input2_offset + input2_data[i];
79     const int32_t shifted_input1_val = input1_val * (1 << left_shift);
80     const int32_t shifted_input2_val = input2_val * (1 << left_shift);
81     const int32_t scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
82         shifted_input1_val, input1_multiplier, input1_shift);
83     const int32_t scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
84         shifted_input2_val, input2_multiplier, input2_shift);
85     output_data[i] = F(scaled_input1_val, scaled_input2_val);
86   }
87 }
88
89 template <typename T, ComparisonFn<T> F>
90 inline void
91 BroadcastComparison4DSlowImpl(const Shape &unextended_input1_shape, const T *input1_data,
92                               const Shape &unextended_input2_shape, const T *input2_data,
93                               const Shape &unextended_output_shape, bool *output_data)
94 {
95   assert(unextended_input1_shape.DimensionsCount() <= 4);
96   assert(unextended_input2_shape.DimensionsCount() <= 4);
97   assert(unextended_output_shape.DimensionsCount() <= 4);
98   const Shape output_shape = Shape::ExtendedShape(4, unextended_output_shape);
99
100   NdArrayDesc<4> desc1;
101   NdArrayDesc<4> desc2;
102   NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, unextended_input2_shape, &desc1,
103                                       &desc2);
104
105   for (int b = 0; b < output_shape.Dims(0); ++b)
106   {
107     for (int y = 0; y < output_shape.Dims(1); ++y)
108     {
109       for (int x = 0; x < output_shape.Dims(2); ++x)
110       {
111         for (int c = 0; c < output_shape.Dims(3); ++c)
112         {
113           output_data[Offset(output_shape, b, y, x, c)] =
114               F(input1_data[SubscriptToIndex(desc1, b, y, x, c)],
115                 input2_data[SubscriptToIndex(desc2, b, y, x, c)]);
116         }
117       }
118     }
119   }
120 }
121
122 template <typename T, ComparisonFn<T> F>
123 inline void BroadcastComparison4DSlow(const Shape &input1_shape, const T *input1_data,
124                                       const Shape &input2_shape, const T *input2_data,
125                                       const Shape &output_shape, bool *output_data)
126 {
127   BroadcastComparison4DSlowImpl<T, F>(input1_shape, input1_data, input2_shape, input2_data,
128                                       output_shape, output_data);
129 }
130
131 template <typename T, ComparisonFn<int32_t> F>
132 inline void BroadcastComparison4DSlowWithScaling(ComparisonParams &params,
133                                                  const Shape &input1_shape, const T *input1_data,
134                                                  const Shape &input2_shape, const T *input2_data,
135                                                  const Shape &output_shape, bool *output_data)
136 {
137   assert(input1_shape.DimensionsCount() <= 4);
138   assert(input2_shape.DimensionsCount() <= 4);
139   assert(output_shape.DimensionsCount() <= 4);
140
141   int left_shift = params.left_shift;
142   int32_t input1_offset = params.input1_offset;
143   int32_t input1_multiplier = params.input1_multiplier;
144   int input1_shift = params.input1_shift;
145   int32_t input2_offset = params.input2_offset;
146   int32_t input2_multiplier = params.input2_multiplier;
147   int input2_shift = params.input2_shift;
148
149   NdArrayDesc<4> desc1;
150   NdArrayDesc<4> desc2;
151   NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, &desc2);
152
153   for (int b = 0; b < output_shape.Dims(0); ++b)
154   {
155     for (int y = 0; y < output_shape.Dims(1); ++y)
156     {
157       for (int x = 0; x < output_shape.Dims(2); ++x)
158       {
159         for (int c = 0; c < output_shape.Dims(3); ++c)
160         {
161           const int32_t input1_val =
162               input1_offset + input1_data[SubscriptToIndex(desc1, b, y, x, c)];
163           const int32_t input2_val =
164               input2_offset + input2_data[SubscriptToIndex(desc2, b, y, x, c)];
165           const int32_t shifted_input1_val = input1_val * (1 << left_shift);
166           const int32_t shifted_input2_val = input2_val * (1 << left_shift);
167           const int32_t scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
168               shifted_input1_val, input1_multiplier, input1_shift);
169           const int32_t scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
170               shifted_input2_val, input2_multiplier, input2_shift);
171           output_data[Offset(output_shape, b, y, x, c)] = F(scaled_input1_val, scaled_input2_val);
172         }
173       }
174     }
175   }
176 }
177
178 #define TFLITE_COMPARISON_OP(name)                                                                \
179   template <typename T>                                                                           \
180   inline void name(const Shape &input1_shape, const T *input1_data, const Shape &input2_shape,    \
181                    const T *input2_data, const Shape &output_shape, bool *output_data)            \
182   {                                                                                               \
183     Comparison<name##Fn>(input1_shape, input1_data, input2_shape, input2_data, output_shape,      \
184                          output_data);                                                            \
185   }                                                                                               \
186   template <typename T>                                                                           \
187   inline void name##NoScaling(const Shape &input1_shape, const T *input1_data,                    \
188                               const Shape &input2_shape, const T *input2_data,                    \
189                               const Shape &output_shape, bool *output_data)                       \
190   {                                                                                               \
191     ComparisonImpl<T, name##Fn>(input1_shape, input1_data, input2_shape, input2_data,             \
192                                 output_shape, output_data);                                       \
193   }                                                                                               \
194   template <typename T>                                                                           \
195   inline void name##WithScaling(ComparisonParams &params, const Shape &input1_shape,              \
196                                 const T *input1_data, const Shape &input2_shape,                  \
197                                 const T *input2_data, const Shape &output_shape,                  \
198                                 bool *output_data)                                                \
199   {                                                                                               \
200     ComparisonWithScaling<T, name##Fn>(params, input1_shape, input1_data, input2_shape,           \
201                                        input2_data, output_shape, output_data);                   \
202   }                                                                                               \
203   template <typename T>                                                                           \
204   inline void Broadcast4DSlow##name##NoScaling(const Shape &input1_shape, const T *input1_data,   \
205                                                const Shape &input2_shape, const T *input2_data,   \
206                                                const Shape &output_shape, bool *output_data)      \
207   {                                                                                               \
208     BroadcastComparison4DSlowImpl<T, name##Fn>(input1_shape, input1_data, input2_shape,           \
209                                                input2_data, output_shape, output_data);           \
210   }                                                                                               \
211   template <typename T>                                                                           \
212   inline void Broadcast4DSlow##name(const Shape &input1_shape, const T *input1_data,              \
213                                     const Shape &input2_shape, const T *input2_data,              \
214                                     const Shape &output_shape, bool *output_data)                 \
215   {                                                                                               \
216     BroadcastComparison4DSlow<T, name##Fn>(input1_shape, input1_data, input2_shape, input2_data,  \
217                                            output_shape, output_data);                            \
218   }                                                                                               \
219   template <typename T>                                                                           \
220   inline void Broadcast4DSlow##name##WithScaling(ComparisonParams &params,                        \
221                                                  const Shape &input1_shape, const T *input1_data, \
222                                                  const Shape &input2_shape, const T *input2_data, \
223                                                  const Shape &output_shape, bool *output_data)    \
224   {                                                                                               \
225     BroadcastComparison4DSlowWithScaling<T, name##Fn>(                                            \
226         params, input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data); \
227   }
228
229 TFLITE_COMPARISON_OP(Equal);
230 TFLITE_COMPARISON_OP(NotEqual);
231 TFLITE_COMPARISON_OP(Greater);
232 TFLITE_COMPARISON_OP(GreaterEqual);
233 TFLITE_COMPARISON_OP(Less);
234 TFLITE_COMPARISON_OP(LessEqual);
235 #undef TFLITE_COMPARISON_OP
236
237 } // namespace cker
238 } // namespace nnfw
239
240 #endif