Add method for computing the maximal set of live LogicalBuffers in an allocation.
authorMark Heffernan <meheff@google.com>
Tue, 6 Mar 2018 02:07:12 +0000 (18:07 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 6 Mar 2018 02:11:09 +0000 (18:11 -0800)
PiperOrigin-RevId: 187954755

tensorflow/compiler/xla/service/buffer_assignment.cc
tensorflow/compiler/xla/service/buffer_assignment.h
tensorflow/compiler/xla/service/buffer_assignment_test.cc

index d44d3d7..0434c0a 100644 (file)
@@ -292,6 +292,112 @@ BufferAllocationProto BufferAllocation::ToProto() const {
   return proto;
 }
 
+std::pair<int64, std::vector<const LogicalBuffer*>>
+BufferAllocation::ComputePeakMemoryLogicalBuffers() const {
+  if (HeapTraces().empty()) {
+    // Just return the largest LogicalBuffer in the allocation.
+    const LogicalBuffer* largest_buffer = nullptr;
+    int64 largest_size = 0;
+    for (const auto& pair : assigned_buffers()) {
+      const LogicalBuffer* buffer = pair.first;
+      int64 size = pair.second.size;
+      if (largest_buffer == nullptr) {
+        largest_buffer = buffer;
+        largest_size = size;
+        continue;
+      }
+      // Tie-break with LogicalBuffer::Id so the return value is stable relative
+      // to changing addresses.
+      if (size > largest_size ||
+          ((size == largest_size) && (largest_buffer->id() > buffer->id()))) {
+        largest_buffer = buffer;
+        largest_size = size;
+      }
+    }
+    CHECK(largest_buffer != nullptr)
+        << "No logical buffers in allocation: " << ToString();
+    return {largest_size, {largest_buffer}};
+  }
+
+  // Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical
+  // buffers in this allocation.
+  tensorflow::gtl::FlatMap<LogicalBuffer::Id, const LogicalBuffer*>
+      id_to_buffer;
+  tensorflow::gtl::FlatMap<const LogicalBuffer*, int64> buffer_sizes;
+  for (const auto& pair : assigned_buffers()) {
+    const LogicalBuffer* buffer = pair.first;
+    const OffsetSize& offset_size = pair.second;
+    id_to_buffer[buffer->id()] = buffer;
+    buffer_sizes[buffer] = offset_size.size;
+  }
+
+  // Returns how much the given event increases the total size of live
+  // buffers. Can be negative.
+  auto memory_delta = [this, &id_to_buffer, &buffer_sizes](
+                          const HeapSimulatorTrace::Event& event) -> int64 {
+    const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id());
+    const int64 buffer_size = buffer_sizes.at(buffer);
+    if (event.kind() == HeapSimulatorTrace::Event::ALLOC) {
+      return buffer_size;
+    } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) {
+      // Sharing a buffer does not change the live set size for the purposes of
+      // the heap simulator. Even though the shared-with buffer may be smaller,
+      // the entire allocation remains live.
+      return 0;
+    } else if (event.kind() == HeapSimulatorTrace::Event::FREE) {
+      return -1 * buffer_size;
+    }
+    LOG(FATAL) << "Unknown event kind: " << event.kind();
+  };
+
+  int64 total_max_live_size = 0;
+  std::vector<const LogicalBuffer*> live_buffers_vector;
+  for (const HeapSimulatorTrace& heap_trace : HeapTraces()) {
+    // First compute the size of the maximal live set.
+    int64 max_live_size = 0;
+    int64 live_size = 0;
+    for (const auto& event : heap_trace.events()) {
+      live_size += memory_delta(event);
+      if (max_live_size < live_size) {
+        max_live_size = live_size;
+      }
+    }
+
+    // Next gather the set of logical buffers live at the earliest point of
+    // maximal live set size.
+    tensorflow::gtl::FlatSet<const LogicalBuffer*> live_buffers;
+    live_size = 0;
+    for (const auto& event : heap_trace.events()) {
+      const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id());
+      if (event.kind() == HeapSimulatorTrace::Event::ALLOC) {
+        InsertOrDie(&live_buffers, buffer);
+      } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) {
+        // Nothing to do.
+      } else if (event.kind() == HeapSimulatorTrace::Event::FREE) {
+        CHECK(ContainsKey(live_buffers, buffer));
+        live_buffers.erase(buffer);
+      }
+
+      live_size += memory_delta(event);
+      if (live_size == max_live_size) {
+        break;
+      }
+    }
+    CHECK_EQ(live_size, max_live_size);
+    total_max_live_size += max_live_size;
+
+    live_buffers_vector.insert(live_buffers_vector.end(), live_buffers.begin(),
+                               live_buffers.end());
+  }
+
+  // Stabily sort the live buffers.
+  std::sort(live_buffers_vector.begin(), live_buffers_vector.end(),
+            [](const LogicalBuffer* a, const LogicalBuffer* b) {
+              return a->id() < b->id();
+            });
+  return {total_max_live_size, live_buffers_vector};
+}
+
 string BufferAllocation::ToString() const {
   string output;
   Appendf(&output, "allocation %lld: %p, size %lld", index_, this, size());
@@ -525,6 +631,7 @@ void BufferAssignment::AddAssignment(BufferAllocation* allocation,
 // Combines allocations of temporary buffers of the same color into one big
 // BufferAllocation.
 void BufferAssignment::CombineTempAllocations() {
+  VLOG(1) << "CombineTempAllocations()";
   FlatMap<LogicalBuffer::Color, BufferAllocation, LogicalBuffer::Color::Hasher>
       combined_allocation_map;
 
@@ -546,11 +653,16 @@ void BufferAssignment::CombineTempAllocations() {
       if (combined_it == combined_allocation_map.end()) {
         // We have found the first temp allocation of this color. Collect
         // the other temp allocations of the same color into it.
+        VLOG(1) << "Combined temp allocation for color " << color
+                << " is: " << temp_allocation;
         combined_allocation_map.emplace(color, temp_allocation);
         continue;
       }
 
       auto* combined_allocation = &combined_it->second;
+      VLOG(1) << "Combined allocation absorbing temp allocation: "
+              << temp_allocation;
+
       // Each temp allocation is placed end-to-end, accounting for alignment.
       // The offset of each buffer in the combined allocation is computed from
       // the base offset of the allocation.
@@ -564,6 +676,10 @@ void BufferAssignment::CombineTempAllocations() {
         const int64 size = buffer_offset_size.second.size;
         combined_allocation->AddAssignment(*buffer, base + offset, size);
       }
+      if (!temp_allocation.HeapTraces().empty()) {
+        CHECK_EQ(temp_allocation.HeapTraces().size(), 1);
+        combined_allocation->AddHeapTrace(temp_allocation.HeapTraces().front());
+      }
     }
     // Replace all existing temporary allocations with the new combined
     // allocations.
@@ -693,9 +809,9 @@ BufferAssignmentProto BufferAssignment::ToProto() const {
   for (const BufferAllocation& allocation : Allocations()) {
     BufferAllocationProto proto_allocation = allocation.ToProto();
     proto.add_buffer_allocations()->Swap(&proto_allocation);
-  }
-  for (const HeapSimulatorTrace& trace : heap_simulator_traces_) {
-    *proto.add_heap_simulator_traces() = trace;
+    for (const HeapSimulatorTrace& heap_trace : allocation.HeapTraces()) {
+      *proto.add_heap_simulator_traces() = heap_trace;
+    }
   }
   return proto;
 }
@@ -1131,7 +1247,8 @@ void BufferAssigner::AssignBuffersFromHeapSimulator(
     assignment->AddAssignment(allocation, buffer, chunk.offset, chunk.size);
   }
 
-  assignment->heap_simulator_traces_.push_back(result.debug_trace);
+  VLOG(1) << "Ran heap simulation for allocation: " << allocation->ToString();
+  allocation->AddHeapTrace(result.debug_trace);
 }
 
 // Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining
index 6b7fd00..3086d0e 100644 (file)
@@ -192,6 +192,37 @@ class BufferAllocation {
            !is_thread_local();
   }
 
+  // Add a heap trace which was used to assign slices to logical buffers in this
+  // allocation. A single BufferAllocation may include multiple heap traces
+  // in the case of the temporary block where there is a heap trace per
+  // computation.
+  void AddHeapTrace(const HeapSimulatorTrace& heap_trace) {
+    heap_traces_.push_back(heap_trace);
+  }
+
+  // Return the set of heap traces used to assign slices to logical buffers in
+  // this allocation.
+  const std::vector<HeapSimulatorTrace> HeapTraces() const {
+    return heap_traces_;
+  }
+
+  // Compute and return the LogicalBuffers which are live at the point of peak
+  // memory usage for the given allocation. The point of peak memory usage is
+  // the point at which the total size of all live logical buffers is
+  // maximal. If peak memory is reached at multiple points, the set of logical
+  // buffers live at the earliest maximal point is returned. The vector is
+  // stabily asserted by LogicalBuffer::Index.
+  //
+  // The return value is a pair of total size of the logical buffers at peak,
+  // and the buffers themselves.
+  std::pair<int64, std::vector<const LogicalBuffer*>>
+  ComputePeakMemoryLogicalBuffers() const;
+
+  // Get the number of bytes lost to fragmentation. This is equal to the
+  // difference between the size of the allocation and the size of the maximal
+  // live set.
+  int64 fragmentation_bytes() const { return fragmentation_bytes_; }
+
   bool operator==(const BufferAllocation& other) const {
     return index_ == other.index_;
   }
@@ -257,6 +288,9 @@ class BufferAllocation {
   // Mapping from the set of buffers assigned to this allocation to their
   // logical offsets and sizes.
   tensorflow::gtl::FlatMap<const LogicalBuffer*, OffsetSize> assigned_buffers_;
+
+  int64 fragmentation_bytes_ = 0;
+  std::vector<HeapSimulatorTrace> heap_traces_;
 };
 
 // Add stream operators for nicer output of CHECK/RET_CHECK failures.
@@ -441,7 +475,6 @@ class BufferAssignment {
   LogicalBuffer::AlignmentFunction color_alignment_;
 
   Stats stats_;
-  std::vector<HeapSimulatorTrace> heap_simulator_traces_;
 
   TF_DISALLOW_COPY_AND_ASSIGN(BufferAssignment);
 };
index cd73654..234c725 100644 (file)
@@ -42,9 +42,10 @@ limitations under the License.
 #include "tensorflow/core/platform/macros.h"
 
 namespace xla {
-
 namespace {
 
+using ::testing::UnorderedElementsAre;
+
 // DFS visitor that collects the instructions referenced by a computation
 // without descending into nested computations, i.e., only from the operands.
 class InstructionListVisitor : public DfsHloVisitorWithDefault {
@@ -101,6 +102,22 @@ class BufferAssignmentTest : public HloTestBase {
         .ConsumeValueOrDie();
   }
 
+  std::unique_ptr<BufferAssignment> RunBufferAssignmentWithInstructionSequence(
+      HloModule* module,
+      tensorflow::gtl::ArraySlice<const HloInstruction*> instruction_sequence,
+      int64 alignment = 1) {
+    SequentialHloOrdering::HloModuleSequence module_sequence;
+    module_sequence[module->entry_computation()] =
+        std::vector<const HloInstruction*>(instruction_sequence.begin(),
+                                           instruction_sequence.end());
+    return BufferAssigner::Run(
+               module,
+               xla::MakeUnique<SequentialHloOrdering>(module, module_sequence),
+               backend().compiler()->BufferSizeBytesFunction(),
+               [alignment](LogicalBuffer::Color) { return alignment; })
+        .ConsumeValueOrDie();
+  }
+
   // Builds an x+1.0 computation to use in a Map.
   std::unique_ptr<HloComputation> BuildMapComputationPlus1(const string& name) {
     auto builder = HloComputation::Builder(name);
@@ -1370,7 +1387,7 @@ TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) {
   auto element_slices = assignment->GetAllSlices(select, /*index=*/{0});
   EXPECT_EQ(2, element_slices.size());
   EXPECT_THAT(element_slices,
-              ::testing::UnorderedElementsAre(
+              UnorderedElementsAre(
                   assignment->GetUniqueSlice(tuple_param0, /*index=*/{0})
                       .ConsumeValueOrDie(),
                   assignment->GetUniqueSlice(tuple_param1, /*index=*/{0})
@@ -1473,6 +1490,98 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) {
   }
 }
 
+TEST_F(BufferAssignmentTest, TrivialPeakBuffers) {
+  // paramscalar ------- (mul) -- (add) -- (sub)
+  //                     /        /        /
+  // param0[100] -------/        /        /
+  //                            /        /
+  // param1[100] --------------/--------/
+  auto builder = HloComputation::Builder(TestName());
+  auto paramscalar =
+      builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, ""));
+  auto param0 = builder.AddInstruction(
+      HloInstruction::CreateParameter(1, f32vec100_, ""));
+  auto param1 = builder.AddInstruction(
+      HloInstruction::CreateParameter(2, f32vec100_, ""));
+  auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
+      f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
+  auto add = builder.AddInstruction(
+      HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
+  builder.AddInstruction(HloInstruction::CreateBinary(
+      f32vec100_, HloOpcode::kSubtract, add, param1));
+  auto module = CreateNewModule();
+  module->AddEntryComputation(builder.Build());
+
+  auto buffers = RunBufferAssignment(module.get());
+
+  // Trivially, the set of peak memory logical buffer(s) of an allocation with a
+  // single logical buffer should be exactly the logical buffer in that
+  // allocation.
+  const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
+  int64 peak_size;
+  std::vector<const LogicalBuffer*> peak_buffers;
+
+  std::tie(peak_size, peak_buffers) =
+      mul_buffer.ComputePeakMemoryLogicalBuffers();
+  EXPECT_EQ(peak_size, ShapeUtil::ByteSizeOf(f32vec100_));
+  ASSERT_EQ(peak_buffers.size(), 1);
+  EXPECT_EQ(peak_buffers[0]->instruction(), mul);
+}
+
+TEST_F(BufferAssignmentTest, PeakBuffers) {
+  // Compute the peak liveness buffers of the following sequence:
+  //
+  //   %param = ...
+  //   %log = log(%param)
+  //   %rev = reverse(%log)
+  //   %neg = neg(%param)
+  //   %concat = concat(%rev, %neg)
+  //   ROOT %root = slice(concat)
+  //
+  // In the temporary block, the set of live buffers at peak memory use should
+  // be {%rev, %neg, %concat}. This occurs right at the concat itself.
+  auto builder = HloComputation::Builder(TestName());
+  auto param = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, f32vec100_, ""));
+  auto log = builder.AddInstruction(
+      HloInstruction::CreateUnary(f32vec100_, HloOpcode::kLog, param));
+  auto rev = builder.AddInstruction(
+      HloInstruction::CreateReverse(f32vec100_, log, {0}));
+  auto neg = builder.AddInstruction(
+      HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param));
+  const Shape concat_shape = ShapeUtil::MakeShape(F32, {200});
+  auto concat = builder.AddInstruction(
+      HloInstruction::CreateConcatenate(concat_shape, {rev, neg}, 0));
+  // Make the root tiny so no interior nodes can share its buffer.
+  auto root = builder.AddInstruction(HloInstruction::CreateSlice(
+      ShapeUtil::MakeShape(F32, {1}), concat, {0}, {1}, {1}));
+
+  auto module = CreateNewModule();
+  module->AddEntryComputation(builder.Build());
+
+  auto buffers = RunBufferAssignmentWithInstructionSequence(
+      module.get(), {param, log, rev, neg, concat, root});
+
+  // The temporary buffer should hold the 4 interior instructions.
+  const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, concat);
+  EXPECT_FALSE(buffer.IsInputOrOutput());
+  EXPECT_TRUE(buffer.IsPreallocatedTempBuffer());
+  ASSERT_EQ(buffer.assigned_buffers().size(), 4);
+
+  int64 peak_size;
+  std::vector<const LogicalBuffer*> peak_buffers;
+  std::tie(peak_size, peak_buffers) = buffer.ComputePeakMemoryLogicalBuffers();
+
+  // The peak live set should be concat and its inputs.
+  EXPECT_EQ(peak_size, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F32, {400})));
+  ASSERT_EQ(peak_buffers.size(), 3);
+  std::vector<const HloInstruction*> peak_instructions;
+  for (const LogicalBuffer* logical_buffer : peak_buffers) {
+    peak_instructions.push_back(logical_buffer->instruction());
+  }
+  EXPECT_THAT(peak_instructions, UnorderedElementsAre(rev, neg, concat));
+}
+
 class WhileBufferAssignmentTest : public HloTestBase {
  protected:
   std::unique_ptr<HloComputation> BuildWhileConditionComputation(