Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / runtime / onert / backend / cpu / ops / CompareLayer.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "CompareLayer.h"
17
18 #include "OperationUtils.h"
19
20 #include <cker/operation/Comparison.h>
21 using namespace nnfw::cker;
22 namespace onert
23 {
24 namespace backend
25 {
26 namespace cpu
27 {
28 namespace ops
29 {
30
31 namespace
32 {
33
34 using OpType = onert::ir::operation::Comparison::ComparisonType;
35 using namespace onert::backend::cpu;
36
37 template <typename T>
38 void compareQuant8(const IPortableTensor *lhs, const IPortableTensor *rhs, IPortableTensor *output,
39                    OpType op_type)
40 {
41   nnfw::cker::ComparisonParams params;
42   params.left_shift = 8;
43   params.input1_offset = -lhs->data_offset();
44   params.input2_offset = -rhs->data_offset();
45   const double norm_max_scale =
46       2 * std::max(std::abs(lhs->data_scale()), std::abs(rhs->data_scale()));
47   const double adjusted_lhs_scale = lhs->data_scale() / norm_max_scale;
48   const double adjusted_rhs_scale = rhs->data_scale() / norm_max_scale;
49   QuantizeMultiplierSmallerThanOneExp(adjusted_lhs_scale, &params.input1_multiplier,
50                                       &params.input1_shift);
51   QuantizeMultiplierSmallerThanOneExp(adjusted_rhs_scale, &params.input2_multiplier,
52                                       &params.input2_shift);
53   params.is_broadcast = !HaveSameShapes(lhs, rhs);
54
55   if (params.is_broadcast)
56   {
57     switch (op_type)
58     {
59       case OpType::Equal:
60         Broadcast4DSlowEqualWithScaling(
61             params, getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
62             getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
63             getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
64         break;
65       case OpType::NotEqual:
66         Broadcast4DSlowNotEqualWithScaling(
67             params, getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
68             getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
69             getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
70         break;
71       case OpType::Greater:
72         Broadcast4DSlowGreaterWithScaling(
73             params, getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
74             getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
75             getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
76         break;
77       case OpType::GreaterEqual:
78         Broadcast4DSlowGreaterEqualWithScaling(
79             params, getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
80             getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
81             getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
82         break;
83       case OpType::Less:
84         Broadcast4DSlowLessWithScaling(
85             params, getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
86             getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
87             getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
88         break;
89       case OpType::LessEqual:
90         Broadcast4DSlowLessEqualWithScaling(
91             params, getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
92             getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
93             getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
94         break;
95       default:
96         throw std::runtime_error{"Invalid OpType for CompareLayer"};
97     }
98   }
99   else // if (requires_broadcast == false)
100   {
101     switch (op_type)
102     {
103       case OpType::Equal:
104         EqualWithScaling(params, getExtendedTensorShape(lhs),
105                          reinterpret_cast<const T *>(lhs->buffer()), getExtendedTensorShape(rhs),
106                          reinterpret_cast<const T *>(rhs->buffer()), getExtendedTensorShape(output),
107                          reinterpret_cast<bool *>(output->buffer()));
108         break;
109       case OpType::NotEqual:
110         NotEqualWithScaling(
111             params, getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
112             getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
113             getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
114         break;
115       case OpType::Greater:
116         GreaterWithScaling(
117             params, getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
118             getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
119             getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
120         break;
121       case OpType::GreaterEqual:
122         GreaterEqualWithScaling(
123             params, getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
124             getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
125             getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
126         break;
127       case OpType::Less:
128         LessWithScaling(params, getExtendedTensorShape(lhs),
129                         reinterpret_cast<const T *>(lhs->buffer()), getExtendedTensorShape(rhs),
130                         reinterpret_cast<const T *>(rhs->buffer()), getExtendedTensorShape(output),
131                         reinterpret_cast<bool *>(output->buffer()));
132         break;
133       case OpType::LessEqual:
134         LessEqualWithScaling(
135             params, getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
136             getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
137             getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
138         break;
139       default:
140         throw std::runtime_error{"Invalid OpType for CompareLayer"};
141     }
142   }
143   return;
144 }
145
146 template <typename T>
147 void compareScalar(const IPortableTensor *lhs, const IPortableTensor *rhs, IPortableTensor *output,
148                    OpType op_type)
149 {
150   bool requires_broadcast = !HaveSameShapes(lhs, rhs);
151
152   if (requires_broadcast)
153   {
154     switch (op_type)
155     {
156       case OpType::Equal:
157         Broadcast4DSlowEqual(
158             getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
159             getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
160             getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
161         break;
162       case OpType::NotEqual:
163         Broadcast4DSlowNotEqual(
164             getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
165             getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
166             getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
167         break;
168       case OpType::Greater:
169         Broadcast4DSlowGreater(
170             getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
171             getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
172             getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
173         break;
174       case OpType::GreaterEqual:
175         Broadcast4DSlowGreaterEqual(
176             getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
177             getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
178             getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
179         break;
180       case OpType::Less:
181         Broadcast4DSlowLess(getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
182                             getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
183                             getExtendedTensorShape(output),
184                             reinterpret_cast<bool *>(output->buffer()));
185         break;
186       case OpType::LessEqual:
187         Broadcast4DSlowLessEqual(
188             getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
189             getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
190             getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
191         break;
192       default:
193         throw std::runtime_error{"Invalid OpType for CompareLayer"};
194     }
195   }
196   else // if (requires_broadcast == false)
197   {
198     switch (op_type)
199     {
200       case OpType::Equal:
201         EqualNoScaling(getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
202                        getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
203                        getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
204         break;
205       case OpType::NotEqual:
206         NotEqualNoScaling(getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
207                           getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
208                           getExtendedTensorShape(output),
209                           reinterpret_cast<bool *>(output->buffer()));
210         break;
211       case OpType::Greater:
212         GreaterNoScaling(getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
213                          getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
214                          getExtendedTensorShape(output),
215                          reinterpret_cast<bool *>(output->buffer()));
216         break;
217       case OpType::GreaterEqual:
218         GreaterEqualNoScaling(
219             getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
220             getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
221             getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
222         break;
223       case OpType::Less:
224         LessNoScaling(getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
225                       getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
226                       getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
227         break;
228       case OpType::LessEqual:
229         LessEqualNoScaling(getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
230                            getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
231                            getExtendedTensorShape(output),
232                            reinterpret_cast<bool *>(output->buffer()));
233         break;
234       default:
235         throw std::runtime_error{"Invalid OpType for CompareLayer"};
236     }
237   }
238   return;
239 }
240 } // namespace
241
242 CompareLayer::CompareLayer()
243     : _lhs(nullptr), _rhs(nullptr), _output(nullptr),
244       _op_type(ir::operation::Comparison::ComparisonType::Equal)
245 {
246   // DO NOTHING
247 }
248
249 void CompareLayer::configure(const IPortableTensor *lhs, const IPortableTensor *rhs,
250                              const OpType op_type, IPortableTensor *output)
251 {
252   _lhs = lhs;
253   _rhs = rhs;
254   _op_type = op_type;
255   _output = output;
256 }
257
258 void CompareLayer::run()
259 {
260   if (_lhs->data_type() == OperandType::FLOAT32)
261   {
262     compareScalar<float>(_lhs, _rhs, _output, _op_type);
263   }
264   else if (_lhs->data_type() == OperandType::INT32)
265   {
266     compareScalar<int32_t>(_lhs, _rhs, _output, _op_type);
267   }
268   else if (_lhs->data_type() == OperandType::INT64)
269   {
270     compareScalar<int64_t>(_lhs, _rhs, _output, _op_type);
271   }
272   else if (_lhs->data_type() == OperandType::BOOL8)
273   {
274     compareScalar<uint8_t>(_lhs, _rhs, _output, _op_type);
275   }
276   else if (_lhs->data_type() == OperandType::QUANT_UINT8_ASYMM)
277   {
278     compareQuant8<uint8_t>(_lhs, _rhs, _output, _op_type);
279   }
280   else
281   {
282     throw std::runtime_error{"Compare: unsupported data type"};
283   }
284 }
285
286 } // namespace ops
287 } // namespace cpu
288 } // namespace backend
289 } // namespace onert