[XLA] Fix priority queue in HLO scheduling.
authorYuanzhong Xu <yuanzx@google.com>
Thu, 15 Feb 2018 20:08:36 +0000 (12:08 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 15 Feb 2018 20:12:45 +0000 (12:12 -0800)
The priority of an HLO can change during the scheduling. Use immutable values in
priority queue entries, and reinsert an entry if its priority goes up.

PiperOrigin-RevId: 185878562

tensorflow/compiler/xla/service/hlo_scheduling.cc

index 5f5a930dad002c215a5332286ade97ef19cc67af..8dc4d4f7bac1b2007f2b9f60d126fa07e314dac9 100644 (file)
@@ -101,7 +101,7 @@ class ListScheduler {
     // LogicalBuffer is in an operand of the instruction as indicated by
     // points-to analysis.
     for (auto* instruction : computation.instructions()) {
-      std::unordered_set<const LogicalBuffer*> instr_uses;
+      tensorflow::gtl::FlatSet<const LogicalBuffer*> instr_uses;
       for (auto* operand : instruction->operands()) {
         for (const LogicalBuffer* buffer :
              points_to_analysis.GetBuffersDefinedByInstruction(operand)) {
@@ -151,10 +151,8 @@ class ListScheduler {
     int64 bytes_defined;
 
     // For each buffer B used by this instruction, we keep a pair (B, U), where
-    // U is the number of uses of B that have not yet been scheduled.  This pair
-    // is a pointer into the unscheduled_use_count_ map, so it gets updated for
-    // free when we update counts in the map.
-    std::vector<const std::pair<const LogicalBuffer* const, int64>*>
+    // U is the number of uses of B that have not yet been scheduled.
+    std::vector<std::pair<const LogicalBuffer* const, int64>>
         used_buffer_unscheduled_use_counts;
   };
 
@@ -177,8 +175,8 @@ class ListScheduler {
       }
       auto unscheduled_use_count_it = unscheduled_use_count_.find(buffer);
       CHECK(unscheduled_use_count_it != unscheduled_use_count_.end());
-      entry.used_buffer_unscheduled_use_counts.push_back(
-          &*unscheduled_use_count_it);
+      entry.used_buffer_unscheduled_use_counts.emplace_back(
+          unscheduled_use_count_it->first, unscheduled_use_count_it->second);
     }
     return entry;
   }
@@ -187,8 +185,8 @@ class ListScheduler {
   int64 BytesFreedIfScheduled(const ReadyListEntry& entry) {
     int64 freed_bytes = 0;
     for (const auto& kv : entry.used_buffer_unscheduled_use_counts) {
-      auto buffer = kv->first;
-      auto use_count = kv->second;
+      auto buffer = kv.first;
+      auto use_count = kv.second;
       if (use_count == 1) {
         freed_bytes += size_function_(*buffer);
       }
@@ -206,7 +204,8 @@ class ListScheduler {
 
     // Populate the ready list with instructions which have no operands or
     // control predecessors.
-    std::unordered_map<const HloInstruction*, int64> unscheduled_pred_count;
+    tensorflow::gtl::FlatMap<const HloInstruction*, int64>
+        unscheduled_pred_count;
     for (auto* instruction : computation_.instructions()) {
       // TODO(b/34466113): Replace this and above with successors() or
       // predecessors() when these methods are added to HloInstruction.
@@ -218,33 +217,57 @@ class ListScheduler {
       }
     }
 
-    auto priority_comparator = [this](const ReadyListEntry& lhs,
-                                      const ReadyListEntry& rhs) {
-      return GetPriority(lhs) < GetPriority(rhs);
-    };
-    std::priority_queue<ReadyListEntry, std::vector<ReadyListEntry>,
+    auto priority_comparator =
+        [this](const std::pair<Priority, ReadyListEntry>& lhs,
+               const std::pair<Priority, ReadyListEntry>& rhs) {
+          return lhs.first < rhs.first;
+        };
+    std::priority_queue<std::pair<Priority, ReadyListEntry>,
+                        std::vector<std::pair<Priority, ReadyListEntry>>,
                         decltype(priority_comparator)>
         ready_queue(priority_comparator);
+
+    // Set of instructions in the ready list.
+    tensorflow::gtl::FlatSet<const HloInstruction*> ready_instructions;
+
+    auto add_to_ready_queue = [&](HloInstruction* inst) {
+      auto entry = MakeReadyListEntry(inst);
+      ready_queue.emplace(GetPriority(entry), std::move(entry));
+      ready_instructions.insert(inst);
+    };
+
     for (auto* instruction : computation_.instructions()) {
       // Instruction with no operands or control predecessors will
       // not be in the map.
       if (unscheduled_pred_count.count(instruction) == 0) {
-        ready_queue.emplace(MakeReadyListEntry(instruction));
+        add_to_ready_queue(instruction);
       }
     }
 
     while (!ready_queue.empty()) {
       // Remove the selected instruction from the ready list and add it to the
       // schedule.
-      const HloInstruction* best = ready_queue.top().instruction;
+      const HloInstruction* best = ready_queue.top().second.instruction;
       ready_queue.pop();
+      // We may have duplicates in the priority queue, because when a ready
+      // instruction's priority goes up, we reinsert it to the priority queue.
+      // Skip the duplicate.
+      if (scheduled_instructions_.find(best) != scheduled_instructions_.end()) {
+        continue;
+      }
+      ready_instructions.erase(best);
       schedule.push_back(best);
       scheduled_instructions_.insert(best);
 
+      bool adjust_ready_queue = false;
       // Update the unscheduled uses of the logical buffers.
       for (const LogicalBuffer* buffer : buffer_uses_.at(best)) {
-        CHECK_GT(unscheduled_use_count_.at(buffer), 0);
-        --unscheduled_use_count_[buffer];
+        int64& count = unscheduled_use_count_[buffer];
+        CHECK_GT(count, 0);
+        --count;
+        if (count == 1) {
+          adjust_ready_queue = true;
+        }
       }
 
       // Add new instructions to ready list.
@@ -252,7 +275,7 @@ class ListScheduler {
         int64 pred_count = --unscheduled_pred_count.at(inst);
         CHECK_GE(pred_count, 0);
         if (pred_count == 0) {
-          ready_queue.emplace(MakeReadyListEntry(inst));
+          add_to_ready_queue(inst);
         }
       };
       // TODO(b/34466113): Replace this and above with successors() or
@@ -263,6 +286,20 @@ class ListScheduler {
       for (HloInstruction* succ : best->control_successors()) {
         update_pred_count(succ);
       }
+      // The unscheduled use count for a buffer has changed to 1, so the
+      // priorities of some ready instructions may go up. We reinsert them to
+      // the priority queue, so that they can appear earlier. The old entries
+      // will become duplicates and will be skipped.
+      if (adjust_ready_queue) {
+        for (HloInstruction* operand : best->operands()) {
+          for (HloInstruction* operand_user : operand->users()) {
+            if (ready_instructions.find(operand_user) !=
+                ready_instructions.end()) {
+              add_to_ready_queue(operand_user);
+            }
+          }
+        }
+      }
     }
     CHECK_EQ(schedule.size(), computation_.instruction_count());
     CHECK_EQ(scheduled_instructions_.size(), computation_.instruction_count());
@@ -275,15 +312,16 @@ class ListScheduler {
   const LogicalBuffer::SizeFunction& size_function_;
 
   // A map containing the LogicalBuffers that each instruction uses.
-  std::unordered_map<const HloInstruction*, std::vector<const LogicalBuffer*>>
+  tensorflow::gtl::FlatMap<const HloInstruction*,
+                           std::vector<const LogicalBuffer*>>
       buffer_uses_;
 
   // A map containing the count of unscheduled HLOs which using a particular
   // LogicalBuffer.  We rely on iterator stability in this map.
-  std::unordered_map<const LogicalBuffer*, int64> unscheduled_use_count_;
+  tensorflow::gtl::FlatMap<const LogicalBuffer*, int64> unscheduled_use_count_;
 
   // Set of instructions which have been scheduled.
-  std::unordered_set<const HloInstruction*> scheduled_instructions_;
+  tensorflow::gtl::FlatSet<const HloInstruction*> scheduled_instructions_;
 };
 
 int64 SumLogicalBufferSizes(