[XLA] Make HLO memory schedulers pluggable. Introduce a typedef MemorySchedulerAlgori...
authorPeter Hawkins <phawkins@google.com>
Tue, 20 Mar 2018 23:11:23 +0000 (16:11 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 20 Mar 2018 23:14:19 +0000 (16:14 -0700)
PiperOrigin-RevId: 189830685

tensorflow/compiler/xla/service/hlo_rematerialization.cc
tensorflow/compiler/xla/service/hlo_rematerialization.h
tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
tensorflow/compiler/xla/service/hlo_scheduling.cc
tensorflow/compiler/xla/service/hlo_scheduling.h
tensorflow/compiler/xla/service/hlo_scheduling_test.cc

index 98b8d34..b063244 100644 (file)
@@ -1320,7 +1320,7 @@ StatusOr<bool> HloRematerialization::Run(
 /* static */ StatusOr<bool> HloRematerialization::RematerializeAndSchedule(
     const HloRematerialization::ShapeSizeFunction& size_function,
     int64 memory_limit_bytes, HloModule* hlo_module,
-    SchedulerAlgorithm scheduler_algorithm,
+    MemorySchedulerAlgorithm scheduler_algorithm,
     SequentialHloOrdering::HloModuleSequence* sequence,
     RematerializationSizes* sizes) {
   HloRematerialization remat(scheduler_algorithm, size_function);
index 5255343..2ee2dd0 100644 (file)
@@ -66,12 +66,12 @@ class HloRematerialization {
   // code generation.
   static StatusOr<bool> RematerializeAndSchedule(
       const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
-      HloModule* hlo_module, SchedulerAlgorithm scheduler_algorithm,
+      HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm,
       SequentialHloOrdering::HloModuleSequence* sequence,
       RematerializationSizes* sizes = nullptr);
 
  protected:
-  HloRematerialization(SchedulerAlgorithm scheduler_algorithm,
+  HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm,
                        const ShapeSizeFunction& size_function)
       : scheduler_algorithm_(scheduler_algorithm),
         size_function_(size_function) {}
@@ -108,7 +108,7 @@ class HloRematerialization {
       const HloInstruction* instruction) const;
 
   // Selects an algorithm to use for HLO scheduling.
-  SchedulerAlgorithm scheduler_algorithm_;
+  MemorySchedulerAlgorithm scheduler_algorithm_;
 
   // Function which computes the size of the top-level buffer of a shape.
   const ShapeSizeFunction size_function_;
index 1b7d26d..83de54f 100644 (file)
@@ -162,7 +162,7 @@ TEST_F(HloRematerializationTest, SingleComputation) {
                           HloRematerialization::RematerializeAndSchedule(
                               ByteSizeOf,
                               /*memory_limit_bytes=*/14 * 1024, module.get(),
-                              SchedulerAlgorithm::kAuto, &sequence));
+                              DefaultMemoryScheduler, &sequence));
   EXPECT_TRUE(changed);
 
   // Root should not have changed.
@@ -195,7 +195,7 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) {
                           HloRematerialization::RematerializeAndSchedule(
                               ByteSizeOf,
                               /*memory_limit_bytes=*/20 * 1024, module.get(),
-                              SchedulerAlgorithm::kAuto, &sequence));
+                              DefaultMemoryScheduler, &sequence));
 
   // No instructions should have been materialized.
   EXPECT_FALSE(changed);
@@ -236,7 +236,7 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) {
                           HloRematerialization::RematerializeAndSchedule(
                               ByteSizeOf,
                               /*memory_limit_bytes=*/17 * 1024, module.get(),
-                              SchedulerAlgorithm::kAuto, &sequence));
+                              DefaultMemoryScheduler, &sequence));
   EXPECT_TRUE(changed);
 
   // Only the entry computation should have a rematerialized instruction added.
@@ -272,7 +272,7 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) {
                           HloRematerialization::RematerializeAndSchedule(
                               ByteSizeOf,
                               /*memory_limit_bytes=*/15 * 1024, module.get(),
-                              SchedulerAlgorithm::kAuto, &sequence));
+                              DefaultMemoryScheduler, &sequence));
   EXPECT_TRUE(changed);
 
   // Both computations should have a rematerialized instruction added.
@@ -314,7 +314,7 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) {
                           HloRematerialization::RematerializeAndSchedule(
                               ByteSizeOf,
                               /*memory_limit_bytes=*/13 * 1024, module.get(),
-                              SchedulerAlgorithm::kAuto, &sequence));
+                              DefaultMemoryScheduler, &sequence));
   EXPECT_TRUE(changed);
 
   // All computations should have a rematerialized instruction added.
@@ -385,7 +385,7 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) {
       bool changed, HloRematerialization::RematerializeAndSchedule(
                         ByteSizeOf,
                         /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_),
-                        module.get(), SchedulerAlgorithm::kAuto, &sequence));
+                        module.get(), DefaultMemoryScheduler, &sequence));
   EXPECT_TRUE(changed);
   // The rng should not have been rematerialized.
   EXPECT_EQ(count_rngs(entry_computation), 1);
@@ -480,7 +480,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) {
                           HloRematerialization::RematerializeAndSchedule(
                               ByteSizeOf,
                               /*memory_limit_bytes=*/22 * 1024, module.get(),
-                              SchedulerAlgorithm::kAuto, &sequence));
+                              DefaultMemoryScheduler, &sequence));
   EXPECT_TRUE(changed);
 
   // The broadcast should have been rematerialized 3 times.
@@ -577,7 +577,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
                           HloRematerialization::RematerializeAndSchedule(
                               ByteSizeOf,
                               /*memory_limit_bytes=*/22 * 1024, module.get(),
-                              SchedulerAlgorithm::kAuto, &sequence));
+                              DefaultMemoryScheduler, &sequence));
   // Rematerialization should only occur if the rematerializable instruction has
   // no indirect uses.
   if (indirectly_used) {
index 099dd8d..1a76762 100644 (file)
@@ -340,7 +340,33 @@ int64 SumLogicalBufferSizes(
   return size;
 }
 
-StatusOr<std::vector<const HloInstruction*>> RunDFSMemoryScheduler(
+StatusOr<int64> MinimumMemoryForComputation(
+    const HloComputation& computation,
+    const std::vector<const HloInstruction*>& sequence,
+    const TuplePointsToAnalysis& points_to_analysis,
+    const LogicalBuffer::SizeFunction& size_function) {
+  TF_ASSIGN_OR_RETURN(
+      HeapSimulator::Result result,
+      HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), computation,
+                         sequence, points_to_analysis, size_function));
+  return result.heap_size;
+}
+
+StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
+    const HloComputation& computation,
+    const TuplePointsToAnalysis& points_to_analysis,
+    const LogicalBuffer::SizeFunction& size_function,
+    const MemorySchedulerAlgorithm& algorithm) {
+  VLOG(2) << "Computation: " << computation.name();
+  if (algorithm) {
+    return algorithm(computation, points_to_analysis, size_function);
+  }
+  return DefaultMemoryScheduler(computation, points_to_analysis, size_function);
+}
+
+}  // namespace
+
+StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
     const HloComputation& computation,
     const TuplePointsToAnalysis& points_to_analysis,
     const LogicalBuffer::SizeFunction& size_function) {
@@ -397,32 +423,17 @@ StatusOr<std::vector<const HloInstruction*>> RunDFSMemoryScheduler(
   return sequence;
 }
 
-StatusOr<int64> MinimumMemoryForComputation(
+StatusOr<std::vector<const HloInstruction*>> ListMemoryScheduler(
     const HloComputation& computation,
-    const std::vector<const HloInstruction*>& sequence,
     const TuplePointsToAnalysis& points_to_analysis,
     const LogicalBuffer::SizeFunction& size_function) {
-  TF_ASSIGN_OR_RETURN(
-      HeapSimulator::Result result,
-      HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), computation,
-                         sequence, points_to_analysis, size_function));
-  return result.heap_size;
+  return ListScheduler::Run(computation, points_to_analysis, size_function);
 }
 
-StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
+StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
     const HloComputation& computation,
     const TuplePointsToAnalysis& points_to_analysis,
-    const LogicalBuffer::SizeFunction& size_function,
-    SchedulerAlgorithm algorithm) {
-  VLOG(2) << "Computation: " << computation.name();
-  if (algorithm == SchedulerAlgorithm::kListSchedule) {
-    return ListScheduler::Run(computation, points_to_analysis, size_function);
-  }
-  if (algorithm == SchedulerAlgorithm::kDfsSchedule) {
-    return RunDFSMemoryScheduler(computation, points_to_analysis,
-                                 size_function);
-  }
-
+    const LogicalBuffer::SizeFunction& size_function) {
   // We try both a list-scheduler based ordering and a DFS based ordering, and
   // choose whichever returns a lower min-memory, not accounting for
   // fragmentation.
@@ -432,7 +443,7 @@ StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
   // within the caller's context. But it's good enough for now.
   TF_ASSIGN_OR_RETURN(
       std::vector<const HloInstruction*> list_sequence,
-      ListScheduler::Run(computation, points_to_analysis, size_function));
+      ListMemoryScheduler(computation, points_to_analysis, size_function));
   TF_ASSIGN_OR_RETURN(
       const int64 list_memory,
       MinimumMemoryForComputation(computation, list_sequence,
@@ -441,7 +452,7 @@ StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
 
   TF_ASSIGN_OR_RETURN(
       std::vector<const HloInstruction*> dfs_sequence,
-      RunDFSMemoryScheduler(computation, points_to_analysis, size_function));
+      DFSMemoryScheduler(computation, points_to_analysis, size_function));
   TF_ASSIGN_OR_RETURN(
       const int64 dfs_memory,
       MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis,
@@ -459,12 +470,10 @@ StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
   }
 }
 
-}  // namespace
-
 StatusOr<SequentialHloOrdering::HloModuleSequence>
 CreateMemoryMinimizingSequence(const HloModule& module,
                                const LogicalBuffer::SizeFunction& size_function,
-                               SchedulerAlgorithm algorithm) {
+                               const MemorySchedulerAlgorithm& algorithm) {
   SequentialHloOrdering::HloModuleSequence sequence;
   TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
                       TuplePointsToAnalysis::Run(&module));
@@ -480,7 +489,7 @@ CreateMemoryMinimizingSequence(const HloModule& module,
 StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
     const HloComputation& computation,
     const LogicalBuffer::SizeFunction& size_function,
-    SchedulerAlgorithm algorithm) {
+    const MemorySchedulerAlgorithm& algorithm) {
   CHECK(!computation.IsFusionComputation());
   TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
                       TuplePointsToAnalysis::Run(computation.parent()));
index 1d1eb1e..068e683 100644 (file)
@@ -22,6 +22,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
 #include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/compiler/xla/types.h"
 
@@ -33,28 +34,48 @@ StatusOr<int64> MinimumMemoryForSequence(
     const SequentialHloOrdering::HloModuleSequence& module_sequence,
     const LogicalBuffer::SizeFunction& size_function);
 
-enum class SchedulerAlgorithm {
-  kListSchedule,
-  kDfsSchedule,
+// A memory scheduler computes an execution sequence for the HLO instructions in
+// 'computation' that minimizes peak memory, given a points-to analysis result
+// that describes buffer aliasing, together with a target-specific size function
+// that maps a tensor's logical size to its padded size.
+typedef std::function<StatusOr<std::vector<const HloInstruction*>>(
+    const HloComputation&, const TuplePointsToAnalysis&,
+    const LogicalBuffer::SizeFunction&)>
+    MemorySchedulerAlgorithm;
 
-  // Selects the available scheduler algorithm that had the minimum memory in
-  // the resulting sequence (a la MinimumMemoryForSequence).
-  kAuto,
-};
+// List scheduler
+StatusOr<std::vector<const HloInstruction*>> ListMemoryScheduler(
+    const HloComputation& computation,
+    const TuplePointsToAnalysis& points_to_analysis,
+    const LogicalBuffer::SizeFunction& size_function);
+
+// DFS-order scheduler
+StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
+    const HloComputation& computation,
+    const TuplePointsToAnalysis& points_to_analysis,
+    const LogicalBuffer::SizeFunction& size_function);
+
+// The default scheduling algorithm. Runs both the list scheduler
+// and the DFS scheduler, and chooses whichever returns a lower min-memory,
+// not accounting for fragmentation.
+StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
+    const HloComputation& computation,
+    const TuplePointsToAnalysis& points_to_analysis,
+    const LogicalBuffer::SizeFunction& size_function);
 
 // Returns an HloModuleSequence which seeks to minimize the memory required for
 // the computation. size_function is the function returning the number of bytes
 // required for a LogicalBuffer.
 StatusOr<SequentialHloOrdering::HloModuleSequence>
-CreateMemoryMinimizingSequence(
-    const HloModule& module, const LogicalBuffer::SizeFunction& size_function,
-    SchedulerAlgorithm algorithm = SchedulerAlgorithm::kAuto);
+CreateMemoryMinimizingSequence(const HloModule& module,
+                               const LogicalBuffer::SizeFunction& size_function,
+                               const MemorySchedulerAlgorithm& algorithm = {});
 
 // Overload of above that computes the sequence for a single computation.
 StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
     const HloComputation& computation,
     const LogicalBuffer::SizeFunction& size_function,
-    SchedulerAlgorithm algorithm = SchedulerAlgorithm::kAuto);
+    const MemorySchedulerAlgorithm& algorithm = {});
 
 }  // namespace xla
 
index 2dd6e43..74544c4 100644 (file)
@@ -165,8 +165,7 @@ ENTRY root {
   };
   TF_ASSERT_OK_AND_ASSIGN(
       SequentialHloOrdering::HloModuleSequence sequence,
-      CreateMemoryMinimizingSequence(*module, size_fn,
-                                     SchedulerAlgorithm::kListSchedule));
+      CreateMemoryMinimizingSequence(*module, size_fn, ListMemoryScheduler));
   // Verify that all instructions are in the sequence.
   EXPECT_EQ(module->entry_computation()->instruction_count(),
             sequence.at(module->entry_computation()).size());