Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / Less.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "Builders.h"
18 #include "kernels/Utils.h"
19 #include "TISOKernel.h"
20
21 #include "PALComparisons.h"
22
23 namespace luci_interpreter
24 {
25
26 namespace
27 {
28 #ifndef DIS_QUANT
29 void evalQuantized(const circle::Tensor *x, const circle::Tensor *y, const circle::Tensor *output,
30                    BaseRuntimeGraph *runtime_graph)
31 {
32   auto x_data = kernels::getTensorData<uint8_t>(runtime_graph->getDataByTensor(x));
33   if (x_data == nullptr)
34     x_data = kernels::getTensorData<uint8_t>(runtime_graph->getConstDataByTensor(x));
35
36   assert(x_data != nullptr);
37
38   auto y_data = kernels::getTensorData<uint8_t>(runtime_graph->getDataByTensor(y));
39   if (y_data == nullptr)
40     y_data = kernels::getTensorData<uint8_t>(runtime_graph->getConstDataByTensor(y));
41
42   assert(y_data != nullptr);
43
44   auto output_data = kernels::getTensorData<bool>(runtime_graph->getDataByTensor(output));
45
46   int32_t x_multiplier;
47   int x_shift;
48
49   int32_t y_multiplier;
50   int y_shift;
51
52   kernels::quantizeMultiplierSmallerThanOneExp(Tensor::scale(x), &x_multiplier, &x_shift);
53   kernels::quantizeMultiplierSmallerThanOneExp(Tensor::scale(y), &y_multiplier, &y_shift);
54
55   luci_interpreter_pal::ComparisonParams op_params;
56   op_params.left_shift = 8;
57   op_params.input1_offset = -Tensor::zero_point(x); // Note the '-'
58   op_params.input1_shift = x_shift;
59   op_params.input1_multiplier = x_multiplier;
60   op_params.input2_offset = -Tensor::zero_point(y); // Note the '-'
61   op_params.input2_shift = y_shift;
62   op_params.input2_multiplier = y_multiplier;
63   op_params.is_broadcast = Tensor::num_elements(x) != Tensor::num_elements(y);
64
65   if (op_params.is_broadcast)
66   {
67     luci_interpreter_pal::BroadcastComparison4DSlowWithScaling<uint8_t>(
68       op_params, kernels::getTensorShape(x), x_data, kernels::getTensorShape(y), y_data,
69       kernels::getTensorShape(output), output_data, luci_interpreter_pal::LessFn);
70   }
71   else
72   {
73     const int64_t flat_size = kernels::getTensorShape(x).flatSize();
74     luci_interpreter_pal::ComparisonWithScaling<uint8_t>(op_params, flat_size, x_data, y_data,
75                                                          output_data, luci_interpreter_pal::LessFn);
76   }
77 }
78 #endif // DIS_QUANT
79
80 template <typename T>
81 void evalGeneric(const circle::Tensor *x, const circle::Tensor *y, const circle::Tensor *output,
82                  BaseRuntimeGraph *runtime_graph)
83 {
84   auto x_data = kernels::getTensorData<T>(runtime_graph->getDataByTensor(x));
85   if (x_data == nullptr)
86     x_data = kernels::getTensorData<T>(runtime_graph->getConstDataByTensor(x));
87
88   assert(x_data != nullptr);
89
90   auto y_data = kernels::getTensorData<T>(runtime_graph->getDataByTensor(y));
91   if (y_data == nullptr)
92     y_data = kernels::getTensorData<T>(runtime_graph->getConstDataByTensor(y));
93
94   assert(y_data != nullptr);
95
96   auto output_data = kernels::getTensorData<bool>(runtime_graph->getDataByTensor(output));
97
98   luci_interpreter_pal::ComparisonParams op_params;
99   op_params.is_broadcast = Tensor::num_elements(x) != Tensor::num_elements(y);
100
101   if (op_params.is_broadcast)
102   {
103     luci_interpreter_pal::BroadcastComparison4DSlowNoScaling<T>(
104       op_params, kernels::getTensorShape(x), x_data, kernels::getTensorShape(y), y_data,
105       kernels::getTensorShape(output), output_data, luci_interpreter_pal::LessFn);
106   }
107   else
108   {
109     const int64_t flat_size = kernels::getTensorShape(x).flatSize();
110     luci_interpreter_pal::ComparisonNoScaling<T>(flat_size, x_data, y_data, output_data,
111                                                  luci_interpreter_pal::LessFn);
112   }
113 }
114
115 } // namespace
116
117 void configure_kernel_CircleLess(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
118 {
119   kernels::TISOKernel kernel(cur_op, runtime_graph);
120
121   LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input1()) ==
122                          Tensor::element_type(kernel.input2()));
123   LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.output()) == DataType::BOOL);
124 }
125
126 void execute_kernel_CircleLess(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
127 {
128   kernels::TISOKernel kernel(cur_op, runtime_graph);
129
130   switch (Tensor::element_type(kernel.input1()))
131   {
132     case DataType::S64:
133       evalGeneric<int64_t>(kernel.input1(), kernel.input2(), kernel.output(), runtime_graph);
134       break;
135     case DataType::S32:
136       evalGeneric<int32_t>(kernel.input1(), kernel.input2(), kernel.output(), runtime_graph);
137       break;
138 #ifndef DIS_QUANT
139     case DataType::U8:
140       evalQuantized(kernel.input1(), kernel.input2(), kernel.output(), runtime_graph);
141       break;
142 #endif // DIS_QUANT
143 #ifndef DIS_FLOAT
144     case DataType::FLOAT32:
145       evalGeneric<float>(kernel.input1(), kernel.input2(), kernel.output(), runtime_graph);
146       break;
147 #endif // DIS_FLOAT
148     default:
149       assert(false && "Unsupported type.");
150   }
151 }
152
153 } // namespace luci_interpreter