2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
16 #include "CompareLayer.h"
18 #include "OperationUtils.h"
20 #include <cker/operation/Comparison.h>
21 using namespace nnfw::cker;
34 using OpType = onert::ir::operation::Comparison::ComparisonType;
35 using namespace onert::backend::cpu;
38 void compareQuant8(const IPortableTensor *lhs, const IPortableTensor *rhs, IPortableTensor *output,
41 nnfw::cker::ComparisonParams params;
42 params.left_shift = 8;
43 params.input1_offset = -lhs->data_offset();
44 params.input2_offset = -rhs->data_offset();
45 const double norm_max_scale =
46 2 * std::max(std::abs(lhs->data_scale()), std::abs(rhs->data_scale()));
47 const double adjusted_lhs_scale = lhs->data_scale() / norm_max_scale;
48 const double adjusted_rhs_scale = rhs->data_scale() / norm_max_scale;
49 QuantizeMultiplierSmallerThanOneExp(adjusted_lhs_scale, ¶ms.input1_multiplier,
50 ¶ms.input1_shift);
51 QuantizeMultiplierSmallerThanOneExp(adjusted_rhs_scale, ¶ms.input2_multiplier,
52 ¶ms.input2_shift);
53 params.is_broadcast = !HaveSameShapes(lhs, rhs);
55 if (params.is_broadcast)
60 Broadcast4DSlowEqualWithScaling(
61 params, getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
62 getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
63 getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
65 case OpType::NotEqual:
66 Broadcast4DSlowNotEqualWithScaling(
67 params, getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
68 getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
69 getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
72 Broadcast4DSlowGreaterWithScaling(
73 params, getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
74 getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
75 getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
77 case OpType::GreaterEqual:
78 Broadcast4DSlowGreaterEqualWithScaling(
79 params, getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
80 getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
81 getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
84 Broadcast4DSlowLessWithScaling(
85 params, getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
86 getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
87 getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
89 case OpType::LessEqual:
90 Broadcast4DSlowLessEqualWithScaling(
91 params, getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
92 getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
93 getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
96 throw std::runtime_error{"Invalid OpType for CompareLayer"};
99 else // if (requires_broadcast == false)
104 EqualWithScaling(params, getExtendedTensorShape(lhs),
105 reinterpret_cast<const T *>(lhs->buffer()), getExtendedTensorShape(rhs),
106 reinterpret_cast<const T *>(rhs->buffer()), getExtendedTensorShape(output),
107 reinterpret_cast<bool *>(output->buffer()));
109 case OpType::NotEqual:
111 params, getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
112 getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
113 getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
115 case OpType::Greater:
117 params, getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
118 getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
119 getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
121 case OpType::GreaterEqual:
122 GreaterEqualWithScaling(
123 params, getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
124 getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
125 getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
128 LessWithScaling(params, getExtendedTensorShape(lhs),
129 reinterpret_cast<const T *>(lhs->buffer()), getExtendedTensorShape(rhs),
130 reinterpret_cast<const T *>(rhs->buffer()), getExtendedTensorShape(output),
131 reinterpret_cast<bool *>(output->buffer()));
133 case OpType::LessEqual:
134 LessEqualWithScaling(
135 params, getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
136 getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
137 getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
140 throw std::runtime_error{"Invalid OpType for CompareLayer"};
146 template <typename T>
147 void compareScalar(const IPortableTensor *lhs, const IPortableTensor *rhs, IPortableTensor *output,
150 bool requires_broadcast = !HaveSameShapes(lhs, rhs);
152 if (requires_broadcast)
157 Broadcast4DSlowEqual(
158 getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
159 getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
160 getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
162 case OpType::NotEqual:
163 Broadcast4DSlowNotEqual(
164 getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
165 getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
166 getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
168 case OpType::Greater:
169 Broadcast4DSlowGreater(
170 getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
171 getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
172 getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
174 case OpType::GreaterEqual:
175 Broadcast4DSlowGreaterEqual(
176 getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
177 getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
178 getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
181 Broadcast4DSlowLess(getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
182 getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
183 getExtendedTensorShape(output),
184 reinterpret_cast<bool *>(output->buffer()));
186 case OpType::LessEqual:
187 Broadcast4DSlowLessEqual(
188 getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
189 getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
190 getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
193 throw std::runtime_error{"Invalid OpType for CompareLayer"};
196 else // if (requires_broadcast == false)
201 EqualNoScaling(getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
202 getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
203 getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
205 case OpType::NotEqual:
206 NotEqualNoScaling(getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
207 getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
208 getExtendedTensorShape(output),
209 reinterpret_cast<bool *>(output->buffer()));
211 case OpType::Greater:
212 GreaterNoScaling(getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
213 getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
214 getExtendedTensorShape(output),
215 reinterpret_cast<bool *>(output->buffer()));
217 case OpType::GreaterEqual:
218 GreaterEqualNoScaling(
219 getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
220 getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
221 getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
224 LessNoScaling(getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
225 getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
226 getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
228 case OpType::LessEqual:
229 LessEqualNoScaling(getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
230 getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
231 getExtendedTensorShape(output),
232 reinterpret_cast<bool *>(output->buffer()));
235 throw std::runtime_error{"Invalid OpType for CompareLayer"};
242 CompareLayer::CompareLayer()
243 : _lhs(nullptr), _rhs(nullptr), _output(nullptr),
244 _op_type(ir::operation::Comparison::ComparisonType::Equal)
249 void CompareLayer::configure(const IPortableTensor *lhs, const IPortableTensor *rhs,
250 const OpType op_type, IPortableTensor *output)
258 void CompareLayer::run()
260 if (_lhs->data_type() == OperandType::FLOAT32)
262 compareScalar<float>(_lhs, _rhs, _output, _op_type);
264 else if (_lhs->data_type() == OperandType::INT32)
266 compareScalar<int32_t>(_lhs, _rhs, _output, _op_type);
268 else if (_lhs->data_type() == OperandType::INT64)
270 compareScalar<int64_t>(_lhs, _rhs, _output, _op_type);
272 else if (_lhs->data_type() == OperandType::BOOL8)
274 compareScalar<uint8_t>(_lhs, _rhs, _output, _op_type);
276 else if (_lhs->data_type() == OperandType::QUANT_UINT8_ASYMM)
278 compareQuant8<uint8_t>(_lhs, _rhs, _output, _op_type);
282 throw std::runtime_error{"Compare: unsupported data type"};
288 } // namespace backend