From 0bd851b38810540034069d92a2f76a026429bced Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 20 Mar 2018 16:11:23 -0700 Subject: [PATCH] [XLA] Make HLO memory schedulers pluggable. Introduce a typedef MemorySchedulerAlgorithm which is a function instead of an enum to allow experimentation with non-standard schedulers. Refactoring only; no functional changes to the scheduling itself. PiperOrigin-RevId: 189830685 --- .../compiler/xla/service/hlo_rematerialization.cc | 2 +- .../compiler/xla/service/hlo_rematerialization.h | 6 +-- .../xla/service/hlo_rematerialization_test.cc | 16 +++--- tensorflow/compiler/xla/service/hlo_scheduling.cc | 61 +++++++++++++--------- tensorflow/compiler/xla/service/hlo_scheduling.h | 43 +++++++++++---- .../compiler/xla/service/hlo_scheduling_test.cc | 3 +- 6 files changed, 80 insertions(+), 51 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 98b8d34..b063244 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -1320,7 +1320,7 @@ StatusOr HloRematerialization::Run( /* static */ StatusOr HloRematerialization::RematerializeAndSchedule( const HloRematerialization::ShapeSizeFunction& size_function, int64 memory_limit_bytes, HloModule* hlo_module, - SchedulerAlgorithm scheduler_algorithm, + MemorySchedulerAlgorithm scheduler_algorithm, SequentialHloOrdering::HloModuleSequence* sequence, RematerializationSizes* sizes) { HloRematerialization remat(scheduler_algorithm, size_function); diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 5255343..2ee2dd0 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -66,12 +66,12 @@ class HloRematerialization { // code generation. static StatusOr RematerializeAndSchedule( const ShapeSizeFunction& size_function, int64 memory_limit_bytes, - HloModule* hlo_module, SchedulerAlgorithm scheduler_algorithm, + HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm, SequentialHloOrdering::HloModuleSequence* sequence, RematerializationSizes* sizes = nullptr); protected: - HloRematerialization(SchedulerAlgorithm scheduler_algorithm, + HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm, const ShapeSizeFunction& size_function) : scheduler_algorithm_(scheduler_algorithm), size_function_(size_function) {} @@ -108,7 +108,7 @@ class HloRematerialization { const HloInstruction* instruction) const; // Selects an algorithm to use for HLO scheduling. - SchedulerAlgorithm scheduler_algorithm_; + MemorySchedulerAlgorithm scheduler_algorithm_; // Function which computes the size of the top-level buffer of a shape. const ShapeSizeFunction size_function_; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 1b7d26d..83de54f 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -162,7 +162,7 @@ TEST_F(HloRematerializationTest, SingleComputation) { HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/14 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + DefaultMemoryScheduler, &sequence)); EXPECT_TRUE(changed); // Root should not have changed. @@ -195,7 +195,7 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/20 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + DefaultMemoryScheduler, &sequence)); // No instructions should have been materialized. EXPECT_FALSE(changed); @@ -236,7 +236,7 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/17 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + DefaultMemoryScheduler, &sequence)); EXPECT_TRUE(changed); // Only the entry computation should have a rematerialized instruction added. @@ -272,7 +272,7 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/15 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + DefaultMemoryScheduler, &sequence)); EXPECT_TRUE(changed); // Both computations should have a rematerialized instruction added. @@ -314,7 +314,7 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/13 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + DefaultMemoryScheduler, &sequence)); EXPECT_TRUE(changed); // All computations should have a rematerialized instruction added. @@ -385,7 +385,7 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { bool changed, HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), - module.get(), SchedulerAlgorithm::kAuto, &sequence)); + module.get(), DefaultMemoryScheduler, &sequence)); EXPECT_TRUE(changed); // The rng should not have been rematerialized. EXPECT_EQ(count_rngs(entry_computation), 1); @@ -480,7 +480,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/22 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + DefaultMemoryScheduler, &sequence)); EXPECT_TRUE(changed); // The broadcast should have been rematerialized 3 times. @@ -577,7 +577,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/22 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + DefaultMemoryScheduler, &sequence)); // Rematerialization should only occur if the rematerializable instruction has // no indirect uses. if (indirectly_used) { diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 099dd8d..1a76762 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -340,7 +340,33 @@ int64 SumLogicalBufferSizes( return size; } -StatusOr> RunDFSMemoryScheduler( +StatusOr MinimumMemoryForComputation( + const HloComputation& computation, + const std::vector& sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) { + TF_ASSIGN_OR_RETURN( + HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique(), computation, + sequence, points_to_analysis, size_function)); + return result.heap_size; +} + +StatusOr> CreateMemoryMinimizingSequence( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm) { + VLOG(2) << "Computation: " << computation.name(); + if (algorithm) { + return algorithm(computation, points_to_analysis, size_function); + } + return DefaultMemoryScheduler(computation, points_to_analysis, size_function); +} + +} // namespace + +StatusOr> DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function) { @@ -397,32 +423,17 @@ StatusOr> RunDFSMemoryScheduler( return sequence; } -StatusOr MinimumMemoryForComputation( +StatusOr> ListMemoryScheduler( const HloComputation& computation, - const std::vector& sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function) { - TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique(), computation, - sequence, points_to_analysis, size_function)); - return result.heap_size; + return ListScheduler::Run(computation, points_to_analysis, size_function); } -StatusOr> CreateMemoryMinimizingSequence( +StatusOr> DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function, - SchedulerAlgorithm algorithm) { - VLOG(2) << "Computation: " << computation.name(); - if (algorithm == SchedulerAlgorithm::kListSchedule) { - return ListScheduler::Run(computation, points_to_analysis, size_function); - } - if (algorithm == SchedulerAlgorithm::kDfsSchedule) { - return RunDFSMemoryScheduler(computation, points_to_analysis, - size_function); - } - + const LogicalBuffer::SizeFunction& size_function) { // We try both a list-scheduler based ordering and a DFS based ordering, and // choose whichever returns a lower min-memory, not accounting for // fragmentation. @@ -432,7 +443,7 @@ StatusOr> CreateMemoryMinimizingSequence( // within the caller's context. But it's good enough for now. TF_ASSIGN_OR_RETURN( std::vector list_sequence, - ListScheduler::Run(computation, points_to_analysis, size_function)); + ListMemoryScheduler(computation, points_to_analysis, size_function)); TF_ASSIGN_OR_RETURN( const int64 list_memory, MinimumMemoryForComputation(computation, list_sequence, @@ -441,7 +452,7 @@ StatusOr> CreateMemoryMinimizingSequence( TF_ASSIGN_OR_RETURN( std::vector dfs_sequence, - RunDFSMemoryScheduler(computation, points_to_analysis, size_function)); + DFSMemoryScheduler(computation, points_to_analysis, size_function)); TF_ASSIGN_OR_RETURN( const int64 dfs_memory, MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, @@ -459,12 +470,10 @@ StatusOr> CreateMemoryMinimizingSequence( } } -} // namespace - StatusOr CreateMemoryMinimizingSequence(const HloModule& module, const LogicalBuffer::SizeFunction& size_function, - SchedulerAlgorithm algorithm) { + const MemorySchedulerAlgorithm& algorithm) { SequentialHloOrdering::HloModuleSequence sequence; TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(&module)); @@ -480,7 +489,7 @@ CreateMemoryMinimizingSequence(const HloModule& module, StatusOr> CreateMemoryMinimizingSequence( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function, - SchedulerAlgorithm algorithm) { + const MemorySchedulerAlgorithm& algorithm) { CHECK(!computation.IsFusionComputation()); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(computation.parent())); diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h index 1d1eb1e..068e683 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ b/tensorflow/compiler/xla/service/hlo_scheduling.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -33,28 +34,48 @@ StatusOr MinimumMemoryForSequence( const SequentialHloOrdering::HloModuleSequence& module_sequence, const LogicalBuffer::SizeFunction& size_function); -enum class SchedulerAlgorithm { - kListSchedule, - kDfsSchedule, +// A memory scheduler computes an execution sequence for the HLO instructions in +// 'computation' that minimizes peak memory, given a points-to analysis result +// that describes buffer aliasing, together with a target-specific size function +// that maps a tensor's logical size to its padded size. +typedef std::function>( + const HloComputation&, const TuplePointsToAnalysis&, + const LogicalBuffer::SizeFunction&)> + MemorySchedulerAlgorithm; - // Selects the available scheduler algorithm that had the minimum memory in - // the resulting sequence (a la MinimumMemoryForSequence). - kAuto, -}; +// List scheduler +StatusOr> ListMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function); + +// DFS-order scheduler +StatusOr> DFSMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function); + +// The default scheduling algorithm. Runs both the list scheduler +// and the DFS scheduler, and chooses whichever returns a lower min-memory, +// not accounting for fragmentation. +StatusOr> DefaultMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function); // Returns an HloModuleSequence which seeks to minimize the memory required for // the computation. size_function is the function returning the number of bytes // required for a LogicalBuffer. StatusOr -CreateMemoryMinimizingSequence( - const HloModule& module, const LogicalBuffer::SizeFunction& size_function, - SchedulerAlgorithm algorithm = SchedulerAlgorithm::kAuto); +CreateMemoryMinimizingSequence(const HloModule& module, + const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm = {}); // Overload of above that computes the sequence for a single computation. StatusOr> CreateMemoryMinimizingSequence( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function, - SchedulerAlgorithm algorithm = SchedulerAlgorithm::kAuto); + const MemorySchedulerAlgorithm& algorithm = {}); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index 2dd6e43..74544c4 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -165,8 +165,7 @@ ENTRY root { }; TF_ASSERT_OK_AND_ASSIGN( SequentialHloOrdering::HloModuleSequence sequence, - CreateMemoryMinimizingSequence(*module, size_fn, - SchedulerAlgorithm::kListSchedule)); + CreateMemoryMinimizingSequence(*module, size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.at(module->entry_computation()).size()); -- 2.7.4