Bug fix in buffer assignment for colocated buffers
authorHyoukJoong Lee <hyouklee@google.com>
Thu, 8 Feb 2018 20:30:28 +0000 (12:30 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 8 Feb 2018 20:37:30 +0000 (12:37 -0800)
PiperOrigin-RevId: 185034095

tensorflow/compiler/xla/service/buffer_assignment.cc
tensorflow/compiler/xla/service/buffer_assignment.h
tensorflow/compiler/xla/service/buffer_assignment_test.cc
tensorflow/compiler/xla/service/buffer_liveness.cc

index 774b114..f0a9de5 100644 (file)
@@ -1122,140 +1122,6 @@ void BufferAssigner::AddSetToColocatedBufferSets(
   }
 }
 
-// Conceptually the same as AddSetToColocatedBufferSets, but specific to the
-// colocated buffers for while instructions. 'colocated_set' contains the
-// buffers for a single while instruction that must be colocated. The idea here
-// is to apply a memory-saving heuristic for separate while instructions whose
-// buffers are disjoint in liveness, by using the colocation mechanism to force
-// buffer sharing. This often reduces memory for multi-layer RNNs.
-//
-// TODO(b/32491382): We should be able to remove this heuristic after we
-// implement module-level liveness analysis, which would let us directly detect
-// buffer sharing opportunities between the while instruction buffer and the
-// buffers from the predicate and body computation, as well as sharing across
-// different while instructions.
-void BufferAssigner::AddWhileSetToColocatedBufferSets(
-    const std::vector<const LogicalBuffer*>& colocated_set,
-    const LogicalBuffer* while_init_buffer,
-    const LogicalBuffer* while_result_buffer, const HloInstruction* while_hlo,
-    const HloComputation& computation, const BufferLiveness& buffer_liveness,
-    const LogicalBuffer::SizeFunction& buffer_size,
-    std::vector<ColocatedBufferSet>* colocated_buffer_sets) {
-  CHECK(!colocated_set.empty());
-  const TuplePointsToAnalysis& points_to_analysis =
-      buffer_liveness.points_to_analysis();
-
-  // Parallel while loops cannot safely share colocated buffer sets.
-  if (buffer_liveness.hlo_ordering().SequentialOrder(computation) == nullptr) {
-    AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
-    return;
-  }
-
-  // Scan 'colocated_buffer_sets' in reverse order for locality; colocated sets
-  // are added in postorder over computations and instructions.
-  const int64 init_buffer_size = buffer_size(*while_init_buffer);
-  const bool is_live_out = buffer_liveness.MaybeLiveOut(*while_result_buffer);
-  for (int i = colocated_buffer_sets->size() - 1; i >= 0; --i) {
-    const ColocatedBufferSet& predecessor_set = (*colocated_buffer_sets)[i];
-
-    // Skip predecessor sets not associated with while loops.
-    if (std::all_of(predecessor_set.begin(), predecessor_set.end(),
-                    [](const LogicalBuffer* buffer) {
-                      return buffer->instruction()->opcode() !=
-                             HloOpcode::kWhile;
-                    })) {
-      continue;
-    }
-
-    // Skip predecessor sets already associated with 'while_hlo'.
-    if (std::any_of(predecessor_set.begin(), predecessor_set.end(),
-                    [&while_hlo](const LogicalBuffer* buffer) {
-                      return buffer->instruction() == while_hlo;
-                    })) {
-      continue;
-    }
-
-    // Skip predecessor sets with entry parameter if the while result is live
-    // out.
-    if (is_live_out &&
-        std::any_of(predecessor_set.begin(), predecessor_set.end(),
-                    [](const LogicalBuffer* buffer) {
-                      auto* instruction = buffer->instruction();
-                      auto* computation = instruction->parent();
-                      auto* module = computation->parent();
-                      return instruction->opcode() == HloOpcode::kParameter &&
-                             computation == module->entry_computation();
-                    })) {
-      continue;
-    }
-
-    // Build vector of predecessor while result and init buffers, which are
-    // checked for liveness interference below. We must check both the result
-    // and init buffers because they're aliased together, but
-    // TuplePointsToAnalysis is unaware of this aliasing.
-    std::vector<const LogicalBuffer*> predecessor_while_buffers;
-    for (const LogicalBuffer* buffer : predecessor_set) {
-      const HloInstruction* instruction = buffer->instruction();
-      if (instruction->opcode() == HloOpcode::kWhile &&
-          buffer_size(*buffer) == init_buffer_size &&
-          instruction->parent() == &computation) {
-        predecessor_while_buffers.push_back(buffer);
-        // Add the init buffer at the same index, which must also exist in the
-        // predecessor set, and must be unambiguous.
-        const PointsToSet& init_points_to =
-            points_to_analysis.GetPointsToSet(instruction->operand(0));
-        const auto& init_buffers = init_points_to.element(buffer->index());
-        CHECK_EQ(init_buffers.size(), 1);
-        CHECK_GT(predecessor_set.count(init_buffers[0]), 0);
-        predecessor_while_buffers.push_back(init_buffers[0]);
-      }
-    }
-    if (predecessor_while_buffers.empty()) {
-      continue;
-    }
-
-    // Skip predecessor set if the live range of any predecessor
-    // buffers overlaps with 'while_init_buffer' or
-    // 'while_result_buffer' (we need to check both since they're
-    // aliased together, but the points-to analysis is unaware of this
-    // aliasing). Note that tuple element buffer forwarding can cause
-    // the same buffer to appear on both sides of the interference
-    // comparison below.
-    auto may_interfere_with_init_or_result = [&](const LogicalBuffer* buffer) {
-      if (while_init_buffer->id() != buffer->id() &&
-          buffer_liveness.MayInterfere(*while_init_buffer, *buffer)) {
-        return true;
-      }
-
-      if (while_result_buffer->id() != buffer->id() &&
-          buffer_liveness.MayInterfere(*while_result_buffer, *buffer)) {
-        return true;
-      }
-
-      return false;
-    };
-
-    if (std::any_of(predecessor_while_buffers.begin(),
-                    predecessor_while_buffers.end(),
-                    may_interfere_with_init_or_result)) {
-      continue;
-    }
-
-    // All our checks have passed; merge 'predecessor_set' with 'colocated_set',
-    // and add the merged set to 'colocated_buffer_sets'. This forces the
-    // colocation of buffers across different while instructions.
-    FlatSet<const LogicalBuffer*> unique;
-    unique.insert(predecessor_set.begin(), predecessor_set.end());
-    unique.insert(colocated_set.begin(), colocated_set.end());
-    std::vector<const LogicalBuffer*> merged_set(unique.begin(), unique.end());
-    AddSetToColocatedBufferSets(merged_set, colocated_buffer_sets);
-    return;
-  }
-
-  // Failed to merge into predecessor set; add 'colocated_set' as-is.
-  AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
-}
-
 namespace {
 
 // Checks that points-to set of 'instruction' is unambiguous and distinct
@@ -1272,8 +1138,130 @@ const LogicalBuffer* AddBufferToColocatedSet(
   return colocated_set->back();
 }
 
+// Given the interference map of a graph (the list of interfering node indices
+// for each node), perform graph coloring such that interfering nodes are
+// assigned to different colors. Returns the assigned color of the nodes, where
+// the colors are represented as integer values [0, color_count).
+std::vector<int64> ColorInterferenceGraph(
+    const std::vector<std::vector<int64>>& interference_map) {
+  const int64 node_count = interference_map.size();
+
+  // Sort the nodes such that we assign nodes with more interference first. This
+  // relies on the common heuristic of assigning the most constrained node
+  // first, but it would be good to investigate other ordering heuristics too.
+  std::vector<int64> nodes(node_count);
+  std::iota(nodes.begin(), nodes.end(), 0);
+  std::sort(nodes.begin(), nodes.end(),
+            [&interference_map](const int64 i, const int64 j) {
+              return interference_map[i].size() > interference_map[j].size();
+            });
+
+  const int64 kColorUnassigned = -1;
+  std::vector<int64> assigned_colors(node_count, kColorUnassigned);
+  for (int64 node : nodes) {
+    // Mark the colors that are already assigned to the neighbors.
+    std::vector<bool> available_colors(node_count, true);
+    for (int64 neighbor : interference_map[node]) {
+      int64 color = assigned_colors[neighbor];
+      if (color != kColorUnassigned) {
+        available_colors[color] = false;
+      }
+    }
+
+    // Find the color that is not yet assigned to the neighbors.
+    int64 color = kColorUnassigned;
+    for (color = 0; color < available_colors.size(); ++color) {
+      if (available_colors[color]) {
+        break;
+      }
+    }
+    CHECK_NE(color, kColorUnassigned);
+    assigned_colors[node] = color;
+  }
+  return assigned_colors;
+}
+
 }  // namespace
 
+std::vector<BufferAssigner::ColocatedBufferSet>
+BufferAssigner::MergeColocatedBufferSets(
+    const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
+    const BufferLiveness& buffer_liveness,
+    const LogicalBuffer::SizeFunction& buffer_size) {
+  VLOG(1) << "colocation sets count before coalescing:"
+          << colocated_buffer_sets.size();
+
+  // Returns true if the given buffer is for the entry parameter.
+  auto is_entry_parameter = [](const LogicalBuffer& buffer) {
+    auto* instruction = buffer.instruction();
+    auto* computation = instruction->parent();
+    auto* module = computation->parent();
+    return instruction->opcode() == HloOpcode::kParameter &&
+           computation == module->entry_computation();
+  };
+
+  // Returns true if the two colocated buffer sets (specified by their indices
+  // into the colocated_buffer_sets) can be merged into a single set.
+  auto cannot_merge_buffer_sets = [&colocated_buffer_sets, &buffer_liveness,
+                                   &buffer_size,
+                                   &is_entry_parameter](int64 i, int64 j) {
+    for (auto& buffer_a : colocated_buffer_sets[i]) {
+      for (auto& buffer_b : colocated_buffer_sets[j]) {
+        // Do not merge if the set includes live outs or entry parameters.
+        if ((buffer_liveness.MaybeLiveOut(*buffer_a) &&
+             is_entry_parameter(*buffer_b)) ||
+            (buffer_liveness.MaybeLiveOut(*buffer_b) &&
+             is_entry_parameter(*buffer_a))) {
+          return true;
+        }
+        // Do not merge if the buffers interfere with each other.
+        if (buffer_a->id() != buffer_b->id() &&
+            buffer_liveness.MayInterfere(*buffer_a, *buffer_b)) {
+          return true;
+        }
+        // Do not merge if the buffer sizes are different.
+        if (buffer_size(*buffer_a) != buffer_size(*buffer_b)) {
+          return true;
+        }
+      }
+    }
+    return false;
+  };
+
+  // Build the interference map among the colocated buffer sets (nodes), by
+  // adding an edge between any two nodes that cannot be merged into a single
+  // colocated buffer set.
+  std::vector<std::vector<int64>> interference_map(
+      colocated_buffer_sets.size());
+  for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) {
+    for (int64 j = i + 1; j < colocated_buffer_sets.size(); ++j) {
+      if (cannot_merge_buffer_sets(i, j)) {
+        interference_map[i].push_back(j);
+        interference_map[j].push_back(i);
+      }
+    }
+  }
+
+  // Assign a color to each colocation set in colocated_buffer_sets, such that
+  // the sets that can be merged are assigned with the same color.
+  auto assigned_colors = ColorInterferenceGraph(interference_map);
+
+  // Merge the buffer sets with the same color.
+  CHECK(!assigned_colors.empty());
+  int64 num_sets =
+      *std::max_element(assigned_colors.begin(), assigned_colors.end()) + 1;
+  std::vector<ColocatedBufferSet> new_colocated_buffer_sets(num_sets);
+  for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) {
+    const auto& buffer_set = colocated_buffer_sets[i];
+    new_colocated_buffer_sets[assigned_colors[i]].insert(buffer_set.begin(),
+                                                         buffer_set.end());
+  }
+
+  VLOG(1) << "colocation sets count after coalescing:"
+          << colocated_buffer_sets.size();
+  return new_colocated_buffer_sets;
+}
+
 // Builds sets of buffers in 'colocated_buffer_sets' which should be colocated
 // in the same allocation (currently just supports kWhile, kCall, and
 // kConditional).
@@ -1299,12 +1287,11 @@ void BufferAssigner::BuildColocatedBufferSets(
                 const Shape& /*subshape*/, const ShapeIndex& index) {
               std::vector<const LogicalBuffer*> colocated_set;
               // Add while.init.
-              auto* init_buffer =
-                  AddBufferToColocatedSet(while_hlo->operand(0), index,
-                                          points_to_analysis, &colocated_set);
+              AddBufferToColocatedSet(while_hlo->operand(0), index,
+                                      points_to_analysis, &colocated_set);
               // Add while.result.
-              auto* result_buffer = AddBufferToColocatedSet(
-                  while_hlo, index, points_to_analysis, &colocated_set);
+              AddBufferToColocatedSet(while_hlo, index, points_to_analysis,
+                                      &colocated_set);
               // Add while.cond.parameter.
               AddBufferToColocatedSet(
                   while_hlo->while_condition()->parameter_instruction(0), index,
@@ -1317,10 +1304,7 @@ void BufferAssigner::BuildColocatedBufferSets(
               AddBufferToColocatedSet(
                   while_hlo->while_body()->root_instruction(), index,
                   points_to_analysis, &colocated_set);
-              AddWhileSetToColocatedBufferSets(
-                  colocated_set, init_buffer, result_buffer, while_hlo,
-                  *computation, buffer_liveness, buffer_size,
-                  colocated_buffer_sets);
+              AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
             });
       } else if (opcode == HloOpcode::kCall) {
         const HloInstruction* call_hlo = instruction;
@@ -1400,6 +1384,22 @@ void BufferAssigner::BuildColocatedBufferSets(
       }
     }
   }
+
+  if (colocated_buffer_sets->empty()) {
+    return;
+  }
+
+  // Try to find more coalescing opportunities among the colocated buffer sets.
+  //
+  // TODO(b/32491382): We should be able to remove this by using the
+  // module-level liveness analysis, which would let us directly detect buffer
+  // sharing opportunities between the while instruction buffer and the buffers
+  // from the predicate and body computation, as well as sharing across
+  // different while instructions.
+  std::vector<ColocatedBufferSet> new_colocated_buffer_sets =
+      MergeColocatedBufferSets(*colocated_buffer_sets, buffer_liveness,
+                               buffer_size);
+  std::swap(*colocated_buffer_sets, new_colocated_buffer_sets);
 }
 
 // Assigns all colocated buffer sets in 'colocated_buffer_sets' to the same
index 08a40bf..65019b6 100644 (file)
@@ -528,15 +528,13 @@ class BufferAssigner {
       const std::vector<const LogicalBuffer*>& colocated_set,
       std::vector<ColocatedBufferSet>* colocated_buffer_sets);
 
-  // Conceptually the same as AddSetToColocatedBufferSets, but specific to the
-  // colocated buffers for while instructions.
-  void AddWhileSetToColocatedBufferSets(
-      const std::vector<const LogicalBuffer*>& colocated_set,
-      const LogicalBuffer* while_init_buffer,
-      const LogicalBuffer* while_result_buffer, const HloInstruction* while_hlo,
-      const HloComputation& computation, const BufferLiveness& buffer_liveness,
-      const LogicalBuffer::SizeFunction& buffer_size,
-      std::vector<ColocatedBufferSet>* colocated_buffer_sets);
+  // Given a list of colocated buffer sets (each colocated buffer set represents
+  // the logical buffers that would be assigned to the same physical buffer),
+  // try to merge the sets if the buffers can be shared. Returns the merged set.
+  std::vector<ColocatedBufferSet> MergeColocatedBufferSets(
+      const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
+      const BufferLiveness& buffer_liveness,
+      const LogicalBuffer::SizeFunction& buffer_size);
 
   // Split a set of buffers into several sets, each of which contains buffers
   // colored with the same color.
index 6fc9d78..ef067cc 100644 (file)
@@ -1587,6 +1587,117 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) {
             assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie());
 }
 
+// Tests that the colocated buffers for while instructions are properly assigned
+// during buffer assignment such that the result tuple elements are not assigned
+// to the same buffer.
+//
+// %infeed --> %while.0 --> %while.1 --+
+//                                     +-- %tuple
+//   %zero -->   %add   --> %while.2 --+
+//
+// Execution Order:
+// %infeed -> %while.0 -> %while.1 -> %zero -> %add -> %while.2 -> %tuple
+//
+// The HLO computation used in this test requires specific ordering to expose
+// the bug (b/72496031). During buffer assignment, the visitation order of
+// colocated buffers is %while.2 -> while.0 -> while.1, and the buffer
+// assignment was coalescing the colocated buffers for all 3 while instructions,
+// therefore assigning the same buffer to the two result tuple elements.
+TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
+  const Shape r0s32 = ShapeUtil::MakeShape(S32, {});
+
+  // Builds a condition computation: x -> x < 4
+  auto build_cond = [&]() {
+    auto builder = HloComputation::Builder("cond");
+    auto const4 = builder.AddInstruction(
+        HloInstruction::CreateConstant(Literal::CreateR0<int>(4)));
+    auto param =
+        builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
+    builder.AddInstruction(HloInstruction::CreateBinary(
+        ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param, const4));
+    return builder.Build();
+  };
+
+  // Builds a body computation: x -> x + 9
+  auto build_body = [&]() {
+    auto builder = HloComputation::Builder("body");
+    auto const9 = builder.AddInstruction(
+        HloInstruction::CreateConstant(Literal::CreateR0<int>(9)));
+    auto param =
+        builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
+    builder.AddInstruction(
+        HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, param, const9));
+    return builder.Build();
+  };
+
+  // Build the entry computation as described in the comment above.
+  auto module = xla::MakeUnique<HloModule>(TestName());
+  auto builder = HloComputation::Builder("entry");
+
+  auto infeed = builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, ""));
+  auto cond0 = module->AddEmbeddedComputation(build_cond());
+  auto body0 = module->AddEmbeddedComputation(build_body());
+  auto while0 = builder.AddInstruction(
+      HloInstruction::CreateWhile(r0s32, cond0, body0, infeed));
+
+  auto cond1 = module->AddEmbeddedComputation(build_cond());
+  auto body1 = module->AddEmbeddedComputation(build_body());
+  auto while1 = builder.AddInstruction(
+      HloInstruction::CreateWhile(r0s32, cond1, body1, while0));
+
+  auto zero = builder.AddInstruction(
+      HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
+  auto add = builder.AddInstruction(
+      HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, zero, zero));
+  auto cond2 = module->AddEmbeddedComputation(build_cond());
+  auto body2 = module->AddEmbeddedComputation(build_body());
+  auto while2 = builder.AddInstruction(
+      HloInstruction::CreateWhile(r0s32, cond2, body2, add));
+
+  auto tuple =
+      builder.AddInstruction(HloInstruction::CreateTuple({while2, while1}));
+  module->AddEntryComputation(builder.Build());
+
+  // Run CopyInsertion and check if the graph constructed above doesn't need
+  // any copies inserted for BufferAssignment to run.
+  int64 instruction_count = module->instruction_count();
+  CopyInsertion copy_insertion;
+  ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
+  ASSERT_EQ(instruction_count, module->instruction_count());
+
+  // Create a sequential order among all the instructions in the entry
+  // computation, since the issue this test stresses depends on the order the
+  // nodes are traversed during BufferAssignment.
+  SequentialHloOrdering::HloModuleSequence sequence;
+  sequence[module->entry_computation()] = {infeed, while0, while1, zero,
+                                           add,    while2, tuple};
+  TF_ASSERT_OK_AND_ASSIGN(
+      auto assignment,
+      BufferAssigner::Run(
+          module.get(),
+          xla::MakeUnique<SequentialHloOrdering>(module.get(), sequence),
+          backend().compiler()->BufferSizeBytesFunction(),
+          [](LogicalBuffer::Color) { return 1; }));
+
+  // The result tuple elements must be assigned with different buffers.
+  TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0}));
+  TF_ASSERT_OK_AND_ASSIGN(auto slice1, assignment->GetUniqueSlice(tuple, {1}));
+  EXPECT_NE(slice0, slice1);
+
+  // while0 and while1 result buffers must be equal to slice1.
+  TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
+                          assignment->GetUniqueSlice(while0, {}));
+  TF_ASSERT_OK_AND_ASSIGN(auto slice_while1,
+                          assignment->GetUniqueSlice(while1, {}));
+  EXPECT_EQ(slice1, slice_while0);
+  EXPECT_EQ(slice1, slice_while1);
+
+  // while2 result buffer must be equal to slice0.
+  TF_ASSERT_OK_AND_ASSIGN(auto slice_while2,
+                          assignment->GetUniqueSlice(while2, {}));
+  EXPECT_EQ(slice0, slice_while2);
+}
+
 TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
   auto module = xla::MakeUnique<HloModule>(TestName());
   auto builder = HloComputation::Builder("entry");
index e774925..37982aa 100644 (file)
@@ -117,11 +117,12 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a,
 
     // If the root instruction aliases the buffer 'a', the live range of 'a' is
     // until the end of the computation and can never be strictly before another
-    // buffer. This is needed to prevent the root instruction's buffers from
-    // being reused by later instructions even when the root is not the last
-    // instruction in the schedule.
+    // buffer defined in the same computation. This is needed to prevent the
+    // root instruction's buffers from being reused by later instructions even
+    // when the root is not the last instruction in the schedule.
     if (alias.instruction()->parent()->root_instruction() ==
-        alias.instruction()) {
+            alias.instruction() &&
+        alias.instruction()->parent() == b.instruction()->parent()) {
       return false;
     }
   }