[flang] Refactor rounding code.
authorpeter klausler <pklausler@nvidia.com>
Fri, 8 Jun 2018 17:58:58 +0000 (10:58 -0700)
committerpeter klausler <pklausler@nvidia.com>
Thu, 14 Jun 2018 20:52:53 +0000 (13:52 -0700)
Original-commit: flang-compiler/f18@8ef2418791e62d8a90802a2e6754975a52412479
Reviewed-on: https://github.com/flang-compiler/f18/pull/101
Tree-same-pre-rewrite: false

flang/lib/evaluate/real.h
flang/test/evaluate/real.cc

index 1124215..c2a47dd 100644 (file)
@@ -63,9 +63,9 @@ public:
       result.flags |= result.value.Normalize(
           isNegative, exponent, fraction.SHIFTL(-bitsLost));
     } else {
-      RoundingBits roundingBits{GetRoundingBits(absN, bitsLost)};
       Fraction fraction{Fraction::Convert(absN.SHIFTR(bitsLost)).value};
       result.flags |= result.value.Normalize(isNegative, exponent, fraction);
+      RoundingBits roundingBits{absN, bitsLost};
       result.flags |= result.value.Round(rounding, roundingBits);
     }
     return result;
@@ -258,37 +258,24 @@ public:
     // of the opposite sign and greater magnitude.  So (x+y) will have the
     // same sign as x.
     Fraction yFraction{y.GetFraction()};
-    RoundingBits roundingBits;
     Fraction fraction{GetFraction()};
     int rshift = exponent - yExponent;
     if (exponent > 0 && yExponent == 0) {
       --rshift;  // correct overshift when only y is denormal
     }
-    roundingBits = GetRoundingBits(yFraction, rshift);
+    RoundingBits roundingBits{yFraction, rshift};
     yFraction = yFraction.SHIFTR(rshift);
     bool carry{false};
     if (isNegative != yIsNegative) {
-      // Opposite signs: subtract
-      carry = !roundingBits.sticky_;
-      if (carry) {
-        carry = !roundingBits.round_;
-        if (carry) {
-          carry = !roundingBits.guard_;
-        } else {
-          roundingBits.guard_ ^= true;
-        }
-      } else {
-        roundingBits.round_ ^= true;
-        roundingBits.guard_ ^= true;
-      }
+      // Opposite signs: subtract via addition of two's complement of y and
+      // the rounding bits.
       yFraction = yFraction.NOT();
+      carry = roundingBits.Negate();
     }
     auto sum = fraction.AddUnsigned(yFraction, carry);
     fraction = sum.value;
     if (isNegative == yIsNegative && sum.carry) {
-      roundingBits.sticky_ |= roundingBits.round_;
-      roundingBits.round_ = roundingBits.guard_;
-      roundingBits.guard_ = sum.value.BTEST(0);
+      roundingBits.ShiftRight(sum.value.BTEST(0));
       fraction = fraction.SHIFTR(1).IBSET(precision - 1);
       ++exponent;
     }
@@ -319,7 +306,7 @@ public:
         result.flags |=
             result.value.Normalize(isNegative, exponent, product.upper);
         result.flags |= result.value.Round(
-            rounding, GetRoundingBits(product.lower, precision));
+            rounding, RoundingBits{product.lower, precision});
       }
     }
     return result;
@@ -352,9 +339,8 @@ public:
           // To round, double the remainder and compare it to the divisor.
           auto doubled = qr.remainder.AddUnsigned(qr.remainder);
           Ordering drcmp{doubled.value.CompareUnsigned(y.GetFraction())};
-          RoundingBits roundingBits;
-          roundingBits.guard_ = drcmp != Ordering::Less;
-          roundingBits.round_ = drcmp != Ordering::Equal;
+          RoundingBits roundingBits{
+              drcmp != Ordering::Less, drcmp != Ordering::Equal};
           std::uint64_t exponent{Exponent() - y.Exponent() + exponentBias};
           result.flags |=
               result.value.Normalize(isNegative, exponent, qr.quotient);
@@ -369,10 +355,77 @@ private:
   using Fraction = Integer<precision>;  // all bits made explicit
   using Significand = Integer<significandBits>;  // no implicit bit
 
-  struct RoundingBits {
-    RoundingBits() {}
-    RoundingBits(const RoundingBits &) = default;
-    RoundingBits &operator=(const RoundingBits &) = default;
+  class RoundingBits {
+  public:
+    constexpr RoundingBits(
+        bool guard = false, bool round = false, bool sticky = false)
+      : guard_{guard}, round_{round}, sticky_{sticky} {}
+
+    template<typename FRACTION>
+    constexpr RoundingBits(const FRACTION &fraction, int rshift) {
+      if (rshift > 0 && rshift < fraction.bits + 1) {
+        guard_ = fraction.BTEST(rshift - 1);
+      }
+      if (rshift > 1 && rshift < fraction.bits + 2) {
+        round_ = fraction.BTEST(rshift - 2);
+      }
+      if (rshift > 2) {
+        if (rshift >= fraction.bits + 2) {
+          sticky_ = !fraction.IsZero();
+        } else {
+          auto mask = fraction.MASKR(rshift - 2);
+          sticky_ = !fraction.IAND(mask).IsZero();
+        }
+      }
+    }
+
+    constexpr bool Zero() const { return !(guard_ | round_ | sticky_); }
+
+    constexpr bool Negate() {
+      bool carry{!sticky_};
+      if (carry) {
+        carry = !round_;
+      } else {
+        round_ = !round_;
+      }
+      if (carry) {
+        carry = !guard_;
+      } else {
+        guard_ = !guard_;
+      }
+      return carry;
+    }
+
+    constexpr bool ShiftLeft() {
+      bool oldGuard{guard_};
+      guard_ = round_;
+      round_ = sticky_;
+      return oldGuard;
+    }
+
+    constexpr void ShiftRight(bool newGuard) {
+      sticky_ |= round_;
+      round_ = guard_;
+      guard_ = newGuard;
+    }
+
+    // Determines whether a value should be rounded by increasing its
+    // fraction, given a rounding mode and a summary of the lost bits.
+    constexpr bool MustRound(Rounding rounding, const Real &real) const {
+      bool round{false};  // to dodge bogus g++ warning about missing return
+      switch (rounding) {
+      case Rounding::TiesToEven:
+        round = guard_ && (round_ | sticky_ | real.RawBits().BTEST(0));
+        break;
+      case Rounding::ToZero: break;
+      case Rounding::Down: round = real.IsNegative() && !Zero(); break;
+      case Rounding::Up: round = !real.IsNegative() && !Zero(); break;
+      case Rounding::TiesAwayFromZero: round = guard_; break;
+      }
+      return round;
+    }
+
+  private:
     bool guard_{false};
     bool round_{false};
     bool sticky_{false};
@@ -396,27 +449,6 @@ private:
     }
   }
 
-  template<typename INT>
-  static constexpr RoundingBits GetRoundingBits(
-      const INT &fraction, int rshift) {
-    RoundingBits roundingBits;
-    if (rshift > 0 && rshift < fraction.bits + 1) {
-      roundingBits.guard_ = fraction.BTEST(rshift - 1);
-    }
-    if (rshift > 1 && rshift < fraction.bits + 2) {
-      roundingBits.round_ = fraction.BTEST(rshift - 2);
-    }
-    if (rshift > 2) {
-      if (rshift >= fraction.bits + 2) {
-        roundingBits.sticky_ = !fraction.IsZero();
-      } else {
-        auto mask = fraction.MASKR(rshift - 2);
-        roundingBits.sticky_ = !fraction.IAND(mask).IsZero();
-      }
-    }
-    return roundingBits;
-  }
-
   // TODO: Configurable NaN representations
   static constexpr Word NaNWord() {
     return Word{maxExponent}
@@ -452,11 +484,9 @@ private:
           word_ = word_.SHIFTL(lshift);
           if (roundingBits != nullptr) {
             for (; lshift > 0; --lshift) {
-              if (roundingBits->guard_) {
+              if (roundingBits->ShiftLeft()) {
                 word_ = word_.IBSET(lshift - 1);
               }
-              roundingBits->guard_ = roundingBits->round_;
-              roundingBits->round_ = roundingBits->sticky_;
             }
           }
         }
@@ -472,35 +502,14 @@ private:
     }
   }
 
-  // Determines whether a value should be rounded by increasing its
-  // fraction, given a rounding mode and a summary of the lost bits.
-  constexpr bool MustRound(Rounding rounding, const RoundingBits &bits) const {
-    bool round{false};  // to dodge bogus g++ warning about missing return
-    bool roundOrSticky{bits.round_ | bits.sticky_};
-    switch (rounding) {
-    case Rounding::TiesToEven:
-      round = bits.guard_ && (roundOrSticky || word_.BTEST(0));
-      break;
-    case Rounding::ToZero: break;
-    case Rounding::Down:
-      round = IsNegative() && (bits.guard_ || roundOrSticky);
-      break;
-    case Rounding::Up:
-      round = !IsNegative() && (bits.guard_ || roundOrSticky);
-      break;
-    case Rounding::TiesAwayFromZero: round = bits.guard_; break;
-    }
-    return round;
-  }
-
   // Rounds a result, if necessary.
   RealFlags Round(Rounding rounding, const RoundingBits &bits) {
     std::uint64_t exponent{Exponent()};
     RealFlags flags;
-    if (bits.guard_ | bits.round_ | bits.sticky_) {
+    if (!bits.Zero()) {
       flags.set(RealFlag::Inexact);
     }
-    if (exponent < maxExponent && MustRound(rounding, bits)) {
+    if (exponent < maxExponent && bits.MustRound(rounding, *this)) {
       typename Fraction::ValueWithCarry sum{
           GetFraction().AddUnsigned(Fraction{}, true)};
       if (sum.carry) {
index f5d0fd1..f01b117 100644 (file)
@@ -24,7 +24,7 @@ template<typename R> void tests() {
   char desc[64];
   using Word = typename R::Word;
   std::snprintf(desc, sizeof desc, "bits=%d, le=%d",
-               R::bits, Word::littleEndian);
+                R::bits, Word::littleEndian);
   R zero;
   TEST(!zero.IsNegative())(desc);
   TEST(!zero.IsNotANumber())(desc);
@@ -175,7 +175,6 @@ void subset32bit() {
         std::uint32_t check = sum.value.RawBits().ToUInt64();
         MATCH(rcheck, check)("0x%x + 0x%x", rj, rk);
       }
-#if 0
       { ValueWithRealFlags<RealKind4> diff{x.Subtract(y)};
         ScopedHostFloatingPointEnvironment fpenv;
         float fcheck{fj - fk};
@@ -184,6 +183,7 @@ void subset32bit() {
         std::uint32_t check = diff.value.RawBits().ToUInt64();
         MATCH(rcheck, check)("0x%x - 0x%x", rj, rk);
       }
+#if 0
       { ValueWithRealFlags<RealKind4> prod{x.Multiply(y)};
         ScopedHostFloatingPointEnvironment fpenv;
         float fcheck{fj * fk};