Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / pal / common / PALTanh.h
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef LUCI_INTERPRETER_PAL_TANH_H
19 #define LUCI_INTERPRETER_PAL_TANH_H
20
21 #include "PALUtils.h"
22
23 namespace luci_interpreter_pal
24 {
25
26 inline void Tanh(const int flat_size, const float *input_data, float *output_data)
27 {
28   for (int i = 0; i < flat_size; i++)
29   {
30     float val = input_data[i];
31     float result = std::tanh(val);
32     output_data[i] = result;
33   }
34 }
35
36 inline void Tanh(int32_t input_multiplier, int32_t input_left_shift, const int flat_size,
37                  const int16_t *ptr_input_data, int16_t *ptr_output_data)
38 {
39   // We use the LUT for sigmoid and take into account, that
40   // tanh(x) = 2*sigmoid(2*x) - 1
41
42   // We scale by 3/4 to expand range [-8,8]->[-10.7,10.7].
43   // In case of general parameter scale, multiplier 3 is taken into account
44   // in TanhPrepare function and it is included in
45   // input_multiplier already.
46
47   if (input_multiplier == 0)
48   { // power of two case
49     input_multiplier = 3 << input_left_shift;
50     input_left_shift = 0;
51   }
52
53   int32_t round = (input_left_shift > 0) ? 1 << (input_left_shift - 1) : 0;
54
55   for (int i = 0; i < flat_size; ++i, ptr_input_data++, ptr_output_data++)
56   {
57     int32_t input_data = ((*ptr_input_data) * input_multiplier + round) >> input_left_shift;
58
59     uint32_t abs_input_data = abs(input_data);
60     uint32_t uh = abs_input_data >> 8;
61     int32_t result;
62
63     if (uh >= 255)
64     {
65       // Saturate to maximum.
66       result = 0xFFFF << 8;
67     }
68     else
69     {
70       uint32_t ua = sigmoid_table_uint16[uh];
71       uint32_t ub = sigmoid_table_uint16[uh + 1];
72
73       uint8_t ut = abs_input_data & 0xFF;
74
75       result = (ua << 8) + ut * (ub - ua);
76     }
77
78     result = (input_data >= 0) ? (result - (1 << (14 + 9)) + (1 << (9 - 2)))
79                                : (-result + (1 << (14 + 9)) + (1 << (9 - 2)) - 1);
80
81     // Convert back to 16-bit.
82     result >>= (9 - 1);
83
84     *ptr_output_data = result;
85   }
86 }
87
88 #if 0
89 inline void Tanh(int32_t input_zero_point, int32_t input_range_radius,
90                  int32_t input_multiplier, int32_t input_shift,
91                  const int flat_size, const int8_t* input_data, int8_t* output_data) {
92   // Integer bits must be in sync with Prepare() function.
93   static constexpr int32_t kInputIntegerBits = 4;
94   static constexpr int32_t kOutputScale = 7;
95   static constexpr int32_t kMinInt8 = std::numeric_limits<int8_t>::min();
96   static constexpr int32_t kMaxInt8 = std::numeric_limits<int8_t>::max();
97
98   for (int i = 0; i < flat_size; ++i) {
99     const int32_t input =
100       static_cast<int32_t>(input_data[i]) - input_zero_point;
101     if (input <= -input_range_radius) {
102       output_data[i] = kMinInt8;
103     } else if (input >= input_range_radius) {
104       output_data[i] = kMaxInt8;
105     } else {
106       const int32_t input_in_q4 =
107         multiplyByQuantizedMultiplier(input, input_multiplier, input_shift);
108       const int32_t output_in_q0 = std::tanh(input_in_q4);
109
110       int32_t output_in_q24 =
111         roundingDivideByPOT(output_in_q0, 31 - kOutputScale);
112       output_in_q24 = std::min(std::max(output_in_q24, kMinInt8), kMaxInt8);
113       output_data[i] = static_cast<int8_t>(output_in_q24);
114     }
115   }
116 }
117 #endif // 0
118
119 } // namespace luci_interpreter_pal
120
121 #endif // LUCI_INTERPRETER_PAL_TANH_H