arm_compute v17.06
[platform/upstream/armcl.git] / arm_compute / core / FixedPoint.inl
1 /*
2  * Copyright (c) 2017 ARM Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include <cmath>
25 #include <limits>
26
27 namespace
28 {
29 template <typename TpIn, typename TpSat>
30 inline TpSat saturate_convert(TpIn a)
31 {
32     if(a > std::numeric_limits<TpSat>::max())
33     {
34         a = std::numeric_limits<TpSat>::max();
35     }
36     if(a < std::numeric_limits<TpSat>::min())
37     {
38         a = std::numeric_limits<TpSat>::min();
39     }
40     return static_cast<TpSat>(a);
41 }
42 } // namespace
43
44 namespace arm_compute
45 {
46 inline qint8_t sqshl_qs8(qint8_t a, int shift)
47 {
48     qint16_t tmp = static_cast<qint16_t>(a) << shift;
49     // Saturate the result in case of overflow and cast to qint8_t
50     return saturate_convert<qint16_t, qint8_t>(tmp);
51 }
52
53 inline qint8_t sabs_qs8(qint8_t a)
54 {
55     return a & 0x7F;
56 }
57
58 inline qint8_t sadd_qs8(qint8_t a, qint8_t b)
59 {
60     return a + b;
61 }
62
63 inline qint8_t sqadd_qs8(qint8_t a, qint8_t b)
64 {
65     // We need to store the temporary result in qint16_t otherwise we cannot evaluate the overflow
66     qint16_t tmp = (static_cast<qint16_t>(a) + static_cast<qint16_t>(b));
67
68     // Saturate the result in case of overflow and cast to qint8_t
69     return saturate_convert<qint16_t, qint8_t>(tmp);
70 }
71
72 inline qint16_t sqadd_qs16(qint16_t a, qint16_t b)
73 {
74     // We need to store the temporary result in qint16_t otherwise we cannot evaluate the overflow
75     qint32_t tmp = (static_cast<qint32_t>(a) + static_cast<qint32_t>(b));
76
77     // Saturate the result in case of overflow and cast to qint16_t
78     return saturate_convert<qint32_t, qint16_t>(tmp);
79 }
80
81 inline qint8_t ssub_qs8(qint8_t a, qint8_t b)
82 {
83     return a - b;
84 }
85
86 inline qint8_t sqsub_qs8(qint8_t a, qint8_t b)
87 {
88     // We need to store the temporary result in uint16_t otherwise we cannot evaluate the overflow
89     qint16_t tmp = static_cast<qint16_t>(a) - static_cast<qint16_t>(b);
90
91     // Saturate the result in case of overflow and cast to qint8_t
92     return saturate_convert<qint16_t, qint8_t>(tmp);
93 }
94
95 inline qint8_t smul_qs8(qint8_t a, qint8_t b, int fixed_point_position)
96 {
97     const qint16_t round_up_const = (1 << (fixed_point_position - 1));
98
99     qint16_t tmp = static_cast<qint16_t>(a) * static_cast<qint16_t>(b);
100
101     // Rounding up
102     tmp += round_up_const;
103
104     return static_cast<qint8_t>(tmp >> fixed_point_position);
105 }
106
107 inline qint8_t sqmul_qs8(qint8_t a, qint8_t b, int fixed_point_position)
108 {
109     const qint16_t round_up_const = (1 << (fixed_point_position - 1));
110
111     qint16_t tmp = static_cast<qint16_t>(a) * static_cast<qint16_t>(b);
112
113     // Rounding up
114     tmp += round_up_const;
115
116     return saturate_convert<qint16_t, qint8_t>(tmp >> fixed_point_position);
117 }
118
119 inline qint16_t sqmul_qs16(qint16_t a, qint16_t b, int fixed_point_position)
120 {
121     const qint32_t round_up_const = (1 << (fixed_point_position - 1));
122
123     qint32_t tmp = static_cast<qint32_t>(a) * static_cast<qint32_t>(b);
124
125     // Rounding up
126     tmp += round_up_const;
127
128     return saturate_convert<qint32_t, qint16_t>(tmp >> fixed_point_position);
129 }
130
131 inline qint16_t sqmull_qs8(qint8_t a, qint8_t b, int fixed_point_position)
132 {
133     const qint16_t round_up_const = (1 << (fixed_point_position - 1));
134
135     qint16_t tmp = static_cast<qint16_t>(a) * static_cast<qint16_t>(b);
136
137     // Rounding up
138     tmp += round_up_const;
139
140     return tmp >> fixed_point_position;
141 }
142
143 inline qint8_t sinvsqrt_qs8(qint8_t a, int fixed_point_position)
144 {
145     qint8_t shift = 8 - (fixed_point_position + (__builtin_clz(a) - 24));
146
147     qint8_t const_three = (3 << fixed_point_position);
148     qint8_t temp        = shift < 0 ? (a << -shift) : (a >> shift);
149     qint8_t x2          = temp;
150
151     // We need three iterations to find the result
152     for(int i = 0; i < 3; i++)
153     {
154         qint8_t three_minus_dx = ssub_qs8(const_three, smul_qs8(temp, smul_qs8(x2, x2, fixed_point_position), fixed_point_position));
155         x2                     = (smul_qs8(x2, three_minus_dx, fixed_point_position) >> 1);
156     }
157
158     temp = shift < 0 ? (x2 << (-shift >> 1)) : (x2 >> (shift >> 1));
159
160     return temp;
161 }
162
163 inline qint8_t sdiv_qs8(qint8_t a, qint8_t b, int fixed_point_position)
164 {
165     qint16_t temp = a << fixed_point_position;
166     return (qint8_t)(temp / b);
167 }
168
169 inline qint8_t sqexp_qs8(qint8_t a, int fixed_point_position)
170 {
171     // Constants
172     qint8_t const_one = (1 << fixed_point_position);
173     qint8_t ln2       = ((0x58 >> (6 - fixed_point_position)) + 1) >> 1;
174     qint8_t inv_ln2   = (((0x38 >> (6 - fixed_point_position)) + 1) >> 1) | const_one;
175     qint8_t A         = ((0x7F >> (6 - fixed_point_position)) + 1) >> 1;
176     qint8_t B         = ((0x3F >> (6 - fixed_point_position)) + 1) >> 1;
177     qint8_t C         = ((0x16 >> (6 - fixed_point_position)) + 1) >> 1;
178     qint8_t D         = ((0x05 >> (6 - fixed_point_position)) + 1) >> 1;
179
180     // Polynomial expansion
181     int     dec_a = (sqmul_qs8(a, inv_ln2, fixed_point_position) >> fixed_point_position);
182     qint8_t alpha = sabs_qs8(sqsub_qs8(a, sqmul_qs8(ln2, sqshl_qs8(dec_a, fixed_point_position), fixed_point_position)));
183     qint8_t sum   = sqadd_qs8(sqmul_qs8(alpha, D, fixed_point_position), C);
184     sum           = sqadd_qs8(sqmul_qs8(alpha, sum, fixed_point_position), B);
185     sum           = sqadd_qs8(sqmul_qs8(alpha, sum, fixed_point_position), A);
186     sum           = sqmul_qs8(alpha, sum, fixed_point_position);
187     sum           = sqadd_qs8(sum, const_one);
188
189     return (dec_a < 0) ? (sum >> -dec_a) : sqshl_qs8(sum, dec_a);
190 }
191
192 inline qint8_t slog_qs8(qint8_t a, int fixed_point_position)
193 {
194     // Constants
195     qint8_t const_one = (1 << fixed_point_position);
196     qint8_t ln2       = (0x58 >> (7 - fixed_point_position));
197     qint8_t A         = (0x5C >> (7 - fixed_point_position - 1));
198     qint8_t B         = -(0x56 >> (7 - fixed_point_position));
199     qint8_t C         = (0x29 >> (7 - fixed_point_position));
200     qint8_t D         = -(0x0A >> (7 - fixed_point_position));
201
202     if((const_one == a) || (a < 0))
203     {
204         return 0;
205     }
206     else if(a < const_one)
207     {
208         return -slog_qs8(sdiv_qs8(const_one, a, fixed_point_position), fixed_point_position);
209     }
210
211     // Remove even powers of 2
212     qint8_t shift_val = 31 - __builtin_clz(a >> fixed_point_position);
213     a >>= shift_val;
214     a = ssub_qs8(a, const_one);
215
216     // Polynomial expansion
217     auto sum = sqadd_qs8(sqmul_qs8(a, D, fixed_point_position), C);
218     sum      = sqadd_qs8(sqmul_qs8(a, sum, fixed_point_position), B);
219     sum      = sqadd_qs8(sqmul_qs8(a, sum, fixed_point_position), A);
220     sum      = sqmul_qs8(a, sum, fixed_point_position);
221
222     return smul_qs8(sadd_qs8(sum, shift_val << fixed_point_position), ln2, fixed_point_position);
223 }
224
225 inline float scvt_f32_qs8(qint8_t a, int fixed_point_position)
226 {
227     return static_cast<float>(a) / (1 << fixed_point_position);
228 }
229
230 inline qint8_t scvt_qs8_f32(float a, int fixed_point_position)
231 {
232     // round_nearest_integer(a * 2^(fixed_point_position))
233     return static_cast<qint8_t>(static_cast<float>(a) * (1 << fixed_point_position) + 0.5f);
234 }
235
236 inline float scvt_f32_qs16(qint16_t a, int fixed_point_position)
237 {
238     return static_cast<float>(a) / (1 << fixed_point_position);
239 }
240
241 inline qint8_t scvt_qs16_f32(float a, int fixed_point_position)
242 {
243     // round_nearest_integer(a * 2^(fixed_point_position))
244     return static_cast<qint16_t>(static_cast<float>(a) * (1 << fixed_point_position) + 0.5f);
245 }
246
247 inline qint8_t sqmovn_qs16(qint16_t a)
248 {
249     // Saturate the result in case of overflow and cast to qint8_t
250     return saturate_convert<qint16_t, qint8_t>(a);
251 }
252 }