[XLA] optimize NearComparator#ExpectLiteralsNear()
authorNick Desaulniers <ndesaulniers@google.com>
Fri, 26 Jan 2018 23:01:40 +0000 (15:01 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 26 Jan 2018 23:17:51 +0000 (15:17 -0800)
While tracking down the issue of timeouts when running THE ISOLATOR, it was
observed that NearComparator#ExpectLiteralsNear() could be optimized in the
case of matching layouts to not compute multi indexes.

In the process of tracking down timeouts in THE ISOLATOR, I had assumed that
time spent was dominated by either generating input data, executing the input
data on various backends, or comparing the data. Never assume you know where
the time is spent in a program; the profiler may surprise you.

After making that optimization and then profiling the code before and after, I
was surprised by the profile. Image the shock, horror, and disgust I
experienced when discovering that runs of THE ISOLATOR were dominated (45%) by
calls to Literal#ToString() in NearComparator#ExpectLiteralsNear() for huge
(>120 million elements) literals that failed comparisons.  No wonder passing
shards of THE ISOLATOR were fast, and failing shards were slow.

Further, computing multi indexes many times is expensive enough (18%) to show
up in profiles, so avoid calculating it until it is necessary.

The optimizations in this patch:
* Don't call Literal#ToString() on huge literals that are going to get written
  to disk anyways. The utility of printing said literal to stdout is suspect.
* Initialize NearComparator#miscompares_ to false, only update miscompares_ and
  other stats when miscompare occurs.
* Split NearComparator#ExpectLiteralsNear() into two, since we only need to log
  and update stats if an actual miscompare occurs.
* Add fast path in NearComparator#ExpectLiteralsNear() for case of matching
  layouts, being careful not to compute multi index unless mismatch actually
  occurs.

This optimized NearComparator#ExpectLiteralsNear() for the case of many element
literals, with few miscompares. For many miscompares, we cannot avoid
calculating multi indexes, but can fast path for equal layouts. For zero
miscompares, we can at least fast path in the case of matching layouts.

Before this CL, a run of THE ISOLATOR for a single literal with >120 million
elements and a few miscompares took 379s (6.3m). With this CL, the same test
case now takes 44s.

Beautiful flame graphs omitted from public commit message, regrettably.

PiperOrigin-RevId: 183451138

tensorflow/compiler/xla/literal_util.h
tensorflow/compiler/xla/tests/literal_test_util.cc

index e019650..2b68b8f 100644 (file)
@@ -486,6 +486,7 @@ class Literal {
       std::vector<std::unique_ptr<Literal>> elements);
 
   // Returns a string representation of the literal value.
+  // Warning: this function can take minutes for multi-million element Literals.
   string ToString(bool print_layout = false) const;
 
   // Invokes the "per cell" callback for each element in the provided
index f8205de..39c0729 100644 (file)
@@ -355,9 +355,9 @@ class NearComparator {
   // temporary files on failure. Returns true if  literals match.
   bool ExpectNear(const Literal& expected, const Literal& actual) {
     VLOG(1) << "expected:";
-    XLA_VLOG_LINES(1, expected.ToString());
+    XLA_VLOG_LINES(1, TruncateHugeLiteral(expected));
     VLOG(1) << "actual:";
-    XLA_VLOG_LINES(1, actual.ToString());
+    XLA_VLOG_LINES(1, TruncateHugeLiteral(actual));
 
     // If the shapes mismatch, we simply fail the expectation instead of
     // printing out data, as it's a type error rather than a value error.
@@ -377,6 +377,7 @@ class NearComparator {
     max_rel_err_ = 0.0;
     max_abs_err_ = 0.0;
     miscompares_ = Literal(ShapeUtil::ChangeElementType(actual.shape(), PRED));
+    miscompares_.PopulateWithValue(false);
     multi_index_.resize(expected.shape().dimensions_size(), 0);
 
     switch (expected.shape().element_type()) {
@@ -404,21 +405,33 @@ class NearComparator {
     if (num_miscompares_ > 0) {
       if (!VLOG_IS_ON(1)) {
         LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape())
-                  << " " << expected.ToString();
+                  << " " << TruncateHugeLiteral(expected);
         LOG(INFO) << "actual:   " << ShapeUtil::HumanString(actual.shape())
-                  << " " << actual.ToString();
+                  << " " << TruncateHugeLiteral(actual);
+        LOG(INFO) << "Dumping literals to temp files...";
+        WriteLiteralToTempFile(expected, "expected");
+        WriteLiteralToTempFile(actual, "actual");
+        WriteLiteralToTempFile(miscompares_, "miscompares");
       }
       EXPECT_TRUE(num_miscompares_ == 0)
           << "\nmax relative mismatch at index "
-          << LiteralTestUtil::MultiIndexAsString(max_rel_multi_index_)
+          << LiteralTestUtil::MultiIndexAsString(
+                 IndexUtil::LinearIndexToMultidimensionalIndex(
+                     actual.shape(), max_rel_linear_index_))
           << "\nmaximum relative error " << max_rel_err_
           << "\nmax absolute mismatch at index "
-          << LiteralTestUtil::MultiIndexAsString(max_abs_multi_index_)
+          << LiteralTestUtil::MultiIndexAsString(
+                 IndexUtil::LinearIndexToMultidimensionalIndex(
+                     actual.shape(), max_abs_linear_index_))
           << "\nmaximum absolute error " << max_abs_err_
           << "\nfirst mismatch at index "
-          << LiteralTestUtil::MultiIndexAsString(first_multi_index_)
+          << LiteralTestUtil::MultiIndexAsString(
+                 IndexUtil::LinearIndexToMultidimensionalIndex(
+                     actual.shape(), first_linear_index_))
           << "\nlast mismatch at index "
-          << LiteralTestUtil::MultiIndexAsString(last_multi_index_)
+          << LiteralTestUtil::MultiIndexAsString(
+                 IndexUtil::LinearIndexToMultidimensionalIndex(
+                     actual.shape(), last_linear_index_))
           << "\ntotal absolute error " << abs_diff_sum_
           << "\ntotal absolute error of miscompares "
           << abs_diff_miscompare_sum_ << "\ntotal relative error "
@@ -426,10 +439,6 @@ class NearComparator {
           << "\ntotal relative error of miscompares "
           << (abs_diff_miscompare_sum_ / abs_expected_miscompare_sum_)
           << "\nfailure count " << num_miscompares_;
-
-      WriteLiteralToTempFile(expected, "expected");
-      WriteLiteralToTempFile(actual, "actual");
-      WriteLiteralToTempFile(miscompares_, "miscompares");
     }
     return num_miscompares_ == 0;
   }
@@ -457,57 +466,93 @@ class NearComparator {
       return true;
     }
 
-    float abs_diff = std::abs(actual - expected);
-    float rel_err = abs_diff / std::abs(expected);
+    const float abs_diff = std::abs(actual - expected);
+    const float rel_err = abs_diff / std::abs(expected);
+    const bool nan_mismatch = NanMismatch<NativeT>(expected, actual);
+    const bool mismatch =
+        (nan_mismatch || (abs_diff >= error_.abs && rel_err >= error_.rel));
+    return !mismatch;
+  }
+
+  // Assumes that expected vs actual fail ExpectValuesNear.
+  template <typename NativeT>
+  void UpdateAndLogMiscompares(const NativeT expected, const NativeT actual,
+                               const Shape& shape, const int64 linear_index) {
+    const float abs_diff = std::abs(actual - expected);
+    const float rel_err = abs_diff / std::abs(expected);
     abs_diff_sum_ += abs_diff;
     abs_expected_sum_ += std::abs(expected);
     if (rel_err > max_rel_err_) {
       max_rel_err_ = rel_err;
-      max_rel_multi_index_ = multi_index_;
+      max_rel_linear_index_ = linear_index;
     }
     if (abs_diff > max_abs_err_) {
       max_abs_err_ = abs_diff;
-      max_abs_multi_index_ = multi_index_;
+      max_abs_linear_index_ = linear_index;
     }
-    VLOG(10) << tensorflow::strings::Printf(
-        "index %s abs_diff %f rel_err %f",
-        LiteralTestUtil::MultiIndexAsString(multi_index_).c_str(), abs_diff,
-        rel_err);
-    bool nan_mismatch = NanMismatch<NativeT>(expected, actual);
-    bool mismatch =
-        (nan_mismatch || (abs_diff >= error_.abs && rel_err >= error_.rel));
-    if (mismatch) {
-      abs_diff_miscompare_sum_ += abs_diff;
-      abs_expected_miscompare_sum_ += std::abs(expected);
-      const int64 kMaxFailures = 2;
-      if (num_miscompares_ < kMaxFailures) {
-        ::testing::Message msg;
-        msg << "mismatch at index "
-            << LiteralTestUtil::MultiIndexAsString(multi_index_) << " abs diff "
-            << abs_diff << " rel err " << rel_err << " failure #"
-            << num_miscompares_;
-        ExpectNear<NativeT>(expected, actual, msg);
-      } else if (num_miscompares_ == kMaxFailures) {
-        LOG(ERROR)
-            << "reached max 'loud' failure count; silently proceeding...";
-      }
-      if (num_miscompares_ == 0) {
-        first_multi_index_ = multi_index_;
-      }
-      num_miscompares_++;
-      last_multi_index_ = multi_index_;
+    if (VLOG_IS_ON(10)) {
+      VLOG(10) << tensorflow::strings::Printf(
+          "index %s abs_diff %f rel_err %f",
+          LiteralTestUtil::MultiIndexAsString(
+              IndexUtil::LinearIndexToMultidimensionalIndex(shape,
+                                                            linear_index))
+              .c_str(),
+          abs_diff, rel_err);
     }
-    return !mismatch;
+    abs_diff_miscompare_sum_ += abs_diff;
+    abs_expected_miscompare_sum_ += std::abs(expected);
+    const int64 kMaxFailures = 2;
+    if (num_miscompares_ < kMaxFailures) {
+      const auto multi_index =
+          IndexUtil::LinearIndexToMultidimensionalIndex(shape, linear_index);
+      ::testing::Message msg;
+      msg << "mismatch at index "
+          << LiteralTestUtil::MultiIndexAsString(multi_index) << " abs diff "
+          << abs_diff << " rel err " << rel_err << " failure #"
+          << num_miscompares_;
+      ExpectNear<NativeT>(expected, actual, msg);
+    } else if (num_miscompares_ == kMaxFailures) {
+      LOG(ERROR) << "reached max 'loud' failure count; silently proceeding...";
+    }
+    if (num_miscompares_ == 0) {
+      first_linear_index_ = linear_index;
+    }
+    num_miscompares_++;
+    last_linear_index_ = linear_index;
+    miscompares_.data<bool>()[linear_index] = true;
   }
 
   // Recursive function which compares the two given literals elementwise.
   template <typename NativeT>
   void ExpectLiteralsNear(const Literal& expected, const Literal& actual,
                           int64 dimension) {
+    // Fast path optimization for the case were layouts match.
+    if (LayoutUtil::Equal(actual.shape().layout(), expected.shape().layout())) {
+      tensorflow::gtl::ArraySlice<const NativeT> expected_data =
+          expected.data<NativeT>();
+      tensorflow::gtl::ArraySlice<const NativeT> actual_data =
+          actual.data<NativeT>();
+      const int64 len = expected_data.size();
+      for (int64 i = 0; i < len; ++i) {
+        const bool near = ExpectValuesNear(expected_data[i], actual_data[i]);
+        if (!near) {
+          UpdateAndLogMiscompares<NativeT>(expected_data[i], actual_data[i],
+                                           actual.shape(), i);
+        }
+      }
+      return;
+    }
+
     if (dimension == expected.shape().dimensions_size()) {
       bool near = ExpectValuesNear(expected.Get<NativeT>(multi_index_),
                                    actual.Get<NativeT>(multi_index_));
-      miscompares_.Set<bool>(multi_index_, !near);
+      if (!near) {
+        UpdateAndLogMiscompares<NativeT>(
+            expected.Get<NativeT>(multi_index_),
+            actual.Get<NativeT>(multi_index_), actual.shape(),
+            IndexUtil::MultidimensionalIndexToLinearIndex(actual.shape(),
+                                                          multi_index_));
+      }
     } else {
       for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) {
         multi_index_[dimension] = i;
@@ -528,6 +573,32 @@ class NearComparator {
     LOG(ERROR) << "wrote to " << name << " file: " << filename;
   }
 
+  // Gets the total element count.  For tuples, this is not the count of tuple
+  // elements, but the sum of elements of each tuple element.
+  int64 RecursiveElementCount(const Shape& shape) {
+    if (ShapeUtil::IsTuple(shape)) {
+      const int64 tuple_elements = ShapeUtil::TupleElementCount(shape);
+      int64 total = 0;
+      for (int64 i = 0; i < tuple_elements; ++i) {
+        total +=
+            RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i));
+      }
+      return total;
+    } else {
+      return ShapeUtil::ElementsIn(shape);
+    }
+  }
+
+  // Calling ToString on a literal with over 100 million elements takes around
+  // 3 minutes.  The utility of printing a literal with >1000 elements is
+  // questionable, especially when writing the Literal proto to disk is orders
+  // of magnitude faster.
+  string TruncateHugeLiteral(const Literal& literal) {
+    return RecursiveElementCount(literal.shape()) < 1000
+               ? literal.ToString()
+               : "[TRUNCATED, Literal with more than 1000 values]";
+  }
+
   ErrorSpec error_;
 
   // Number of element miscomparisons encountered so far.
@@ -548,10 +619,10 @@ class NearComparator {
   double abs_expected_miscompare_sum_;
   float max_rel_err_;
   float max_abs_err_;
-  std::vector<int64> first_multi_index_;
-  std::vector<int64> last_multi_index_;
-  std::vector<int64> max_rel_multi_index_;
-  std::vector<int64> max_abs_multi_index_;
+  int64 first_linear_index_;
+  int64 last_linear_index_;
+  int64 max_rel_linear_index_;
+  int64 max_abs_linear_index_;
 };
 
 template <>
@@ -584,6 +655,23 @@ bool NearComparator::ExpectValuesNear<half>(half expected, half actual) {
                           static_cast<float>(std::move(actual)));
 }
 
+template <>
+void NearComparator::UpdateAndLogMiscompares<bfloat16>(
+    const bfloat16 expected, const bfloat16 actual, const Shape& shape,
+    const int64 linear_index) {
+  UpdateAndLogMiscompares(static_cast<float>(expected),
+                          static_cast<float>(actual), shape, linear_index);
+}
+
+template <>
+void NearComparator::UpdateAndLogMiscompares<half>(half expected, half actual,
+                                                   const Shape& shape,
+                                                   const int64 linear_index) {
+  UpdateAndLogMiscompares(static_cast<float>(std::move(expected)),
+                          static_cast<float>(std::move(actual)), shape,
+                          linear_index);
+}
+
 }  // namespace
 
 /* static */ ::testing::AssertionResult LiteralTestUtil::Near(