Imported Upstream version 1.15.0
[platform/core/ml/nnfw.git] / runtime / onert / backend / cpu / ops / CompareLayer.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "CompareLayer.h"
17
18 #include "OperationUtils.h"
19
20 #include <assert.h>
21 #include <cker/operation/Comparison.h>
22 using namespace nnfw::cker;
23 namespace onert
24 {
25 namespace backend
26 {
27 namespace cpu
28 {
29 namespace ops
30 {
31
32 namespace
33 {
34
35 using OpType = onert::ir::operation::Comparison::ComparisonType;
36 using namespace onert::backend::cpu;
37
38 // Assumes these enum values to be in the order like this
39 static_assert(static_cast<int>(OpType::Equal) == 0, "An OpType value has changed!");
40 static_assert(static_cast<int>(OpType::NotEqual) == 1, "An OpType value has changed!");
41 static_assert(static_cast<int>(OpType::Greater) == 2, "An OpType value has changed!");
42 static_assert(static_cast<int>(OpType::GreaterEqual) == 3, "An OpType value has changed!");
43 static_assert(static_cast<int>(OpType::Less) == 4, "An OpType value has changed!");
44 static_assert(static_cast<int>(OpType::LessEqual) == 5, "An OpType value has changed!");
45
46 template <typename T>
47 void compareQuant8(const IPortableTensor *lhs, const IPortableTensor *rhs, IPortableTensor *output,
48                    OpType op_type)
49 {
50   nnfw::cker::ComparisonParams params;
51   params.left_shift = 8;
52   params.input1_offset = -lhs->data_zero_point();
53   params.input2_offset = -rhs->data_zero_point();
54   const double norm_max_scale =
55     2 * std::max(std::abs(lhs->data_scale()), std::abs(rhs->data_scale()));
56   const double adjusted_lhs_scale = lhs->data_scale() / norm_max_scale;
57   const double adjusted_rhs_scale = rhs->data_scale() / norm_max_scale;
58   QuantizeMultiplierSmallerThanOneExp(adjusted_lhs_scale, &params.input1_multiplier,
59                                       &params.input1_shift);
60   QuantizeMultiplierSmallerThanOneExp(adjusted_rhs_scale, &params.input2_multiplier,
61                                       &params.input2_shift);
62   params.is_broadcast = !HaveSameShapes(lhs, rhs);
63
64   using CompareFunction = void (*)(
65     ComparisonParams & params, const Shape &input1_shape, const T *input1_data,
66     const Shape &input2_shape, const T *input2_data, const Shape &output_shape, bool *output_data);
67
68   static const CompareFunction broadcast_fns[] = {
69     Broadcast4DSlowEqualWithScaling,   Broadcast4DSlowNotEqualWithScaling,
70     Broadcast4DSlowGreaterWithScaling, Broadcast4DSlowGreaterEqualWithScaling,
71     Broadcast4DSlowLessWithScaling,    Broadcast4DSlowLessEqualWithScaling,
72   };
73   static const CompareFunction non_broadcast_fns[] = {
74     EqualWithScaling,        NotEqualWithScaling, GreaterWithScaling,
75     GreaterEqualWithScaling, LessWithScaling,     LessEqualWithScaling,
76   };
77
78   static_assert(sizeof(broadcast_fns) == sizeof(non_broadcast_fns),
79                 "Sizes of broadcast_fns and non_broadcast_fns must match!");
80
81   auto index = static_cast<int>(op_type);
82   if (index < 0 || index >= static_cast<int>(sizeof(broadcast_fns) / sizeof(broadcast_fns[0])))
83     throw std::runtime_error{"Invalid OpType for CompareLayer"};
84
85   CompareFunction fn = (params.is_broadcast ? broadcast_fns[index] : non_broadcast_fns[index]);
86
87   fn(params, getExtendedTensorShape(lhs), getBuffer<T>(lhs), getExtendedTensorShape(rhs),
88      getBuffer<T>(rhs), getExtendedTensorShape(output), getBuffer<bool>(output));
89 }
90
91 template <typename T>
92 void compareScalar(const IPortableTensor *lhs, const IPortableTensor *rhs, IPortableTensor *output,
93                    OpType op_type)
94 {
95   bool requires_broadcast = !HaveSameShapes(lhs, rhs);
96
97   using CompareFunction =
98     void (*)(const Shape &input1_shape, const T *input1_data, const Shape &input2_shape,
99              const T *input2_data, const Shape &output_shape, bool *output_data);
100
101   static const CompareFunction broadcast_fns[] = {
102     Broadcast4DSlowEqual,        Broadcast4DSlowNotEqual, Broadcast4DSlowGreater,
103     Broadcast4DSlowGreaterEqual, Broadcast4DSlowLess,     Broadcast4DSlowLessEqual,
104   };
105   static const CompareFunction non_broadcast_fns[] = {
106     EqualNoScaling,        NotEqualNoScaling, GreaterNoScaling,
107     GreaterEqualNoScaling, LessNoScaling,     LessEqualNoScaling,
108   };
109
110   static_assert(sizeof(broadcast_fns) == sizeof(non_broadcast_fns),
111                 "Sizes of broadcast_fns and non_broadcast_fns must match!");
112
113   auto index = static_cast<int>(op_type);
114   if (index < 0 || index >= static_cast<int>(sizeof(broadcast_fns) / sizeof(broadcast_fns[0])))
115     throw std::runtime_error{"Invalid OpType for CompareLayer"};
116
117   CompareFunction fn = (requires_broadcast ? broadcast_fns[index] : non_broadcast_fns[index]);
118
119   fn(getExtendedTensorShape(lhs), getBuffer<T>(lhs), getExtendedTensorShape(rhs), getBuffer<T>(rhs),
120      getExtendedTensorShape(output), getBuffer<bool>(output));
121 }
122
123 } // namespace
124
125 CompareLayer::CompareLayer()
126   : _lhs(nullptr), _rhs(nullptr), _output(nullptr),
127     _op_type(ir::operation::Comparison::ComparisonType::Equal)
128 {
129   // DO NOTHING
130 }
131
132 void CompareLayer::configure(const IPortableTensor *lhs, const IPortableTensor *rhs,
133                              const OpType op_type, IPortableTensor *output)
134 {
135   _lhs = lhs;
136   _rhs = rhs;
137   _op_type = op_type;
138   _output = output;
139 }
140
141 void CompareLayer::run()
142 {
143   if (_lhs->data_type() == OperandType::FLOAT32)
144   {
145     compareScalar<float>(_lhs, _rhs, _output, _op_type);
146   }
147   else if (_lhs->data_type() == OperandType::INT32)
148   {
149     compareScalar<int32_t>(_lhs, _rhs, _output, _op_type);
150   }
151   else if (_lhs->data_type() == OperandType::INT64)
152   {
153     compareScalar<int64_t>(_lhs, _rhs, _output, _op_type);
154   }
155   else if (_lhs->data_type() == OperandType::BOOL8)
156   {
157     compareScalar<uint8_t>(_lhs, _rhs, _output, _op_type);
158   }
159   else if (_lhs->data_type() == OperandType::QUANT_UINT8_ASYMM)
160   {
161     compareQuant8<uint8_t>(_lhs, _rhs, _output, _op_type);
162   }
163   else
164   {
165     throw std::runtime_error{"Compare: unsupported data type"};
166   }
167 }
168
169 } // namespace ops
170 } // namespace cpu
171 } // namespace backend
172 } // namespace onert