24 #ifndef __ARM_COMPUTE_TEST_VALIDATION_FIXEDPOINT_H__ 25 #define __ARM_COMPUTE_TEST_VALIDATION_FIXEDPOINT_H__ 35 #include <type_traits> 41 namespace fixed_point_arithmetic
60 template <>
struct promote<uint8_t> {
using type = uint16_t; };
64 template <>
struct promote<uint16_t> {
using type = uint32_t; };
66 template <>
struct promote<int16_t> {
using type = int32_t; };
68 template <>
struct promote<uint32_t> {
using type = uint64_t; };
70 template <>
struct promote<int32_t> {
using type = int64_t; };
102 static_assert(std::is_integral<T>::value,
"Type is not an integer");
109 template <
typename U>
111 : _value(0), _fixed_point_position(p)
113 assert(p > 0 && p < std::numeric_limits<T>::digits);
116 if(std::numeric_limits<T>::digits < std::numeric_limits<U>::digits)
127 _value =
static_cast<T
>(v);
135 template <typename U, typename = typename std::enable_if<std::is_integral<U>::value>::type>
137 : _value(val << p), _fixed_point_position(p)
150 : _value(detail::constant_expr<T>::to_fixed(val, p)), _fixed_point_position(p)
152 assert(p > 0 && p < std::numeric_limits<T>::digits);
160 : _value(detail::constant_expr<T>::to_fixed(support::cpp11::
stof(str), p)), _fixed_point_position(p)
162 assert(p > 0 && p < std::numeric_limits<T>::digits);
177 operator float()
const 185 template <typename U, typename = typename std::enable_if<std::is_integral<T>::value>::type>
194 template <
typename U>
197 U val =
static_cast<U>(_value);
198 if(std::numeric_limits<U>::digits < std::numeric_limits<T>::digits)
211 template <
typename U>
224 template <
typename U>
246 return _fixed_point_position;
254 assert(p > 0 && p < std::numeric_limits<T>::digits);
257 promoted_T val = _value;
258 if(p > _fixed_point_position)
260 val <<= (p - _fixed_point_position);
262 else if(p < _fixed_point_position)
264 uint8_t pbar = _fixed_point_position - p;
265 val += (pbar != 0) ? (1 << (pbar - 1)) : 0;
270 _fixed_point_position = p;
275 uint8_t _fixed_point_position;
286 template <
typename T>
287 constexpr
int clz(T value)
289 using unsigned_T =
typename std::make_unsigned<T>::type;
292 return __builtin_clz(value) - (32 - std::numeric_limits<unsigned_T>::digits);
296 template <
typename T>
317 return (1.0f / static_cast<float>(1 << p));
329 return static_cast<float>(val * fixed_step(p));
338 static constexpr T
to_int(T val, uint8_t p)
351 return static_cast<T
>(
saturate_cast<
float>(val * fixed_one(p) + ((val >= 0) ? 0.5 : -0.5)));
371 template <
typename U>
387 template <
typename T,
typename U,
typename traits>
390 return s << static_cast<float>(x);
398 template <
typename T>
401 return ((x.
raw() >> std::numeric_limits<T>::digits) != 0);
410 template <
typename T>
416 return (x.
raw() == y.
raw());
425 template <
typename T>
428 return !isequal(x, y);
437 template <
typename T>
443 return (x.
raw() > y.
raw());
452 template <
typename T>
458 return (x.
raw() >= y.
raw());
467 template <
typename T>
473 return (x.
raw() < y.
raw());
482 template <
typename T>
488 return (x.
raw() <= y.
raw());
497 template <
typename T>
500 return isnotequal(x, y);
510 template <
typename T>
521 template <OverflowPolicy OP = OverflowPolicy::SATURATE,
typename T>
525 promoted_T val = -x.
raw();
539 template <OverflowPolicy OP = OverflowPolicy::SATURATE,
typename T>
548 type val =
static_cast<type
>(x.
raw()) + static_cast<type>(y.
raw());
564 template <OverflowPolicy OP = OverflowPolicy::SATURATE,
typename T>
573 type val =
static_cast<type
>(x.
raw()) - static_cast<type>(y.
raw());
589 template <OverflowPolicy OP = OverflowPolicy::SATURATE,
typename T>
595 promoted_T round_factor = (1 << (p_max - 1));
596 promoted_T val = ((
static_cast<promoted_T
>(x.
raw()) * static_cast<promoted_T>(y.
raw())) + round_factor) >> p_max;
610 template <OverflowPolicy OP = OverflowPolicy::SATURATE,
typename T>
615 promoted_T denom =
static_cast<promoted_T
>(y.
raw());
638 template <OverflowPolicy OP = OverflowPolicy::SATURATE,
typename T>
642 promoted_T val =
static_cast<promoted_T
>(x.
raw()) << shift;
656 template <
typename T>
667 template <
typename T>
680 template <
typename T>
688 if(isequal(x, const_one) || islessequal(x,
fixed_point<T>(static_cast<T>(0), p)))
692 else if(isless(x, const_one))
698 T shift_val = 31 - __builtin_clz(x.
raw() >> p);
699 x = shift_right(x, shift_val);
700 x =
sub(x, const_one);
711 sum =
add(
mul(x, sum), B);
712 sum =
add(
mul(x, sum), A);
727 template <
typename T>
744 auto taylor =
add(
mul(frac_part, D), C);
745 taylor =
add(
mul(frac_part, taylor), B);
746 taylor =
add(
mul(frac_part, taylor), A);
747 taylor =
mul(frac_part, taylor);
748 taylor =
add(taylor, const_one);
751 if(static_cast<T>(
clz(taylor.raw())) <= scaled_int_part)
756 return (scaled_int_part < 0) ? shift_right(taylor, -scaled_int_part) : shift_left(taylor, scaled_int_part);
764 template <
typename T>
768 int8_t shift = std::numeric_limits<T>::digits - (p +
detail::clz(x.
raw()));
770 shift += std::numeric_limits<T>::is_signed ? 1 : 0;
773 volatile int8_t *shift_ptr = &shift;
776 auto a = (*shift_ptr < 0) ? shift_left(x, -(shift)) : shift_right(x, shift);
780 constexpr
int num_iterations = std::is_same<T, int8_t>::value ? 3 : 5;
781 for(
int i = 0; i < num_iterations; ++i)
784 x2 = shift_right(
mul(x2, three_minus_dx), 1);
787 return (shift < 0) ? shift_left(x2, (-shift) >> 1) : shift_right(x2, shift >> 1);
795 template <
typename T>
803 auto exp2x =
exp(const_two * x);
804 auto num = exp2x - const_one;
805 auto den = exp2x + const_one;
806 auto tanh = num / den;
819 template <
typename T>
826 template <
typename T>
831 template <
typename T>
836 template <
typename T>
841 template <
typename T>
846 template <
typename T>
851 template <
typename T>
856 template <
typename T>
861 template <
typename T>
866 template <
typename T>
871 template <
typename T>
876 template <
typename T>
881 template <
typename T>
886 template <
typename T>
891 template <
typename T,
typename U,
typename traits>
892 std::basic_ostream<T, traits> &operator<<(std::basic_ostream<T, traits> &s,
fixed_point<U> x)
896 template <
typename T>
899 return x > y ? y : x;
901 template <
typename T>
904 return x > y ? x : y;
906 template <OverflowPolicy OP = OverflowPolicy::SATURATE,
typename T>
909 return functions::add<OP>(x, y);
911 template <OverflowPolicy OP = OverflowPolicy::SATURATE,
typename T>
914 return functions::sub<OP>(x, y);
916 template <OverflowPolicy OP = OverflowPolicy::SATURATE,
typename T>
919 return functions::mul<OP>(x, y);
921 template <
typename T>
926 template <
typename T>
931 template <
typename T>
936 template <
typename T>
941 template <
typename T>
946 template <
typename T>
951 template <
typename T>
956 template <
typename T>
964 using detail::operator==;
965 using detail::operator!=;
966 using detail::operator<;
967 using detail::operator>;
968 using detail::operator<=;
969 using detail::operator>=;
970 using detail::operator+;
971 using detail::operator-;
972 using detail::operator*;
973 using detail::operator/;
974 using detail::operator>>;
975 using detail::operator<<;
fixed_point< T > mul(fixed_point< T > x, fixed_point< T > y)
fixed_point< T > min(fixed_point< T > x, fixed_point< T > y)
RoundingPolicy
Strongly typed enum class representing the rounding policy.
static constexpr T clamp(T val, T min, T max)
Clamp value between two ranges.
fixed_point< T > clamp(fixed_point< T > x, T min, T max)
fixed_point< T > operator>>(fixed_point< T > x, size_t shift)
static constexpr float to_float(T val, uint8_t p)
Convert a fixed point value to float given its precision.
static bool isgreater(fixed_point< T > x, fixed_point< T > y)
Checks if one fixed point is greater than the other.
static constexpr T to_int(T val, uint8_t p)
Convert a fixed point value to integer given its precision.
static bool signbit(fixed_point< T > x)
Signbit of a fixed point number.
fixed_point< T > exp(fixed_point< T > x)
DATA_TYPE sum(__global const DATA_TYPE *input)
Calculate sum of a vector.
static fixed_point< T > add(fixed_point< T > x, fixed_point< T > y)
Perform addition among two fixed point numbers.
typename promote< T >::type promote_t
Get promoted type.
static bool isequal(fixed_point< T > x, fixed_point< T > y)
Checks if two fixed point numbers are equal.
half_float::half half
16-bit floating point type
fixed_point< T > operator/(fixed_point< T > x, fixed_point< T > y)
bool operator==(const fixed_point< T > &lhs, const fixed_point< T > &rhs)
static bool isgreaterequal(fixed_point< T > x, fixed_point< T > y)
Checks if one fixed point is greater or equal than the other.
bool operator>=(const fixed_point< T > &lhs, const fixed_point< T > &rhs)
static std::basic_ostream< T, traits > & write(std::basic_ostream< T, traits > &s, fixed_point< U > &x)
Output stream operator.
static fixed_point< T > pow(fixed_point< T > x, fixed_point< T > a)
Calculate the a-th power of a fixed point number.
fixed_point(U val, uint8_t p, bool is_raw=false)
Constructor (from integer)
This file contains all available output stages for GEMMLowp on OpenCL.
static fixed_point< T > shift_left(fixed_point< T > x, size_t shift)
Shift left.
fixed_point< T > log(fixed_point< T > x)
fixed_point(std::string str, uint8_t p)
Constructor (from float string)
fixed_point< T > add(fixed_point< T > x, fixed_point< T > y)
static constexpr T fixed_one(uint8_t p)
Calculate representation of 1 in fixed point given a fixed point precision.
static constexpr T saturate_cast(U val)
Saturate given number.
static fixed_point< T > tanh(fixed_point< T > x)
Calculate the hyperbolic tangent of a fixed point number.
static fixed_point< T > sub(fixed_point< T > x, fixed_point< T > y)
Perform subtraction among two fixed point numbers.
static fixed_point< T > div(fixed_point< T > x, fixed_point< T > y)
Perform division among two fixed point numbers.
static fixed_point< T > shift_right(fixed_point< T > x, size_t shift)
Shift right.
fixed_point< T > & operator-=(const fixed_point< U > &rhs)
Arithmetic -= assignment operator.
static fixed_point< T > mul(fixed_point< T > x, fixed_point< T > y)
Perform multiplication among two fixed point numbers.
fixed_point(fixed_point< U > val, uint8_t p)
Constructor (from different fixed point type)
static fixed_point< T > negate(fixed_point< T > x)
Negate number.
Arbitrary fixed-point arithmetic class.
fixed_point< T > sub(fixed_point< T > x, fixed_point< T > y)
bool operator<(const fixed_point< T > &lhs, const fixed_point< T > &rhs)
static fixed_point< T > abs(fixed_point< T > x)
Calculate absolute value.
uint8_t precision() const
Precision accessor.
T raw() const
Raw value accessor.
static constexpr float fixed_step(uint8_t p)
Calculate fixed point precision step given a fixed point precision.
static fixed_point< T > clamp(fixed_point< T > x, T min, T max)
Clamp fixed point to specific range.
fixed_point< T > inv_sqrt(fixed_point< T > x)
static fixed_point< T > inv_sqrt(fixed_point< T > x)
Calculate the inverse square root of a fixed point number.
static bool isless(fixed_point< T > x, fixed_point< T > y)
Checks if one fixed point is less than the other.
static fixed_point< T > exp(fixed_point< T > x)
Calculate the exponential of a fixed point number.
fixed_point(float val, uint8_t p)
Constructor (from float)
fixed_point< T > & operator+=(const fixed_point< U > &rhs)
Arithmetic += assignment operator.
static bool islessgreater(fixed_point< T > x, fixed_point< T > y)
Checks if one fixed point is less or greater than the other.
void rescale(uint8_t p)
Rescale a fixed point to a new precision.
Rounds to nearest value; half rounds to nearest even.
bool operator!=(const fixed_point< T > &lhs, const fixed_point< T > &rhs)
fixed_point< T > operator+(const fixed_point< T > &lhs, const fixed_point< T > &rhs)
static bool islessequal(fixed_point< T > x, fixed_point< T > y)
Checks if one fixed point is less or equal than the other.
fixed_point< T > operator*(fixed_point< T > x, fixed_point< T > y)
T saturate_cast(T val)
Saturate a value of type T against the numeric limits of type U.
static bool isnotequal(fixed_point< T > x, fixed_point< T > y)
Checks if two fixed point number are not equal.
OverflowPolicy
Strongly typed enum class representing the overflow policy.
fixed_point< T > div(fixed_point< T > x, fixed_point< T > y)
fixed_point< T > pow(fixed_point< T > x, fixed_point< T > a)
constexpr int clz(T value)
Count the number of leading zero bits in the given value.
fixed_point< T > max(fixed_point< T > x, fixed_point< T > y)
static fixed_point< T > log(fixed_point< T > x)
Calculate the logarithm of a fixed point number.
fixed_point< T > tanh(fixed_point< T > x)
fixed_point< T > abs(fixed_point< T > x)
bool operator>(const fixed_point< T > &lhs, const fixed_point< T > &rhs)
fixed_point< T > operator-(const fixed_point< T > &lhs, const fixed_point< T > &rhs)
Truncates the least significand values that are lost in operations.
static constexpr T to_fixed(float val, uint8_t p)
Convert a single precision floating point value to a fixed point representation given its precision...
int stof(Ts &&...args)
Convert string values to float.