2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
16 #include "CompareLayer.h"
18 #include "OperationUtils.h"
21 #include <cker/operation/Comparison.h>
22 using namespace nnfw::cker;
35 using OpType = onert::ir::operation::Comparison::ComparisonType;
36 using namespace onert::backend::cpu;
38 // Assumes these enum values to be in the order like this
39 static_assert(static_cast<int>(OpType::Equal) == 0, "An OpType value has changed!");
40 static_assert(static_cast<int>(OpType::NotEqual) == 1, "An OpType value has changed!");
41 static_assert(static_cast<int>(OpType::Greater) == 2, "An OpType value has changed!");
42 static_assert(static_cast<int>(OpType::GreaterEqual) == 3, "An OpType value has changed!");
43 static_assert(static_cast<int>(OpType::Less) == 4, "An OpType value has changed!");
44 static_assert(static_cast<int>(OpType::LessEqual) == 5, "An OpType value has changed!");
47 void compareQuant8(const IPortableTensor *lhs, const IPortableTensor *rhs, IPortableTensor *output,
50 nnfw::cker::ComparisonParams params;
51 params.left_shift = 8;
52 params.input1_offset = -lhs->data_zero_point();
53 params.input2_offset = -rhs->data_zero_point();
54 const double norm_max_scale =
55 2 * std::max(std::abs(lhs->data_scale()), std::abs(rhs->data_scale()));
56 const double adjusted_lhs_scale = lhs->data_scale() / norm_max_scale;
57 const double adjusted_rhs_scale = rhs->data_scale() / norm_max_scale;
58 QuantizeMultiplierSmallerThanOneExp(adjusted_lhs_scale, ¶ms.input1_multiplier,
59 ¶ms.input1_shift);
60 QuantizeMultiplierSmallerThanOneExp(adjusted_rhs_scale, ¶ms.input2_multiplier,
61 ¶ms.input2_shift);
62 params.is_broadcast = !HaveSameShapes(lhs, rhs);
64 using CompareFunction = void (*)(
65 ComparisonParams & params, const Shape &input1_shape, const T *input1_data,
66 const Shape &input2_shape, const T *input2_data, const Shape &output_shape, bool *output_data);
68 static const CompareFunction broadcast_fns[] = {
69 Broadcast4DSlowEqualWithScaling, Broadcast4DSlowNotEqualWithScaling,
70 Broadcast4DSlowGreaterWithScaling, Broadcast4DSlowGreaterEqualWithScaling,
71 Broadcast4DSlowLessWithScaling, Broadcast4DSlowLessEqualWithScaling,
73 static const CompareFunction non_broadcast_fns[] = {
74 EqualWithScaling, NotEqualWithScaling, GreaterWithScaling,
75 GreaterEqualWithScaling, LessWithScaling, LessEqualWithScaling,
78 static_assert(sizeof(broadcast_fns) == sizeof(non_broadcast_fns),
79 "Sizes of broadcast_fns and non_broadcast_fns must match!");
81 auto index = static_cast<int>(op_type);
82 if (index < 0 || index >= static_cast<int>(sizeof(broadcast_fns) / sizeof(broadcast_fns[0])))
83 throw std::runtime_error{"Invalid OpType for CompareLayer"};
85 CompareFunction fn = (params.is_broadcast ? broadcast_fns[index] : non_broadcast_fns[index]);
87 fn(params, getExtendedTensorShape(lhs), getBuffer<T>(lhs), getExtendedTensorShape(rhs),
88 getBuffer<T>(rhs), getExtendedTensorShape(output), getBuffer<bool>(output));
92 void compareScalar(const IPortableTensor *lhs, const IPortableTensor *rhs, IPortableTensor *output,
95 bool requires_broadcast = !HaveSameShapes(lhs, rhs);
97 using CompareFunction =
98 void (*)(const Shape &input1_shape, const T *input1_data, const Shape &input2_shape,
99 const T *input2_data, const Shape &output_shape, bool *output_data);
101 static const CompareFunction broadcast_fns[] = {
102 Broadcast4DSlowEqual, Broadcast4DSlowNotEqual, Broadcast4DSlowGreater,
103 Broadcast4DSlowGreaterEqual, Broadcast4DSlowLess, Broadcast4DSlowLessEqual,
105 static const CompareFunction non_broadcast_fns[] = {
106 EqualNoScaling, NotEqualNoScaling, GreaterNoScaling,
107 GreaterEqualNoScaling, LessNoScaling, LessEqualNoScaling,
110 static_assert(sizeof(broadcast_fns) == sizeof(non_broadcast_fns),
111 "Sizes of broadcast_fns and non_broadcast_fns must match!");
113 auto index = static_cast<int>(op_type);
114 if (index < 0 || index >= static_cast<int>(sizeof(broadcast_fns) / sizeof(broadcast_fns[0])))
115 throw std::runtime_error{"Invalid OpType for CompareLayer"};
117 CompareFunction fn = (requires_broadcast ? broadcast_fns[index] : non_broadcast_fns[index]);
119 fn(getExtendedTensorShape(lhs), getBuffer<T>(lhs), getExtendedTensorShape(rhs), getBuffer<T>(rhs),
120 getExtendedTensorShape(output), getBuffer<bool>(output));
125 CompareLayer::CompareLayer()
126 : _lhs(nullptr), _rhs(nullptr), _output(nullptr),
127 _op_type(ir::operation::Comparison::ComparisonType::Equal)
132 void CompareLayer::configure(const IPortableTensor *lhs, const IPortableTensor *rhs,
133 const OpType op_type, IPortableTensor *output)
141 void CompareLayer::run()
143 if (_lhs->data_type() == OperandType::FLOAT32)
145 compareScalar<float>(_lhs, _rhs, _output, _op_type);
147 else if (_lhs->data_type() == OperandType::INT32)
149 compareScalar<int32_t>(_lhs, _rhs, _output, _op_type);
151 else if (_lhs->data_type() == OperandType::INT64)
153 compareScalar<int64_t>(_lhs, _rhs, _output, _op_type);
155 else if (_lhs->data_type() == OperandType::BOOL8)
157 compareScalar<uint8_t>(_lhs, _rhs, _output, _op_type);
159 else if (_lhs->data_type() == OperandType::QUANT_UINT8_ASYMM)
161 compareQuant8<uint8_t>(_lhs, _rhs, _output, _op_type);
165 throw std::runtime_error{"Compare: unsupported data type"};
171 } // namespace backend