}
}
-// Conceptually the same as AddSetToColocatedBufferSets, but specific to the
-// colocated buffers for while instructions. 'colocated_set' contains the
-// buffers for a single while instruction that must be colocated. The idea here
-// is to apply a memory-saving heuristic for separate while instructions whose
-// buffers are disjoint in liveness, by using the colocation mechanism to force
-// buffer sharing. This often reduces memory for multi-layer RNNs.
-//
-// TODO(b/32491382): We should be able to remove this heuristic after we
-// implement module-level liveness analysis, which would let us directly detect
-// buffer sharing opportunities between the while instruction buffer and the
-// buffers from the predicate and body computation, as well as sharing across
-// different while instructions.
-void BufferAssigner::AddWhileSetToColocatedBufferSets(
- const std::vector<const LogicalBuffer*>& colocated_set,
- const LogicalBuffer* while_init_buffer,
- const LogicalBuffer* while_result_buffer, const HloInstruction* while_hlo,
- const HloComputation& computation, const BufferLiveness& buffer_liveness,
- const LogicalBuffer::SizeFunction& buffer_size,
- std::vector<ColocatedBufferSet>* colocated_buffer_sets) {
- CHECK(!colocated_set.empty());
- const TuplePointsToAnalysis& points_to_analysis =
- buffer_liveness.points_to_analysis();
-
- // Parallel while loops cannot safely share colocated buffer sets.
- if (buffer_liveness.hlo_ordering().SequentialOrder(computation) == nullptr) {
- AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
- return;
- }
-
- // Scan 'colocated_buffer_sets' in reverse order for locality; colocated sets
- // are added in postorder over computations and instructions.
- const int64 init_buffer_size = buffer_size(*while_init_buffer);
- const bool is_live_out = buffer_liveness.MaybeLiveOut(*while_result_buffer);
- for (int i = colocated_buffer_sets->size() - 1; i >= 0; --i) {
- const ColocatedBufferSet& predecessor_set = (*colocated_buffer_sets)[i];
-
- // Skip predecessor sets not associated with while loops.
- if (std::all_of(predecessor_set.begin(), predecessor_set.end(),
- [](const LogicalBuffer* buffer) {
- return buffer->instruction()->opcode() !=
- HloOpcode::kWhile;
- })) {
- continue;
- }
-
- // Skip predecessor sets already associated with 'while_hlo'.
- if (std::any_of(predecessor_set.begin(), predecessor_set.end(),
- [&while_hlo](const LogicalBuffer* buffer) {
- return buffer->instruction() == while_hlo;
- })) {
- continue;
- }
-
- // Skip predecessor sets with entry parameter if the while result is live
- // out.
- if (is_live_out &&
- std::any_of(predecessor_set.begin(), predecessor_set.end(),
- [](const LogicalBuffer* buffer) {
- auto* instruction = buffer->instruction();
- auto* computation = instruction->parent();
- auto* module = computation->parent();
- return instruction->opcode() == HloOpcode::kParameter &&
- computation == module->entry_computation();
- })) {
- continue;
- }
-
- // Build vector of predecessor while result and init buffers, which are
- // checked for liveness interference below. We must check both the result
- // and init buffers because they're aliased together, but
- // TuplePointsToAnalysis is unaware of this aliasing.
- std::vector<const LogicalBuffer*> predecessor_while_buffers;
- for (const LogicalBuffer* buffer : predecessor_set) {
- const HloInstruction* instruction = buffer->instruction();
- if (instruction->opcode() == HloOpcode::kWhile &&
- buffer_size(*buffer) == init_buffer_size &&
- instruction->parent() == &computation) {
- predecessor_while_buffers.push_back(buffer);
- // Add the init buffer at the same index, which must also exist in the
- // predecessor set, and must be unambiguous.
- const PointsToSet& init_points_to =
- points_to_analysis.GetPointsToSet(instruction->operand(0));
- const auto& init_buffers = init_points_to.element(buffer->index());
- CHECK_EQ(init_buffers.size(), 1);
- CHECK_GT(predecessor_set.count(init_buffers[0]), 0);
- predecessor_while_buffers.push_back(init_buffers[0]);
- }
- }
- if (predecessor_while_buffers.empty()) {
- continue;
- }
-
- // Skip predecessor set if the live range of any predecessor
- // buffers overlaps with 'while_init_buffer' or
- // 'while_result_buffer' (we need to check both since they're
- // aliased together, but the points-to analysis is unaware of this
- // aliasing). Note that tuple element buffer forwarding can cause
- // the same buffer to appear on both sides of the interference
- // comparison below.
- auto may_interfere_with_init_or_result = [&](const LogicalBuffer* buffer) {
- if (while_init_buffer->id() != buffer->id() &&
- buffer_liveness.MayInterfere(*while_init_buffer, *buffer)) {
- return true;
- }
-
- if (while_result_buffer->id() != buffer->id() &&
- buffer_liveness.MayInterfere(*while_result_buffer, *buffer)) {
- return true;
- }
-
- return false;
- };
-
- if (std::any_of(predecessor_while_buffers.begin(),
- predecessor_while_buffers.end(),
- may_interfere_with_init_or_result)) {
- continue;
- }
-
- // All our checks have passed; merge 'predecessor_set' with 'colocated_set',
- // and add the merged set to 'colocated_buffer_sets'. This forces the
- // colocation of buffers across different while instructions.
- FlatSet<const LogicalBuffer*> unique;
- unique.insert(predecessor_set.begin(), predecessor_set.end());
- unique.insert(colocated_set.begin(), colocated_set.end());
- std::vector<const LogicalBuffer*> merged_set(unique.begin(), unique.end());
- AddSetToColocatedBufferSets(merged_set, colocated_buffer_sets);
- return;
- }
-
- // Failed to merge into predecessor set; add 'colocated_set' as-is.
- AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
-}
-
namespace {
// Checks that points-to set of 'instruction' is unambiguous and distinct
return colocated_set->back();
}
+// Given the interference map of a graph (the list of interfering node indices
+// for each node), perform graph coloring such that interfering nodes are
+// assigned to different colors. Returns the assigned color of the nodes, where
+// the colors are represented as integer values [0, color_count).
+std::vector<int64> ColorInterferenceGraph(
+ const std::vector<std::vector<int64>>& interference_map) {
+ const int64 node_count = interference_map.size();
+
+ // Sort the nodes such that we assign nodes with more interference first. This
+ // relies on the common heuristic of assigning the most constrained node
+ // first, but it would be good to investigate other ordering heuristics too.
+ std::vector<int64> nodes(node_count);
+ std::iota(nodes.begin(), nodes.end(), 0);
+ std::sort(nodes.begin(), nodes.end(),
+ [&interference_map](const int64 i, const int64 j) {
+ return interference_map[i].size() > interference_map[j].size();
+ });
+
+ const int64 kColorUnassigned = -1;
+ std::vector<int64> assigned_colors(node_count, kColorUnassigned);
+ for (int64 node : nodes) {
+ // Mark the colors that are already assigned to the neighbors.
+ std::vector<bool> available_colors(node_count, true);
+ for (int64 neighbor : interference_map[node]) {
+ int64 color = assigned_colors[neighbor];
+ if (color != kColorUnassigned) {
+ available_colors[color] = false;
+ }
+ }
+
+ // Find the color that is not yet assigned to the neighbors.
+ int64 color = kColorUnassigned;
+ for (color = 0; color < available_colors.size(); ++color) {
+ if (available_colors[color]) {
+ break;
+ }
+ }
+ CHECK_NE(color, kColorUnassigned);
+ assigned_colors[node] = color;
+ }
+ return assigned_colors;
+}
+
} // namespace
+std::vector<BufferAssigner::ColocatedBufferSet>
+BufferAssigner::MergeColocatedBufferSets(
+ const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
+ const BufferLiveness& buffer_liveness,
+ const LogicalBuffer::SizeFunction& buffer_size) {
+ VLOG(1) << "colocation sets count before coalescing:"
+ << colocated_buffer_sets.size();
+
+ // Returns true if the given buffer is for the entry parameter.
+ auto is_entry_parameter = [](const LogicalBuffer& buffer) {
+ auto* instruction = buffer.instruction();
+ auto* computation = instruction->parent();
+ auto* module = computation->parent();
+ return instruction->opcode() == HloOpcode::kParameter &&
+ computation == module->entry_computation();
+ };
+
+ // Returns true if the two colocated buffer sets (specified by their indices
+ // into the colocated_buffer_sets) can be merged into a single set.
+ auto cannot_merge_buffer_sets = [&colocated_buffer_sets, &buffer_liveness,
+ &buffer_size,
+ &is_entry_parameter](int64 i, int64 j) {
+ for (auto& buffer_a : colocated_buffer_sets[i]) {
+ for (auto& buffer_b : colocated_buffer_sets[j]) {
+ // Do not merge if the set includes live outs or entry parameters.
+ if ((buffer_liveness.MaybeLiveOut(*buffer_a) &&
+ is_entry_parameter(*buffer_b)) ||
+ (buffer_liveness.MaybeLiveOut(*buffer_b) &&
+ is_entry_parameter(*buffer_a))) {
+ return true;
+ }
+ // Do not merge if the buffers interfere with each other.
+ if (buffer_a->id() != buffer_b->id() &&
+ buffer_liveness.MayInterfere(*buffer_a, *buffer_b)) {
+ return true;
+ }
+ // Do not merge if the buffer sizes are different.
+ if (buffer_size(*buffer_a) != buffer_size(*buffer_b)) {
+ return true;
+ }
+ }
+ }
+ return false;
+ };
+
+ // Build the interference map among the colocated buffer sets (nodes), by
+ // adding an edge between any two nodes that cannot be merged into a single
+ // colocated buffer set.
+ std::vector<std::vector<int64>> interference_map(
+ colocated_buffer_sets.size());
+ for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) {
+ for (int64 j = i + 1; j < colocated_buffer_sets.size(); ++j) {
+ if (cannot_merge_buffer_sets(i, j)) {
+ interference_map[i].push_back(j);
+ interference_map[j].push_back(i);
+ }
+ }
+ }
+
+ // Assign a color to each colocation set in colocated_buffer_sets, such that
+ // the sets that can be merged are assigned with the same color.
+ auto assigned_colors = ColorInterferenceGraph(interference_map);
+
+ // Merge the buffer sets with the same color.
+ CHECK(!assigned_colors.empty());
+ int64 num_sets =
+ *std::max_element(assigned_colors.begin(), assigned_colors.end()) + 1;
+ std::vector<ColocatedBufferSet> new_colocated_buffer_sets(num_sets);
+ for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) {
+ const auto& buffer_set = colocated_buffer_sets[i];
+ new_colocated_buffer_sets[assigned_colors[i]].insert(buffer_set.begin(),
+ buffer_set.end());
+ }
+
+ VLOG(1) << "colocation sets count after coalescing:"
+ << colocated_buffer_sets.size();
+ return new_colocated_buffer_sets;
+}
+
// Builds sets of buffers in 'colocated_buffer_sets' which should be colocated
// in the same allocation (currently just supports kWhile, kCall, and
// kConditional).
const Shape& /*subshape*/, const ShapeIndex& index) {
std::vector<const LogicalBuffer*> colocated_set;
// Add while.init.
- auto* init_buffer =
- AddBufferToColocatedSet(while_hlo->operand(0), index,
- points_to_analysis, &colocated_set);
+ AddBufferToColocatedSet(while_hlo->operand(0), index,
+ points_to_analysis, &colocated_set);
// Add while.result.
- auto* result_buffer = AddBufferToColocatedSet(
- while_hlo, index, points_to_analysis, &colocated_set);
+ AddBufferToColocatedSet(while_hlo, index, points_to_analysis,
+ &colocated_set);
// Add while.cond.parameter.
AddBufferToColocatedSet(
while_hlo->while_condition()->parameter_instruction(0), index,
AddBufferToColocatedSet(
while_hlo->while_body()->root_instruction(), index,
points_to_analysis, &colocated_set);
- AddWhileSetToColocatedBufferSets(
- colocated_set, init_buffer, result_buffer, while_hlo,
- *computation, buffer_liveness, buffer_size,
- colocated_buffer_sets);
+ AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
});
} else if (opcode == HloOpcode::kCall) {
const HloInstruction* call_hlo = instruction;
}
}
}
+
+ if (colocated_buffer_sets->empty()) {
+ return;
+ }
+
+ // Try to find more coalescing opportunities among the colocated buffer sets.
+ //
+ // TODO(b/32491382): We should be able to remove this by using the
+ // module-level liveness analysis, which would let us directly detect buffer
+ // sharing opportunities between the while instruction buffer and the buffers
+ // from the predicate and body computation, as well as sharing across
+ // different while instructions.
+ std::vector<ColocatedBufferSet> new_colocated_buffer_sets =
+ MergeColocatedBufferSets(*colocated_buffer_sets, buffer_liveness,
+ buffer_size);
+ std::swap(*colocated_buffer_sets, new_colocated_buffer_sets);
}
// Assigns all colocated buffer sets in 'colocated_buffer_sets' to the same
assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie());
}
+// Tests that the colocated buffers for while instructions are properly assigned
+// during buffer assignment such that the result tuple elements are not assigned
+// to the same buffer.
+//
+// %infeed --> %while.0 --> %while.1 --+
+// +-- %tuple
+// %zero --> %add --> %while.2 --+
+//
+// Execution Order:
+// %infeed -> %while.0 -> %while.1 -> %zero -> %add -> %while.2 -> %tuple
+//
+// The HLO computation used in this test requires specific ordering to expose
+// the bug (b/72496031). During buffer assignment, the visitation order of
+// colocated buffers is %while.2 -> while.0 -> while.1, and the buffer
+// assignment was coalescing the colocated buffers for all 3 while instructions,
+// therefore assigning the same buffer to the two result tuple elements.
+TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
+ const Shape r0s32 = ShapeUtil::MakeShape(S32, {});
+
+ // Builds a condition computation: x -> x < 4
+ auto build_cond = [&]() {
+ auto builder = HloComputation::Builder("cond");
+ auto const4 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int>(4)));
+ auto param =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param, const4));
+ return builder.Build();
+ };
+
+ // Builds a body computation: x -> x + 9
+ auto build_body = [&]() {
+ auto builder = HloComputation::Builder("body");
+ auto const9 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int>(9)));
+ auto param =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, param, const9));
+ return builder.Build();
+ };
+
+ // Build the entry computation as described in the comment above.
+ auto module = xla::MakeUnique<HloModule>(TestName());
+ auto builder = HloComputation::Builder("entry");
+
+ auto infeed = builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, ""));
+ auto cond0 = module->AddEmbeddedComputation(build_cond());
+ auto body0 = module->AddEmbeddedComputation(build_body());
+ auto while0 = builder.AddInstruction(
+ HloInstruction::CreateWhile(r0s32, cond0, body0, infeed));
+
+ auto cond1 = module->AddEmbeddedComputation(build_cond());
+ auto body1 = module->AddEmbeddedComputation(build_body());
+ auto while1 = builder.AddInstruction(
+ HloInstruction::CreateWhile(r0s32, cond1, body1, while0));
+
+ auto zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
+ auto add = builder.AddInstruction(
+ HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, zero, zero));
+ auto cond2 = module->AddEmbeddedComputation(build_cond());
+ auto body2 = module->AddEmbeddedComputation(build_body());
+ auto while2 = builder.AddInstruction(
+ HloInstruction::CreateWhile(r0s32, cond2, body2, add));
+
+ auto tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({while2, while1}));
+ module->AddEntryComputation(builder.Build());
+
+ // Run CopyInsertion and check if the graph constructed above doesn't need
+ // any copies inserted for BufferAssignment to run.
+ int64 instruction_count = module->instruction_count();
+ CopyInsertion copy_insertion;
+ ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
+ ASSERT_EQ(instruction_count, module->instruction_count());
+
+ // Create a sequential order among all the instructions in the entry
+ // computation, since the issue this test stresses depends on the order the
+ // nodes are traversed during BufferAssignment.
+ SequentialHloOrdering::HloModuleSequence sequence;
+ sequence[module->entry_computation()] = {infeed, while0, while1, zero,
+ add, while2, tuple};
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto assignment,
+ BufferAssigner::Run(
+ module.get(),
+ xla::MakeUnique<SequentialHloOrdering>(module.get(), sequence),
+ backend().compiler()->BufferSizeBytesFunction(),
+ [](LogicalBuffer::Color) { return 1; }));
+
+ // The result tuple elements must be assigned with different buffers.
+ TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0}));
+ TF_ASSERT_OK_AND_ASSIGN(auto slice1, assignment->GetUniqueSlice(tuple, {1}));
+ EXPECT_NE(slice0, slice1);
+
+ // while0 and while1 result buffers must be equal to slice1.
+ TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
+ assignment->GetUniqueSlice(while0, {}));
+ TF_ASSERT_OK_AND_ASSIGN(auto slice_while1,
+ assignment->GetUniqueSlice(while1, {}));
+ EXPECT_EQ(slice1, slice_while0);
+ EXPECT_EQ(slice1, slice_while1);
+
+ // while2 result buffer must be equal to slice0.
+ TF_ASSERT_OK_AND_ASSIGN(auto slice_while2,
+ assignment->GetUniqueSlice(while2, {}));
+ EXPECT_EQ(slice0, slice_while2);
+}
+
TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
auto module = xla::MakeUnique<HloModule>(TestName());
auto builder = HloComputation::Builder("entry");