2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 #include "kernels/Utils.h"
19 #include "TISOKernel.h"
21 #include "PALComparisons.h"
23 namespace luci_interpreter
29 void evalQuantized(const circle::Tensor *x, const circle::Tensor *y, const circle::Tensor *output,
30 BaseRuntimeGraph *runtime_graph)
32 auto x_data = kernels::getTensorData<uint8_t>(runtime_graph->getDataByTensor(x));
33 if (x_data == nullptr)
34 x_data = kernels::getTensorData<uint8_t>(runtime_graph->getConstDataByTensor(x));
36 assert(x_data != nullptr);
38 auto y_data = kernels::getTensorData<uint8_t>(runtime_graph->getDataByTensor(y));
39 if (y_data == nullptr)
40 y_data = kernels::getTensorData<uint8_t>(runtime_graph->getConstDataByTensor(y));
42 assert(y_data != nullptr);
44 auto output_data = kernels::getTensorData<bool>(runtime_graph->getDataByTensor(output));
52 kernels::quantizeMultiplierSmallerThanOneExp(Tensor::scale(x), &x_multiplier, &x_shift);
53 kernels::quantizeMultiplierSmallerThanOneExp(Tensor::scale(y), &y_multiplier, &y_shift);
55 luci_interpreter_pal::ComparisonParams op_params;
56 op_params.left_shift = 8;
57 op_params.input1_offset = -Tensor::zero_point(x); // Note the '-'
58 op_params.input1_shift = x_shift;
59 op_params.input1_multiplier = x_multiplier;
60 op_params.input2_offset = -Tensor::zero_point(y); // Note the '-'
61 op_params.input2_shift = y_shift;
62 op_params.input2_multiplier = y_multiplier;
63 op_params.is_broadcast = Tensor::num_elements(x) != Tensor::num_elements(y);
65 if (op_params.is_broadcast)
67 luci_interpreter_pal::BroadcastComparison4DSlowWithScaling<uint8_t>(
68 op_params, kernels::getTensorShape(x), x_data, kernels::getTensorShape(y), y_data,
69 kernels::getTensorShape(output), output_data, luci_interpreter_pal::LessFn);
73 const int64_t flat_size = kernels::getTensorShape(x).flatSize();
74 luci_interpreter_pal::ComparisonWithScaling<uint8_t>(op_params, flat_size, x_data, y_data,
75 output_data, luci_interpreter_pal::LessFn);
81 void evalGeneric(const circle::Tensor *x, const circle::Tensor *y, const circle::Tensor *output,
82 BaseRuntimeGraph *runtime_graph)
84 auto x_data = kernels::getTensorData<T>(runtime_graph->getDataByTensor(x));
85 if (x_data == nullptr)
86 x_data = kernels::getTensorData<T>(runtime_graph->getConstDataByTensor(x));
88 assert(x_data != nullptr);
90 auto y_data = kernels::getTensorData<T>(runtime_graph->getDataByTensor(y));
91 if (y_data == nullptr)
92 y_data = kernels::getTensorData<T>(runtime_graph->getConstDataByTensor(y));
94 assert(y_data != nullptr);
96 auto output_data = kernels::getTensorData<bool>(runtime_graph->getDataByTensor(output));
98 luci_interpreter_pal::ComparisonParams op_params;
99 op_params.is_broadcast = Tensor::num_elements(x) != Tensor::num_elements(y);
101 if (op_params.is_broadcast)
103 luci_interpreter_pal::BroadcastComparison4DSlowNoScaling<T>(
104 op_params, kernels::getTensorShape(x), x_data, kernels::getTensorShape(y), y_data,
105 kernels::getTensorShape(output), output_data, luci_interpreter_pal::LessFn);
109 const int64_t flat_size = kernels::getTensorShape(x).flatSize();
110 luci_interpreter_pal::ComparisonNoScaling<T>(flat_size, x_data, y_data, output_data,
111 luci_interpreter_pal::LessFn);
117 void configure_kernel_CircleLess(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
119 kernels::TISOKernel kernel(cur_op, runtime_graph);
121 LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input1()) ==
122 Tensor::element_type(kernel.input2()));
123 LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.output()) == DataType::BOOL);
126 void execute_kernel_CircleLess(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
128 kernels::TISOKernel kernel(cur_op, runtime_graph);
130 switch (Tensor::element_type(kernel.input1()))
133 evalGeneric<int64_t>(kernel.input1(), kernel.input2(), kernel.output(), runtime_graph);
136 evalGeneric<int32_t>(kernel.input1(), kernel.input2(), kernel.output(), runtime_graph);
140 evalQuantized(kernel.input1(), kernel.input2(), kernel.output(), runtime_graph);
144 case DataType::FLOAT32:
145 evalGeneric<float>(kernel.input1(), kernel.input2(), kernel.output(), runtime_graph);
149 assert(false && "Unsupported type.");
153 } // namespace luci_interpreter