2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "kernels/LogSoftmax.h"
19 #include "kernels/Utils.h"
21 #include <tensorflow/lite/kernels/internal/reference/log_softmax.h>
23 #include "PALLogSoftmax.h"
25 namespace luci_interpreter
30 LogSoftmax::LogSoftmax(const Tensor *input, Tensor *output) : Kernel({input}, {output}) {}
32 void LogSoftmax::configure()
34 LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type());
35 if (input()->element_type() == DataType::U8)
37 LUCI_INTERPRETER_CHECK(output()->scale() == 16. / 256);
38 LUCI_INTERPRETER_CHECK(output()->zero_point() == 255);
40 tflite::SoftmaxParams params{};
42 params.table = _table;
44 luci_interpreter_pal::PopulateSoftmaxLookupTable(¶ms, input()->scale(), params.beta);
46 output()->resize(input()->shape());
49 void LogSoftmax::execute() const
51 switch (input()->element_type())
53 case DataType::FLOAT32:
60 throw std::runtime_error("Unsupported type.");
64 void LogSoftmax::evalFloat() const
66 tflite::SoftmaxParams params{};
67 tflite::reference_ops::LogSoftmax(params, getTensorShape(input()), getTensorData<float>(input()),
68 getTensorShape(output()), getTensorData<float>(output()));
71 void LogSoftmax::evalQuantized() const
73 const auto input_shape = getTensorShape(input());
74 const auto output_shape = getTensorShape(output());
75 const auto input_scale = input()->scale();
76 uint8_t *output_data = getTensorData<uint8_t>(output());
77 const uint8_t *input_data = getTensorData<uint8_t>(input());
78 const float beta = 1.0;
80 tflite::SoftmaxParams params{};
82 params.table = const_cast<float *>(_table);
83 params.zero_point = output()->zero_point();
84 params.scale = output()->scale();
86 luci_interpreter_pal::InitializeParams(¶ms, input_scale, beta);
87 luci_interpreter_pal::LogSoftmax(params, input_scale, input_shape, input_data, output_shape,
91 } // namespace kernels
92 } // namespace luci_interpreter