From 4ed183d9d471ca04cf3961610a027136298c1788 Mon Sep 17 00:00:00 2001 From: Yuanzhong Xu Date: Thu, 15 Feb 2018 12:08:36 -0800 Subject: [PATCH] [XLA] Fix priority queue in HLO scheduling. The priority of an HLO can change during the scheduling. Use immutable values in priority queue entries, and reinsert an entry if its priority goes up. PiperOrigin-RevId: 185878562 --- .../compiler/xla/service/hlo_scheduling.cc | 84 ++++++++++++++----- 1 file changed, 61 insertions(+), 23 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 5f5a930dad..8dc4d4f7ba 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -101,7 +101,7 @@ class ListScheduler { // LogicalBuffer is in an operand of the instruction as indicated by // points-to analysis. for (auto* instruction : computation.instructions()) { - std::unordered_set instr_uses; + tensorflow::gtl::FlatSet instr_uses; for (auto* operand : instruction->operands()) { for (const LogicalBuffer* buffer : points_to_analysis.GetBuffersDefinedByInstruction(operand)) { @@ -151,10 +151,8 @@ class ListScheduler { int64 bytes_defined; // For each buffer B used by this instruction, we keep a pair (B, U), where - // U is the number of uses of B that have not yet been scheduled. This pair - // is a pointer into the unscheduled_use_count_ map, so it gets updated for - // free when we update counts in the map. - std::vector*> + // U is the number of uses of B that have not yet been scheduled. + std::vector> used_buffer_unscheduled_use_counts; }; @@ -177,8 +175,8 @@ class ListScheduler { } auto unscheduled_use_count_it = unscheduled_use_count_.find(buffer); CHECK(unscheduled_use_count_it != unscheduled_use_count_.end()); - entry.used_buffer_unscheduled_use_counts.push_back( - &*unscheduled_use_count_it); + entry.used_buffer_unscheduled_use_counts.emplace_back( + unscheduled_use_count_it->first, unscheduled_use_count_it->second); } return entry; } @@ -187,8 +185,8 @@ class ListScheduler { int64 BytesFreedIfScheduled(const ReadyListEntry& entry) { int64 freed_bytes = 0; for (const auto& kv : entry.used_buffer_unscheduled_use_counts) { - auto buffer = kv->first; - auto use_count = kv->second; + auto buffer = kv.first; + auto use_count = kv.second; if (use_count == 1) { freed_bytes += size_function_(*buffer); } @@ -206,7 +204,8 @@ class ListScheduler { // Populate the ready list with instructions which have no operands or // control predecessors. - std::unordered_map unscheduled_pred_count; + tensorflow::gtl::FlatMap + unscheduled_pred_count; for (auto* instruction : computation_.instructions()) { // TODO(b/34466113): Replace this and above with successors() or // predecessors() when these methods are added to HloInstruction. @@ -218,33 +217,57 @@ class ListScheduler { } } - auto priority_comparator = [this](const ReadyListEntry& lhs, - const ReadyListEntry& rhs) { - return GetPriority(lhs) < GetPriority(rhs); - }; - std::priority_queue, + auto priority_comparator = + [this](const std::pair& lhs, + const std::pair& rhs) { + return lhs.first < rhs.first; + }; + std::priority_queue, + std::vector>, decltype(priority_comparator)> ready_queue(priority_comparator); + + // Set of instructions in the ready list. + tensorflow::gtl::FlatSet ready_instructions; + + auto add_to_ready_queue = [&](HloInstruction* inst) { + auto entry = MakeReadyListEntry(inst); + ready_queue.emplace(GetPriority(entry), std::move(entry)); + ready_instructions.insert(inst); + }; + for (auto* instruction : computation_.instructions()) { // Instruction with no operands or control predecessors will // not be in the map. if (unscheduled_pred_count.count(instruction) == 0) { - ready_queue.emplace(MakeReadyListEntry(instruction)); + add_to_ready_queue(instruction); } } while (!ready_queue.empty()) { // Remove the selected instruction from the ready list and add it to the // schedule. - const HloInstruction* best = ready_queue.top().instruction; + const HloInstruction* best = ready_queue.top().second.instruction; ready_queue.pop(); + // We may have duplicates in the priority queue, because when a ready + // instruction's priority goes up, we reinsert it to the priority queue. + // Skip the duplicate. + if (scheduled_instructions_.find(best) != scheduled_instructions_.end()) { + continue; + } + ready_instructions.erase(best); schedule.push_back(best); scheduled_instructions_.insert(best); + bool adjust_ready_queue = false; // Update the unscheduled uses of the logical buffers. for (const LogicalBuffer* buffer : buffer_uses_.at(best)) { - CHECK_GT(unscheduled_use_count_.at(buffer), 0); - --unscheduled_use_count_[buffer]; + int64& count = unscheduled_use_count_[buffer]; + CHECK_GT(count, 0); + --count; + if (count == 1) { + adjust_ready_queue = true; + } } // Add new instructions to ready list. @@ -252,7 +275,7 @@ class ListScheduler { int64 pred_count = --unscheduled_pred_count.at(inst); CHECK_GE(pred_count, 0); if (pred_count == 0) { - ready_queue.emplace(MakeReadyListEntry(inst)); + add_to_ready_queue(inst); } }; // TODO(b/34466113): Replace this and above with successors() or @@ -263,6 +286,20 @@ class ListScheduler { for (HloInstruction* succ : best->control_successors()) { update_pred_count(succ); } + // The unscheduled use count for a buffer has changed to 1, so the + // priorities of some ready instructions may go up. We reinsert them to + // the priority queue, so that they can appear earlier. The old entries + // will become duplicates and will be skipped. + if (adjust_ready_queue) { + for (HloInstruction* operand : best->operands()) { + for (HloInstruction* operand_user : operand->users()) { + if (ready_instructions.find(operand_user) != + ready_instructions.end()) { + add_to_ready_queue(operand_user); + } + } + } + } } CHECK_EQ(schedule.size(), computation_.instruction_count()); CHECK_EQ(scheduled_instructions_.size(), computation_.instruction_count()); @@ -275,15 +312,16 @@ class ListScheduler { const LogicalBuffer::SizeFunction& size_function_; // A map containing the LogicalBuffers that each instruction uses. - std::unordered_map> + tensorflow::gtl::FlatMap> buffer_uses_; // A map containing the count of unscheduled HLOs which using a particular // LogicalBuffer. We rely on iterator stability in this map. - std::unordered_map unscheduled_use_count_; + tensorflow::gtl::FlatMap unscheduled_use_count_; // Set of instructions which have been scheduled. - std::unordered_set scheduled_instructions_; + tensorflow::gtl::FlatSet scheduled_instructions_; }; int64 SumLogicalBufferSizes( -- 2.34.1