// simply users-1 for each instruction. By subtracting 1, we're saying that
// instructions with no users or a single user don't count; instructions with
// lots of fan-out will be visited earlier.
- int64 cumulative_total_size = 0;
tensorflow::gtl::FlatMap<const HloInstruction*, int64> extra_users;
tensorflow::gtl::FlatMap<const HloInstruction*, int64> total_sizes;
for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) {
int64 logical_buffer_size = SumLogicalBufferSizes(
points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function);
total_sizes[hlo] = logical_buffer_size;
- cumulative_total_size += logical_buffer_size;
tensorflow::gtl::FlatSet<const HloInstruction*> unique_operands(
hlo->operands().begin(), hlo->operands().end());
for (const HloInstruction* operand : unique_operands) {
extra_users[hlo] += extra_users[operand];
total_sizes[hlo] += total_sizes[operand];
}
- total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size);
}
CHECK_EQ(extra_users.size(), computation.instruction_count());
CHECK_EQ(total_sizes.size(), computation.instruction_count());
const LogicalBuffer::SizeFunction& size_function,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
memory_by_computation) {
- // We try both a list-scheduler based ordering and a DFS based ordering, and
- // choose whichever returns a lower min-memory, not accounting for
- // fragmentation.
- //
- // Note that this is just a heuristic. One obvious inaccuracy is that the
- // memory required for sub-computations might be different when considered
- // within the caller's context. But it's good enough for now.
+ // We try a few schedulers and choose whichever returns a lower min-memory,
+ // not accounting for fragmentation.
+ // - List is a scheduler that uses greedy heuristics.
+ // - DFS visits HLOs in postorder, with a heuristic to decide the order of
+ // children.
+ // - Postorder does not use any heuristics.
+ // List wins for most of our benchmarks; postorder-based schedulers win for
+ // some RNNs.
TF_ASSIGN_OR_RETURN(
std::vector<const HloInstruction*> list_sequence,
ListMemoryScheduler(computation, points_to_analysis, size_function,