24 #ifndef __ARM_COMPUTE_TEST_VALIDATION_FIXEDPOINT_H__ 25 #define __ARM_COMPUTE_TEST_VALIDATION_FIXEDPOINT_H__ 35 #include <type_traits> 41 namespace fixed_point_arithmetic
58 template <>
struct promote<uint8_t> {
using type = uint16_t; };
60 template <>
struct promote<uint16_t> {
using type = uint32_t; };
61 template <>
struct promote<int16_t> {
using type = int32_t; };
62 template <>
struct promote<uint32_t> {
using type = uint64_t; };
63 template <>
struct promote<int32_t> {
using type = int64_t; };
64 template <>
struct promote<uint64_t> {
using type = uint64_t; };
65 template <>
struct promote<int64_t> {
using type = int64_t; };
100 : _value(0), _fixed_point_position(p)
102 assert(p > 0 && p < std::numeric_limits<T>::digits);
105 if(std::numeric_limits<T>::digits < std::numeric_limits<U>::digits)
116 _value =
static_cast<T
>(v);
126 : _value(val << p), _fixed_point_position(p)
139 : _value(detail::constant_expr<T>::to_fixed(val, p)), _fixed_point_position(p)
141 assert(p > 0 && p < std::numeric_limits<T>::digits);
149 : _value(detail::constant_expr<T>::to_fixed(support::cpp11::
stof(str), p)), _fixed_point_position(p)
151 assert(p > 0 && p < std::numeric_limits<T>::digits);
166 operator float()
const 183 template <
typename U>
186 U val =
static_cast<U>(_value);
187 if(std::numeric_limits<U>::digits < std::numeric_limits<T>::digits)
200 template <
typename U>
213 template <
typename U>
235 return _fixed_point_position;
243 assert(p > 0 && p < std::numeric_limits<T>::digits);
246 promoted_T val = _value;
247 if(p > _fixed_point_position)
249 val <<= (p - _fixed_point_position);
251 else if(p < _fixed_point_position)
253 uint8_t pbar = _fixed_point_position - p;
254 val += (pbar != 0) ? (1 << (pbar - 1)) : 0;
259 _fixed_point_position = p;
264 uint8_t _fixed_point_position;
275 template <
typename T>
281 return __builtin_clz(value) - (32 - std::numeric_limits<unsigned_T>::digits);
284 template <
typename T>
305 return (1.0f / static_cast<float>(1 << p));
317 return static_cast<float>(val * fixed_step(p));
326 static constexpr T
to_int(T val, uint8_t p)
339 return static_cast<T
>(
saturate_cast<
float>(val * fixed_one(p) + ((val >= 0) ? 0.5 : -0.5)));
359 template <
typename U>
374 template <
typename T,
typename U,
typename traits>
377 return s << static_cast<float>(x);
385 template <
typename T>
388 return ((x.
raw() >> std::numeric_limits<T>::digits) != 0);
397 template <
typename T>
403 return (x.
raw() == y.
raw());
412 template <
typename T>
415 return !isequal(x, y);
424 template <
typename T>
430 return (x.
raw() > y.
raw());
439 template <
typename T>
445 return (x.
raw() >= y.
raw());
454 template <
typename T>
460 return (x.
raw() < y.
raw());
469 template <
typename T>
475 return (x.
raw() <= y.
raw());
484 template <
typename T>
487 return isnotequal(x, y);
497 template <
typename T>
508 template <OverflowPolicy OP = OverflowPolicy::SATURATE,
typename T>
512 promoted_T val = -x.
raw();
526 template <OverflowPolicy OP = OverflowPolicy::SATURATE,
typename T>
535 type val =
static_cast<type>(x.
raw()) + static_cast<type>(y.
raw());
551 template <OverflowPolicy OP = OverflowPolicy::SATURATE,
typename T>
560 type val =
static_cast<type>(x.
raw()) - static_cast<type>(y.
raw());
576 template <OverflowPolicy OP = OverflowPolicy::SATURATE,
typename T>
582 promoted_T round_factor = (1 << (p_max - 1));
583 promoted_T val = ((
static_cast<promoted_T
>(x.
raw()) * static_cast<promoted_T>(y.
raw())) + round_factor) >> p_max;
597 template <OverflowPolicy OP = OverflowPolicy::SATURATE,
typename T>
602 promoted_T denom =
static_cast<promoted_T
>(y.
raw());
625 template <OverflowPolicy OP = OverflowPolicy::SATURATE,
typename T>
629 promoted_T val =
static_cast<promoted_T
>(x.
raw()) << shift;
643 template <
typename T>
654 template <
typename T>
667 template <
typename T>
675 if(isequal(x, const_one) || islessequal(x,
fixed_point<T>(static_cast<T>(0), p)))
679 else if(isless(x, const_one))
685 T shift_val = 31 - __builtin_clz(x.
raw() >> p);
686 x = shift_right(x, shift_val);
687 x =
sub(x, const_one);
698 sum =
add(
mul(x, sum), B);
699 sum =
add(
mul(x, sum), A);
714 template <
typename T>
731 auto taylor =
add(
mul(frac_part, D), C);
732 taylor =
add(
mul(frac_part, taylor), B);
733 taylor =
add(
mul(frac_part, taylor), A);
734 taylor =
mul(frac_part, taylor);
735 taylor =
add(taylor, const_one);
738 if(static_cast<T>(
clz(taylor.raw())) <= scaled_int_part)
743 return (scaled_int_part < 0) ? shift_right(taylor, -scaled_int_part) : shift_left(taylor, scaled_int_part);
751 template <
typename T>
755 int8_t shift = std::numeric_limits<T>::digits - (p +
detail::clz(x.
raw()));
757 shift += std::numeric_limits<T>::is_signed ? 1 : 0;
760 volatile int8_t *shift_ptr = &shift;
763 auto a = (*shift_ptr < 0) ? shift_left(x, -(shift)) : shift_right(x, shift);
768 for(
int i = 0; i < num_iterations; ++i)
771 x2 = shift_right(
mul(x2, three_minus_dx), 1);
774 return (shift < 0) ? shift_left(x2, (-shift) >> 1) : shift_right(x2, shift >> 1);
782 template <
typename T>
790 auto exp2x =
exp(const_two * x);
791 auto num = exp2x - const_one;
792 auto den = exp2x + const_one;
793 auto tanh = num / den;
806 template <
typename T>
813 template <
typename T>
818 template <
typename T>
823 template <
typename T>
828 template <
typename T>
833 template <
typename T>
838 template <
typename T>
843 template <
typename T>
848 template <
typename T>
853 template <
typename T>
858 template <
typename T>
863 template <
typename T>
868 template <
typename T>
873 template <
typename T>
878 template <
typename T,
typename U,
typename traits>
879 std::basic_ostream<T, traits> &operator<<(std::basic_ostream<T, traits> &s,
fixed_point<U> x)
883 template <
typename T>
886 return x > y ? y : x;
888 template <
typename T>
891 return x > y ? x : y;
893 template <OverflowPolicy OP = OverflowPolicy::SATURATE,
typename T>
896 return functions::add<OP>(x, y);
898 template <OverflowPolicy OP = OverflowPolicy::SATURATE,
typename T>
901 return functions::sub<OP>(x, y);
903 template <OverflowPolicy OP = OverflowPolicy::SATURATE,
typename T>
906 return functions::mul<OP>(x, y);
908 template <
typename T>
913 template <
typename T>
918 template <
typename T>
923 template <
typename T>
928 template <
typename T>
933 template <
typename T>
938 template <
typename T>
943 template <
typename T>
951 using detail::operator==;
952 using detail::operator!=;
953 using detail::operator<;
954 using detail::operator>;
955 using detail::operator<=;
956 using detail::operator>=;
957 using detail::operator+;
958 using detail::operator-;
959 using detail::operator*;
960 using detail::operator/;
961 using detail::operator>>;
962 using detail::operator<<;
fixed_point< T > mul(fixed_point< T > x, fixed_point< T > y)
fixed_point< T > min(fixed_point< T > x, fixed_point< T > y)
RoundingPolicy
Strongly typed enum class representing the rounding policy.
static constexpr T clamp(T val, T min, T max)
Clamp value between two ranges.
fixed_point< T > clamp(fixed_point< T > x, T min, T max)
fixed_point< T > operator>>(fixed_point< T > x, size_t shift)
static constexpr float to_float(T val, uint8_t p)
Convert a fixed point value to float given its precision.
static bool isgreater(fixed_point< T > x, fixed_point< T > y)
Checks if one fixed point is greater than the other.
static constexpr T to_int(T val, uint8_t p)
Convert a fixed point value to integer given its precision.
static bool signbit(fixed_point< T > x)
Signbit of a fixed point number.
fixed_point< T > exp(fixed_point< T > x)
DATA_TYPE sum(__global const DATA_TYPE *input)
Calculate sum of a vector.
static fixed_point< T > add(fixed_point< T > x, fixed_point< T > y)
Perform addition among two fixed point numbers.
typename promote< T >::type promote_t
static bool isequal(fixed_point< T > x, fixed_point< T > y)
Checks if two fixed point numbers are equal.
fixed_point< T > operator/(fixed_point< T > x, fixed_point< T > y)
bool operator==(const fixed_point< T > &lhs, const fixed_point< T > &rhs)
static bool isgreaterequal(fixed_point< T > x, fixed_point< T > y)
Checks if one fixed point is greater or equal than the other.
bool operator>=(const fixed_point< T > &lhs, const fixed_point< T > &rhs)
static std::basic_ostream< T, traits > & write(std::basic_ostream< T, traits > &s, fixed_point< U > &x)
Output stream operator.
static fixed_point< T > pow(fixed_point< T > x, fixed_point< T > a)
Calculate the a-th power of a fixed point number.
fixed_point(U val, uint8_t p, bool is_raw=false)
Constructor (from integer)
This file contains all available output stages for GEMMLowp on OpenCL.
static fixed_point< T > shift_left(fixed_point< T > x, size_t shift)
Shift left.
fixed_point< T > log(fixed_point< T > x)
fixed_point(std::string str, uint8_t p)
Constructor (from float string)
fixed_point< T > add(fixed_point< T > x, fixed_point< T > y)
static constexpr T fixed_one(uint8_t p)
Calculate representation of 1 in fixed point given a fixed point precision.
static constexpr T saturate_cast(U val)
Saturate given number.
static fixed_point< T > tanh(fixed_point< T > x)
Calculate the hyperbolic tangent of a fixed point number.
static fixed_point< T > sub(fixed_point< T > x, fixed_point< T > y)
Perform subtraction among two fixed point numbers.
static fixed_point< T > div(fixed_point< T > x, fixed_point< T > y)
Perform division among two fixed point numbers.
static fixed_point< T > shift_right(fixed_point< T > x, size_t shift)
Shift right.
fixed_point< T > & operator-=(const fixed_point< U > &rhs)
Arithmetic -= assignment operator.
static fixed_point< T > mul(fixed_point< T > x, fixed_point< T > y)
Perform multiplication among two fixed point numbers.
fixed_point(fixed_point< U > val, uint8_t p)
Constructor (from different fixed point type)
static fixed_point< T > negate(fixed_point< T > x)
Negate number.
Arbitrary fixed-point arithmetic class.
fixed_point< T > sub(fixed_point< T > x, fixed_point< T > y)
bool operator<(const fixed_point< T > &lhs, const fixed_point< T > &rhs)
static fixed_point< T > abs(fixed_point< T > x)
Calculate absolute value.
uint8_t precision() const
Precision accessor.
T raw() const
Raw value accessor.
static constexpr float fixed_step(uint8_t p)
Calculate fixed point precision step given a fixed point precision.
static fixed_point< T > clamp(fixed_point< T > x, T min, T max)
Clamp fixed point to specific range.
fixed_point< T > inv_sqrt(fixed_point< T > x)
static fixed_point< T > inv_sqrt(fixed_point< T > x)
Calculate the inverse square root of a fixed point number.
static bool isless(fixed_point< T > x, fixed_point< T > y)
Checks if one fixed point is less than the other.
static fixed_point< T > exp(fixed_point< T > x)
Calculate the exponential of a fixed point number.
fixed_point(float val, uint8_t p)
Constructor (from float)
fixed_point< T > & operator+=(const fixed_point< U > &rhs)
Arithmetic += assignment operator.
static bool islessgreater(fixed_point< T > x, fixed_point< T > y)
Checks if one fixed point is less or greater than the other.
void rescale(uint8_t p)
Rescale a fixed point to a new precision.
Rounds to nearest value; half rounds to nearest even.
bool operator!=(const fixed_point< T > &lhs, const fixed_point< T > &rhs)
fixed_point< T > operator+(const fixed_point< T > &lhs, const fixed_point< T > &rhs)
static bool islessequal(fixed_point< T > x, fixed_point< T > y)
Checks if one fixed point is less or equal than the other.
fixed_point< T > operator*(fixed_point< T > x, fixed_point< T > y)
T saturate_cast(T val)
Saturate a value of type T against the numeric limits of type U.
static bool isnotequal(fixed_point< T > x, fixed_point< T > y)
Checks if two fixed point number are not equal.
OverflowPolicy
Strongly typed enum class representing the overflow policy.
fixed_point< T > div(fixed_point< T > x, fixed_point< T > y)
fixed_point< T > pow(fixed_point< T > x, fixed_point< T > a)
constexpr int clz(T value)
Count the number of leading zero bits in the given value.
fixed_point< T > max(fixed_point< T > x, fixed_point< T > y)
static fixed_point< T > log(fixed_point< T > x)
Calculate the logarithm of a fixed point number.
fixed_point< T > tanh(fixed_point< T > x)
fixed_point< T > abs(fixed_point< T > x)
bool operator>(const fixed_point< T > &lhs, const fixed_point< T > &rhs)
fixed_point< T > operator-(const fixed_point< T > &lhs, const fixed_point< T > &rhs)
Truncates the least significand values that are lost in operations.
static constexpr T to_fixed(float val, uint8_t p)
Convert a single precision floating point value to a fixed point representation given its precision...
int stof(Ts &&...args)
Convert string values to float.