[TF:XLA] Delete the reverseDFS scheduler. With recent improvements to the List schedu...
authorDimitris Vardoulakis <dimvar@google.com>
Fri, 18 May 2018 19:24:20 +0000 (12:24 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 18 May 2018 19:27:38 +0000 (12:27 -0700)
PiperOrigin-RevId: 197183727

tensorflow/compiler/xla/service/hlo_scheduling.cc
tensorflow/compiler/xla/service/hlo_scheduling.h

index 0254581..51c29d4 100644 (file)
@@ -426,10 +426,12 @@ StatusOr<int64> MinimumMemoryForComputation(
   return result.heap_size;
 }
 
-StatusOr<std::vector<const HloInstruction*>> DFSMemorySchedulerImpl(
+StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
     const HloComputation& computation,
     const TuplePointsToAnalysis& points_to_analysis,
-    const LogicalBuffer::SizeFunction& size_function, bool reverse_heuristics) {
+    const LogicalBuffer::SizeFunction& size_function,
+    const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+        memory_by_computation) {
   // This ordering is based on DFS post-order, with a heuristic to decide which
   // operand to visit first.  The heuristic is based on 'extra_users', which is
   // simply users-1 for each instruction.  By subtracting 1, we're saying that
@@ -469,16 +471,15 @@ StatusOr<std::vector<const HloInstruction*>> DFSMemorySchedulerImpl(
     return Status::OK();
   });
   TF_RETURN_IF_ERROR(computation.AcceptWithOperandOrder(
-      &visitor, [&extra_users, &total_sizes, reverse_heuristics](
-                    const HloInstruction* a, const HloInstruction* b) {
-        auto lhs = std::tuple<int64, int64, string>(extra_users[a],
-                                                    total_sizes[a], b->name());
-        auto rhs = std::tuple<int64, int64, string>(extra_users[b],
-                                                    total_sizes[b], a->name());
-
-        // Reverse heuristics. This helps some cases as a different starting
-        // point of gradient descent, see b/78906799 for more context.
-        return reverse_heuristics ? rhs > lhs : lhs > rhs;
+      &visitor, [&extra_users, &total_sizes](const HloInstruction* a,
+                                             const HloInstruction* b) {
+        if (extra_users[a] != extra_users[b]) {
+          return extra_users[a] > extra_users[b];
+        }
+        if (total_sizes[a] != total_sizes[b]) {
+          return total_sizes[a] > total_sizes[b];
+        }
+        return a->name() < b->name();
       }));
   CHECK_EQ(sequence.size(), computation.instruction_count());
   return sequence;
@@ -505,26 +506,6 @@ StatusOr<std::vector<const HloInstruction*>> PostOrderMemoryScheduler(
                                             post_order.end()};
 }
 
-StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
-    const HloComputation& computation,
-    const TuplePointsToAnalysis& points_to_analysis,
-    const LogicalBuffer::SizeFunction& size_function,
-    const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
-        memory_by_computation) {
-  return DFSMemorySchedulerImpl(computation, points_to_analysis, size_function,
-                                /*reverse_heuristics=*/false);
-}
-
-StatusOr<std::vector<const HloInstruction*>> DFSMemorySchedulerReverse(
-    const HloComputation& computation,
-    const TuplePointsToAnalysis& points_to_analysis,
-    const LogicalBuffer::SizeFunction& size_function,
-    const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
-        memory_by_computation) {
-  return DFSMemorySchedulerImpl(computation, points_to_analysis, size_function,
-                                /*reverse_heuristics=*/true);
-}
-
 StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
     const HloComputation& computation,
     const TuplePointsToAnalysis& points_to_analysis,
@@ -568,18 +549,7 @@ StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
   VLOG(2) << "Min-memory post order sequence: "
           << HumanReadableNumBytes(post_order_memory);
 
-  TF_ASSIGN_OR_RETURN(
-      std::vector<const HloInstruction*> reverse_dfs,
-      DFSMemorySchedulerReverse(computation, points_to_analysis, size_function,
-                                memory_by_computation));
-  TF_ASSIGN_OR_RETURN(
-      const int64 reverse_dfs_memory,
-      MinimumMemoryForComputation(computation, reverse_dfs, points_to_analysis,
-                                  size_function));
-  VLOG(2) << "Min-memory reverse_dfs sequence: "
-          << HumanReadableNumBytes(reverse_dfs_memory);
-  auto min_memory = std::min(
-      {dfs_memory, post_order_memory, reverse_dfs_memory, list_memory});
+  auto min_memory = std::min({dfs_memory, post_order_memory, list_memory});
 
   if (min_memory == list_memory) {
     VLOG(2) << "Chose min-memory list sequence: "
@@ -589,10 +559,6 @@ StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
     VLOG(2) << "Chose min-memory dfs sequence: "
             << HumanReadableNumBytes(dfs_memory);
     return dfs_sequence;
-  } else if (min_memory == reverse_dfs_memory) {
-    VLOG(2) << "Chose min-memory reverse_dfs memory: "
-            << HumanReadableNumBytes(reverse_dfs_memory);
-    return reverse_dfs;
   } else {
     VLOG(2) << "Chose min-memory post_order sequence: "
             << HumanReadableNumBytes(post_order_memory);
index 0e5ac20..49b927e 100644 (file)
@@ -76,15 +76,6 @@ StatusOr<std::vector<const HloInstruction*>> PostOrderMemoryScheduler(
     const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
         memory_by_computation);
 
-// DFS-order scheduler with reversed heuristics. This helps some cases (see
-// b/78906799).
-StatusOr<std::vector<const HloInstruction*>> DFSMemorySchedulerReverse(
-    const HloComputation& computation,
-    const TuplePointsToAnalysis& points_to_analysis,
-    const LogicalBuffer::SizeFunction& size_function,
-    const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
-        memory_by_computation);
-
 // The default scheduling algorithm. Runs both the list scheduler
 // and the DFS scheduler, and chooses whichever returns a lower min-memory,
 // not accounting for fragmentation.