/* static */ StatusOr<bool> HloRematerialization::RematerializeAndSchedule(
const HloRematerialization::ShapeSizeFunction& size_function,
int64 memory_limit_bytes, HloModule* hlo_module,
- SchedulerAlgorithm scheduler_algorithm,
+ MemorySchedulerAlgorithm scheduler_algorithm,
SequentialHloOrdering::HloModuleSequence* sequence,
RematerializationSizes* sizes) {
HloRematerialization remat(scheduler_algorithm, size_function);
// code generation.
static StatusOr<bool> RematerializeAndSchedule(
const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
- HloModule* hlo_module, SchedulerAlgorithm scheduler_algorithm,
+ HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm,
SequentialHloOrdering::HloModuleSequence* sequence,
RematerializationSizes* sizes = nullptr);
protected:
- HloRematerialization(SchedulerAlgorithm scheduler_algorithm,
+ HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm,
const ShapeSizeFunction& size_function)
: scheduler_algorithm_(scheduler_algorithm),
size_function_(size_function) {}
const HloInstruction* instruction) const;
// Selects an algorithm to use for HLO scheduling.
- SchedulerAlgorithm scheduler_algorithm_;
+ MemorySchedulerAlgorithm scheduler_algorithm_;
// Function which computes the size of the top-level buffer of a shape.
const ShapeSizeFunction size_function_;
HloRematerialization::RematerializeAndSchedule(
ByteSizeOf,
/*memory_limit_bytes=*/14 * 1024, module.get(),
- SchedulerAlgorithm::kAuto, &sequence));
+ DefaultMemoryScheduler, &sequence));
EXPECT_TRUE(changed);
// Root should not have changed.
HloRematerialization::RematerializeAndSchedule(
ByteSizeOf,
/*memory_limit_bytes=*/20 * 1024, module.get(),
- SchedulerAlgorithm::kAuto, &sequence));
+ DefaultMemoryScheduler, &sequence));
// No instructions should have been materialized.
EXPECT_FALSE(changed);
HloRematerialization::RematerializeAndSchedule(
ByteSizeOf,
/*memory_limit_bytes=*/17 * 1024, module.get(),
- SchedulerAlgorithm::kAuto, &sequence));
+ DefaultMemoryScheduler, &sequence));
EXPECT_TRUE(changed);
// Only the entry computation should have a rematerialized instruction added.
HloRematerialization::RematerializeAndSchedule(
ByteSizeOf,
/*memory_limit_bytes=*/15 * 1024, module.get(),
- SchedulerAlgorithm::kAuto, &sequence));
+ DefaultMemoryScheduler, &sequence));
EXPECT_TRUE(changed);
// Both computations should have a rematerialized instruction added.
HloRematerialization::RematerializeAndSchedule(
ByteSizeOf,
/*memory_limit_bytes=*/13 * 1024, module.get(),
- SchedulerAlgorithm::kAuto, &sequence));
+ DefaultMemoryScheduler, &sequence));
EXPECT_TRUE(changed);
// All computations should have a rematerialized instruction added.
bool changed, HloRematerialization::RematerializeAndSchedule(
ByteSizeOf,
/*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_),
- module.get(), SchedulerAlgorithm::kAuto, &sequence));
+ module.get(), DefaultMemoryScheduler, &sequence));
EXPECT_TRUE(changed);
// The rng should not have been rematerialized.
EXPECT_EQ(count_rngs(entry_computation), 1);
HloRematerialization::RematerializeAndSchedule(
ByteSizeOf,
/*memory_limit_bytes=*/22 * 1024, module.get(),
- SchedulerAlgorithm::kAuto, &sequence));
+ DefaultMemoryScheduler, &sequence));
EXPECT_TRUE(changed);
// The broadcast should have been rematerialized 3 times.
HloRematerialization::RematerializeAndSchedule(
ByteSizeOf,
/*memory_limit_bytes=*/22 * 1024, module.get(),
- SchedulerAlgorithm::kAuto, &sequence));
+ DefaultMemoryScheduler, &sequence));
// Rematerialization should only occur if the rematerializable instruction has
// no indirect uses.
if (indirectly_used) {
return size;
}
-StatusOr<std::vector<const HloInstruction*>> RunDFSMemoryScheduler(
+StatusOr<int64> MinimumMemoryForComputation(
+ const HloComputation& computation,
+ const std::vector<const HloInstruction*>& sequence,
+ const TuplePointsToAnalysis& points_to_analysis,
+ const LogicalBuffer::SizeFunction& size_function) {
+ TF_ASSIGN_OR_RETURN(
+ HeapSimulator::Result result,
+ HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), computation,
+ sequence, points_to_analysis, size_function));
+ return result.heap_size;
+}
+
+StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
+ const HloComputation& computation,
+ const TuplePointsToAnalysis& points_to_analysis,
+ const LogicalBuffer::SizeFunction& size_function,
+ const MemorySchedulerAlgorithm& algorithm) {
+ VLOG(2) << "Computation: " << computation.name();
+ if (algorithm) {
+ return algorithm(computation, points_to_analysis, size_function);
+ }
+ return DefaultMemoryScheduler(computation, points_to_analysis, size_function);
+}
+
+} // namespace
+
+StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function) {
return sequence;
}
-StatusOr<int64> MinimumMemoryForComputation(
+StatusOr<std::vector<const HloInstruction*>> ListMemoryScheduler(
const HloComputation& computation,
- const std::vector<const HloInstruction*>& sequence,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function) {
- TF_ASSIGN_OR_RETURN(
- HeapSimulator::Result result,
- HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), computation,
- sequence, points_to_analysis, size_function));
- return result.heap_size;
+ return ListScheduler::Run(computation, points_to_analysis, size_function);
}
-StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
+StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
- const LogicalBuffer::SizeFunction& size_function,
- SchedulerAlgorithm algorithm) {
- VLOG(2) << "Computation: " << computation.name();
- if (algorithm == SchedulerAlgorithm::kListSchedule) {
- return ListScheduler::Run(computation, points_to_analysis, size_function);
- }
- if (algorithm == SchedulerAlgorithm::kDfsSchedule) {
- return RunDFSMemoryScheduler(computation, points_to_analysis,
- size_function);
- }
-
+ const LogicalBuffer::SizeFunction& size_function) {
// We try both a list-scheduler based ordering and a DFS based ordering, and
// choose whichever returns a lower min-memory, not accounting for
// fragmentation.
// within the caller's context. But it's good enough for now.
TF_ASSIGN_OR_RETURN(
std::vector<const HloInstruction*> list_sequence,
- ListScheduler::Run(computation, points_to_analysis, size_function));
+ ListMemoryScheduler(computation, points_to_analysis, size_function));
TF_ASSIGN_OR_RETURN(
const int64 list_memory,
MinimumMemoryForComputation(computation, list_sequence,
TF_ASSIGN_OR_RETURN(
std::vector<const HloInstruction*> dfs_sequence,
- RunDFSMemoryScheduler(computation, points_to_analysis, size_function));
+ DFSMemoryScheduler(computation, points_to_analysis, size_function));
TF_ASSIGN_OR_RETURN(
const int64 dfs_memory,
MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis,
}
}
-} // namespace
-
StatusOr<SequentialHloOrdering::HloModuleSequence>
CreateMemoryMinimizingSequence(const HloModule& module,
const LogicalBuffer::SizeFunction& size_function,
- SchedulerAlgorithm algorithm) {
+ const MemorySchedulerAlgorithm& algorithm) {
SequentialHloOrdering::HloModuleSequence sequence;
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
TuplePointsToAnalysis::Run(&module));
StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
const HloComputation& computation,
const LogicalBuffer::SizeFunction& size_function,
- SchedulerAlgorithm algorithm) {
+ const MemorySchedulerAlgorithm& algorithm) {
CHECK(!computation.IsFusionComputation());
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
TuplePointsToAnalysis::Run(computation.parent()));
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
const SequentialHloOrdering::HloModuleSequence& module_sequence,
const LogicalBuffer::SizeFunction& size_function);
-enum class SchedulerAlgorithm {
- kListSchedule,
- kDfsSchedule,
+// A memory scheduler computes an execution sequence for the HLO instructions in
+// 'computation' that minimizes peak memory, given a points-to analysis result
+// that describes buffer aliasing, together with a target-specific size function
+// that maps a tensor's logical size to its padded size.
+typedef std::function<StatusOr<std::vector<const HloInstruction*>>(
+ const HloComputation&, const TuplePointsToAnalysis&,
+ const LogicalBuffer::SizeFunction&)>
+ MemorySchedulerAlgorithm;
- // Selects the available scheduler algorithm that had the minimum memory in
- // the resulting sequence (a la MinimumMemoryForSequence).
- kAuto,
-};
+// List scheduler
+StatusOr<std::vector<const HloInstruction*>> ListMemoryScheduler(
+ const HloComputation& computation,
+ const TuplePointsToAnalysis& points_to_analysis,
+ const LogicalBuffer::SizeFunction& size_function);
+
+// DFS-order scheduler
+StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
+ const HloComputation& computation,
+ const TuplePointsToAnalysis& points_to_analysis,
+ const LogicalBuffer::SizeFunction& size_function);
+
+// The default scheduling algorithm. Runs both the list scheduler
+// and the DFS scheduler, and chooses whichever returns a lower min-memory,
+// not accounting for fragmentation.
+StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
+ const HloComputation& computation,
+ const TuplePointsToAnalysis& points_to_analysis,
+ const LogicalBuffer::SizeFunction& size_function);
// Returns an HloModuleSequence which seeks to minimize the memory required for
// the computation. size_function is the function returning the number of bytes
// required for a LogicalBuffer.
StatusOr<SequentialHloOrdering::HloModuleSequence>
-CreateMemoryMinimizingSequence(
- const HloModule& module, const LogicalBuffer::SizeFunction& size_function,
- SchedulerAlgorithm algorithm = SchedulerAlgorithm::kAuto);
+CreateMemoryMinimizingSequence(const HloModule& module,
+ const LogicalBuffer::SizeFunction& size_function,
+ const MemorySchedulerAlgorithm& algorithm = {});
// Overload of above that computes the sequence for a single computation.
StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
const HloComputation& computation,
const LogicalBuffer::SizeFunction& size_function,
- SchedulerAlgorithm algorithm = SchedulerAlgorithm::kAuto);
+ const MemorySchedulerAlgorithm& algorithm = {});
} // namespace xla
};
TF_ASSERT_OK_AND_ASSIGN(
SequentialHloOrdering::HloModuleSequence sequence,
- CreateMemoryMinimizingSequence(*module, size_fn,
- SchedulerAlgorithm::kListSchedule));
+ CreateMemoryMinimizingSequence(*module, size_fn, ListMemoryScheduler));
// Verify that all instructions are in the sequence.
EXPECT_EQ(module->entry_computation()->instruction_count(),
sequence.at(module->entry_computation()).size());