arm_compute v18.02
[platform/upstream/armcl.git] / arm_compute / core / FixedPoint.inl
1 /*
2  * Copyright (c) 2017-2018 ARM Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/core/Error.h"
25 #include "arm_compute/core/utils/misc/utility.h"
26
27 #include <cmath>
28 #include <limits>
29
30 namespace arm_compute
31 {
32 inline qint8_t sqshl_qs8(qint8_t a, int shift)
33 {
34     qint16_t tmp = static_cast<qint16_t>(a) << shift;
35
36     // Saturate the result in case of overflow and cast to qint8_t
37     return utility::saturate_cast<qint8_t>(tmp);
38 }
39
40 inline qint16_t sqshl_qs16(qint16_t a, int shift)
41 {
42     qint32_t tmp = static_cast<qint32_t>(a) << shift;
43
44     // Saturate the result in case of overflow and cast to qint16_t
45     return utility::saturate_cast<qint16_t>(tmp);
46 }
47
48 inline qint8_t sshr_qs8(qint8_t a, int shift)
49 {
50     ARM_COMPUTE_ERROR_ON_MSG(shift == 0, "Shift should not be zero");
51     const qint8_t round_val = 1 << (shift - 1);
52     return sqadd_qs8(a, round_val) >> shift;
53 }
54
55 inline qint16_t sshr_qs16(qint16_t a, int shift)
56 {
57     ARM_COMPUTE_ERROR_ON_MSG(shift == 0, "Shift should not be zero");
58     const qint16_t round_val = 1 << (shift - 1);
59     return sqadd_qs16(a, round_val) >> shift;
60 }
61
62 inline qint8_t sabs_qs8(qint8_t a)
63 {
64     return (a < 0) ? (a == std::numeric_limits<int8_t>::min()) ? std::numeric_limits<int8_t>::max() : -a : a;
65 }
66
67 inline qint16_t sabs_qs16(qint16_t a)
68 {
69     return (a < 0) ? (a == std::numeric_limits<int16_t>::min()) ? std::numeric_limits<int16_t>::max() : -a : a;
70 }
71
72 inline qint8_t sadd_qs8(qint8_t a, qint8_t b)
73 {
74     return a + b;
75 }
76
77 inline qint16_t sadd_qs16(qint16_t a, qint16_t b)
78 {
79     return a + b;
80 }
81
82 inline qint8_t sqadd_qs8(qint8_t a, qint8_t b)
83 {
84     // We need to store the temporary result in qint16_t otherwise we cannot evaluate the overflow
85     qint16_t tmp = (static_cast<qint16_t>(a) + static_cast<qint16_t>(b));
86
87     // Saturate the result in case of overflow and cast to qint8_t
88     return utility::saturate_cast<qint8_t>(tmp);
89 }
90
91 inline qint16_t sqadd_qs16(qint16_t a, qint16_t b)
92 {
93     // We need to store the temporary result in qint32_t otherwise we cannot evaluate the overflow
94     qint32_t tmp = (static_cast<qint32_t>(a) + static_cast<qint32_t>(b));
95
96     // Saturate the result in case of overflow and cast to qint16_t
97     return utility::saturate_cast<qint16_t>(tmp);
98 }
99
100 inline qint32_t sqadd_qs32(qint32_t a, qint32_t b)
101 {
102     // We need to store the temporary result in qint64_t otherwise we cannot evaluate the overflow
103     qint64_t tmp = (static_cast<qint64_t>(a) + static_cast<qint64_t>(b));
104
105     // Saturate the result in case of overflow and cast to qint32_t
106     return utility::saturate_cast<qint32_t>(tmp);
107 }
108
109 inline qint8_t ssub_qs8(qint8_t a, qint8_t b)
110 {
111     return a - b;
112 }
113
114 inline qint16_t ssub_qs16(qint16_t a, qint16_t b)
115 {
116     return a - b;
117 }
118
119 inline qint8_t sqsub_qs8(qint8_t a, qint8_t b)
120 {
121     // We need to store the temporary result in uint16_t otherwise we cannot evaluate the overflow
122     qint16_t tmp = static_cast<qint16_t>(a) - static_cast<qint16_t>(b);
123
124     // Saturate the result in case of overflow and cast to qint8_t
125     return utility::saturate_cast<qint8_t>(tmp);
126 }
127
128 inline qint16_t sqsub_qs16(qint16_t a, qint16_t b)
129 {
130     // We need to store the temporary result in qint32_t otherwise we cannot evaluate the overflow
131     qint32_t tmp = static_cast<qint32_t>(a) - static_cast<qint32_t>(b);
132
133     // Saturate the result in case of overflow and cast to qint16_t
134     return utility::saturate_cast<qint16_t>(tmp);
135 }
136
137 inline qint8_t smul_qs8(qint8_t a, qint8_t b, int fixed_point_position)
138 {
139     const qint16_t round_up_const = (1 << (fixed_point_position - 1));
140
141     qint16_t tmp = static_cast<qint16_t>(a) * static_cast<qint16_t>(b);
142
143     // Rounding up
144     tmp += round_up_const;
145
146     return static_cast<qint8_t>(tmp >> fixed_point_position);
147 }
148
149 inline qint16_t smul_qs16(qint16_t a, qint16_t b, int fixed_point_position)
150 {
151     const qint32_t round_up_const = (1 << (fixed_point_position - 1));
152
153     qint32_t tmp = static_cast<qint32_t>(a) * static_cast<qint32_t>(b);
154
155     // Rounding up
156     tmp += round_up_const;
157
158     return static_cast<qint16_t>(tmp >> fixed_point_position);
159 }
160
161 inline qint8_t sqmul_qs8(qint8_t a, qint8_t b, int fixed_point_position)
162 {
163     const qint16_t round_up_const = (1 << (fixed_point_position - 1));
164
165     qint16_t tmp = static_cast<qint16_t>(a) * static_cast<qint16_t>(b);
166
167     // Rounding up
168     tmp += round_up_const;
169
170     return utility::saturate_cast<qint8_t>(tmp >> fixed_point_position);
171 }
172
173 inline qint16_t sqmul_qs16(qint16_t a, qint16_t b, int fixed_point_position)
174 {
175     const qint32_t round_up_const = (1 << (fixed_point_position - 1));
176
177     qint32_t tmp = static_cast<qint32_t>(a) * static_cast<qint32_t>(b);
178
179     // Rounding up
180     tmp += round_up_const;
181
182     return utility::saturate_cast<qint16_t>(tmp >> fixed_point_position);
183 }
184
185 inline qint16_t sqmull_qs8(qint8_t a, qint8_t b, int fixed_point_position)
186 {
187     const qint16_t round_up_const = (1 << (fixed_point_position - 1));
188
189     qint16_t tmp = static_cast<qint16_t>(a) * static_cast<qint16_t>(b);
190
191     // Rounding up
192     tmp += round_up_const;
193
194     return tmp >> fixed_point_position;
195 }
196
197 inline qint32_t sqmull_qs16(qint16_t a, qint16_t b, int fixed_point_position)
198 {
199     const qint32_t round_up_const = (1 << (fixed_point_position - 1));
200
201     qint32_t tmp = static_cast<qint32_t>(a) * static_cast<qint32_t>(b);
202
203     // Rounding up
204     tmp += round_up_const;
205
206     return tmp >> fixed_point_position;
207 }
208
209 inline qint8_t sinvsqrt_qs8(qint8_t a, int fixed_point_position)
210 {
211     const qint8_t shift = 8 - (fixed_point_position + (__builtin_clz(a) - 24));
212
213     const qint8_t const_three = (3 << fixed_point_position);
214     qint8_t       temp        = shift < 0 ? (a << -shift) : (a >> shift);
215     qint8_t       x2          = temp;
216
217     // We need three iterations to find the result
218     for(int i = 0; i < 3; ++i)
219     {
220         qint8_t three_minus_dx = ssub_qs8(const_three, smul_qs8(temp, smul_qs8(x2, x2, fixed_point_position), fixed_point_position));
221         x2                     = (smul_qs8(x2, three_minus_dx, fixed_point_position) >> 1);
222     }
223
224     temp = shift < 0 ? (x2 << (-shift >> 1)) : (x2 >> (shift >> 1));
225
226     return temp;
227 }
228
229 inline qint16_t sinvsqrt_qs16(qint16_t a, int fixed_point_position)
230 {
231     const qint16_t shift = 16 - (fixed_point_position + (__builtin_clz(a) - 16));
232
233     const qint16_t const_three = (3 << fixed_point_position);
234     qint16_t       temp        = shift < 0 ? (a << -shift) : (a >> shift);
235     qint16_t       x2          = temp;
236
237     // We need three iterations to find the result
238     for(int i = 0; i < 3; ++i)
239     {
240         qint16_t three_minus_dx = ssub_qs16(const_three, smul_qs16(temp, smul_qs16(x2, x2, fixed_point_position), fixed_point_position));
241         x2                      = smul_qs16(x2, three_minus_dx, fixed_point_position) >> 1;
242     }
243
244     temp = shift < 0 ? (x2 << ((-shift) >> 1)) : (x2 >> (shift >> 1));
245
246     return temp;
247 }
248
249 inline qint8_t sdiv_qs8(qint8_t a, qint8_t b, int fixed_point_position)
250 {
251     const qint16_t temp = a << fixed_point_position;
252     return static_cast<qint8_t>(temp / b);
253 }
254
255 inline qint16_t sdiv_qs16(qint16_t a, qint16_t b, int fixed_point_position)
256 {
257     const qint32_t temp = a << fixed_point_position;
258     return static_cast<qint16_t>(temp / b);
259 }
260
261 inline qint8_t sqexp_qs8(qint8_t a, int fixed_point_position)
262 {
263     // Constants
264     const qint8_t const_one = (1 << fixed_point_position);
265     const qint8_t ln2       = ((0x58 >> (6 - fixed_point_position)) + 1) >> 1;
266     const qint8_t inv_ln2   = (((0x38 >> (6 - fixed_point_position)) + 1) >> 1) | const_one;
267     const qint8_t A         = ((0x7F >> (6 - fixed_point_position)) + 1) >> 1;
268     const qint8_t B         = ((0x3F >> (6 - fixed_point_position)) + 1) >> 1;
269     const qint8_t C         = ((0x16 >> (6 - fixed_point_position)) + 1) >> 1;
270     const qint8_t D         = ((0x05 >> (6 - fixed_point_position)) + 1) >> 1;
271
272     // Polynomial expansion
273     const int     dec_a = (sqmul_qs8(a, inv_ln2, fixed_point_position) >> fixed_point_position);
274     const qint8_t alpha = sabs_qs8(sqsub_qs8(a, sqmul_qs8(ln2, sqshl_qs8(dec_a, fixed_point_position), fixed_point_position)));
275     qint8_t       sum   = sqadd_qs8(sqmul_qs8(alpha, D, fixed_point_position), C);
276     sum                 = sqadd_qs8(sqmul_qs8(alpha, sum, fixed_point_position), B);
277     sum                 = sqadd_qs8(sqmul_qs8(alpha, sum, fixed_point_position), A);
278     sum                 = sqmul_qs8(alpha, sum, fixed_point_position);
279     sum                 = sqadd_qs8(sum, const_one);
280
281     return (dec_a < 0) ? (sum >> -dec_a) : sqshl_qs8(sum, dec_a);
282 }
283
284 inline qint16_t sqexp_qs16(qint16_t a, int fixed_point_position)
285 {
286     // Constants
287     const qint16_t const_one = (1 << fixed_point_position);
288     const qint16_t ln2       = ((0x58B9 >> (14 - fixed_point_position)) + 1) >> 1;
289     const qint16_t inv_ln2   = (((0x38AA >> (14 - fixed_point_position)) + 1) >> 1) | const_one;
290     const qint16_t A         = ((0x7FBA >> (14 - fixed_point_position)) + 1) >> 1;
291     const qint16_t B         = ((0x3FE9 >> (14 - fixed_point_position)) + 1) >> 1;
292     const qint16_t C         = ((0x1693 >> (14 - fixed_point_position)) + 1) >> 1;
293     const qint16_t D         = ((0x0592 >> (14 - fixed_point_position)) + 1) >> 1;
294
295     // Polynomial expansion
296     const int      dec_a = (sqmul_qs16(a, inv_ln2, fixed_point_position) >> fixed_point_position);
297     const qint16_t alpha = sabs_qs16(sqsub_qs16(a, sqmul_qs16(ln2, sqshl_qs16(dec_a, fixed_point_position), fixed_point_position)));
298     qint16_t       sum   = sqadd_qs16(sqmul_qs16(alpha, D, fixed_point_position), C);
299     sum                  = sqadd_qs16(sqmul_qs16(alpha, sum, fixed_point_position), B);
300     sum                  = sqadd_qs16(sqmul_qs16(alpha, sum, fixed_point_position), A);
301     sum                  = sqmul_qs16(alpha, sum, fixed_point_position);
302     sum                  = sqadd_qs16(sum, const_one);
303
304     return (dec_a < 0) ? (sum >> -dec_a) : sqshl_qs16(sum, dec_a);
305 }
306
307 inline qint8_t slog_qs8(qint8_t a, int fixed_point_position)
308 {
309     // Constants
310     qint8_t const_one = (1 << fixed_point_position);
311     qint8_t ln2       = (0x58 >> (7 - fixed_point_position));
312     qint8_t A         = (0x5C >> (7 - fixed_point_position - 1));
313     qint8_t B         = -(0x56 >> (7 - fixed_point_position));
314     qint8_t C         = (0x29 >> (7 - fixed_point_position));
315     qint8_t D         = -(0x0A >> (7 - fixed_point_position));
316
317     if((const_one == a) || (a < 0))
318     {
319         return 0;
320     }
321     else if(a < const_one)
322     {
323         return -slog_qs8(sdiv_qs8(const_one, a, fixed_point_position), fixed_point_position);
324     }
325
326     // Remove even powers of 2
327     qint8_t shift_val = 31 - __builtin_clz(a >> fixed_point_position);
328     a >>= shift_val;
329     a = ssub_qs8(a, const_one);
330
331     // Polynomial expansion
332     qint8_t sum = sqadd_qs8(sqmul_qs8(a, D, fixed_point_position), C);
333     sum         = sqadd_qs8(sqmul_qs8(a, sum, fixed_point_position), B);
334     sum         = sqadd_qs8(sqmul_qs8(a, sum, fixed_point_position), A);
335     sum         = sqmul_qs8(a, sum, fixed_point_position);
336
337     return smul_qs8(sadd_qs8(sum, shift_val << fixed_point_position), ln2, fixed_point_position);
338 }
339
340 inline qint16_t slog_qs16(qint16_t a, int fixed_point_position)
341 {
342     // Constants
343     qint16_t const_one = (1 << fixed_point_position);
344     qint16_t ln2       = (0x58B9 >> (7 - fixed_point_position));
345     qint16_t A         = (0x5C0F >> (7 - fixed_point_position - 1));
346     qint16_t B         = -(0x56AE >> (7 - fixed_point_position));
347     qint16_t C         = (0x2933 >> (7 - fixed_point_position));
348     qint16_t D         = -(0x0AA7 >> (7 - fixed_point_position));
349
350     if((const_one == a) || (a < 0))
351     {
352         return 0;
353     }
354     else if(a < const_one)
355     {
356         return -slog_qs16(sdiv_qs16(const_one, a, fixed_point_position), fixed_point_position);
357     }
358
359     // Remove even powers of 2
360     qint16_t shift_val = 31 - __builtin_clz(a >> fixed_point_position);
361     a >>= shift_val;
362     a = ssub_qs16(a, const_one);
363
364     // Polynomial expansion
365     qint16_t sum = sqadd_qs16(sqmul_qs16(a, D, fixed_point_position), C);
366     sum          = sqadd_qs16(sqmul_qs16(a, sum, fixed_point_position), B);
367     sum          = sqadd_qs16(sqmul_qs16(a, sum, fixed_point_position), A);
368     sum          = sqmul_qs16(a, sum, fixed_point_position);
369
370     return smul_qs16(sadd_qs16(sum, shift_val << fixed_point_position), ln2, fixed_point_position);
371 }
372
373 inline float scvt_f32_qs8(qint8_t a, int fixed_point_position)
374 {
375     return static_cast<float>(a) / (1 << fixed_point_position);
376 }
377
378 inline qint8_t sqcvt_qs8_f32(float a, int fixed_point_position)
379 {
380     // round_nearest_integer(a * 2^(fixed_point_position))
381     return utility::saturate_cast<qint8_t>(a * (1 << fixed_point_position) + ((a >= 0) ? 0.5 : -0.5));
382 }
383
384 inline float scvt_f32_qs16(qint16_t a, int fixed_point_position)
385 {
386     return static_cast<float>(a) / (1 << fixed_point_position);
387 }
388
389 inline qint16_t sqcvt_qs16_f32(float a, int fixed_point_position)
390 {
391     // round_nearest_integer(a * 2^(fixed_point_position))
392     return utility::saturate_cast<qint16_t>(a * (1 << fixed_point_position) + ((a >= 0) ? 0.5 : -0.5));
393 }
394
395 inline qint8_t sqmovn_qs16(qint16_t a)
396 {
397     // Saturate the result in case of overflow and cast to qint8_t
398     return utility::saturate_cast<qint8_t>(a);
399 }
400
401 inline qint16_t sqmovn_qs32(qint32_t a)
402 {
403     // Saturate the result in case of overflow and cast to qint16_t
404     return utility::saturate_cast<qint16_t>(a);
405 }
406 }