arm_compute v17.09
[platform/upstream/armcl.git] / arm_compute / core / FixedPoint.inl
1 /*
2  * Copyright (c) 2017 ARM Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/core/Error.h"
25
26 #include <cmath>
27 #include <limits>
28
29 namespace
30 {
31 template <typename TpIn, typename TpSat>
32 inline TpSat saturate_convert(TpIn a)
33 {
34     if(a > std::numeric_limits<TpSat>::max())
35     {
36         a = std::numeric_limits<TpSat>::max();
37     }
38     if(a < std::numeric_limits<TpSat>::min())
39     {
40         a = std::numeric_limits<TpSat>::min();
41     }
42     return static_cast<TpSat>(a);
43 }
44 } // namespace
45
46 namespace arm_compute
47 {
48 inline qint8_t sqshl_qs8(qint8_t a, int shift)
49 {
50     qint16_t tmp = static_cast<qint16_t>(a) << shift;
51
52     // Saturate the result in case of overflow and cast to qint8_t
53     return saturate_convert<qint16_t, qint8_t>(tmp);
54 }
55
56 inline qint16_t sqshl_qs16(qint16_t a, int shift)
57 {
58     qint32_t tmp = static_cast<qint32_t>(a) << shift;
59
60     // Saturate the result in case of overflow and cast to qint16_t
61     return saturate_convert<qint32_t, qint16_t>(tmp);
62 }
63
64 inline qint8_t sshr_qs8(qint8_t a, int shift)
65 {
66     ARM_COMPUTE_ERROR_ON_MSG(shift == 0, "Shift should not be zero");
67     const qint8_t round_val = 1 << (shift - 1);
68     return sqadd_qs8(a, round_val) >> shift;
69 }
70
71 inline qint16_t sshr_qs16(qint16_t a, int shift)
72 {
73     ARM_COMPUTE_ERROR_ON_MSG(shift == 0, "Shift should not be zero");
74     const qint16_t round_val = 1 << (shift - 1);
75     return sqadd_qs16(a, round_val) >> shift;
76 }
77
78 inline qint8_t sabs_qs8(qint8_t a)
79 {
80     return (a < 0) ? (a == std::numeric_limits<int8_t>::min()) ? std::numeric_limits<int8_t>::max() : -a : a;
81 }
82
83 inline qint16_t sabs_qs16(qint16_t a)
84 {
85     return (a < 0) ? (a == std::numeric_limits<int16_t>::min()) ? std::numeric_limits<int16_t>::max() : -a : a;
86 }
87
88 inline qint8_t sadd_qs8(qint8_t a, qint8_t b)
89 {
90     return a + b;
91 }
92
93 inline qint16_t sadd_qs16(qint16_t a, qint16_t b)
94 {
95     return a + b;
96 }
97
98 inline qint8_t sqadd_qs8(qint8_t a, qint8_t b)
99 {
100     // We need to store the temporary result in qint16_t otherwise we cannot evaluate the overflow
101     qint16_t tmp = (static_cast<qint16_t>(a) + static_cast<qint16_t>(b));
102
103     // Saturate the result in case of overflow and cast to qint8_t
104     return saturate_convert<qint16_t, qint8_t>(tmp);
105 }
106
107 inline qint16_t sqadd_qs16(qint16_t a, qint16_t b)
108 {
109     // We need to store the temporary result in qint32_t otherwise we cannot evaluate the overflow
110     qint32_t tmp = (static_cast<qint32_t>(a) + static_cast<qint32_t>(b));
111
112     // Saturate the result in case of overflow and cast to qint16_t
113     return saturate_convert<qint32_t, qint16_t>(tmp);
114 }
115
116 inline qint32_t sqadd_qs32(qint32_t a, qint32_t b)
117 {
118     // We need to store the temporary result in qint64_t otherwise we cannot evaluate the overflow
119     qint64_t tmp = (static_cast<qint64_t>(a) + static_cast<qint64_t>(b));
120
121     // Saturate the result in case of overflow and cast to qint32_t
122     return saturate_convert<qint64_t, qint32_t>(tmp);
123 }
124
125 inline qint8_t ssub_qs8(qint8_t a, qint8_t b)
126 {
127     return a - b;
128 }
129
130 inline qint16_t ssub_qs16(qint16_t a, qint16_t b)
131 {
132     return a - b;
133 }
134
135 inline qint8_t sqsub_qs8(qint8_t a, qint8_t b)
136 {
137     // We need to store the temporary result in uint16_t otherwise we cannot evaluate the overflow
138     qint16_t tmp = static_cast<qint16_t>(a) - static_cast<qint16_t>(b);
139
140     // Saturate the result in case of overflow and cast to qint8_t
141     return saturate_convert<qint16_t, qint8_t>(tmp);
142 }
143
144 inline qint16_t sqsub_qs16(qint16_t a, qint16_t b)
145 {
146     // We need to store the temporary result in qint32_t otherwise we cannot evaluate the overflow
147     qint32_t tmp = static_cast<qint32_t>(a) - static_cast<qint32_t>(b);
148
149     // Saturate the result in case of overflow and cast to qint16_t
150     return saturate_convert<qint32_t, qint16_t>(tmp);
151 }
152
153 inline qint8_t smul_qs8(qint8_t a, qint8_t b, int fixed_point_position)
154 {
155     const qint16_t round_up_const = (1 << (fixed_point_position - 1));
156
157     qint16_t tmp = static_cast<qint16_t>(a) * static_cast<qint16_t>(b);
158
159     // Rounding up
160     tmp += round_up_const;
161
162     return static_cast<qint8_t>(tmp >> fixed_point_position);
163 }
164
165 inline qint16_t smul_qs16(qint16_t a, qint16_t b, int fixed_point_position)
166 {
167     const qint32_t round_up_const = (1 << (fixed_point_position - 1));
168
169     qint32_t tmp = static_cast<qint32_t>(a) * static_cast<qint32_t>(b);
170
171     // Rounding up
172     tmp += round_up_const;
173
174     return static_cast<qint16_t>(tmp >> fixed_point_position);
175 }
176
177 inline qint8_t sqmul_qs8(qint8_t a, qint8_t b, int fixed_point_position)
178 {
179     const qint16_t round_up_const = (1 << (fixed_point_position - 1));
180
181     qint16_t tmp = static_cast<qint16_t>(a) * static_cast<qint16_t>(b);
182
183     // Rounding up
184     tmp += round_up_const;
185
186     return saturate_convert<qint16_t, qint8_t>(tmp >> fixed_point_position);
187 }
188
189 inline qint16_t sqmul_qs16(qint16_t a, qint16_t b, int fixed_point_position)
190 {
191     const qint32_t round_up_const = (1 << (fixed_point_position - 1));
192
193     qint32_t tmp = static_cast<qint32_t>(a) * static_cast<qint32_t>(b);
194
195     // Rounding up
196     tmp += round_up_const;
197
198     return saturate_convert<qint32_t, qint16_t>(tmp >> fixed_point_position);
199 }
200
201 inline qint16_t sqmull_qs8(qint8_t a, qint8_t b, int fixed_point_position)
202 {
203     const qint16_t round_up_const = (1 << (fixed_point_position - 1));
204
205     qint16_t tmp = static_cast<qint16_t>(a) * static_cast<qint16_t>(b);
206
207     // Rounding up
208     tmp += round_up_const;
209
210     return tmp >> fixed_point_position;
211 }
212
213 inline qint32_t sqmull_qs16(qint16_t a, qint16_t b, int fixed_point_position)
214 {
215     const qint32_t round_up_const = (1 << (fixed_point_position - 1));
216
217     qint32_t tmp = static_cast<qint32_t>(a) * static_cast<qint32_t>(b);
218
219     // Rounding up
220     tmp += round_up_const;
221
222     return tmp >> fixed_point_position;
223 }
224
225 inline qint8_t sinvsqrt_qs8(qint8_t a, int fixed_point_position)
226 {
227     const qint8_t shift = 8 - (fixed_point_position + (__builtin_clz(a) - 24));
228
229     const qint8_t const_three = (3 << fixed_point_position);
230     qint8_t       temp        = shift < 0 ? (a << -shift) : (a >> shift);
231     qint8_t       x2          = temp;
232
233     // We need three iterations to find the result
234     for(int i = 0; i < 3; ++i)
235     {
236         qint8_t three_minus_dx = ssub_qs8(const_three, smul_qs8(temp, smul_qs8(x2, x2, fixed_point_position), fixed_point_position));
237         x2                     = (smul_qs8(x2, three_minus_dx, fixed_point_position) >> 1);
238     }
239
240     temp = shift < 0 ? (x2 << (-shift >> 1)) : (x2 >> (shift >> 1));
241
242     return temp;
243 }
244
245 inline qint16_t sinvsqrt_qs16(qint16_t a, int fixed_point_position)
246 {
247     const qint16_t shift = 16 - (fixed_point_position + (__builtin_clz(a) - 16));
248
249     const qint16_t const_three = (3 << fixed_point_position);
250     qint16_t       temp        = shift < 0 ? (a << -shift) : (a >> shift);
251     qint16_t       x2          = temp;
252
253     // We need three iterations to find the result
254     for(int i = 0; i < 3; ++i)
255     {
256         qint16_t three_minus_dx = ssub_qs16(const_three, smul_qs16(temp, smul_qs16(x2, x2, fixed_point_position), fixed_point_position));
257         x2                      = smul_qs16(x2, three_minus_dx, fixed_point_position) >> 1;
258     }
259
260     temp = shift < 0 ? (x2 << ((-shift) >> 1)) : (x2 >> (shift >> 1));
261
262     return temp;
263 }
264
265 inline qint8_t sdiv_qs8(qint8_t a, qint8_t b, int fixed_point_position)
266 {
267     const qint16_t temp = a << fixed_point_position;
268     return static_cast<qint8_t>(temp / b);
269 }
270
271 inline qint16_t sdiv_qs16(qint16_t a, qint16_t b, int fixed_point_position)
272 {
273     const qint32_t temp = a << fixed_point_position;
274     return static_cast<qint16_t>(temp / b);
275 }
276
277 inline qint8_t sqexp_qs8(qint8_t a, int fixed_point_position)
278 {
279     // Constants
280     const qint8_t const_one = (1 << fixed_point_position);
281     const qint8_t ln2       = ((0x58 >> (6 - fixed_point_position)) + 1) >> 1;
282     const qint8_t inv_ln2   = (((0x38 >> (6 - fixed_point_position)) + 1) >> 1) | const_one;
283     const qint8_t A         = ((0x7F >> (6 - fixed_point_position)) + 1) >> 1;
284     const qint8_t B         = ((0x3F >> (6 - fixed_point_position)) + 1) >> 1;
285     const qint8_t C         = ((0x16 >> (6 - fixed_point_position)) + 1) >> 1;
286     const qint8_t D         = ((0x05 >> (6 - fixed_point_position)) + 1) >> 1;
287
288     // Polynomial expansion
289     const int     dec_a = (sqmul_qs8(a, inv_ln2, fixed_point_position) >> fixed_point_position);
290     const qint8_t alpha = sabs_qs8(sqsub_qs8(a, sqmul_qs8(ln2, sqshl_qs8(dec_a, fixed_point_position), fixed_point_position)));
291     qint8_t       sum   = sqadd_qs8(sqmul_qs8(alpha, D, fixed_point_position), C);
292     sum                 = sqadd_qs8(sqmul_qs8(alpha, sum, fixed_point_position), B);
293     sum                 = sqadd_qs8(sqmul_qs8(alpha, sum, fixed_point_position), A);
294     sum                 = sqmul_qs8(alpha, sum, fixed_point_position);
295     sum                 = sqadd_qs8(sum, const_one);
296
297     return (dec_a < 0) ? (sum >> -dec_a) : sqshl_qs8(sum, dec_a);
298 }
299
300 inline qint16_t sqexp_qs16(qint16_t a, int fixed_point_position)
301 {
302     // Constants
303     const qint16_t const_one = (1 << fixed_point_position);
304     const qint16_t ln2       = ((0x58B9 >> (14 - fixed_point_position)) + 1) >> 1;
305     const qint16_t inv_ln2   = (((0x38AA >> (14 - fixed_point_position)) + 1) >> 1) | const_one;
306     const qint16_t A         = ((0x7FBA >> (14 - fixed_point_position)) + 1) >> 1;
307     const qint16_t B         = ((0x3FE9 >> (14 - fixed_point_position)) + 1) >> 1;
308     const qint16_t C         = ((0x1693 >> (14 - fixed_point_position)) + 1) >> 1;
309     const qint16_t D         = ((0x0592 >> (14 - fixed_point_position)) + 1) >> 1;
310
311     // Polynomial expansion
312     const int      dec_a = (sqmul_qs16(a, inv_ln2, fixed_point_position) >> fixed_point_position);
313     const qint16_t alpha = sabs_qs16(sqsub_qs16(a, sqmul_qs16(ln2, sqshl_qs16(dec_a, fixed_point_position), fixed_point_position)));
314     qint16_t       sum   = sqadd_qs16(sqmul_qs16(alpha, D, fixed_point_position), C);
315     sum                  = sqadd_qs16(sqmul_qs16(alpha, sum, fixed_point_position), B);
316     sum                  = sqadd_qs16(sqmul_qs16(alpha, sum, fixed_point_position), A);
317     sum                  = sqmul_qs16(alpha, sum, fixed_point_position);
318     sum                  = sqadd_qs16(sum, const_one);
319
320     return (dec_a < 0) ? (sum >> -dec_a) : sqshl_qs16(sum, dec_a);
321 }
322
323 inline qint8_t slog_qs8(qint8_t a, int fixed_point_position)
324 {
325     // Constants
326     qint8_t const_one = (1 << fixed_point_position);
327     qint8_t ln2       = (0x58 >> (7 - fixed_point_position));
328     qint8_t A         = (0x5C >> (7 - fixed_point_position - 1));
329     qint8_t B         = -(0x56 >> (7 - fixed_point_position));
330     qint8_t C         = (0x29 >> (7 - fixed_point_position));
331     qint8_t D         = -(0x0A >> (7 - fixed_point_position));
332
333     if((const_one == a) || (a < 0))
334     {
335         return 0;
336     }
337     else if(a < const_one)
338     {
339         return -slog_qs8(sdiv_qs8(const_one, a, fixed_point_position), fixed_point_position);
340     }
341
342     // Remove even powers of 2
343     qint8_t shift_val = 31 - __builtin_clz(a >> fixed_point_position);
344     a >>= shift_val;
345     a = ssub_qs8(a, const_one);
346
347     // Polynomial expansion
348     qint8_t sum = sqadd_qs8(sqmul_qs8(a, D, fixed_point_position), C);
349     sum         = sqadd_qs8(sqmul_qs8(a, sum, fixed_point_position), B);
350     sum         = sqadd_qs8(sqmul_qs8(a, sum, fixed_point_position), A);
351     sum         = sqmul_qs8(a, sum, fixed_point_position);
352
353     return smul_qs8(sadd_qs8(sum, shift_val << fixed_point_position), ln2, fixed_point_position);
354 }
355
356 inline qint16_t slog_qs16(qint16_t a, int fixed_point_position)
357 {
358     // Constants
359     qint16_t const_one = (1 << fixed_point_position);
360     qint16_t ln2       = (0x58B9 >> (7 - fixed_point_position));
361     qint16_t A         = (0x5C0F >> (7 - fixed_point_position - 1));
362     qint16_t B         = -(0x56AE >> (7 - fixed_point_position));
363     qint16_t C         = (0x2933 >> (7 - fixed_point_position));
364     qint16_t D         = -(0x0AA7 >> (7 - fixed_point_position));
365
366     if((const_one == a) || (a < 0))
367     {
368         return 0;
369     }
370     else if(a < const_one)
371     {
372         return -slog_qs16(sdiv_qs16(const_one, a, fixed_point_position), fixed_point_position);
373     }
374
375     // Remove even powers of 2
376     qint16_t shift_val = 31 - __builtin_clz(a >> fixed_point_position);
377     a >>= shift_val;
378     a = ssub_qs16(a, const_one);
379
380     // Polynomial expansion
381     qint16_t sum = sqadd_qs16(sqmul_qs16(a, D, fixed_point_position), C);
382     sum          = sqadd_qs16(sqmul_qs16(a, sum, fixed_point_position), B);
383     sum          = sqadd_qs16(sqmul_qs16(a, sum, fixed_point_position), A);
384     sum          = sqmul_qs16(a, sum, fixed_point_position);
385
386     return smul_qs16(sadd_qs16(sum, shift_val << fixed_point_position), ln2, fixed_point_position);
387 }
388
389 inline float scvt_f32_qs8(qint8_t a, int fixed_point_position)
390 {
391     return static_cast<float>(a) / (1 << fixed_point_position);
392 }
393
394 inline qint8_t sqcvt_qs8_f32(float a, int fixed_point_position)
395 {
396     // round_nearest_integer(a * 2^(fixed_point_position))
397     return saturate_convert<float, qint8_t>(a * (1 << fixed_point_position) + ((a >= 0) ? 0.5 : -0.5));
398 }
399
400 inline float scvt_f32_qs16(qint16_t a, int fixed_point_position)
401 {
402     return static_cast<float>(a) / (1 << fixed_point_position);
403 }
404
405 inline qint16_t sqcvt_qs16_f32(float a, int fixed_point_position)
406 {
407     // round_nearest_integer(a * 2^(fixed_point_position))
408     return saturate_convert<float, qint16_t>(a * (1 << fixed_point_position) + ((a >= 0) ? 0.5 : -0.5));
409 }
410
411 inline qint8_t sqmovn_qs16(qint16_t a)
412 {
413     // Saturate the result in case of overflow and cast to qint8_t
414     return saturate_convert<qint16_t, qint8_t>(a);
415 }
416
417 inline qint16_t sqmovn_qs32(qint32_t a)
418 {
419     // Saturate the result in case of overflow and cast to qint16_t
420     return saturate_convert<qint32_t, qint16_t>(a);
421 }
422 }