Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / LogSoftmax.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "kernels/LogSoftmax.h"
18
19 #include "kernels/Utils.h"
20
21 #include <tensorflow/lite/kernels/internal/reference/log_softmax.h>
22
23 #include "PALLogSoftmax.h"
24
25 namespace luci_interpreter
26 {
27 namespace kernels
28 {
29
30 LogSoftmax::LogSoftmax(const Tensor *input, Tensor *output) : Kernel({input}, {output}) {}
31
32 void LogSoftmax::configure()
33 {
34   LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type());
35   if (input()->element_type() == DataType::U8)
36   {
37     LUCI_INTERPRETER_CHECK(output()->scale() == 16. / 256);
38     LUCI_INTERPRETER_CHECK(output()->zero_point() == 255);
39
40     tflite::SoftmaxParams params{};
41
42     params.table = _table;
43     params.beta = 1.0;
44     luci_interpreter_pal::PopulateSoftmaxLookupTable(&params, input()->scale(), params.beta);
45   }
46   // TODO: enable it only if kernel with dynamic shapes
47   output()->resize(input()->shape());
48 }
49
50 void LogSoftmax::execute() const
51 {
52   switch (input()->element_type())
53   {
54     case DataType::FLOAT32:
55       evalFloat();
56       break;
57     case DataType::U8:
58       evalQuantized();
59       break;
60     default:
61       assert(false && "Unsupported type.");
62   }
63 }
64
65 void LogSoftmax::evalFloat() const
66 {
67   tflite::SoftmaxParams params{};
68   tflite::reference_ops::LogSoftmax(params, getTensorShape(input()), getTensorData<float>(input()),
69                                     getTensorShape(output()), getTensorData<float>(output()));
70 }
71
72 void LogSoftmax::evalQuantized() const
73 {
74   const auto input_shape = getTensorShape(input());
75   const auto output_shape = getTensorShape(output());
76   const auto input_scale = input()->scale();
77   uint8_t *output_data = getTensorData<uint8_t>(output());
78   const uint8_t *input_data = getTensorData<uint8_t>(input());
79   const float beta = 1.0;
80
81   tflite::SoftmaxParams params{};
82
83   params.table = const_cast<float *>(_table);
84   params.zero_point = output()->zero_point();
85   params.scale = output()->scale();
86
87   luci_interpreter_pal::InitializeParams(&params, input_scale, beta);
88   luci_interpreter_pal::LogSoftmax(params, input_scale, input_shape, input_data, output_shape,
89                                    output_data);
90 }
91
92 } // namespace kernels
93 } // namespace luci_interpreter