2 * Copyright (c) 2017-2018 ARM Limited.
4 * SPDX-License-Identifier: MIT
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 #include "arm_compute/core/Error.h"
25 #include "arm_compute/core/utils/misc/utility.h"
32 inline qint8_t sqshl_qs8(qint8_t a, int shift)
34 qint16_t tmp = static_cast<qint16_t>(a) << shift;
36 // Saturate the result in case of overflow and cast to qint8_t
37 return utility::saturate_cast<qint8_t>(tmp);
40 inline qint16_t sqshl_qs16(qint16_t a, int shift)
42 qint32_t tmp = static_cast<qint32_t>(a) << shift;
44 // Saturate the result in case of overflow and cast to qint16_t
45 return utility::saturate_cast<qint16_t>(tmp);
48 inline qint8_t sshr_qs8(qint8_t a, int shift)
50 ARM_COMPUTE_ERROR_ON_MSG(shift == 0, "Shift should not be zero");
51 const qint8_t round_val = 1 << (shift - 1);
52 return sqadd_qs8(a, round_val) >> shift;
55 inline qint16_t sshr_qs16(qint16_t a, int shift)
57 ARM_COMPUTE_ERROR_ON_MSG(shift == 0, "Shift should not be zero");
58 const qint16_t round_val = 1 << (shift - 1);
59 return sqadd_qs16(a, round_val) >> shift;
62 inline qint8_t sabs_qs8(qint8_t a)
64 return (a < 0) ? (a == std::numeric_limits<int8_t>::min()) ? std::numeric_limits<int8_t>::max() : -a : a;
67 inline qint16_t sabs_qs16(qint16_t a)
69 return (a < 0) ? (a == std::numeric_limits<int16_t>::min()) ? std::numeric_limits<int16_t>::max() : -a : a;
72 inline qint8_t sadd_qs8(qint8_t a, qint8_t b)
77 inline qint16_t sadd_qs16(qint16_t a, qint16_t b)
82 inline qint8_t sqadd_qs8(qint8_t a, qint8_t b)
84 // We need to store the temporary result in qint16_t otherwise we cannot evaluate the overflow
85 qint16_t tmp = (static_cast<qint16_t>(a) + static_cast<qint16_t>(b));
87 // Saturate the result in case of overflow and cast to qint8_t
88 return utility::saturate_cast<qint8_t>(tmp);
91 inline qint16_t sqadd_qs16(qint16_t a, qint16_t b)
93 // We need to store the temporary result in qint32_t otherwise we cannot evaluate the overflow
94 qint32_t tmp = (static_cast<qint32_t>(a) + static_cast<qint32_t>(b));
96 // Saturate the result in case of overflow and cast to qint16_t
97 return utility::saturate_cast<qint16_t>(tmp);
100 inline qint32_t sqadd_qs32(qint32_t a, qint32_t b)
102 // We need to store the temporary result in qint64_t otherwise we cannot evaluate the overflow
103 qint64_t tmp = (static_cast<qint64_t>(a) + static_cast<qint64_t>(b));
105 // Saturate the result in case of overflow and cast to qint32_t
106 return utility::saturate_cast<qint32_t>(tmp);
109 inline qint8_t ssub_qs8(qint8_t a, qint8_t b)
114 inline qint16_t ssub_qs16(qint16_t a, qint16_t b)
119 inline qint8_t sqsub_qs8(qint8_t a, qint8_t b)
121 // We need to store the temporary result in uint16_t otherwise we cannot evaluate the overflow
122 qint16_t tmp = static_cast<qint16_t>(a) - static_cast<qint16_t>(b);
124 // Saturate the result in case of overflow and cast to qint8_t
125 return utility::saturate_cast<qint8_t>(tmp);
128 inline qint16_t sqsub_qs16(qint16_t a, qint16_t b)
130 // We need to store the temporary result in qint32_t otherwise we cannot evaluate the overflow
131 qint32_t tmp = static_cast<qint32_t>(a) - static_cast<qint32_t>(b);
133 // Saturate the result in case of overflow and cast to qint16_t
134 return utility::saturate_cast<qint16_t>(tmp);
137 inline qint8_t smul_qs8(qint8_t a, qint8_t b, int fixed_point_position)
139 const qint16_t round_up_const = (1 << (fixed_point_position - 1));
141 qint16_t tmp = static_cast<qint16_t>(a) * static_cast<qint16_t>(b);
144 tmp += round_up_const;
146 return static_cast<qint8_t>(tmp >> fixed_point_position);
149 inline qint16_t smul_qs16(qint16_t a, qint16_t b, int fixed_point_position)
151 const qint32_t round_up_const = (1 << (fixed_point_position - 1));
153 qint32_t tmp = static_cast<qint32_t>(a) * static_cast<qint32_t>(b);
156 tmp += round_up_const;
158 return static_cast<qint16_t>(tmp >> fixed_point_position);
161 inline qint8_t sqmul_qs8(qint8_t a, qint8_t b, int fixed_point_position)
163 const qint16_t round_up_const = (1 << (fixed_point_position - 1));
165 qint16_t tmp = static_cast<qint16_t>(a) * static_cast<qint16_t>(b);
168 tmp += round_up_const;
170 return utility::saturate_cast<qint8_t>(tmp >> fixed_point_position);
173 inline qint16_t sqmul_qs16(qint16_t a, qint16_t b, int fixed_point_position)
175 const qint32_t round_up_const = (1 << (fixed_point_position - 1));
177 qint32_t tmp = static_cast<qint32_t>(a) * static_cast<qint32_t>(b);
180 tmp += round_up_const;
182 return utility::saturate_cast<qint16_t>(tmp >> fixed_point_position);
185 inline qint16_t sqmull_qs8(qint8_t a, qint8_t b, int fixed_point_position)
187 const qint16_t round_up_const = (1 << (fixed_point_position - 1));
189 qint16_t tmp = static_cast<qint16_t>(a) * static_cast<qint16_t>(b);
192 tmp += round_up_const;
194 return tmp >> fixed_point_position;
197 inline qint32_t sqmull_qs16(qint16_t a, qint16_t b, int fixed_point_position)
199 const qint32_t round_up_const = (1 << (fixed_point_position - 1));
201 qint32_t tmp = static_cast<qint32_t>(a) * static_cast<qint32_t>(b);
204 tmp += round_up_const;
206 return tmp >> fixed_point_position;
209 inline qint8_t sinvsqrt_qs8(qint8_t a, int fixed_point_position)
211 const qint8_t shift = 8 - (fixed_point_position + (__builtin_clz(a) - 24));
213 const qint8_t const_three = (3 << fixed_point_position);
214 qint8_t temp = shift < 0 ? (a << -shift) : (a >> shift);
217 // We need three iterations to find the result
218 for(int i = 0; i < 3; ++i)
220 qint8_t three_minus_dx = ssub_qs8(const_three, smul_qs8(temp, smul_qs8(x2, x2, fixed_point_position), fixed_point_position));
221 x2 = (smul_qs8(x2, three_minus_dx, fixed_point_position) >> 1);
224 temp = shift < 0 ? (x2 << (-shift >> 1)) : (x2 >> (shift >> 1));
229 inline qint16_t sinvsqrt_qs16(qint16_t a, int fixed_point_position)
231 const qint16_t shift = 16 - (fixed_point_position + (__builtin_clz(a) - 16));
233 const qint16_t const_three = (3 << fixed_point_position);
234 qint16_t temp = shift < 0 ? (a << -shift) : (a >> shift);
237 // We need three iterations to find the result
238 for(int i = 0; i < 3; ++i)
240 qint16_t three_minus_dx = ssub_qs16(const_three, smul_qs16(temp, smul_qs16(x2, x2, fixed_point_position), fixed_point_position));
241 x2 = smul_qs16(x2, three_minus_dx, fixed_point_position) >> 1;
244 temp = shift < 0 ? (x2 << ((-shift) >> 1)) : (x2 >> (shift >> 1));
249 inline qint8_t sdiv_qs8(qint8_t a, qint8_t b, int fixed_point_position)
251 const qint16_t temp = a << fixed_point_position;
252 return static_cast<qint8_t>(temp / b);
255 inline qint16_t sdiv_qs16(qint16_t a, qint16_t b, int fixed_point_position)
257 const qint32_t temp = a << fixed_point_position;
258 return static_cast<qint16_t>(temp / b);
261 inline qint8_t sqexp_qs8(qint8_t a, int fixed_point_position)
264 const qint8_t const_one = (1 << fixed_point_position);
265 const qint8_t ln2 = ((0x58 >> (6 - fixed_point_position)) + 1) >> 1;
266 const qint8_t inv_ln2 = (((0x38 >> (6 - fixed_point_position)) + 1) >> 1) | const_one;
267 const qint8_t A = ((0x7F >> (6 - fixed_point_position)) + 1) >> 1;
268 const qint8_t B = ((0x3F >> (6 - fixed_point_position)) + 1) >> 1;
269 const qint8_t C = ((0x16 >> (6 - fixed_point_position)) + 1) >> 1;
270 const qint8_t D = ((0x05 >> (6 - fixed_point_position)) + 1) >> 1;
272 // Polynomial expansion
273 const int dec_a = (sqmul_qs8(a, inv_ln2, fixed_point_position) >> fixed_point_position);
274 const qint8_t alpha = sabs_qs8(sqsub_qs8(a, sqmul_qs8(ln2, sqshl_qs8(dec_a, fixed_point_position), fixed_point_position)));
275 qint8_t sum = sqadd_qs8(sqmul_qs8(alpha, D, fixed_point_position), C);
276 sum = sqadd_qs8(sqmul_qs8(alpha, sum, fixed_point_position), B);
277 sum = sqadd_qs8(sqmul_qs8(alpha, sum, fixed_point_position), A);
278 sum = sqmul_qs8(alpha, sum, fixed_point_position);
279 sum = sqadd_qs8(sum, const_one);
281 return (dec_a < 0) ? (sum >> -dec_a) : sqshl_qs8(sum, dec_a);
284 inline qint16_t sqexp_qs16(qint16_t a, int fixed_point_position)
287 const qint16_t const_one = (1 << fixed_point_position);
288 const qint16_t ln2 = ((0x58B9 >> (14 - fixed_point_position)) + 1) >> 1;
289 const qint16_t inv_ln2 = (((0x38AA >> (14 - fixed_point_position)) + 1) >> 1) | const_one;
290 const qint16_t A = ((0x7FBA >> (14 - fixed_point_position)) + 1) >> 1;
291 const qint16_t B = ((0x3FE9 >> (14 - fixed_point_position)) + 1) >> 1;
292 const qint16_t C = ((0x1693 >> (14 - fixed_point_position)) + 1) >> 1;
293 const qint16_t D = ((0x0592 >> (14 - fixed_point_position)) + 1) >> 1;
295 // Polynomial expansion
296 const int dec_a = (sqmul_qs16(a, inv_ln2, fixed_point_position) >> fixed_point_position);
297 const qint16_t alpha = sabs_qs16(sqsub_qs16(a, sqmul_qs16(ln2, sqshl_qs16(dec_a, fixed_point_position), fixed_point_position)));
298 qint16_t sum = sqadd_qs16(sqmul_qs16(alpha, D, fixed_point_position), C);
299 sum = sqadd_qs16(sqmul_qs16(alpha, sum, fixed_point_position), B);
300 sum = sqadd_qs16(sqmul_qs16(alpha, sum, fixed_point_position), A);
301 sum = sqmul_qs16(alpha, sum, fixed_point_position);
302 sum = sqadd_qs16(sum, const_one);
304 return (dec_a < 0) ? (sum >> -dec_a) : sqshl_qs16(sum, dec_a);
307 inline qint8_t slog_qs8(qint8_t a, int fixed_point_position)
310 qint8_t const_one = (1 << fixed_point_position);
311 qint8_t ln2 = (0x58 >> (7 - fixed_point_position));
312 qint8_t A = (0x5C >> (7 - fixed_point_position - 1));
313 qint8_t B = -(0x56 >> (7 - fixed_point_position));
314 qint8_t C = (0x29 >> (7 - fixed_point_position));
315 qint8_t D = -(0x0A >> (7 - fixed_point_position));
317 if((const_one == a) || (a < 0))
321 else if(a < const_one)
323 return -slog_qs8(sdiv_qs8(const_one, a, fixed_point_position), fixed_point_position);
326 // Remove even powers of 2
327 qint8_t shift_val = 31 - __builtin_clz(a >> fixed_point_position);
329 a = ssub_qs8(a, const_one);
331 // Polynomial expansion
332 qint8_t sum = sqadd_qs8(sqmul_qs8(a, D, fixed_point_position), C);
333 sum = sqadd_qs8(sqmul_qs8(a, sum, fixed_point_position), B);
334 sum = sqadd_qs8(sqmul_qs8(a, sum, fixed_point_position), A);
335 sum = sqmul_qs8(a, sum, fixed_point_position);
337 return smul_qs8(sadd_qs8(sum, shift_val << fixed_point_position), ln2, fixed_point_position);
340 inline qint16_t slog_qs16(qint16_t a, int fixed_point_position)
343 qint16_t const_one = (1 << fixed_point_position);
344 qint16_t ln2 = (0x58B9 >> (7 - fixed_point_position));
345 qint16_t A = (0x5C0F >> (7 - fixed_point_position - 1));
346 qint16_t B = -(0x56AE >> (7 - fixed_point_position));
347 qint16_t C = (0x2933 >> (7 - fixed_point_position));
348 qint16_t D = -(0x0AA7 >> (7 - fixed_point_position));
350 if((const_one == a) || (a < 0))
354 else if(a < const_one)
356 return -slog_qs16(sdiv_qs16(const_one, a, fixed_point_position), fixed_point_position);
359 // Remove even powers of 2
360 qint16_t shift_val = 31 - __builtin_clz(a >> fixed_point_position);
362 a = ssub_qs16(a, const_one);
364 // Polynomial expansion
365 qint16_t sum = sqadd_qs16(sqmul_qs16(a, D, fixed_point_position), C);
366 sum = sqadd_qs16(sqmul_qs16(a, sum, fixed_point_position), B);
367 sum = sqadd_qs16(sqmul_qs16(a, sum, fixed_point_position), A);
368 sum = sqmul_qs16(a, sum, fixed_point_position);
370 return smul_qs16(sadd_qs16(sum, shift_val << fixed_point_position), ln2, fixed_point_position);
373 inline float scvt_f32_qs8(qint8_t a, int fixed_point_position)
375 return static_cast<float>(a) / (1 << fixed_point_position);
378 inline qint8_t sqcvt_qs8_f32(float a, int fixed_point_position)
380 // round_nearest_integer(a * 2^(fixed_point_position))
381 return utility::saturate_cast<qint8_t>(a * (1 << fixed_point_position) + ((a >= 0) ? 0.5 : -0.5));
384 inline float scvt_f32_qs16(qint16_t a, int fixed_point_position)
386 return static_cast<float>(a) / (1 << fixed_point_position);
389 inline qint16_t sqcvt_qs16_f32(float a, int fixed_point_position)
391 // round_nearest_integer(a * 2^(fixed_point_position))
392 return utility::saturate_cast<qint16_t>(a * (1 << fixed_point_position) + ((a >= 0) ? 0.5 : -0.5));
395 inline qint8_t sqmovn_qs16(qint16_t a)
397 // Saturate the result in case of overflow and cast to qint8_t
398 return utility::saturate_cast<qint8_t>(a);
401 inline qint16_t sqmovn_qs32(qint32_t a)
403 // Saturate the result in case of overflow and cast to qint16_t
404 return utility::saturate_cast<qint16_t>(a);