return result.heap_size;
}
-StatusOr<std::vector<const HloInstruction*>> DFSMemorySchedulerImpl(
+StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
- const LogicalBuffer::SizeFunction& size_function, bool reverse_heuristics) {
+ const LogicalBuffer::SizeFunction& size_function,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ memory_by_computation) {
// This ordering is based on DFS post-order, with a heuristic to decide which
// operand to visit first. The heuristic is based on 'extra_users', which is
// simply users-1 for each instruction. By subtracting 1, we're saying that
return Status::OK();
});
TF_RETURN_IF_ERROR(computation.AcceptWithOperandOrder(
- &visitor, [&extra_users, &total_sizes, reverse_heuristics](
- const HloInstruction* a, const HloInstruction* b) {
- auto lhs = std::tuple<int64, int64, string>(extra_users[a],
- total_sizes[a], b->name());
- auto rhs = std::tuple<int64, int64, string>(extra_users[b],
- total_sizes[b], a->name());
-
- // Reverse heuristics. This helps some cases as a different starting
- // point of gradient descent, see b/78906799 for more context.
- return reverse_heuristics ? rhs > lhs : lhs > rhs;
+ &visitor, [&extra_users, &total_sizes](const HloInstruction* a,
+ const HloInstruction* b) {
+ if (extra_users[a] != extra_users[b]) {
+ return extra_users[a] > extra_users[b];
+ }
+ if (total_sizes[a] != total_sizes[b]) {
+ return total_sizes[a] > total_sizes[b];
+ }
+ return a->name() < b->name();
}));
CHECK_EQ(sequence.size(), computation.instruction_count());
return sequence;
post_order.end()};
}
-StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
- const HloComputation& computation,
- const TuplePointsToAnalysis& points_to_analysis,
- const LogicalBuffer::SizeFunction& size_function,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
- memory_by_computation) {
- return DFSMemorySchedulerImpl(computation, points_to_analysis, size_function,
- /*reverse_heuristics=*/false);
-}
-
-StatusOr<std::vector<const HloInstruction*>> DFSMemorySchedulerReverse(
- const HloComputation& computation,
- const TuplePointsToAnalysis& points_to_analysis,
- const LogicalBuffer::SizeFunction& size_function,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
- memory_by_computation) {
- return DFSMemorySchedulerImpl(computation, points_to_analysis, size_function,
- /*reverse_heuristics=*/true);
-}
-
StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
VLOG(2) << "Min-memory post order sequence: "
<< HumanReadableNumBytes(post_order_memory);
- TF_ASSIGN_OR_RETURN(
- std::vector<const HloInstruction*> reverse_dfs,
- DFSMemorySchedulerReverse(computation, points_to_analysis, size_function,
- memory_by_computation));
- TF_ASSIGN_OR_RETURN(
- const int64 reverse_dfs_memory,
- MinimumMemoryForComputation(computation, reverse_dfs, points_to_analysis,
- size_function));
- VLOG(2) << "Min-memory reverse_dfs sequence: "
- << HumanReadableNumBytes(reverse_dfs_memory);
- auto min_memory = std::min(
- {dfs_memory, post_order_memory, reverse_dfs_memory, list_memory});
+ auto min_memory = std::min({dfs_memory, post_order_memory, list_memory});
if (min_memory == list_memory) {
VLOG(2) << "Chose min-memory list sequence: "
VLOG(2) << "Chose min-memory dfs sequence: "
<< HumanReadableNumBytes(dfs_memory);
return dfs_sequence;
- } else if (min_memory == reverse_dfs_memory) {
- VLOG(2) << "Chose min-memory reverse_dfs memory: "
- << HumanReadableNumBytes(reverse_dfs_memory);
- return reverse_dfs;
} else {
VLOG(2) << "Chose min-memory post_order sequence: "
<< HumanReadableNumBytes(post_order_memory);