From d6802581700bdf65fc6d002d3d5667295044f0b3 Mon Sep 17 00:00:00 2001 From: Tue Ly Date: Wed, 7 Dec 2022 11:53:18 -0500 Subject: [PATCH] [libc] Implement a high-precision floating point class. Implement a high-precision floating point class using UInt<> as its mantissa. This will be used in accurate pass for double precision math functions. Reviewed By: sivachandra Differential Revision: https://reviews.llvm.org/D136799 --- libc/src/__support/FPUtil/CMakeLists.txt | 12 ++ libc/src/__support/FPUtil/dyadic_float.h | 212 +++++++++++++++++++++ libc/src/__support/UInt.h | 14 ++ libc/test/src/__support/CMakeLists.txt | 1 + libc/test/src/__support/FPUtil/CMakeLists.txt | 12 ++ .../src/__support/FPUtil/dyadic_float_test.cpp | 67 +++++++ libc/utils/UnitTest/CMakeLists.txt | 2 +- libc/utils/UnitTest/FPMatcher.h | 16 +- 8 files changed, 334 insertions(+), 2 deletions(-) create mode 100644 libc/src/__support/FPUtil/dyadic_float.h create mode 100644 libc/test/src/__support/FPUtil/CMakeLists.txt create mode 100644 libc/test/src/__support/FPUtil/dyadic_float_test.cpp diff --git a/libc/src/__support/FPUtil/CMakeLists.txt b/libc/src/__support/FPUtil/CMakeLists.txt index 3377fab..0686c01 100644 --- a/libc/src/__support/FPUtil/CMakeLists.txt +++ b/libc/src/__support/FPUtil/CMakeLists.txt @@ -178,4 +178,16 @@ add_header_library( ROUND_OPT ) +add_header_library( + dyadic_float + HDRS + dyadic_float.h + DEPENDS + .float_properties + .fp_bits + .multiply_add + libc.src.__support.common + libc.src.__support.uint +) + add_subdirectory(generic) diff --git a/libc/src/__support/FPUtil/dyadic_float.h b/libc/src/__support/FPUtil/dyadic_float.h new file mode 100644 index 0000000..8b76725 --- /dev/null +++ b/libc/src/__support/FPUtil/dyadic_float.h @@ -0,0 +1,212 @@ +//===-- A class to store high precision floating point numbers --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_DYADIC_FLOAT_H +#define LLVM_LIBC_SRC_SUPPORT_FPUTIL_DYADIC_FLOAT_H + +#include "FPBits.h" +#include "FloatProperties.h" +#include "multiply_add.h" +#include "src/__support/CPP/type_traits.h" +#include "src/__support/UInt.h" + +#include + +namespace __llvm_libc::fputil { + +// A generic class to perform comuptations of high precision floating points. +// We store the value in dyadic format, including 3 fields: +// sign : boolean value - false means positive, true means negative +// exponent: the exponent value of the least significant bit of the mantissa. +// mantissa: unsigned integer of length `Bits`. +// So the real value that is stored is: +// real value = (-1)^sign * 2^exponent * (mantissa as unsigned integer) +// The stored data is normal if for non-zero mantissa, the leading bit is 1. +// The outputs of the constructors and most functions will be normalized. +// To simplify and improve the efficiency, many functions will assume that the +// inputs are normal. +template struct DyadicFloat { + using MantissaType = __llvm_libc::cpp::UInt; + + bool sign = false; + int exponent = 0; + MantissaType mantissa = MantissaType(0); + + DyadicFloat() = default; + + template && + (FloatProperties::MANTISSA_WIDTH < Bits), + int> = 0> + DyadicFloat(T x) { + FPBits x_bits(x); + sign = x_bits.get_sign(); + exponent = x_bits.get_exponent() - FloatProperties::MANTISSA_WIDTH; + mantissa = MantissaType(x_bits.get_explicit_mantissa()); + normalize(); + } + + DyadicFloat(bool s, int e, MantissaType m) + : sign(s), exponent(e), mantissa(m) { + normalize(); + }; + + // Normalizing the mantissa, bringing the leading 1 bit to the most + // significant bit. + DyadicFloat &normalize() { + if (!mantissa.is_zero()) { + int shift_length = static_cast(mantissa.clz()); + exponent -= shift_length; + mantissa.shift_left(static_cast(shift_length)); + } + return *this; + } + + // Used for aligning exponents. Output might not be normalized. + DyadicFloat &shift_left(int shift_length) { + exponent -= shift_length; + mantissa <<= static_cast(shift_length); + return *this; + } + + // Used for aligning exponents. Output might not be normalized. + DyadicFloat &shift_right(int shift_length) { + exponent += shift_length; + mantissa >>= static_cast(shift_length); + return *this; + } + + // Assume that it is already normalized and output is also normal. + // Output is rounded correctly with respect to the current rounding mode. + // TODO(lntue): Test or add support for denormal output. + // TODO(lntue): Test or add specialization for x86 long double. + template && + (FloatProperties::MANTISSA_WIDTH < Bits), + void>> + explicit operator T() const { + // TODO(lntue): Do we need to treat signed zeros properly? + if (mantissa.is_zero()) + return 0.0; + + // Assume that it is normalized, and output is also normal. + constexpr size_t PRECISION = FloatProperties::MANTISSA_WIDTH + 1; + using output_bits_t = typename FPBits::UIntType; + + MantissaType m_hi(mantissa >> (Bits - PRECISION)); + auto d_hi = FPBits::create_value( + sign, exponent + (Bits - 1) + FloatProperties::EXPONENT_BIAS, + output_bits_t(m_hi) & FloatProperties::MANTISSA_MASK); + + const MantissaType ROUND_MASK = MantissaType(1) << (Bits - PRECISION - 1); + const MantissaType STICKY_MASK = ROUND_MASK - MantissaType(1); + + bool round_bit = !(mantissa & ROUND_MASK).is_zero(); + bool sticky_bit = !(mantissa & STICKY_MASK).is_zero(); + int round_and_sticky = int(round_bit) * 2 + int(sticky_bit); + auto d_lo = FPBits::create_value(sign, + exponent + (Bits - PRECISION - 2) + + FloatProperties::EXPONENT_BIAS, + output_bits_t(0)); + + // Still correct without FMA instructions if `d_lo` is not underflow. + return multiply_add(d_lo.get_val(), T(round_and_sticky), d_hi.get_val()); + } +}; + +// Quick add - Add 2 dyadic floats with rounding toward 0 and then normalize the +// output: +// - Align the exponents so that: +// new a.exponent = new b.exponent = max(a.exponent, b.exponent) +// - Add or subtract the mantissas depending on the signs. +// - Normalize the result. +// The absolute errors compared to the mathematical sum is bounded by: +// | quick_add(a, b) - (a + b) | < MSB(a + b) * 2^(-Bits + 2), +// i.e., errors are up to 2 ULPs. +// Assume inputs are normalized (by constructors or other functions) so that we +// don't need to normalize the inputs again in this function. If the inputs are +// not normalized, the results might lose precision significantly. +template +constexpr DyadicFloat quick_add(DyadicFloat a, + DyadicFloat b) { + if (unlikely(a.mantissa.is_zero())) + return b; + if (unlikely(b.mantissa.is_zero())) + return a; + + // Align exponents + if (a.exponent > b.exponent) + b.shift_right(a.exponent - b.exponent); + else if (b.exponent > a.exponent) + a.shift_right(b.exponent - a.exponent); + + DyadicFloat result; + + if (a.sign == b.sign) { + // Addition + result.sign = a.sign; + result.exponent = a.exponent; + result.mantissa = a.mantissa; + if (result.mantissa.add(b.mantissa)) { + // Mantissa addition overflow. + result.shift_right(1); + result.mantissa.val[DyadicFloat::MantissaType::WordCount - 1] |= + (uint64_t(1) << 63); + } + // Result is already normalized. + return result; + } + + // Subtraction + if (a.mantissa >= b.mantissa) { + result.sign = a.sign; + result.exponent = a.exponent; + result.mantissa = a.mantissa - b.mantissa; + } else { + result.sign = b.sign; + result.exponent = b.exponent; + result.mantissa = b.mantissa - a.mantissa; + } + + return result.normalize(); +} + +// Quick Mul - Slightly less accurate but efficient multiplication of 2 dyadic +// floats with rounding toward 0 and then normalize the output: +// result.exponent = a.exponent + b.exponent + Bits, +// result.mantissa = quick_mul_hi(a.mantissa + b.mantissa) +// ~ (full product a.mantissa * b.mantissa) >> Bits. +// The errors compared to the mathematical product is bounded by: +// 2 * errors of quick_mul_hi = 2 * (UInt::WordCount - 1) in ULPs. +// Assume inputs are normalized (by constructors or other functions) so that we +// don't need to normalize the inputs again in this function. If the inputs are +// not normalized, the results might lose precision significantly. +template +constexpr DyadicFloat quick_mul(DyadicFloat a, + DyadicFloat b) { + DyadicFloat result; + result.sign = (a.sign != b.sign); + result.exponent = a.exponent + b.exponent + int(Bits); + + if (!(a.mantissa.is_zero() || b.mantissa.is_zero())) { + result.mantissa = a.mantissa.quick_mul_hi(b.mantissa); + // Check the leading bit directly, should be faster than using clz in + // normalize(). + if (result.mantissa.val[DyadicFloat::MantissaType::WordCount - 1] >> + 63 == + 0) + result.shift_left(1); + } else { + result.mantissa = (typename DyadicFloat::MantissaType)(0); + } + return result; +} + +} // namespace __llvm_libc::fputil + +#endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_DYADIC_FLOAT_H diff --git a/libc/src/__support/UInt.h b/libc/src/__support/UInt.h index 6068ab6..e4d8a8f5 100644 --- a/libc/src/__support/UInt.h +++ b/libc/src/__support/UInt.h @@ -14,6 +14,7 @@ #include "src/__support/CPP/optional.h" #include "src/__support/CPP/type_traits.h" #include "src/__support/builtin_wrappers.h" +#include "src/__support/common.h" #include "src/__support/integer_utils.h" #include "src/__support/number_pair.h" @@ -95,6 +96,14 @@ template struct UInt { return *this; } + constexpr bool is_zero() const { + for (size_t i = 0; i < WordCount; ++i) { + if (val[i] != 0) + return false; + } + return true; + } + // Add x to this number and store the result in this number. // Returns the carry value produced by the addition operation. constexpr uint64_t add(const UInt &x) { @@ -356,6 +365,9 @@ template struct UInt { return; } #endif // __SIZEOF_INT128__ + if (unlikely(s == 0)) + return; + const size_t drop = s / 64; // Number of words to drop const size_t shift = s % 64; // Bits to shift in the remaining words. size_t i = WordCount; @@ -402,6 +414,8 @@ template struct UInt { } #endif // __SIZEOF_INT128__ + if (unlikely(s == 0)) + return; const size_t drop = s / 64; // Number of words to drop const size_t shift = s % 64; // Bit shift in the remaining words. diff --git a/libc/test/src/__support/CMakeLists.txt b/libc/test/src/__support/CMakeLists.txt index c1b0426..b835812 100644 --- a/libc/test/src/__support/CMakeLists.txt +++ b/libc/test/src/__support/CMakeLists.txt @@ -111,3 +111,4 @@ add_custom_command(TARGET libc_str_to_float_comparison_test add_subdirectory(CPP) add_subdirectory(File) add_subdirectory(OSUtil) +add_subdirectory(FPUtil) diff --git a/libc/test/src/__support/FPUtil/CMakeLists.txt b/libc/test/src/__support/FPUtil/CMakeLists.txt new file mode 100644 index 0000000..47d88b1 --- /dev/null +++ b/libc/test/src/__support/FPUtil/CMakeLists.txt @@ -0,0 +1,12 @@ +add_libc_testsuite(libc_fputil_unittests) + +add_fp_unittest( + dyadic_float_test + NEED_MPFR + SUITE + libc_fputil_unittests + SRCS + dyadic_float_test.cpp + DEPENDS + libc.src.__support.FPUtil.dyadic_float +) diff --git a/libc/test/src/__support/FPUtil/dyadic_float_test.cpp b/libc/test/src/__support/FPUtil/dyadic_float_test.cpp new file mode 100644 index 0000000..530b1b1 --- /dev/null +++ b/libc/test/src/__support/FPUtil/dyadic_float_test.cpp @@ -0,0 +1,67 @@ +//===-- Unittests for the DyadicFloat class -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "src/__support/FPUtil/dyadic_float.h" +#include "src/__support/UInt.h" +#include "utils/MPFRWrapper/MPFRUtils.h" +#include "utils/UnitTest/FPMatcher.h" +#include "utils/UnitTest/Test.h" + +using Float128 = __llvm_libc::fputil::DyadicFloat<128>; +using Float192 = __llvm_libc::fputil::DyadicFloat<192>; +using Float256 = __llvm_libc::fputil::DyadicFloat<256>; + +TEST(LlvmLibcDyadicFloatTest, BasicConversions) { + Float128 x(/*sign*/ false, /*exponent*/ 0, + /*mantissa*/ Float128::MantissaType(1)); + volatile float xf = float(x); + volatile double xd = double(x); + ASSERT_FP_EQ(1.0f, xf); + ASSERT_FP_EQ(1.0, xd); + + Float128 y(0x1.0p-53); + volatile float yf = float(y); + volatile double yd = double(y); + ASSERT_FP_EQ(0x1.0p-53f, yf); + ASSERT_FP_EQ(0x1.0p-53, yd); + + Float128 z = quick_add(x, y); + + EXPECT_FP_EQ_ALL_ROUNDING(xf + yf, float(z)); + EXPECT_FP_EQ_ALL_ROUNDING(xd + yd, double(z)); +} + +TEST(LlvmLibcDyadicFloatTest, QuickAdd) { + Float192 x(/*sign*/ false, /*exponent*/ 0, + /*mantissa*/ Float192::MantissaType(0x123456)); + volatile double xd = double(x); + ASSERT_FP_EQ(0x1.23456p20, xd); + + Float192 y(0x1.abcdefp-20); + volatile double yd = double(y); + ASSERT_FP_EQ(0x1.abcdefp-20, yd); + + Float192 z = quick_add(x, y); + + EXPECT_FP_EQ_ALL_ROUNDING(xd + yd, (volatile double)(z)); +} + +TEST(LlvmLibcDyadicFloatTest, QuickMul) { + Float256 x(/*sign*/ false, /*exponent*/ 0, + /*mantissa*/ Float256::MantissaType(0x123456)); + volatile double xd = double(x); + ASSERT_FP_EQ(0x1.23456p20, xd); + + Float256 y(0x1.abcdefp-25); + volatile double yd = double(y); + ASSERT_FP_EQ(0x1.abcdefp-25, yd); + + Float256 z = quick_mul(x, y); + + EXPECT_FP_EQ_ALL_ROUNDING(xd * yd, double(z)); +} diff --git a/libc/utils/UnitTest/CMakeLists.txt b/libc/utils/UnitTest/CMakeLists.txt index f2ec4e0..4303930 100644 --- a/libc/utils/UnitTest/CMakeLists.txt +++ b/libc/utils/UnitTest/CMakeLists.txt @@ -33,7 +33,7 @@ add_library( FPMatcher.h ) target_include_directories(LibcFPTestHelpers PUBLIC ${LIBC_SOURCE_DIR}) -target_link_libraries(LibcFPTestHelpers LibcUnitTest) +target_link_libraries(LibcFPTestHelpers LibcUnitTest libc_test_utils) add_dependencies( LibcFPTestHelpers LibcUnitTest diff --git a/libc/utils/UnitTest/FPMatcher.h b/libc/utils/UnitTest/FPMatcher.h index 1c2b540..a7c0971 100644 --- a/libc/utils/UnitTest/FPMatcher.h +++ b/libc/utils/UnitTest/FPMatcher.h @@ -1,4 +1,4 @@ -//===-- TestMatchers.h ------------------------------------------*- C++ -*-===// +//===-- FPMatchers.h --------------------------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -12,6 +12,7 @@ #include "src/__support/FPUtil/FEnvImpl.h" #include "src/__support/FPUtil/FPBits.h" #include "utils/UnitTest/Test.h" +#include "utils/testutils/RoundingModeUtils.h" #include #include @@ -132,4 +133,17 @@ FPMatcher getMatcher(T expectedValue) { } \ } while (0) +#define EXPECT_FP_EQ_ALL_ROUNDING(expected, actual) \ + do { \ + using namespace __llvm_libc::testutils; \ + ForceRoundingMode __r1(RoundingMode::Nearest); \ + EXPECT_FP_EQ((expected), (actual)); \ + ForceRoundingMode __r2(RoundingMode::Upward); \ + EXPECT_FP_EQ((expected), (actual)); \ + ForceRoundingMode __r3(RoundingMode::Downward); \ + EXPECT_FP_EQ((expected), (actual)); \ + ForceRoundingMode __r4(RoundingMode::TowardZero); \ + EXPECT_FP_EQ((expected), (actual)); \ + } while (0) + #endif // LLVM_LIBC_UTILS_UNITTEST_FPMATCHER_H -- 2.7.4