2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #include "kernels/LogSoftmax.h"
20 #include "kernels/Utils.h"
22 #include <tensorflow/lite/kernels/internal/reference/log_softmax.h>
24 #include "PALLogSoftmax.h"
26 namespace luci_interpreter
31 LogSoftmax::LogSoftmax(const Tensor *input, Tensor *output) : Kernel({input}, {output}) {}
33 void LogSoftmax::configure()
35 LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type());
36 if (input()->element_type() == DataType::U8)
38 LUCI_INTERPRETER_CHECK(output()->scale() == 16. / 256);
39 LUCI_INTERPRETER_CHECK(output()->zero_point() == 255);
41 tflite::SoftmaxParams params{};
43 params.table = _table;
45 luci_interpreter_pal::PopulateSoftmaxLookupTable(¶ms, input()->scale(), params.beta);
47 // TODO: enable it only if kernel with dynamic shapes
48 output()->resize(input()->shape());
51 void LogSoftmax::execute() const
53 switch (input()->element_type())
55 case DataType::FLOAT32:
62 assert(false && "Unsupported type.");
66 void LogSoftmax::evalFloat() const
68 tflite::SoftmaxParams params{};
69 tflite::reference_ops::LogSoftmax(params, getTensorShape(input()), getTensorData<float>(input()),
70 getTensorShape(output()), getTensorData<float>(output()));
73 void LogSoftmax::evalQuantized() const
75 const auto input_shape = getTensorShape(input());
76 const auto output_shape = getTensorShape(output());
77 const auto input_scale = input()->scale();
78 uint8_t *output_data = getTensorData<uint8_t>(output());
79 const uint8_t *input_data = getTensorData<uint8_t>(input());
80 const float beta = 1.0;
82 tflite::SoftmaxParams params{};
84 params.table = const_cast<float *>(_table);
85 params.zero_point = output()->zero_point();
86 params.scale = output()->scale();
88 luci_interpreter_pal::InitializeParams(¶ms, input_scale, beta);
89 luci_interpreter_pal::LogSoftmax(params, input_scale, input_shape, input_data, output_shape,
93 } // namespace kernels
94 } // namespace luci_interpreter