[flang] Continue conversion to value semantics
authorpeter klausler <pklausler@nvidia.com>
Thu, 31 May 2018 17:45:41 +0000 (10:45 -0700)
committerpeter klausler <pklausler@nvidia.com>
Thu, 14 Jun 2018 20:52:21 +0000 (13:52 -0700)
Original-commit: flang-compiler/f18@03fe2666110ce464da49190e83c703031edee932
Reviewed-on: https://github.com/flang-compiler/f18/pull/101
Tree-same-pre-rewrite: false

flang/lib/evaluate/fixed-point.h
flang/test/evaluate/fixed-point-test.cc
flang/test/evaluate/testing.cc
flang/test/evaluate/testing.h

index 57cf58b..a2e082a 100644 (file)
@@ -47,6 +47,8 @@ static constexpr Ordering Reverse(Ordering ordering) {
 // To facilitate exhaustive testing of what would otherwise be more rare
 // edge cases, this class template may be configured to use other part
 // types &/or partial fields in the parts.
+// Member functions that correspond to Fortran intrinsic functions are
+// named accordingly.
 template<int BITS, int PARTBITS = 32, typename PART = std::uint32_t,
     typename BIGPART = std::uint64_t, bool LITTLE_ENDIAN = true>
 class FixedPoint {
@@ -73,35 +75,72 @@ private:
   static constexpr Part topPartMask{static_cast<Part>(~0) >> extraTopPartBits};
 
 public:
+  // Constructors and value-generating static functions
   constexpr FixedPoint() { Clear(); }  // default constructor: zero
   constexpr FixedPoint(const FixedPoint &) = default;
   constexpr FixedPoint(std::uint64_t n) {
     for (int j{0}; j + 1 < parts; ++j) {
-      LEPart(j) = n & partMask;
+      SetLEPart(j, n);
       if constexpr (partBits < 64) {
         n >>= partBits;
       } else {
         n = 0;
       }
     }
-    LEPart(parts - 1) = n & topPartMask;
+    SetLEPart(parts - 1, n);
   }
   constexpr FixedPoint(std::int64_t n) {
     std::int64_t signExtension{-(n < 0)};
     signExtension <<= partBits;
     for (int j{0}; j + 1 < parts; ++j) {
-      LEPart(j) = n & partMask;
+      SetLEPart(j, n);
       if constexpr (partBits < 64) {
         n = (n >> partBits) | signExtension;
       } else {
         n = signExtension;
       }
     }
-    LEPart(parts - 1) = n & topPartMask;
+    SetLEPart(parts - 1, n);
+  }
+
+  // Right-justified mask (e.g., MASKR(1) == 1, MASKR(2) == 3, &c.)
+  static constexpr FixedPoint MASKR(int places) {
+    FixedPoint result{nullptr};
+    int j{0};
+    for (; j + 1 < parts && places >= partBits; ++j, places -= partBits) {
+      result.LEPart(j) = partMask;
+    }
+    if (places > 0) {
+      if (j + 1 < parts) {
+        result.LEPart(j++) = partMask >> (partBits - places);
+      } else if (j + 1 == parts) {
+        if (places >= topPartBits) {
+          result.LEPart(j++) = topPartMask;
+        } else {
+          result.LEPart(j++) = topPartMask >> (topPartBits - places);
+        }
+      }
+    }
+    for (; j < parts; ++j) {
+      result.LEPart(j) = 0;
+    }
+    return result;
+  }
+
+  // Left-justified mask (e.g., MASKL(1) has only its sign bit set)
+  static constexpr FixedPoint MASKL(int places) {
+    if (places < 0) {
+      return {};
+    } else if (places >= bits) {
+      return MASKR(bits);
+    } else {
+      return MASKR(bits - places).NOT();
+    }
   }
 
   constexpr FixedPoint &operator=(const FixedPoint &) = default;
 
+  // Predicates and comparisons
   constexpr bool IsZero() const {
     for (int j{0}; j < parts; ++j) {
       if (part_[j] != 0) {
@@ -162,27 +201,32 @@ public:
     return signExtended;
   }
 
-  // NOT
-  constexpr FixedPoint OnesComplement() const {
+  // Ones'-complement (i.e., C's ~)
+  constexpr FixedPoint NOT() const {
     FixedPoint result{nullptr};
-    for (int j{0}; j + 1 < parts; ++j) {
-      result.LEPart(j) = ~LEPart(j) & partMask;
+    for (int j{0}; j < parts; ++j) {
+      result.SetLEPart(j, ~LEPart(j));
     }
-    result.LEPart(parts - 1) = ~LEPart(parts - 1) & topPartMask;
     return result;
   }
 
-  // Returns true on overflow (i.e., negating the most negative signed number)
-  constexpr bool TwosComplement() {
+  // Two's-complement negation (-x = ~x + 1).
+  struct ValueWithOverflow {
+    FixedPoint value;
+    bool overflow;  // true when operand was MASKL(1), the most negative number
+  };
+  constexpr ValueWithOverflow Negate() const {
+    FixedPoint result;
     Part carry{1};
     for (int j{0}; j + 1 < parts; ++j) {
       Part newCarry{LEPart(j) == 0 && carry};
-      LEPart(j) = (~LEPart(j) + carry) & partMask;
+      result.SetLEPart(j, ~LEPart(j) + carry);
       carry = newCarry;
     }
-    Part before{LEPart(parts - 1)};
-    LEPart(parts - 1) = (~before + carry) & topPartMask;
-    return before != 0 && LEPart(parts - 1) == before;
+    Part top{LEPart(parts - 1)};
+    result.SetLEPart(parts - 1, ~top + carry);
+    bool overflow{top != 0 && result.LEPart(parts - 1) == top};
+    return {result, overflow};
   }
 
   // LEADZ intrinsic
@@ -216,19 +260,18 @@ public:
       int j{parts - 1};
       if (bitShift == 0) {
         for (; j >= shiftParts; --j) {
-          LEPart(j) = LEPart(j - shiftParts) & PartMask(j);
+          SetLEPart(j, LEPart(j - shiftParts));
         }
         for (; j >= 0; --j) {
           LEPart(j) = 0;
         }
       } else {
         for (; j > shiftParts; --j) {
-          LEPart(j) = ((LEPart(j - shiftParts) << bitShift) |
-                         (LEPart(j - shiftParts - 1) >> (partBits - bitShift))) &
-              PartMask(j);
+          SetLEPart(j, ((LEPart(j - shiftParts) << bitShift) |
+                       (LEPart(j - shiftParts - 1) >> (partBits - bitShift))));
         }
         if (j == shiftParts) {
-          LEPart(j) = (LEPart(0) << bitShift) & PartMask(j);
+          SetLEPart(j, LEPart(0) << bitShift);
           --j;
         }
         for (; j >= 0; --j) {
@@ -259,9 +302,8 @@ public:
         }
       } else {
         for (; j + shiftParts + 1 < parts; ++j) {
-          LEPart(j) = ((LEPart(j + shiftParts) >> bitShift) |
-                         (LEPart(j + shiftParts + 1) << (partBits - bitShift))) &
-              partMask;
+          SetLEPart(j, (LEPart(j + shiftParts) >> bitShift) |
+                       (LEPart(j + shiftParts + 1) << (partBits - bitShift)));
         }
         if (j + shiftParts + 1 == parts) {
           LEPart(j++) = LEPart(parts - 1) >> bitShift;
@@ -282,9 +324,7 @@ public:
       bool fill{IsNegative()};
       ShiftRightLogical(count);
       if (fill) {
-        FixedPoint signs;
-        signs.LeftMask(count);
-        Or(signs);
+        Or(MASKL(count));
       }
     }
   }
@@ -316,12 +356,12 @@ public:
     for (int j{0}; j + 1 < parts; ++j) {
       carry += LEPart(j);
       carry += y.LEPart(j);
-      LEPart(j) = carry & partMask;
+      SetLEPart(j, carry);
       carry >>= partBits;
     }
     carry += LEPart(parts - 1);
     carry += y.LEPart(parts - 1);
-    LEPart(parts - 1) = carry & topPartMask;
+    SetLEPart(parts - 1, carry);
     return carry > topPartMask;
   }
 
@@ -337,9 +377,7 @@ public:
   constexpr bool SubtractSigned(const FixedPoint &y) {
     bool isNegative{IsNegative()};
     bool sameSign{isNegative == y.IsNegative()};
-    FixedPoint minusy{y};
-    minusy.TwosComplement();
-    AddUnsigned(minusy);
+    AddUnsigned(y.Negate().value);
     return !sameSign && IsNegative() != isNegative;
   }
 
@@ -377,16 +415,16 @@ public:
     bool yIsNegative{y.IsNegative()};
     FixedPoint yprime{y};
     if (yIsNegative) {
-      yprime.TwosComplement();
+      yprime = y.Negate().value;
     }
     bool isNegative{IsNegative()};
     if (isNegative) {
-      TwosComplement();
+      *this = Negate().value;
     }
     MultiplyUnsigned(yprime, upper);
     if (isNegative != yIsNegative) {
-      *this = OnesComplement();
-      upper = upper.OnesComplement();
+      *this = NOT();
+      upper = upper.NOT();
       FixedPoint one{std::uint64_t{1}};
       if (AddUnsigned(one)) {
         upper.AddUnsigned(one);
@@ -399,7 +437,7 @@ public:
       const FixedPoint &divisor, FixedPoint &remainder) {
     remainder.Clear();
     if (divisor.IsZero()) {
-      RightMask(bits);
+      *this = MASKR(bits);
       return true;
     }
     FixedPoint top{*this};
@@ -428,10 +466,11 @@ public:
     Ordering divisorOrdering{divisor.CompareToZeroSigned()};
     if (divisorOrdering == Ordering::Less) {
       negateQuotient = !negateQuotient;
-      if (divisor.TwosComplement()) {
+      auto negated{divisor.Negate()};
+      if (negated.overflow) {
         // divisor was (and is) the most negative number
         if (CompareUnsigned(divisor) == Ordering::Equal) {
-          RightMask(1);
+          *this = MASKR(1);
           remainder.Clear();
           return bits <= 1;  // edge case: 1-bit signed numbers overflow on 1!
         } else {
@@ -440,18 +479,20 @@ public:
           return false;
         }
       }
+      divisor = negated.value;
     } else if (divisorOrdering == Ordering::Equal) {
       // division by zero
       remainder.Clear();
       if (dividendIsNegative) {
-        LeftMask(1);  // most negative signed number
+        *this = MASKL(1);  // most negative signed number
       } else {
-        RightMask(bits - 1);  // most positive signed number
+        *this = MASKR(bits - 1);  // most positive signed number
       }
       return true;
     }
     if (dividendIsNegative) {
-      if (TwosComplement()) {
+      auto negated{Negate()};
+      if (negated.overflow) {
         // Dividend was (and remains) the most negative number.
         // See whether the original divisor was -1 (if so, it's 1 now).
         if (divisorOrdering == Ordering::Less &&
@@ -461,16 +502,18 @@ public:
           remainder.Clear();
           return true;
         }
+      } else {
+        *this = negated.value;
       }
     }
     // Overflow is not possible, and both the dividend (*this) and divisor
     // are now positive.
     DivideUnsigned(divisor, remainder);
     if (negateQuotient) {
-      TwosComplement();
+      *this = Negate().value;
     }
     if (dividendIsNegative) {
-      remainder.TwosComplement();
+      remainder = remainder.Negate().value;
     }
     return false;
   }
@@ -489,40 +532,6 @@ public:
     return overflow;
   }
 
-  // MASKR intrinsic
-  constexpr void RightMask(int places) {
-    int j{0};
-    for (; j + 1 < parts && places >= partBits; ++j, places -= partBits) {
-      LEPart(j) = partMask;
-    }
-    if (places > 0) {
-      if (j + 1 < parts) {
-        LEPart(j++) = partMask >> (partBits - places);
-      } else if (j + 1 == parts) {
-        if (places >= topPartBits) {
-          LEPart(j++) = topPartMask;
-        } else {
-          LEPart(j++) = topPartMask >> (topPartBits - places);
-        }
-      }
-    }
-    for (; j < parts; ++j) {
-      LEPart(j) = 0;
-    }
-  }
-
-  // MASKL intrinsic
-  constexpr void LeftMask(int places) {
-    if (places < 0) {
-      Clear();
-    } else if (places >= bits) {
-      RightMask(bits);
-    } else {
-      RightMask(bits - places);
-      *this = OnesComplement();
-    }
-  }
-
 private:
   constexpr FixedPoint(std::nullptr_t) {}  // does not initialize
 
@@ -543,6 +552,10 @@ private:
     }
   }
 
+  constexpr void SetLEPart(int part, Part x) {
+    LEPart(part) = x & PartMask(part);
+  }
+
   static constexpr Part PartMask(int part) {
     return part == parts - 1 ? topPartMask : partMask;
   }
index 06697aa..496fe83 100644 (file)
@@ -32,24 +32,23 @@ template<int BITS, typename FP = FixedPoint<BITS>> void exhaustiveTesting() {
   TEST(zero.IsZero())(desc);
   for (std::uint64_t x{0}; x <= maxUnsignedValue; ++x) {
     FP a{x};
-    COMPARE(x, ==, a.ToUInt64())(desc);
+    MATCH(x, a.ToUInt64())(desc);
     FP copy{a};
-    COMPARE(x, ==, copy.ToUInt64())(desc);
+    MATCH(x, copy.ToUInt64())(desc);
     copy = a;
-    COMPARE(x, ==, copy.ToUInt64())(desc);
-    COMPARE(x == 0, ==, a.IsZero())("%s, x=0x%llx", desc, x);
-    FP t{a.OnesComplement()};
-    COMPARE(x ^ maxUnsignedValue, ==, t.ToUInt64())("%s, x=0x%llx", desc, x);
-    copy = a;
-    bool over{copy.TwosComplement()};
-    COMPARE(over, ==, x == std::uint64_t{1} << (BITS - 1))
-    ("%s, x=0x%llx", desc, x);
-    COMPARE(-x & maxUnsignedValue, ==, copy.ToUInt64())
-    ("%s, x=0x%llx", desc, x);
+    MATCH(x, copy.ToUInt64())(desc);
+    MATCH(x == 0, a.IsZero())("%s, x=0x%llx", desc, x);
+    FP t{a.NOT()};
+    MATCH(x ^ maxUnsignedValue, t.ToUInt64())("%s, x=0x%llx", desc, x);
+    auto negated{a.Negate()};
+    MATCH(x == std::uint64_t{1} << (BITS - 1), negated.overflow)
+      ("%s, x=0x%llx", desc, x);
+    MATCH(negated.value.ToUInt64(), -x & maxUnsignedValue)
+      ("%s, x=0x%llx", desc, x);
     int lzbc{a.LeadingZeroBitCount()};
     COMPARE(lzbc, >=, 0)("%s, x=0x%llx", desc, x);
     COMPARE(lzbc, <=, BITS)("%s, x=0x%llx", desc, x);
-    COMPARE(x == 0, ==, lzbc == BITS)("%s, x=0x%llx, lzbc=%d", desc, x, lzbc);
+    MATCH(x == 0, lzbc == BITS)("%s, x=0x%llx, lzbc=%d", desc, x, lzbc);
     std::uint64_t lzcheck{std::uint64_t{1} << (BITS - lzbc)};
     COMPARE(x, <, lzcheck)("%s, x=0x%llx, lzbc=%d", desc, x, lzbc);
     COMPARE(x + x + !x, >=, lzcheck)("%s, x=0x%llx, lzbc=%d", desc, x, lzbc);
@@ -74,26 +73,26 @@ template<int BITS, typename FP = FixedPoint<BITS>> void exhaustiveTesting() {
     for (int count{0}; count <= BITS + 1; ++count) {
       copy = a;
       copy.ShiftLeft(count);
-      COMPARE((x << count) & maxUnsignedValue, ==, copy.ToUInt64())
-      ("%s, x=0x%llx, count=%d", desc, x, count);
+      MATCH((x << count) & maxUnsignedValue, copy.ToUInt64())
+        ("%s, x=0x%llx, count=%d", desc, x, count);
       copy = a;
       copy.ShiftRightLogical(count);
-      COMPARE(x >> count, ==, copy.ToUInt64())
-      ("%s, x=0x%llx, count=%d", desc, x, count);
+      MATCH(x >> count, copy.ToUInt64())
+        ("%s, x=0x%llx, count=%d", desc, x, count);
       copy = a;
       copy.ShiftLeft(-count);
-      COMPARE(x >> count, ==, copy.ToUInt64())
-      ("%s, x=0x%llx, count=%d", desc, x, count);
+      MATCH(x >> count, copy.ToUInt64())
+        ("%s, x=0x%llx, count=%d", desc, x, count);
       copy = a;
       copy.ShiftRightLogical(-count);
-      COMPARE((x << count) & maxUnsignedValue, ==, copy.ToUInt64())
-      ("%s, x=0x%llx, count=%d", desc, x, count);
+      MATCH((x << count) & maxUnsignedValue, copy.ToUInt64())
+        ("%s, x=0x%llx, count=%d", desc, x, count);
       copy = a;
       copy.ShiftRightArithmetic(count);
       std::uint64_t fill{-(x >> (BITS-1))};
       std::uint64_t sra{count >= BITS ? fill : (x >> count) | (fill << (BITS-count))};
-      COMPARE(sra, ==, copy.ToInt64())
-      ("%s, x=0x%llx, count=%d", desc, x, count);
+      MATCH(sra, copy.ToInt64())
+        ("%s, x=0x%llx, count=%d", desc, x, count);
     }
     for (std::uint64_t y{0}; y <= maxUnsignedValue; ++y) {
       std::int64_t sy = y;
@@ -117,84 +116,82 @@ template<int BITS, typename FP = FixedPoint<BITS>> void exhaustiveTesting() {
         ord = Ordering::Equal;
       }
       TEST(a.CompareSigned(b) == ord)
-      ("%s, x=0x%llx %lld %d, y=0x%llx %lld %d", desc, x, sx, a.IsNegative(), y,
-          sy, b.IsNegative());
+        ("%s, x=0x%llx %lld %d, y=0x%llx %lld %d", desc, x, sx,
+         a.IsNegative(), y, sy, b.IsNegative());
       copy = a;
       copy.And(b);
-      COMPARE(x & y, ==, copy.ToUInt64())("%s, x=0x%llx, y=0x%llx", desc, x, y);
+      MATCH(x & y, copy.ToUInt64())("%s, x=0x%llx, y=0x%llx", desc, x, y);
       copy = a;
       copy.Or(b);
-      COMPARE(x | y, ==, copy.ToUInt64())("%s, x=0x%llx, y=0x%llx", desc, x, y);
+      MATCH(x | y, copy.ToUInt64())("%s, x=0x%llx, y=0x%llx", desc, x, y);
       copy = a;
       copy.Xor(b);
-      COMPARE(x ^ y, ==, copy.ToUInt64())("%s, x=0x%llx, y=0x%llx", desc, x, y);
+      MATCH(x ^ y, copy.ToUInt64())("%s, x=0x%llx, y=0x%llx", desc, x, y);
       copy = a;
       bool carry{copy.AddUnsigned(b)};
       COMPARE(x + y, ==, copy.ToUInt64() + (std::uint64_t{carry} << BITS))
-      ("%s, x=0x%llx, y=0x%llx, carry=%d", desc, x, y, carry);
+        ("%s, x=0x%llx, y=0x%llx, carry=%d", desc, x, y, carry);
       copy = a;
-      over = copy.AddSigned(b);
-      COMPARE((sx + sy) & maxUnsignedValue, ==, copy.ToUInt64())
-      ("%s, x=0x%llx, y=0x%llx", desc, x, y);
-      COMPARE(over, ==,
+      bool over{copy.AddSigned(b)};
+      MATCH((sx + sy) & maxUnsignedValue, copy.ToUInt64())
+        ("%s, x=0x%llx, y=0x%llx", desc, x, y);
+      MATCH(over,
           sx + sy < mostNegativeSignedValue || sx + sy > maxPositiveSignedValue)
-      ("%s, x=0x%llx, y=0x%llx", desc, x, y);
+        ("%s, x=0x%llx, y=0x%llx", desc, x, y);
       copy = a;
       over = copy.SubtractSigned(b);
-      COMPARE((sx - sy) & maxUnsignedValue, ==, copy.ToUInt64())
-      ("%s, x=0x%llx, y=0x%llx", desc, x, y);
-      COMPARE(over, ==,
+      MATCH((sx - sy) & maxUnsignedValue, copy.ToUInt64())
+        ("%s, x=0x%llx, y=0x%llx", desc, x, y);
+      MATCH(over,
           sx - sy < mostNegativeSignedValue || sx - sy > maxPositiveSignedValue)
-      ("%s, x=0x%llx, y=0x%llx", desc, x, y);
+        ("%s, x=0x%llx, y=0x%llx", desc, x, y);
       copy = a;
       FP upper;
       copy.MultiplyUnsigned(b, upper);
-      COMPARE(x * y, ==, (upper.ToUInt64() << BITS) ^ copy.ToUInt64())
-      ("%s, x=0x%llx, y=0x%llx, lower=0x%llx, upper=0x%llx", desc, x, y,
+      MATCH(x * y, (upper.ToUInt64() << BITS) ^ copy.ToUInt64())
+        ("%s, x=0x%llx, y=0x%llx, lower=0x%llx, upper=0x%llx", desc, x, y,
           copy.ToUInt64(), upper.ToUInt64());
       copy = a;
       copy.MultiplySigned(b, upper);
-      COMPARE((sx * sy) & maxUnsignedValue, ==, copy.ToUInt64())
-      ("%s, x=0x%llx, y=0x%llx", desc, x, y);
-      COMPARE(((sx * sy) >> BITS) & maxUnsignedValue, ==, upper.ToUInt64())
-      ("%s, x=0x%llx, y=0x%llx", desc, x, y);
+      MATCH((sx * sy) & maxUnsignedValue, copy.ToUInt64())
+        ("%s, x=0x%llx, y=0x%llx", desc, x, y);
+      MATCH(((sx * sy) >> BITS) & maxUnsignedValue, upper.ToUInt64())
+        ("%s, x=0x%llx, y=0x%llx", desc, x, y);
       copy = a;
       FP rem;
-      COMPARE(y == 0, ==, copy.DivideUnsigned(b, rem))
-      ("%s, x=0x%llx, y=0x%llx", desc, x, y);
-      if (y == 0) {
-        COMPARE(maxUnsignedValue, ==, copy.ToUInt64())
+      MATCH(y == 0, copy.DivideUnsigned(b, rem))
         ("%s, x=0x%llx, y=0x%llx", desc, x, y);
-        COMPARE(0, ==, rem.ToUInt64())("%s, x=0x%llx, y=0x%llx", desc, x, y);
+      if (y == 0) {
+        MATCH(maxUnsignedValue, copy.ToUInt64())
+          ("%s, x=0x%llx, y=0x%llx", desc, x, y);
+        MATCH(0, rem.ToUInt64())("%s, x=0x%llx, y=0x%llx", desc, x, y);
       } else {
-        COMPARE(x / y, ==, copy.ToUInt64())
-        ("%s, x=0x%llx, y=0x%llx", desc, x, y);
-        COMPARE(x % y, ==, rem.ToUInt64())
-        ("%s, x=0x%llx, y=0x%llx", desc, x, y);
+        MATCH(x / y, copy.ToUInt64())("%s, x=0x%llx, y=0x%llx", desc, x, y);
+        MATCH(x % y, rem.ToUInt64())("%s, x=0x%llx, y=0x%llx", desc, x, y);
       }
       copy = a;
       bool badCase{sx == mostNegativeSignedValue &&
           ((sy == -1 && sx != sy) || (BITS == 1 && sx == sy))};
-      COMPARE(y == 0 || badCase, ==, copy.DivideSigned(b, rem))
-      ("%s, x=0x%llx, y=0x%llx", desc, x, y);
+      MATCH(y == 0 || badCase, copy.DivideSigned(b, rem))
+        ("%s, x=0x%llx, y=0x%llx", desc, x, y);
       if (y == 0) {
         if (sx >= 0) {
-          COMPARE(maxPositiveSignedValue, ==, copy.ToInt64())
+          MATCH(maxPositiveSignedValue, copy.ToInt64())
           ("%s, x=0x%llx, y=0x%llx", desc, x, y);
         } else {
-          COMPARE(mostNegativeSignedValue, ==, copy.ToInt64())
+          MATCH(mostNegativeSignedValue, copy.ToInt64())
           ("%s, x=0x%llx, y=0x%llx", desc, x, y);
         }
-        COMPARE(0, ==, rem.ToUInt64())("%s, x=0x%llx, y=0x%llx", desc, x, y);
+        MATCH(0, rem.ToUInt64())("%s, x=0x%llx, y=0x%llx", desc, x, y);
       } else if (badCase) {
-        COMPARE(x, ==, copy.ToUInt64())("%s, x=0x%llx, y=0x%llx", desc, x, y);
-        COMPARE(0, ==, rem.ToUInt64())("%s, x=0x%llx, y=0x%llx", desc, x, y);
+        MATCH(x, copy.ToUInt64())("%s, x=0x%llx, y=0x%llx", desc, x, y);
+        MATCH(0, rem.ToUInt64())("%s, x=0x%llx, y=0x%llx", desc, x, y);
       } else {
-        COMPARE(sx / sy, ==, copy.ToInt64())
-        ("%s, x=0x%llx %lld, y=0x%llx %lld; unsigned 0x%llx", desc, x, sx, y,
+        MATCH(sx / sy, copy.ToInt64())
+          ("%s, x=0x%llx %lld, y=0x%llx %lld; unsigned 0x%llx", desc, x, sx, y,
             sy, copy.ToUInt64());
-        COMPARE(sx - sy * (sx / sy), ==, rem.ToInt64())
-        ("%s, x=0x%llx, y=0x%llx", desc, x, y);
+        MATCH(sx - sy * (sx / sy), rem.ToInt64())
+          ("%s, x=0x%llx, y=0x%llx", desc, x, y);
       }
     }
   }
index 7d9af8a..0b8148f 100644 (file)
@@ -48,6 +48,19 @@ FailureDetailPrinter Test(
   }
 }
 
+FailureDetailPrinter Match(const char *file, int line,
+    unsigned long long want, const char *gots, unsigned long long got) {
+  if (want == got) {
+    ++passes;
+    return BitBucket;
+  } else {
+    ++failures;
+    fprintf(stderr, "%s:%d: FAIL: %s == 0x%llx, not 0x%llx\n", file, line, gots,
+        got, want);
+    return PrintFailureDetails;
+  }
+}
+
 FailureDetailPrinter Compare(const char *file, int line, const char *xs,
     const char *rel, const char *ys, unsigned long long x,
     unsigned long long y) {
@@ -77,7 +90,7 @@ FailureDetailPrinter Compare(const char *file, int line, const char *xs,
     return BitBucket;
   } else {
     ++failures;
-    fprintf(stderr, "%s:%d: FAIL %s[0x%llx] %s %s[0x%llx]:\n", file, line, xs,
+    fprintf(stderr, "%s:%d: FAIL %s[0x%llx] %s %s[0x%llx]\n", file, line, xs,
         x, rel, ys, y);
     return PrintFailureDetails;
   }
index 87ef9d3..db1e7b4 100644 (file)
@@ -27,6 +27,8 @@ int Complete();
 // will also print z after the usual failure message if x != y.
 #define TEST(predicate) \
   testing::Test(__FILE__, __LINE__, #predicate, (predicate))
+#define MATCH(want, got) \
+  testing::Match(__FILE__, __LINE__, (want), #got, (got))
 #define COMPARE(x, rel, y) \
   testing::Compare(__FILE__, __LINE__, #x, #rel, #y, (x), (y))
 
@@ -34,6 +36,8 @@ int Complete();
 using FailureDetailPrinter = void (*)(const char *, ...);
 FailureDetailPrinter Test(
     const char *file, int line, const char *predicate, bool pass);
+FailureDetailPrinter Match(const char *file, int line,
+    unsigned long long want, const char *gots, unsigned long long got);
 FailureDetailPrinter Compare(const char *file, int line, const char *xs,
     const char *rel, const char *ys, unsigned long long x,
     unsigned long long y);