// LogicalBuffer is in an operand of the instruction as indicated by
// points-to analysis.
for (auto* instruction : computation.instructions()) {
- std::unordered_set<const LogicalBuffer*> instr_uses;
+ tensorflow::gtl::FlatSet<const LogicalBuffer*> instr_uses;
for (auto* operand : instruction->operands()) {
for (const LogicalBuffer* buffer :
points_to_analysis.GetBuffersDefinedByInstruction(operand)) {
int64 bytes_defined;
// For each buffer B used by this instruction, we keep a pair (B, U), where
- // U is the number of uses of B that have not yet been scheduled. This pair
- // is a pointer into the unscheduled_use_count_ map, so it gets updated for
- // free when we update counts in the map.
- std::vector<const std::pair<const LogicalBuffer* const, int64>*>
+ // U is the number of uses of B that have not yet been scheduled.
+ std::vector<std::pair<const LogicalBuffer* const, int64>>
used_buffer_unscheduled_use_counts;
};
}
auto unscheduled_use_count_it = unscheduled_use_count_.find(buffer);
CHECK(unscheduled_use_count_it != unscheduled_use_count_.end());
- entry.used_buffer_unscheduled_use_counts.push_back(
- &*unscheduled_use_count_it);
+ entry.used_buffer_unscheduled_use_counts.emplace_back(
+ unscheduled_use_count_it->first, unscheduled_use_count_it->second);
}
return entry;
}
int64 BytesFreedIfScheduled(const ReadyListEntry& entry) {
int64 freed_bytes = 0;
for (const auto& kv : entry.used_buffer_unscheduled_use_counts) {
- auto buffer = kv->first;
- auto use_count = kv->second;
+ auto buffer = kv.first;
+ auto use_count = kv.second;
if (use_count == 1) {
freed_bytes += size_function_(*buffer);
}
// Populate the ready list with instructions which have no operands or
// control predecessors.
- std::unordered_map<const HloInstruction*, int64> unscheduled_pred_count;
+ tensorflow::gtl::FlatMap<const HloInstruction*, int64>
+ unscheduled_pred_count;
for (auto* instruction : computation_.instructions()) {
// TODO(b/34466113): Replace this and above with successors() or
// predecessors() when these methods are added to HloInstruction.
}
}
- auto priority_comparator = [this](const ReadyListEntry& lhs,
- const ReadyListEntry& rhs) {
- return GetPriority(lhs) < GetPriority(rhs);
- };
- std::priority_queue<ReadyListEntry, std::vector<ReadyListEntry>,
+ auto priority_comparator =
+ [this](const std::pair<Priority, ReadyListEntry>& lhs,
+ const std::pair<Priority, ReadyListEntry>& rhs) {
+ return lhs.first < rhs.first;
+ };
+ std::priority_queue<std::pair<Priority, ReadyListEntry>,
+ std::vector<std::pair<Priority, ReadyListEntry>>,
decltype(priority_comparator)>
ready_queue(priority_comparator);
+
+ // Set of instructions in the ready list.
+ tensorflow::gtl::FlatSet<const HloInstruction*> ready_instructions;
+
+ auto add_to_ready_queue = [&](HloInstruction* inst) {
+ auto entry = MakeReadyListEntry(inst);
+ ready_queue.emplace(GetPriority(entry), std::move(entry));
+ ready_instructions.insert(inst);
+ };
+
for (auto* instruction : computation_.instructions()) {
// Instruction with no operands or control predecessors will
// not be in the map.
if (unscheduled_pred_count.count(instruction) == 0) {
- ready_queue.emplace(MakeReadyListEntry(instruction));
+ add_to_ready_queue(instruction);
}
}
while (!ready_queue.empty()) {
// Remove the selected instruction from the ready list and add it to the
// schedule.
- const HloInstruction* best = ready_queue.top().instruction;
+ const HloInstruction* best = ready_queue.top().second.instruction;
ready_queue.pop();
+ // We may have duplicates in the priority queue, because when a ready
+ // instruction's priority goes up, we reinsert it to the priority queue.
+ // Skip the duplicate.
+ if (scheduled_instructions_.find(best) != scheduled_instructions_.end()) {
+ continue;
+ }
+ ready_instructions.erase(best);
schedule.push_back(best);
scheduled_instructions_.insert(best);
+ bool adjust_ready_queue = false;
// Update the unscheduled uses of the logical buffers.
for (const LogicalBuffer* buffer : buffer_uses_.at(best)) {
- CHECK_GT(unscheduled_use_count_.at(buffer), 0);
- --unscheduled_use_count_[buffer];
+ int64& count = unscheduled_use_count_[buffer];
+ CHECK_GT(count, 0);
+ --count;
+ if (count == 1) {
+ adjust_ready_queue = true;
+ }
}
// Add new instructions to ready list.
int64 pred_count = --unscheduled_pred_count.at(inst);
CHECK_GE(pred_count, 0);
if (pred_count == 0) {
- ready_queue.emplace(MakeReadyListEntry(inst));
+ add_to_ready_queue(inst);
}
};
// TODO(b/34466113): Replace this and above with successors() or
for (HloInstruction* succ : best->control_successors()) {
update_pred_count(succ);
}
+ // The unscheduled use count for a buffer has changed to 1, so the
+ // priorities of some ready instructions may go up. We reinsert them to
+ // the priority queue, so that they can appear earlier. The old entries
+ // will become duplicates and will be skipped.
+ if (adjust_ready_queue) {
+ for (HloInstruction* operand : best->operands()) {
+ for (HloInstruction* operand_user : operand->users()) {
+ if (ready_instructions.find(operand_user) !=
+ ready_instructions.end()) {
+ add_to_ready_queue(operand_user);
+ }
+ }
+ }
+ }
}
CHECK_EQ(schedule.size(), computation_.instruction_count());
CHECK_EQ(scheduled_instructions_.size(), computation_.instruction_count());
const LogicalBuffer::SizeFunction& size_function_;
// A map containing the LogicalBuffers that each instruction uses.
- std::unordered_map<const HloInstruction*, std::vector<const LogicalBuffer*>>
+ tensorflow::gtl::FlatMap<const HloInstruction*,
+ std::vector<const LogicalBuffer*>>
buffer_uses_;
// A map containing the count of unscheduled HLOs which using a particular
// LogicalBuffer. We rely on iterator stability in this map.
- std::unordered_map<const LogicalBuffer*, int64> unscheduled_use_count_;
+ tensorflow::gtl::FlatMap<const LogicalBuffer*, int64> unscheduled_use_count_;
// Set of instructions which have been scheduled.
- std::unordered_set<const HloInstruction*> scheduled_instructions_;
+ tensorflow::gtl::FlatSet<const HloInstruction*> scheduled_instructions_;
};
int64 SumLogicalBufferSizes(