policy.copy_root_replicated_buffers = true;
}
for (const CallSite& site : node.caller_callsites()) {
- // The kWhile instruction does not have an handling here, as the
- // AddCopiesForWhile() API takes care of adding its own copies.
+ // The AddCopiesForConditional() already adds copies, but the copy remover
+ // removes them, so we re-add them by returning the policy here. But really
+ // the copy remover should not be removing them.
if (site.instruction()->opcode() == HloOpcode::kConditional) {
policy.copy_parameters_and_constants = true;
policy.copy_root_replicated_buffers = true;
return Status::OK();
}
+// We add copies for all the indices of the true and false computaiton roots,
+// in order to resolve interference. We later rely on the CopyRemover to drop
+// the unnecessary ones.
+Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
+ HloInstruction* conditional) {
+ VLOG(2) << "Adding copies for kConditional instruction "
+ << conditional->name();
+ TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional);
+
+ for (HloComputation* computation :
+ {conditional->true_computation(), conditional->false_computation()}) {
+ HloInstruction* root = computation->root_instruction();
+ std::vector<HloInstruction*> users = root->users();
+ TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
+ computation->DeepCopyInstruction(root));
+ for (HloInstruction* user : users) {
+ TF_RETURN_IF_ERROR(root->ReplaceUseWith(user, deep_copy));
+ }
+ computation->set_root_instruction(deep_copy);
+ }
+ return Status::OK();
+}
+
// Removes any control dependencies to or from the given instruction.
Status StripControlDependenciesFrom(HloInstruction* instruction) {
while (!instruction->control_successors().empty()) {
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kWhile) {
TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction));
+ } else if (instruction->opcode() == HloOpcode::kConditional) {
+ TF_RETURN_IF_ERROR(
+ AddCopiesForConditional(*alias_analysis, instruction));
}
}
}
auto is_live_range_before = [this](const ValueNode& a,
const ValueNode& b) {
+ VLOG(3) << "Checking live range of " << *a.value << " WRT " << *b.value;
if (LiveRangeBefore(a, b)) {
VLOG(2) << " Live range of " << a.value->ToShortString()
<< " is before " << b.value->ToShortString();
VLOG(3) << copy->name() << " copies value "
<< src->value->ToShortString();
VLOG(3) << "Source buffer values: " << ValueListToString(src);
- VLOG(3) << "Dest buffer values: " << ValueListToString(src);
+ VLOG(3) << "Dest buffer values: " << ValueListToString(dest);
// A kCopy instruction copies an HLO value from a source buffer and
// defines an HLO value in a destination buffer. Most generally, the
// updated as copies are removed.
bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) {
if (a.uses.empty()) {
- VLOG(2) << "Empty uses";
+ VLOG(2) << "Empty uses for " << *a.value;
return ordering_.IsDefinedBefore(*a.value, *b.value);
}
for (const HloUse* use : a.uses) {
- VLOG(2) << "use: " << *use;
- VLOG(2) << "is before:" << *b.value;
+ VLOG(2) << "Checking use " << *use << " against " << *b.value;
if (!ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_)) {
- VLOG(2) << "Not before";
+ VLOG(2) << "Use " << *use << " is NOT before " << *b.value;
return false;
}
+ VLOG(2) << "Use " << *use << " is before " << *b.value;
}
return true;
}
CopyRemover copy_remover(*alias_analysis, ordering, module);
XLA_VLOG_LINES(3, copy_remover.ToString());
- tensorflow::gtl::FlatSet<int> existing_copies;
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kCopy &&
}
}
}
-
return Status::OK();
}
return value_to_buffer_number_.at(&value);
}
- // Compute and return a vector of buffers that the given value must be
- // contained in due to HLO aliasing rules.
- std::vector<BufferNumber> ComputeAliasedBuffers(const HloValue& value) {
+ void ComputeWhileAliasedBuffers(const HloValue& value,
+ std::vector<BufferNumber>* aliased_buffers) {
+ VLOG(3) << "Compute kWhile aliases";
// Value is init of a while (use is while).
- std::vector<BufferNumber> aliased_buffers;
for (const HloUse& use : value.uses()) {
- VLOG(2) << "use of value " << value.ToShortString() << ": " << use;
if (use.instruction->opcode() == HloOpcode::kWhile) {
// Determine the while value that this shares a buffer with.
const HloValue& while_value =
dataflow_.GetUniqueValueAt(use.instruction, use.operand_index);
- aliased_buffers.push_back(GetBufferForValue(while_value));
+ aliased_buffers->push_back(GetBufferForValue(while_value));
VLOG(3) << " value is init value to a while; must share buffer with "
"while value "
<< while_value.ToShortString();
}
}
-
// Value is a parameter of a while body/condition.
if (value.defining_instruction()->opcode() == HloOpcode::kParameter) {
const HloComputation* computation =
VLOG(3) << " value is parameter value of the body or condition of a "
"while; must share buffer with while value "
<< while_value.ToShortString();
- aliased_buffers.push_back(GetBufferForValue(while_value));
+ aliased_buffers->push_back(GetBufferForValue(while_value));
}
}
}
-
// Value is the root of a while body.
for (const HloPosition& position : value.positions()) {
const HloComputation* computation = position.instruction->parent();
const HloValue& while_value = dataflow_.GetUniqueValueAt(
callsite.instruction(), position.index);
- VLOG(3) << " value is root the body computation of a while; must "
- "share buffer with while value "
+ VLOG(3) << " value @ " << position << " is root of "
+ << callsite.instruction()->name()
+ << "; body root and while value root must share buffer "
+ "among them : "
<< while_value.ToShortString();
- aliased_buffers.push_back(GetBufferForValue(while_value));
+ aliased_buffers->push_back(GetBufferForValue(while_value));
}
}
}
}
-
// Value is the output of the while instruction itself.
if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
VLOG(3) << " value is output of a while instruction";
- aliased_buffers.push_back(GetBufferForValue(value));
+ aliased_buffers->push_back(GetBufferForValue(value));
+ }
+ }
+
+ void ComputeConditionalAliasedBuffers(
+ const HloValue& value, std::vector<BufferNumber>* aliased_buffers) {
+ VLOG(3) << "Compute kConditional aliases";
+ // Aliases the buffers of the true/false computations roots, with the one of
+ // the conditional.
+ for (const HloPosition& position : value.positions()) {
+ const HloComputation* computation = position.instruction->parent();
+ const CallGraphNode& call_graph_node =
+ dataflow_.call_graph().GetNode(computation);
+ if (position.instruction == computation->root_instruction()) {
+ for (const CallSite& callsite : call_graph_node.caller_callsites()) {
+ if (callsite.instruction()->opcode() == HloOpcode::kConditional) {
+ // Call graph must have been flattened.
+ CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
+
+ const HloValue& cond_value = dataflow_.GetUniqueValueAt(
+ callsite.instruction(), position.index);
+ VLOG(3)
+ << " value @ " << position << " is root of "
+ << callsite.instruction()->name()
+ << "; true/false branch roots must share buffer among them : "
+ << cond_value.ToShortString();
+ aliased_buffers->push_back(GetBufferForValue(cond_value));
+ }
+ }
+ }
+ }
+ // Value is the output of the conditional instruction itself.
+ if (value.defining_instruction()->opcode() == HloOpcode::kConditional) {
+ VLOG(3) << " value is output of a conditional instruction";
+ aliased_buffers->push_back(GetBufferForValue(value));
}
+ }
+ // Compute and return a vector of buffers that the given value must be
+ // contained in due to HLO aliasing rules.
+ std::vector<BufferNumber> ComputeAliasedBuffers(const HloValue& value) {
+ for (const HloUse& use : value.uses()) {
+ VLOG(2) << "Use of value " << value.ToShortString() << ": " << use;
+ }
+ std::vector<BufferNumber> aliased_buffers;
+ ComputeWhileAliasedBuffers(value, &aliased_buffers);
+ ComputeConditionalAliasedBuffers(value, &aliased_buffers);
// Uniquify aliased buffers.
std::sort(aliased_buffers.begin(), aliased_buffers.end());
aliased_buffers.erase(
std::unique(aliased_buffers.begin(), aliased_buffers.end()),
aliased_buffers.end());
-
return aliased_buffers;
}
conditional->true_computation()->root_instruction()),
&GetInstructionValueSet(
conditional->false_computation()->root_instruction())};
- // A phi-node is not defined for a kConditional instruction even though it
- // represents a join point. This is because the current approach is to define
- // a phi-node only for kWhile to account for the dataflow through back-edges
- // and deal with the ambiguity in other cases.
- return GetInstructionValueSet(conditional).AssignUnionOf(inputs);
+ if (ssa_form_) {
+ return Phi(conditional, inputs);
+ } else {
+ return GetInstructionValueSet(conditional).AssignUnionOf(inputs);
+ }
}
bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) {
EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
ElementsAre(HloUse{conditional, 2, {}}));
- EXPECT_EQ(analysis.values().size(), 3);
- EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
- EXPECT_THAT(HloValuesAt(conditional),
- UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
- analysis.GetValueDefinedAt(constant2)));
+ bool ssa_form = GetParam();
+ if (ssa_form) {
+ EXPECT_EQ(analysis.values().size(), 4);
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional));
+ } else {
+ EXPECT_EQ(analysis.values().size(), 3);
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
+ EXPECT_THAT(HloValuesAt(conditional),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
+ analysis.GetValueDefinedAt(constant2)));
+ }
}
TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) {
HloUse{true_x, 0, {}}, HloUse{true_y, 0, {}},
HloUse{false_x, 0, {}}, HloUse{false_y, 0, {}}));
- EXPECT_EQ(analysis.values().size(), 6);
- EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
- EXPECT_THAT(HloValuesAt(conditional),
- UnorderedElementsAre(analysis.GetValueDefinedAt(add),
- analysis.GetValueDefinedAt(sub)));
+ bool ssa_form = GetParam();
+ if (ssa_form) {
+ EXPECT_EQ(analysis.values().size(), 7);
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional));
+ } else {
+ EXPECT_EQ(analysis.values().size(), 6);
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
+ EXPECT_THAT(HloValuesAt(conditional),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(add),
+ analysis.GetValueDefinedAt(sub)));
+ }
}
TEST_P(HloDataflowAnalysisTest, NestedConditionals) {
EXPECT_EQ(analysis.GetUniqueValueAt(false_operand_cond),
analysis.GetValueDefinedAt(constant2));
- EXPECT_EQ(analysis.values().size(), 9);
- EXPECT_FALSE(analysis.ValueIsDefinedAt(inner_conditional));
- EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
- EXPECT_THAT(
- HloValuesAt(inner_conditional),
- UnorderedElementsAre(
- analysis.GetValueDefinedAt(computation1->root_instruction()),
- analysis.GetValueDefinedAt(computation2->root_instruction())));
- EXPECT_THAT(
- HloValuesAt(conditional),
- UnorderedElementsAre(
- analysis.GetValueDefinedAt(computation1->root_instruction()),
- analysis.GetValueDefinedAt(computation2->root_instruction()),
- analysis.GetValueDefinedAt(computation3->root_instruction())));
+ bool ssa_form = GetParam();
+ if (ssa_form) {
+ EXPECT_EQ(analysis.values().size(), 11);
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(inner_conditional));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional));
+ } else {
+ EXPECT_EQ(analysis.values().size(), 9);
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(inner_conditional));
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
+ EXPECT_THAT(
+ HloValuesAt(inner_conditional),
+ UnorderedElementsAre(
+ analysis.GetValueDefinedAt(computation1->root_instruction()),
+ analysis.GetValueDefinedAt(computation2->root_instruction())));
+ EXPECT_THAT(
+ HloValuesAt(conditional),
+ UnorderedElementsAre(
+ analysis.GetValueDefinedAt(computation1->root_instruction()),
+ analysis.GetValueDefinedAt(computation2->root_instruction()),
+ analysis.GetValueDefinedAt(computation3->root_instruction())));
+ }
}
INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation,
}
}
+ // If the common ancestor is a conditional instruction, even though the true
+ // and false computations are not really ordered per-se, we define the true
+ // computation to be ordered before the false one.
+ // This ensures that buffers can still be shared among the two computations
+ // as they will forcibly have disjoint liveness.
+ if (a_ancestor == b_ancestor &&
+ a_ancestor->opcode() == HloOpcode::kConditional) {
+ const HloComputation* true_computation = a_ancestor->true_computation();
+ const HloComputation* false_computation = a_ancestor->false_computation();
+ if (call_graph_->InstructionIsNestedIn(a, true_computation) &&
+ call_graph_->InstructionIsNestedIn(b, false_computation)) {
+ return true;
+ }
+ // If 'b' is the conditional ancestor, and 'a' is within the true or false
+ // computations, 'a' executes before 'b'.
+ if (b == a_ancestor &&
+ (call_graph_->InstructionIsNestedIn(a, true_computation) ||
+ call_graph_->InstructionIsNestedIn(a, false_computation))) {
+ return true;
+ }
+ }
+
return ExecutesBeforeInSameComputation(a_ancestor, b_ancestor);
}
b.defining_instruction()->while_condition()))) {
return true;
}
-
+ // If 'b' is a conditional phi and 'a' is in the true or false computation,
+ // then 'a' executes before 'b'.
+ if (b.is_phi() &&
+ b.defining_instruction()->opcode() == HloOpcode::kConditional &&
+ (call_graph_->InstructionIsNestedIn(
+ a.defining_instruction(),
+ b.defining_instruction()->true_computation()) ||
+ call_graph_->InstructionIsNestedIn(
+ a.defining_instruction(),
+ b.defining_instruction()->false_computation()))) {
+ return true;
+ }
return ExecutesBefore(a.defining_instruction(), b.defining_instruction());
}
VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString()
<< ", b = " << b.ToShortString() << ")";
if (!IsDefinedBefore(a, b)) {
- VLOG(4) << "a not defined before b";
+ VLOG(4) << a << " not defined before " << b;
return false;
}
-
// All uses of 'a' must be before 'b' is defined.
for (const HloUse& use : a.uses()) {
if (!UseIsBeforeValueDefinition(use, b, dataflow)) {
- VLOG(4) << "use of a (" << use << ") not before b is defined";
+ VLOG(4) << "use of " << a << " (" << use << ") not before " << b
+ << " is defined";
return false;
}
}
-
return true;
}
ordering.ToString(); // Shouldn't crash.
}
+TEST_F(HloOrderingTest, ConditionalInstructionOrdering) {
+ const char* module_str = R"(
+HloModule test_conditional_module
+
+true_branch {
+ param.1 = (s32[], s32[]) parameter(0)
+ get-tuple-element.1 = s32[] get-tuple-element(param.1), index=0
+ get-tuple-element.2 = s32[] get-tuple-element(param.1), index=1
+ add.1 = s32[] add(get-tuple-element.1, get-tuple-element.2)
+ ROOT tuple.1 = (s32[], s32[]) tuple(add.1, get-tuple-element.1)
+}
+
+false_branch {
+ param.2 = (s32[], s32[]) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(param.2), index=0
+ get-tuple-element.4 = s32[] get-tuple-element(param.2), index=1
+ add.2 = s32[] add(get-tuple-element.3, get-tuple-element.4)
+ ROOT tuple.2 = (s32[], s32[]) tuple(add.2, get-tuple-element.4)
+}
+
+ENTRY root {
+ param.3 = (pred[], (s32[], s32[])) parameter(0)
+ pred.1 = pred[] get-tuple-element(param.3), index=0
+ cond_arg.1 = (s32[], s32[]) get-tuple-element(param.3), index=1
+ conditional = (s32[], s32[]) conditional(pred.1, cond_arg.1, cond_arg.1), true_computation=true_branch, false_computation=false_branch
+ cond_res.1 = s32[] get-tuple-element(conditional), index=0
+ cond_res.2 = s32[] get-tuple-element(conditional), index=1
+ add.3 = s32[] add(cond_res.1, cond_res.2)
+ ROOT result = (s32[], s32[], s32[]) tuple(add.3, cond_res.1, cond_res.2)
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
+ HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
+ DependencyHloOrdering ordering(module.get());
+
+ // Even though the true and false branches has no ordering, since they do not
+ // interfere (as they are mutually exclusive), we define the true computation
+ // to be before the false one.
+ // Similarly, any instruction in the true or false branches are considered
+ // before the conditional instruction. The roots are effectively "at the same
+ // time" WRT the conditional, but they are Phi-ed anyway.
+ HloInstruction* add_1 = FindInstruction(module.get(), "add.1");
+ HloInstruction* add_2 = FindInstruction(module.get(), "add.2");
+ HloInstruction* add_3 = FindInstruction(module.get(), "add.3");
+ HloInstruction* conditional = FindInstruction(module.get(), "conditional");
+ EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1),
+ dataflow->GetValueDefinedAt(add_2)));
+ EXPECT_TRUE(
+ ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_2),
+ dataflow->GetValueDefinedAt(conditional)));
+ EXPECT_TRUE(
+ ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1),
+ dataflow->GetValueDefinedAt(conditional)));
+ EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1),
+ dataflow->GetValueDefinedAt(add_3)));
+ EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_2),
+ dataflow->GetValueDefinedAt(add_3)));
+}
+
} // namespace
} // namespace xla
"only parameter of true_computation"));
}
+XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) {
+ Shape tuple_shape = ShapeUtil::MakeTupleShape({r0f32_, r0f32_});
+ Computation swapper;
+ {
+ ComputationBuilder builder(client_, TestName() + ".swapper");
+ auto param0 = builder.Parameter(0, tuple_shape, "sp0");
+ auto x = builder.GetTupleElement(param0, 0);
+ auto y = builder.GetTupleElement(param0, 1);
+ builder.Tuple({y, x});
+ swapper = builder.Build().ConsumeValueOrDie();
+ }
+ Computation forwarder;
+ {
+ ComputationBuilder builder(client_, TestName() + ".forwarder");
+ auto param0 = builder.Parameter(0, tuple_shape, "fp0");
+ auto x = builder.GetTupleElement(param0, 0);
+ auto y = builder.GetTupleElement(param0, 1);
+ builder.Tuple({x, y});
+ forwarder = builder.Build().ConsumeValueOrDie();
+ }
+ Computation main;
+ {
+ ComputationBuilder builder(client_, TestName() + ".main");
+ auto param0 = builder.Parameter(0, tuple_shape, "mp0");
+ auto x = builder.GetTupleElement(param0, 0);
+ auto y = builder.GetTupleElement(param0, 1);
+ auto lt_pred = builder.Lt(x, y);
+ auto res = builder.Conditional(lt_pred, param0, forwarder, param0, swapper);
+ auto ge_pred = builder.Ge(x, y);
+ builder.Conditional(ge_pred, res, swapper, res, forwarder);
+ main = builder.Build().ConsumeValueOrDie();
+ }
+
+ auto test_swap = [&](float a, float b) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR0<float>(a);
+ auto y = builder.ConstantR0<float>(b);
+ auto tuple_operand = builder.Tuple({x, y});
+ builder.Call(main, {tuple_operand});
+
+ ComputeAndCompareTuple(
+ &builder,
+ *Literal::MakeTuple({Literal::CreateR0<float>(a).get(),
+ Literal::CreateR0<float>(b).get()}),
+ {}, error_spec_);
+ };
+
+ test_swap(3.11f, 9.4f);
+ test_swap(11.24f, 5.55f);
+}
+
} // namespace
} // namespace xla