[flang] Use naive algorithm for folding complex division when it doesn't over/underflow
authorPeter Klausler <pklausler@nvidia.com>
Wed, 17 Aug 2022 18:34:20 +0000 (11:34 -0700)
committerPeter Klausler <pklausler@nvidia.com>
Thu, 18 Aug 2022 22:11:34 +0000 (15:11 -0700)
f18 unconditionally uses a scaling algorithm for complex/complex division
that avoids needless overflows and underflows when computing the sum of
the squares of the components of the denominator -- but testing has shown
some 1 ULP differences relative to the naive calculation due to the
extra operations and roundings.  So use the scaling algorithm only when
the naive calculation actually would overflow or underflow.

Differential Revision: https://reviews.llvm.org/D132164

flang/lib/Evaluate/complex.cpp

index 73ed5b3..e683d7e 100644 (file)
@@ -47,11 +47,30 @@ template <typename R>
 ValueWithRealFlags<Complex<R>> Complex<R>::Divide(
     const Complex &that, Rounding rounding) const {
   // (a + ib)/(c + id) -> [(a+ib)*(c-id)] / [(c+id)*(c-id)]
-  //   -> [ac+bd+i(bc-ad)] / (cc+dd)
+  //   -> [ac+bd+i(bc-ad)] / (cc+dd)  -- note (cc+dd) is real
   //   -> ((ac+bd)/(cc+dd)) + i((bc-ad)/(cc+dd))
-  // but to avoid overflows, scale by d/c if c>=d, else c/d
-  Part scale; // <= 1.0
   RealFlags flags;
+  Part cc{that.re_.Multiply(that.re_, rounding).AccumulateFlags(flags)};
+  Part dd{that.im_.Multiply(that.im_, rounding).AccumulateFlags(flags)};
+  Part ccPdd{cc.Add(dd, rounding).AccumulateFlags(flags)};
+  if (!flags.test(RealFlag::Overflow) && !flags.test(RealFlag::Underflow)) {
+    // den = (cc+dd) did not overflow or underflow; try the naive
+    // sequence without scaling to avoid extra roundings.
+    Part ac{re_.Multiply(that.re_, rounding).AccumulateFlags(flags)};
+    Part ad{re_.Multiply(that.im_, rounding).AccumulateFlags(flags)};
+    Part bc{im_.Multiply(that.re_, rounding).AccumulateFlags(flags)};
+    Part bd{im_.Multiply(that.im_, rounding).AccumulateFlags(flags)};
+    Part acPbd{ac.Add(bd, rounding).AccumulateFlags(flags)};
+    Part bcSad{bc.Subtract(ad, rounding).AccumulateFlags(flags)};
+    Part re{acPbd.Divide(ccPdd, rounding).AccumulateFlags(flags)};
+    Part im{bcSad.Divide(ccPdd, rounding).AccumulateFlags(flags)};
+    if (!flags.test(RealFlag::Overflow) && !flags.test(RealFlag::Underflow)) {
+      return {Complex{re, im}, flags};
+    }
+  }
+  // Scale numerator and denominator by d/c (if c>=d) or c/d (if c<d)
+  flags.clear();
+  Part scale; // will be <= 1.0 in magnitude
   bool cGEd{that.re_.ABS().Compare(that.im_.ABS()) != Relation::Less};
   if (cGEd) {
     scale = that.im_.Divide(that.re_, rounding).AccumulateFlags(flags);