Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / runtime / onert / backend / cpu / ops / CompareLayer.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "CompareLayer.h"
17
18 #include "OperationUtils.h"
19
20 #include <assert.h>
21 #include <cker/operation/Comparison.h>
22 using namespace nnfw::cker;
23 namespace onert
24 {
25 namespace backend
26 {
27 namespace cpu
28 {
29 namespace ops
30 {
31
32 namespace
33 {
34
35 using OpType = onert::ir::operation::Comparison::ComparisonType;
36 using namespace onert::backend::cpu;
37
38 // Assumes these enum values to be in the order like this
39 static_assert(static_cast<int>(OpType::Equal) == 0, "An OpType value has changed!");
40 static_assert(static_cast<int>(OpType::NotEqual) == 1, "An OpType value has changed!");
41 static_assert(static_cast<int>(OpType::Greater) == 2, "An OpType value has changed!");
42 static_assert(static_cast<int>(OpType::GreaterEqual) == 3, "An OpType value has changed!");
43 static_assert(static_cast<int>(OpType::Less) == 4, "An OpType value has changed!");
44 static_assert(static_cast<int>(OpType::LessEqual) == 5, "An OpType value has changed!");
45
46 template <typename T>
47 void compareQuant8(const IPortableTensor *lhs, const IPortableTensor *rhs, IPortableTensor *output,
48                    OpType op_type)
49 {
50   nnfw::cker::ComparisonParams params;
51   params.left_shift = 8;
52   params.input1_offset = -lhs->data_offset();
53   params.input2_offset = -rhs->data_offset();
54   const double norm_max_scale =
55       2 * std::max(std::abs(lhs->data_scale()), std::abs(rhs->data_scale()));
56   const double adjusted_lhs_scale = lhs->data_scale() / norm_max_scale;
57   const double adjusted_rhs_scale = rhs->data_scale() / norm_max_scale;
58   QuantizeMultiplierSmallerThanOneExp(adjusted_lhs_scale, &params.input1_multiplier,
59                                       &params.input1_shift);
60   QuantizeMultiplierSmallerThanOneExp(adjusted_rhs_scale, &params.input2_multiplier,
61                                       &params.input2_shift);
62   params.is_broadcast = !HaveSameShapes(lhs, rhs);
63
64   using CompareFunction =
65       void (*)(ComparisonParams & params, const Shape &input1_shape, const T *input1_data,
66                const Shape &input2_shape, const T *input2_data, const Shape &output_shape,
67                bool *output_data);
68
69   static const CompareFunction broadcast_fns[] = {
70       Broadcast4DSlowEqualWithScaling,   Broadcast4DSlowNotEqualWithScaling,
71       Broadcast4DSlowGreaterWithScaling, Broadcast4DSlowGreaterEqualWithScaling,
72       Broadcast4DSlowLessWithScaling,    Broadcast4DSlowLessEqualWithScaling,
73   };
74   static const CompareFunction non_broadcast_fns[] = {
75       EqualWithScaling,        NotEqualWithScaling, GreaterWithScaling,
76       GreaterEqualWithScaling, LessWithScaling,     LessEqualWithScaling,
77   };
78
79   static_assert(sizeof(broadcast_fns) == sizeof(non_broadcast_fns),
80                 "Sizes of broadcast_fns and non_broadcast_fns must match!");
81
82   auto index = static_cast<int>(op_type);
83   if (index < 0 || index >= static_cast<int>(sizeof(broadcast_fns) / sizeof(broadcast_fns[0])))
84     throw std::runtime_error{"Invalid OpType for CompareLayer"};
85
86   CompareFunction fn = (params.is_broadcast ? broadcast_fns[index] : non_broadcast_fns[index]);
87
88   fn(params, getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
89      getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
90      getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
91 }
92
93 template <typename T>
94 void compareScalar(const IPortableTensor *lhs, const IPortableTensor *rhs, IPortableTensor *output,
95                    OpType op_type)
96 {
97   bool requires_broadcast = !HaveSameShapes(lhs, rhs);
98
99   using CompareFunction =
100       void (*)(const Shape &input1_shape, const T *input1_data, const Shape &input2_shape,
101                const T *input2_data, const Shape &output_shape, bool *output_data);
102
103   static const CompareFunction broadcast_fns[] = {
104       Broadcast4DSlowEqual,        Broadcast4DSlowNotEqual, Broadcast4DSlowGreater,
105       Broadcast4DSlowGreaterEqual, Broadcast4DSlowLess,     Broadcast4DSlowLessEqual,
106   };
107   static const CompareFunction non_broadcast_fns[] = {
108       EqualNoScaling,        NotEqualNoScaling, GreaterNoScaling,
109       GreaterEqualNoScaling, LessNoScaling,     LessEqualNoScaling,
110   };
111
112   static_assert(sizeof(broadcast_fns) == sizeof(non_broadcast_fns),
113                 "Sizes of broadcast_fns and non_broadcast_fns must match!");
114
115   auto index = static_cast<int>(op_type);
116   if (index < 0 || index >= static_cast<int>(sizeof(broadcast_fns) / sizeof(broadcast_fns[0])))
117     throw std::runtime_error{"Invalid OpType for CompareLayer"};
118
119   CompareFunction fn = (requires_broadcast ? broadcast_fns[index] : non_broadcast_fns[index]);
120
121   fn(getExtendedTensorShape(lhs), reinterpret_cast<const T *>(lhs->buffer()),
122      getExtendedTensorShape(rhs), reinterpret_cast<const T *>(rhs->buffer()),
123      getExtendedTensorShape(output), reinterpret_cast<bool *>(output->buffer()));
124 }
125
126 } // namespace
127
128 CompareLayer::CompareLayer()
129     : _lhs(nullptr), _rhs(nullptr), _output(nullptr),
130       _op_type(ir::operation::Comparison::ComparisonType::Equal)
131 {
132   // DO NOTHING
133 }
134
135 void CompareLayer::configure(const IPortableTensor *lhs, const IPortableTensor *rhs,
136                              const OpType op_type, IPortableTensor *output)
137 {
138   _lhs = lhs;
139   _rhs = rhs;
140   _op_type = op_type;
141   _output = output;
142 }
143
144 void CompareLayer::run()
145 {
146   if (_lhs->data_type() == OperandType::FLOAT32)
147   {
148     compareScalar<float>(_lhs, _rhs, _output, _op_type);
149   }
150   else if (_lhs->data_type() == OperandType::INT32)
151   {
152     compareScalar<int32_t>(_lhs, _rhs, _output, _op_type);
153   }
154   else if (_lhs->data_type() == OperandType::INT64)
155   {
156     compareScalar<int64_t>(_lhs, _rhs, _output, _op_type);
157   }
158   else if (_lhs->data_type() == OperandType::BOOL8)
159   {
160     compareScalar<uint8_t>(_lhs, _rhs, _output, _op_type);
161   }
162   else if (_lhs->data_type() == OperandType::QUANT_UINT8_ASYMM)
163   {
164     compareQuant8<uint8_t>(_lhs, _rhs, _output, _op_type);
165   }
166   else
167   {
168     throw std::runtime_error{"Compare: unsupported data type"};
169   }
170 }
171
172 } // namespace ops
173 } // namespace cpu
174 } // namespace backend
175 } // namespace onert