cfc1dfc6f01d5dd10aac393ee187eefae56e0c3a
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / Softmax.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "Builders.h"
19 #include "kernels/Utils.h"
20
21 #include <tensorflow/lite/kernels/internal/reference/softmax.h>
22 #include "PALSoftmax.h"
23
24 namespace luci_interpreter
25 {
26
27 namespace
28 {
29
30 #ifndef DIS_FLOAT
31 void evalFloat(const circle::Tensor *input, const circle::Tensor *output,
32                const circle::SoftmaxOptions *options, BaseRuntimeGraph *runtime_graph)
33 {
34   const auto *input_data = runtime_graph->getDataByTensor(input);
35   auto *output_data = runtime_graph->getDataByTensor(output);
36
37   tflite::SoftmaxParams op_params{};
38   op_params.beta = options->beta();
39
40   tflite::reference_ops::Softmax(
41     op_params, kernels::getTensorShape(input), kernels::getTensorData<float>(input_data),
42     kernels::getTensorShape(output), kernels::getTensorData<float>(output_data));
43 }
44 #endif // DIS_FLOAT
45
46 #ifndef DIS_QUANT
47 template <typename T>
48 void evalQuantized(const circle::Tensor *input, const circle::Tensor *output,
49                    const circle::SoftmaxOptions *options, BaseRuntimeGraph *runtime_graph)
50 {
51   // TODO: Enable it
52   assert(false && "Not impl yet");
53
54   const auto *input_data = runtime_graph->getDataByTensor(input);
55   auto *output_data = runtime_graph->getDataByTensor(output);
56
57   tflite::SoftmaxParams op_params{};
58
59   luci_interpreter_pal::InitializeParams(&op_params, Tensor::scale(input), options->beta());
60   luci_interpreter_pal::Softmax(
61     op_params, kernels::getTensorShape(input), kernels::getTensorData<T>(input_data),
62     kernels::getTensorShape(output), kernels::getTensorData<T>(output_data));
63 }
64 #endif
65
66 } // namespace
67
68 void configure_kernel_CircleSoftmax(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
69 {
70   const auto input_index = cur_op->inputs()->operator[](0);
71   const auto output_index = cur_op->outputs()->operator[](0);
72
73   assert(input_index != -1);
74   assert(output_index != -1);
75
76   const auto input = runtime_graph->getCircleTensorByIndex(input_index);
77   auto output = runtime_graph->getCircleTensorByIndex(output_index);
78
79   assert(input != nullptr);
80   assert(output != nullptr);
81
82   LUCI_INTERPRETER_CHECK(Tensor::element_type(input) == Tensor::element_type(output));
83   LUCI_INTERPRETER_CHECK(Tensor::num_dims(input) >= 1);
84
85 #ifndef DIS_QUANT
86   if (Tensor::element_type(input) == DataType::U8 || Tensor::element_type(input) == DataType::S8)
87   {
88     LUCI_INTERPRETER_CHECK(Tensor::element_type(input) == DataType::S8 ||
89                            Tensor::zero_point(output) == 0);
90     LUCI_INTERPRETER_CHECK(Tensor::element_type(input) == DataType::U8 ||
91                            Tensor::zero_point(output) == std::numeric_limits<int8_t>::min());
92   }
93 #endif
94 }
95
96 void execute_kernel_CircleSoftmax(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph,
97                                   bool)
98 {
99   const auto input_index = cur_op->inputs()->operator[](0);
100   const auto output_index = cur_op->outputs()->operator[](0);
101
102   assert(input_index != -1);
103   assert(output_index != -1);
104
105   const auto input = runtime_graph->getCircleTensorByIndex(input_index);
106   auto output = runtime_graph->getCircleTensorByIndex(output_index);
107
108   assert(input != nullptr);
109   assert(output != nullptr);
110
111   const auto *options = cur_op->builtin_options_as_SoftmaxOptions();
112
113   switch (Tensor::element_type(input))
114   {
115 #ifndef DIS_FLOAT
116     case DataType::FLOAT32:
117       evalFloat(input, output, options, runtime_graph);
118       break;
119 #endif // DIS_FLOAT
120 #ifndef DIS_QUANT
121     case DataType::S8:
122       evalQuantized<int8_t>(input, output, options, runtime_graph);
123       break;
124     case DataType::U8:
125       evalQuantized<uint8_t>(input, output, options, runtime_graph);
126       break;
127 #endif // DIS_QUANT
128     default:
129       assert(false && "Unsupported type.");
130   }
131 }
132
133 } // namespace luci_interpreter