[XLA] Use a real priority queue in list scheduling
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 9 Feb 2018 23:26:13 +0000 (15:26 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 9 Feb 2018 23:30:35 +0000 (15:30 -0800)
PiperOrigin-RevId: 185201882

tensorflow/compiler/xla/service/hlo_scheduling.cc

index 2594c29..5f5a930 100644 (file)
@@ -15,6 +15,7 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/service/hlo_scheduling.h"
 
+#include <queue>
 #include <utility>
 #include <vector>
 
@@ -217,32 +218,26 @@ class ListScheduler {
       }
     }
 
-    std::list<ReadyListEntry> ready_list;
+    auto priority_comparator = [this](const ReadyListEntry& lhs,
+                                      const ReadyListEntry& rhs) {
+      return GetPriority(lhs) < GetPriority(rhs);
+    };
+    std::priority_queue<ReadyListEntry, std::vector<ReadyListEntry>,
+                        decltype(priority_comparator)>
+        ready_queue(priority_comparator);
     for (auto* instruction : computation_.instructions()) {
       // Instruction with no operands or control predecessors will
       // not be in the map.
       if (unscheduled_pred_count.count(instruction) == 0) {
-        ready_list.push_back(MakeReadyListEntry(instruction));
+        ready_queue.emplace(MakeReadyListEntry(instruction));
       }
     }
 
-    while (!ready_list.empty()) {
-      // Select the highest priority HLO instruction from the ready list.
-      auto best_it = ready_list.begin();
-      Priority best_priority = GetPriority(*best_it);
-      for (auto ready_it = std::next(ready_list.begin());
-           ready_it != ready_list.end(); ++ready_it) {
-        Priority priority = GetPriority(*ready_it);
-        if (priority > best_priority) {
-          best_it = ready_it;
-          best_priority = priority;
-        }
-      }
-
+    while (!ready_queue.empty()) {
       // Remove the selected instruction from the ready list and add it to the
       // schedule.
-      const HloInstruction* best = best_it->instruction;
-      ready_list.erase(best_it);
+      const HloInstruction* best = ready_queue.top().instruction;
+      ready_queue.pop();
       schedule.push_back(best);
       scheduled_instructions_.insert(best);
 
@@ -257,7 +252,7 @@ class ListScheduler {
         int64 pred_count = --unscheduled_pred_count.at(inst);
         CHECK_GE(pred_count, 0);
         if (pred_count == 0) {
-          ready_list.push_back(MakeReadyListEntry(inst));
+          ready_queue.emplace(MakeReadyListEntry(inst));
         }
       };
       // TODO(b/34466113): Replace this and above with successors() or