2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef __NNFW_CKER_COMPARISON_H__
19 #define __NNFW_CKER_COMPARISON_H__
21 #include "cker/Shape.h"
22 #include "cker/Types.h"
23 #include "cker/Utils.h"
30 template <typename T> inline bool EqualFn(T lhs, T rhs) { return lhs == rhs; }
31 template <typename T> inline bool NotEqualFn(T lhs, T rhs) { return lhs != rhs; }
32 template <typename T> inline bool GreaterFn(T lhs, T rhs) { return lhs > rhs; }
33 template <typename T> inline bool GreaterEqualFn(T lhs, T rhs) { return lhs >= rhs; }
34 template <typename T> inline bool LessFn(T lhs, T rhs) { return lhs < rhs; }
35 template <typename T> inline bool LessEqualFn(T lhs, T rhs) { return lhs <= rhs; }
37 template <typename T> using ComparisonFn = bool (*)(T, T);
39 template <typename T, ComparisonFn<T> F>
40 inline void ComparisonImpl(const Shape &input1_shape, const T *input1_data,
41 const Shape &input2_shape, const T *input2_data,
42 const Shape &output_shape, bool *output_data)
44 const int64_t flatsize = // number of data....
45 MatchingFlatSize(input1_shape, input2_shape, output_shape);
46 for (int64_t i = 0; i < flatsize; ++i)
48 output_data[i] = F(input1_data[i], input2_data[i]);
52 template <ComparisonFn<float> F>
53 inline void Comparison(const Shape &input1_shape, const float *input1_data,
54 const Shape &input2_shape, const float *input2_data,
55 const Shape &output_shape, bool *output_data)
57 ComparisonImpl<float, F>(input1_shape, input1_data, input2_shape, input2_data, output_shape,
61 template <typename T, ComparisonFn<int32_t> F>
62 inline void ComparisonWithScaling(ComparisonParams ¶ms, const Shape &input1_shape,
63 const T *input1_data, const Shape &input2_shape,
64 const T *input2_data, const Shape &output_shape,
67 int left_shift = params.left_shift;
68 int32_t input1_offset = params.input1_offset;
69 int32_t input1_multiplier = params.input1_multiplier;
70 int input1_shift = params.input1_shift;
71 int32_t input2_offset = params.input2_offset;
72 int32_t input2_multiplier = params.input2_multiplier;
73 int input2_shift = params.input2_shift;
74 const int64_t flatsize = MatchingFlatSize(input1_shape, input2_shape, output_shape);
75 for (int64_t i = 0; i < flatsize; ++i)
77 const int32_t input1_val = input1_offset + input1_data[i];
78 const int32_t input2_val = input2_offset + input2_data[i];
79 const int32_t shifted_input1_val = input1_val * (1 << left_shift);
80 const int32_t shifted_input2_val = input2_val * (1 << left_shift);
81 const int32_t scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
82 shifted_input1_val, input1_multiplier, input1_shift);
83 const int32_t scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
84 shifted_input2_val, input2_multiplier, input2_shift);
85 output_data[i] = F(scaled_input1_val, scaled_input2_val);
89 template <typename T, ComparisonFn<T> F>
91 BroadcastComparison4DSlowImpl(const Shape &unextended_input1_shape, const T *input1_data,
92 const Shape &unextended_input2_shape, const T *input2_data,
93 const Shape &unextended_output_shape, bool *output_data)
95 assert(unextended_input1_shape.DimensionsCount() <= 4);
96 assert(unextended_input2_shape.DimensionsCount() <= 4);
97 assert(unextended_output_shape.DimensionsCount() <= 4);
98 const Shape output_shape = Shape::ExtendedShape(4, unextended_output_shape);
100 NdArrayDesc<4> desc1;
101 NdArrayDesc<4> desc2;
102 NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, unextended_input2_shape, &desc1,
105 for (int b = 0; b < output_shape.Dims(0); ++b)
107 for (int y = 0; y < output_shape.Dims(1); ++y)
109 for (int x = 0; x < output_shape.Dims(2); ++x)
111 for (int c = 0; c < output_shape.Dims(3); ++c)
113 output_data[Offset(output_shape, b, y, x, c)] =
114 F(input1_data[SubscriptToIndex(desc1, b, y, x, c)],
115 input2_data[SubscriptToIndex(desc2, b, y, x, c)]);
122 template <typename T, ComparisonFn<T> F>
123 inline void BroadcastComparison4DSlow(const Shape &input1_shape, const T *input1_data,
124 const Shape &input2_shape, const T *input2_data,
125 const Shape &output_shape, bool *output_data)
127 BroadcastComparison4DSlowImpl<T, F>(input1_shape, input1_data, input2_shape, input2_data,
128 output_shape, output_data);
131 template <typename T, ComparisonFn<int32_t> F>
132 inline void BroadcastComparison4DSlowWithScaling(ComparisonParams ¶ms,
133 const Shape &input1_shape, const T *input1_data,
134 const Shape &input2_shape, const T *input2_data,
135 const Shape &output_shape, bool *output_data)
137 assert(input1_shape.DimensionsCount() <= 4);
138 assert(input2_shape.DimensionsCount() <= 4);
139 assert(output_shape.DimensionsCount() <= 4);
141 int left_shift = params.left_shift;
142 int32_t input1_offset = params.input1_offset;
143 int32_t input1_multiplier = params.input1_multiplier;
144 int input1_shift = params.input1_shift;
145 int32_t input2_offset = params.input2_offset;
146 int32_t input2_multiplier = params.input2_multiplier;
147 int input2_shift = params.input2_shift;
149 NdArrayDesc<4> desc1;
150 NdArrayDesc<4> desc2;
151 NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, &desc2);
153 for (int b = 0; b < output_shape.Dims(0); ++b)
155 for (int y = 0; y < output_shape.Dims(1); ++y)
157 for (int x = 0; x < output_shape.Dims(2); ++x)
159 for (int c = 0; c < output_shape.Dims(3); ++c)
161 const int32_t input1_val =
162 input1_offset + input1_data[SubscriptToIndex(desc1, b, y, x, c)];
163 const int32_t input2_val =
164 input2_offset + input2_data[SubscriptToIndex(desc2, b, y, x, c)];
165 const int32_t shifted_input1_val = input1_val * (1 << left_shift);
166 const int32_t shifted_input2_val = input2_val * (1 << left_shift);
167 const int32_t scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
168 shifted_input1_val, input1_multiplier, input1_shift);
169 const int32_t scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp(
170 shifted_input2_val, input2_multiplier, input2_shift);
171 output_data[Offset(output_shape, b, y, x, c)] = F(scaled_input1_val, scaled_input2_val);
178 #define TFLITE_COMPARISON_OP(name) \
179 template <typename T> \
180 inline void name(const Shape &input1_shape, const T *input1_data, const Shape &input2_shape, \
181 const T *input2_data, const Shape &output_shape, bool *output_data) \
183 Comparison<name##Fn>(input1_shape, input1_data, input2_shape, input2_data, output_shape, \
186 template <typename T> \
187 inline void name##NoScaling(const Shape &input1_shape, const T *input1_data, \
188 const Shape &input2_shape, const T *input2_data, \
189 const Shape &output_shape, bool *output_data) \
191 ComparisonImpl<T, name##Fn>(input1_shape, input1_data, input2_shape, input2_data, \
192 output_shape, output_data); \
194 template <typename T> \
195 inline void name##WithScaling(ComparisonParams ¶ms, const Shape &input1_shape, \
196 const T *input1_data, const Shape &input2_shape, \
197 const T *input2_data, const Shape &output_shape, \
200 ComparisonWithScaling<T, name##Fn>(params, input1_shape, input1_data, input2_shape, \
201 input2_data, output_shape, output_data); \
203 template <typename T> \
204 inline void Broadcast4DSlow##name##NoScaling(const Shape &input1_shape, const T *input1_data, \
205 const Shape &input2_shape, const T *input2_data, \
206 const Shape &output_shape, bool *output_data) \
208 BroadcastComparison4DSlowImpl<T, name##Fn>(input1_shape, input1_data, input2_shape, \
209 input2_data, output_shape, output_data); \
211 template <typename T> \
212 inline void Broadcast4DSlow##name(const Shape &input1_shape, const T *input1_data, \
213 const Shape &input2_shape, const T *input2_data, \
214 const Shape &output_shape, bool *output_data) \
216 BroadcastComparison4DSlow<T, name##Fn>(input1_shape, input1_data, input2_shape, input2_data, \
217 output_shape, output_data); \
219 template <typename T> \
220 inline void Broadcast4DSlow##name##WithScaling(ComparisonParams ¶ms, \
221 const Shape &input1_shape, const T *input1_data, \
222 const Shape &input2_shape, const T *input2_data, \
223 const Shape &output_shape, bool *output_data) \
225 BroadcastComparison4DSlowWithScaling<T, name##Fn>( \
226 params, input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data); \
229 TFLITE_COMPARISON_OP(Equal);
230 TFLITE_COMPARISON_OP(NotEqual);
231 TFLITE_COMPARISON_OP(Greater);
232 TFLITE_COMPARISON_OP(GreaterEqual);
233 TFLITE_COMPARISON_OP(Less);
234 TFLITE_COMPARISON_OP(LessEqual);
235 #undef TFLITE_COMPARISON_OP