2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
16 #include "CompareLayer.h"
18 #include "OperationUtils.h"
21 #include <cker/operation/Comparison.h>
22 using namespace nnfw::cker;
35 using OpType = onert::ir::operation::Comparison::ComparisonType;
36 using namespace onert::backend::cpu;
38 // Assumes these enum values to be in the order like this
39 static_assert(static_cast<int>(OpType::Equal) == 0, "An OpType value has changed!");
40 static_assert(static_cast<int>(OpType::NotEqual) == 1, "An OpType value has changed!");
41 static_assert(static_cast<int>(OpType::Greater) == 2, "An OpType value has changed!");
42 static_assert(static_cast<int>(OpType::GreaterEqual) == 3, "An OpType value has changed!");
43 static_assert(static_cast<int>(OpType::Less) == 4, "An OpType value has changed!");
44 static_assert(static_cast<int>(OpType::LessEqual) == 5, "An OpType value has changed!");
47 void compareQuant8(const IPortableTensor *lhs, const IPortableTensor *rhs, IPortableTensor *output,
50 nnfw::cker::ComparisonParams params;
51 params.left_shift = 8;
52 params.input1_offset = -lhs->data_offset();
53 params.input2_offset = -rhs->data_offset();
54 const double norm_max_scale =
55 2 * std::max(std::abs(lhs->data_scale()), std::abs(rhs->data_scale()));
56 const double adjusted_lhs_scale = lhs->data_scale() / norm_max_scale;
57 const double adjusted_rhs_scale = rhs->data_scale() / norm_max_scale;
58 QuantizeMultiplierSmallerThanOneExp(adjusted_lhs_scale, ¶ms.input1_multiplier,
59 ¶ms.input1_shift);
60 QuantizeMultiplierSmallerThanOneExp(adjusted_rhs_scale, ¶ms.input2_multiplier,
61 ¶ms.input2_shift);
62 params.is_broadcast = !HaveSameShapes(lhs, rhs);
64 using CompareFunction =
65 void (*)(ComparisonParams & params, const Shape &input1_shape, const T *input1_data,
66 const Shape &input2_shape, const T *input2_data, const Shape &output_shape,
69 static const CompareFunction broadcast_fns[] = {
70 Broadcast4DSlowEqualWithScaling, Broadcast4DSlowNotEqualWithScaling,
71 Broadcast4DSlowGreaterWithScaling, Broadcast4DSlowGreaterEqualWithScaling,
72 Broadcast4DSlowLessWithScaling, Broadcast4DSlowLessEqualWithScaling,
74 static const CompareFunction non_broadcast_fns[] = {
75 EqualWithScaling, NotEqualWithScaling, GreaterWithScaling,
76 GreaterEqualWithScaling, LessWithScaling, LessEqualWithScaling,
79 static_assert(sizeof(broadcast_fns) == sizeof(non_broadcast_fns),
80 "Sizes of broadcast_fns and non_broadcast_fns must match!");
82 auto index = static_cast<int>(op_type);
83 if (index < 0 || index >= static_cast<int>(sizeof(broadcast_fns) / sizeof(broadcast_fns[0])))
84 throw std::runtime_error{"Invalid OpType for CompareLayer"};
86 CompareFunction fn = (params.is_broadcast ? broadcast_fns[index] : non_broadcast_fns[index]);
88 fn(params, getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
89 getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
90 getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
94 void compareScalar(const IPortableTensor *lhs, const IPortableTensor *rhs, IPortableTensor *output,
97 bool requires_broadcast = !HaveSameShapes(lhs, rhs);
99 using CompareFunction =
100 void (*)(const Shape &input1_shape, const T *input1_data, const Shape &input2_shape,
101 const T *input2_data, const Shape &output_shape, bool *output_data);
103 static const CompareFunction broadcast_fns[] = {
104 Broadcast4DSlowEqual, Broadcast4DSlowNotEqual, Broadcast4DSlowGreater,
105 Broadcast4DSlowGreaterEqual, Broadcast4DSlowLess, Broadcast4DSlowLessEqual,
107 static const CompareFunction non_broadcast_fns[] = {
108 EqualNoScaling, NotEqualNoScaling, GreaterNoScaling,
109 GreaterEqualNoScaling, LessNoScaling, LessEqualNoScaling,
112 static_assert(sizeof(broadcast_fns) == sizeof(non_broadcast_fns),
113 "Sizes of broadcast_fns and non_broadcast_fns must match!");
115 auto index = static_cast<int>(op_type);
116 if (index < 0 || index >= static_cast<int>(sizeof(broadcast_fns) / sizeof(broadcast_fns[0])))
117 throw std::runtime_error{"Invalid OpType for CompareLayer"};
119 CompareFunction fn = (requires_broadcast ? broadcast_fns[index] : non_broadcast_fns[index]);
121 fn(getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
122 getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
123 getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
128 CompareLayer::CompareLayer()
129 : _lhs(nullptr), _rhs(nullptr), _output(nullptr),
130 _op_type(ir::operation::Comparison::ComparisonType::Equal)
135 void CompareLayer::configure(const IPortableTensor *lhs, const IPortableTensor *rhs,
136 const OpType op_type, IPortableTensor *output)
144 void CompareLayer::run()
146 if (_lhs->data_type() == OperandType::FLOAT32)
148 compareScalar<float>(_lhs, _rhs, _output, _op_type);
150 else if (_lhs->data_type() == OperandType::INT32)
152 compareScalar<int32_t>(_lhs, _rhs, _output, _op_type);
154 else if (_lhs->data_type() == OperandType::INT64)
156 compareScalar<int64_t>(_lhs, _rhs, _output, _op_type);
158 else if (_lhs->data_type() == OperandType::BOOL8)
160 compareScalar<uint8_t>(_lhs, _rhs, _output, _op_type);
162 else if (_lhs->data_type() == OperandType::QUANT_UINT8_ASYMM)
164 compareQuant8<uint8_t>(_lhs, _rhs, _output, _op_type);
168 throw std::runtime_error{"Compare: unsupported data type"};
174 } // namespace backend