3acb22e93aa9fccaf8625ceb886443aca645d7b4
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / Tanh.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "kernels/Tanh.h"
19
20 #include "kernels/Utils.h"
21
22 #include <tensorflow/lite/kernels/internal/reference/tanh.h>
23
24 namespace luci_interpreter
25 {
26 namespace kernels
27 {
28
29 Tanh::Tanh(const Tensor *input, Tensor *output) : Kernel({input}, {output}) {}
30
31 void Tanh::configure()
32 {
33   LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type());
34   if (input()->element_type() == DataType::U8)
35   {
36     populateLookupTable();
37   }
38   // TODO: enable it only if kernel with dynamic shapes
39   output()->resize(input()->shape());
40 }
41
42 void Tanh::execute() const
43 {
44   switch (input()->element_type())
45   {
46     case DataType::FLOAT32:
47       evalFloat();
48       break;
49     case DataType::U8:
50       evalQuantized();
51       break;
52     default:
53       assert(false && "Unsupported type.");
54   }
55 }
56
57 void Tanh::evalFloat() const
58 {
59   tflite::reference_ops::Tanh(getTensorShape(input()), getTensorData<float>(input()),
60                               getTensorShape(output()), getTensorData<float>(output()));
61 }
62
63 void Tanh::evalQuantized() const
64 {
65   const int size = tflite::MatchingFlatSize(getTensorShape(input()), getTensorShape(output()));
66   uint8_t *output_data = getTensorData<uint8_t>(output());
67   const uint8_t *input_data = getTensorData<uint8_t>(input());
68   for (int i = 0; i < size; ++i)
69   {
70     output_data[i] = getTableValue(input_data[i]);
71   }
72 }
73
74 void Tanh::populateLookupTable()
75 {
76   const auto input_scale = static_cast<double>(input()->scale());
77   const auto input_zero_point = static_cast<int32_t>(input()->zero_point());
78   const auto output_scale = static_cast<double>(output()->scale());
79   const auto output_zero_point = static_cast<int32_t>(output()->zero_point());
80   const float inverse_scale = 1 / output_scale;
81   int32_t maxval = std::numeric_limits<uint8_t>::max();
82   int32_t minval = std::numeric_limits<uint8_t>::min();
83   for (int32_t val = minval; val <= maxval; ++val)
84   {
85     const float dequantized = input_scale * (val - input_zero_point);
86     const float transformed = std::tanh(dequantized);
87     const float rescaled = std::round(transformed * inverse_scale);
88     const int32_t quantized = static_cast<int32_t>(rescaled + output_zero_point);
89     setTableValue(static_cast<uint8_t>(std::max(std::min(maxval, quantized), minval)),
90                   static_cast<uint8_t>(val));
91   }
92 }
93
94 } // namespace kernels
95 } // namespace luci_interpreter