Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci-interpreter / src / kernels / LogSoftmax.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "kernels/LogSoftmax.h"
18
19 #include "kernels/Utils.h"
20
21 #include <tensorflow/lite/kernels/internal/reference/log_softmax.h>
22
23 #include "PALLogSoftmax.h"
24
25 namespace luci_interpreter
26 {
27 namespace kernels
28 {
29
30 LogSoftmax::LogSoftmax(const Tensor *input, Tensor *output) : Kernel({input}, {output}) {}
31
32 void LogSoftmax::configure()
33 {
34   LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type());
35   if (input()->element_type() == DataType::U8)
36   {
37     LUCI_INTERPRETER_CHECK(output()->scale() == 16. / 256);
38     LUCI_INTERPRETER_CHECK(output()->zero_point() == 255);
39
40     tflite::SoftmaxParams params{};
41
42     params.table = _table;
43     params.beta = 1.0;
44     luci_interpreter_pal::PopulateSoftmaxLookupTable(&params, input()->scale(), params.beta);
45   }
46   output()->resize(input()->shape());
47 }
48
49 void LogSoftmax::execute() const
50 {
51   switch (input()->element_type())
52   {
53     case DataType::FLOAT32:
54       evalFloat();
55       break;
56     case DataType::U8:
57       evalQuantized();
58       break;
59     default:
60       throw std::runtime_error("Unsupported type.");
61   }
62 }
63
64 void LogSoftmax::evalFloat() const
65 {
66   tflite::SoftmaxParams params{};
67   tflite::reference_ops::LogSoftmax(params, getTensorShape(input()), getTensorData<float>(input()),
68                                     getTensorShape(output()), getTensorData<float>(output()));
69 }
70
71 void LogSoftmax::evalQuantized() const
72 {
73   const auto input_shape = getTensorShape(input());
74   const auto output_shape = getTensorShape(output());
75   const auto input_scale = input()->scale();
76   uint8_t *output_data = getTensorData<uint8_t>(output());
77   const uint8_t *input_data = getTensorData<uint8_t>(input());
78   const float beta = 1.0;
79
80   tflite::SoftmaxParams params{};
81
82   params.table = const_cast<float *>(_table);
83   params.zero_point = output()->zero_point();
84   params.scale = output()->scale();
85
86   luci_interpreter_pal::InitializeParams(&params, input_scale, beta);
87   luci_interpreter_pal::LogSoftmax(params, input_scale, input_shape, input_data, output_shape,
88                                    output_data);
89 }
90
91 } // namespace kernels
92 } // namespace luci_interpreter