Fix the HLO alias analysis and copy insertion to cope with the new kConditional instr...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 15 Mar 2018 21:27:04 +0000 (14:27 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 15 Mar 2018 21:31:07 +0000 (14:31 -0700)
PiperOrigin-RevId: 189245979

tensorflow/compiler/xla/service/copy_insertion.cc
tensorflow/compiler/xla/service/hlo_alias_analysis.cc
tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
tensorflow/compiler/xla/service/hlo_ordering.cc
tensorflow/compiler/xla/service/hlo_ordering_test.cc
tensorflow/compiler/xla/tests/conditional_test.cc

index e9c974a..40519ec 100644 (file)
@@ -78,8 +78,9 @@ SpecialCaseCopyPolicy GetSpecialCaseCopyPolicy(const CallGraphNode& node,
     policy.copy_root_replicated_buffers = true;
   }
   for (const CallSite& site : node.caller_callsites()) {
-    // The kWhile instruction does not have an handling here, as the
-    // AddCopiesForWhile() API takes care of adding its own copies.
+    // The AddCopiesForConditional() already adds copies, but the copy remover
+    // removes them, so we re-add them by returning the policy here. But really
+    // the copy remover should not be removing them.
     if (site.instruction()->opcode() == HloOpcode::kConditional) {
       policy.copy_parameters_and_constants = true;
       policy.copy_root_replicated_buffers = true;
@@ -321,6 +322,29 @@ Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis,
   return Status::OK();
 }
 
+// We add copies for all the indices of the true and false computaiton roots,
+// in order to resolve interference. We later rely on the CopyRemover to drop
+// the unnecessary ones.
+Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
+                               HloInstruction* conditional) {
+  VLOG(2) << "Adding copies for kConditional instruction "
+          << conditional->name();
+  TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional);
+
+  for (HloComputation* computation :
+       {conditional->true_computation(), conditional->false_computation()}) {
+    HloInstruction* root = computation->root_instruction();
+    std::vector<HloInstruction*> users = root->users();
+    TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
+                        computation->DeepCopyInstruction(root));
+    for (HloInstruction* user : users) {
+      TF_RETURN_IF_ERROR(root->ReplaceUseWith(user, deep_copy));
+    }
+    computation->set_root_instruction(deep_copy);
+  }
+  return Status::OK();
+}
+
 // Removes any control dependencies to or from the given instruction.
 Status StripControlDependenciesFrom(HloInstruction* instruction) {
   while (!instruction->control_successors().empty()) {
@@ -348,6 +372,9 @@ Status AddCopiesToResolveInterference(HloModule* module) {
     for (HloInstruction* instruction : computation->instructions()) {
       if (instruction->opcode() == HloOpcode::kWhile) {
         TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction));
+      } else if (instruction->opcode() == HloOpcode::kConditional) {
+        TF_RETURN_IF_ERROR(
+            AddCopiesForConditional(*alias_analysis, instruction));
       }
     }
   }
@@ -596,6 +623,7 @@ class CopyRemover {
 
       auto is_live_range_before = [this](const ValueNode& a,
                                          const ValueNode& b) {
+        VLOG(3) << "Checking live range of " << *a.value << " WRT " << *b.value;
         if (LiveRangeBefore(a, b)) {
           VLOG(2) << "  Live range of " << a.value->ToShortString()
                   << " is before " << b.value->ToShortString();
@@ -610,7 +638,7 @@ class CopyRemover {
       VLOG(3) << copy->name() << " copies value "
               << src->value->ToShortString();
       VLOG(3) << "Source buffer values: " << ValueListToString(src);
-      VLOG(3) << "Dest buffer values: " << ValueListToString(src);
+      VLOG(3) << "Dest buffer values: " << ValueListToString(dest);
 
       // A kCopy instruction copies an HLO value from a source buffer and
       // defines an HLO value in a destination buffer. Most generally, the
@@ -786,16 +814,16 @@ class CopyRemover {
     // updated as copies are removed.
     bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) {
       if (a.uses.empty()) {
-        VLOG(2) << "Empty uses";
+        VLOG(2) << "Empty uses for " << *a.value;
         return ordering_.IsDefinedBefore(*a.value, *b.value);
       }
       for (const HloUse* use : a.uses) {
-        VLOG(2) << "use: " << *use;
-        VLOG(2) << "is before:" << *b.value;
+        VLOG(2) << "Checking use " << *use << " against " << *b.value;
         if (!ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_)) {
-          VLOG(2) << "Not before";
+          VLOG(2) << "Use " << *use << " is NOT before " << *b.value;
           return false;
         }
+        VLOG(2) << "Use " << *use << " is before " << *b.value;
       }
       return true;
     }
@@ -931,7 +959,6 @@ Status RemoveUnnecessaryCopies(
   CopyRemover copy_remover(*alias_analysis, ordering, module);
   XLA_VLOG_LINES(3, copy_remover.ToString());
 
-  tensorflow::gtl::FlatSet<int> existing_copies;
   for (HloComputation* computation : module->computations()) {
     for (HloInstruction* instruction : computation->instructions()) {
       if (instruction->opcode() == HloOpcode::kCopy &&
@@ -940,7 +967,6 @@ Status RemoveUnnecessaryCopies(
       }
     }
   }
-
   return Status::OK();
 }
 
index 30e32a4..a88283e 100644 (file)
@@ -171,24 +171,21 @@ class BufferValueMap {
     return value_to_buffer_number_.at(&value);
   }
 
-  // Compute and return a vector of buffers that the given value must be
-  // contained in due to HLO aliasing rules.
-  std::vector<BufferNumber> ComputeAliasedBuffers(const HloValue& value) {
+  void ComputeWhileAliasedBuffers(const HloValue& value,
+                                  std::vector<BufferNumber>* aliased_buffers) {
+    VLOG(3) << "Compute kWhile aliases";
     // Value is init of a while (use is while).
-    std::vector<BufferNumber> aliased_buffers;
     for (const HloUse& use : value.uses()) {
-      VLOG(2) << "use of value " << value.ToShortString() << ": " << use;
       if (use.instruction->opcode() == HloOpcode::kWhile) {
         // Determine the while value that this shares a buffer with.
         const HloValue& while_value =
             dataflow_.GetUniqueValueAt(use.instruction, use.operand_index);
-        aliased_buffers.push_back(GetBufferForValue(while_value));
+        aliased_buffers->push_back(GetBufferForValue(while_value));
         VLOG(3) << "  value is init value to a while; must share buffer with "
                    "while value "
                 << while_value.ToShortString();
       }
     }
-
     // Value is a parameter of a while body/condition.
     if (value.defining_instruction()->opcode() == HloOpcode::kParameter) {
       const HloComputation* computation =
@@ -205,11 +202,10 @@ class BufferValueMap {
           VLOG(3) << "  value is parameter value of the body or condition of a "
                      "while; must share buffer with while value "
                   << while_value.ToShortString();
-          aliased_buffers.push_back(GetBufferForValue(while_value));
+          aliased_buffers->push_back(GetBufferForValue(while_value));
         }
       }
     }
-
     // Value is the root of a while body.
     for (const HloPosition& position : value.positions()) {
       const HloComputation* computation = position.instruction->parent();
@@ -224,27 +220,71 @@ class BufferValueMap {
 
             const HloValue& while_value = dataflow_.GetUniqueValueAt(
                 callsite.instruction(), position.index);
-            VLOG(3) << "  value is root the body computation of a while; must "
-                       "share buffer with while value "
+            VLOG(3) << "  value @ " << position << " is root of "
+                    << callsite.instruction()->name()
+                    << "; body root and while value root must share buffer "
+                       "among them : "
                     << while_value.ToShortString();
-            aliased_buffers.push_back(GetBufferForValue(while_value));
+            aliased_buffers->push_back(GetBufferForValue(while_value));
           }
         }
       }
     }
-
     // Value is the output of the while instruction itself.
     if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
       VLOG(3) << "  value is output of a while instruction";
-      aliased_buffers.push_back(GetBufferForValue(value));
+      aliased_buffers->push_back(GetBufferForValue(value));
+    }
+  }
+
+  void ComputeConditionalAliasedBuffers(
+      const HloValue& value, std::vector<BufferNumber>* aliased_buffers) {
+    VLOG(3) << "Compute kConditional aliases";
+    // Aliases the buffers of the true/false computations roots, with the one of
+    // the conditional.
+    for (const HloPosition& position : value.positions()) {
+      const HloComputation* computation = position.instruction->parent();
+      const CallGraphNode& call_graph_node =
+          dataflow_.call_graph().GetNode(computation);
+      if (position.instruction == computation->root_instruction()) {
+        for (const CallSite& callsite : call_graph_node.caller_callsites()) {
+          if (callsite.instruction()->opcode() == HloOpcode::kConditional) {
+            // Call graph must have been flattened.
+            CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
+
+            const HloValue& cond_value = dataflow_.GetUniqueValueAt(
+                callsite.instruction(), position.index);
+            VLOG(3)
+                << "  value @ " << position << " is root of "
+                << callsite.instruction()->name()
+                << "; true/false branch roots must share buffer among them : "
+                << cond_value.ToShortString();
+            aliased_buffers->push_back(GetBufferForValue(cond_value));
+          }
+        }
+      }
+    }
+    // Value is the output of the conditional instruction itself.
+    if (value.defining_instruction()->opcode() == HloOpcode::kConditional) {
+      VLOG(3) << "  value is output of a conditional instruction";
+      aliased_buffers->push_back(GetBufferForValue(value));
     }
+  }
 
+  // Compute and return a vector of buffers that the given value must be
+  // contained in due to HLO aliasing rules.
+  std::vector<BufferNumber> ComputeAliasedBuffers(const HloValue& value) {
+    for (const HloUse& use : value.uses()) {
+      VLOG(2) << "Use of value " << value.ToShortString() << ": " << use;
+    }
+    std::vector<BufferNumber> aliased_buffers;
+    ComputeWhileAliasedBuffers(value, &aliased_buffers);
+    ComputeConditionalAliasedBuffers(value, &aliased_buffers);
     // Uniquify aliased buffers.
     std::sort(aliased_buffers.begin(), aliased_buffers.end());
     aliased_buffers.erase(
         std::unique(aliased_buffers.begin(), aliased_buffers.end()),
         aliased_buffers.end());
-
     return aliased_buffers;
   }
 
index 934e43b..0c37a8d 100644 (file)
@@ -368,11 +368,11 @@ bool HloDataflowAnalysis::UpdateConditionalValueSet(
           conditional->true_computation()->root_instruction()),
       &GetInstructionValueSet(
           conditional->false_computation()->root_instruction())};
-  // A phi-node is not defined for a kConditional instruction even though it
-  // represents a join point. This is because the current approach is to define
-  // a phi-node only for kWhile to account for the dataflow through back-edges
-  // and deal with the ambiguity in other cases.
-  return GetInstructionValueSet(conditional).AssignUnionOf(inputs);
+  if (ssa_form_) {
+    return Phi(conditional, inputs);
+  } else {
+    return GetInstructionValueSet(conditional).AssignUnionOf(inputs);
+  }
 }
 
 bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) {
index 7bf3a1a..07f69b8 100644 (file)
@@ -1602,11 +1602,17 @@ TEST_P(HloDataflowAnalysisTest, ConditionalWithIdentity) {
   EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
               ElementsAre(HloUse{conditional, 2, {}}));
 
-  EXPECT_EQ(analysis.values().size(), 3);
-  EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
-  EXPECT_THAT(HloValuesAt(conditional),
-              UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
-                                   analysis.GetValueDefinedAt(constant2)));
+  bool ssa_form = GetParam();
+  if (ssa_form) {
+    EXPECT_EQ(analysis.values().size(), 4);
+    EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional));
+  } else {
+    EXPECT_EQ(analysis.values().size(), 3);
+    EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
+    EXPECT_THAT(HloValuesAt(conditional),
+                UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
+                                     analysis.GetValueDefinedAt(constant2)));
+  }
 }
 
 TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) {
@@ -1713,11 +1719,17 @@ TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) {
                   HloUse{true_x, 0, {}}, HloUse{true_y, 0, {}},
                   HloUse{false_x, 0, {}}, HloUse{false_y, 0, {}}));
 
-  EXPECT_EQ(analysis.values().size(), 6);
-  EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
-  EXPECT_THAT(HloValuesAt(conditional),
-              UnorderedElementsAre(analysis.GetValueDefinedAt(add),
-                                   analysis.GetValueDefinedAt(sub)));
+  bool ssa_form = GetParam();
+  if (ssa_form) {
+    EXPECT_EQ(analysis.values().size(), 7);
+    EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional));
+  } else {
+    EXPECT_EQ(analysis.values().size(), 6);
+    EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
+    EXPECT_THAT(HloValuesAt(conditional),
+                UnorderedElementsAre(analysis.GetValueDefinedAt(add),
+                                     analysis.GetValueDefinedAt(sub)));
+  }
 }
 
 TEST_P(HloDataflowAnalysisTest, NestedConditionals) {
@@ -1834,20 +1846,27 @@ TEST_P(HloDataflowAnalysisTest, NestedConditionals) {
   EXPECT_EQ(analysis.GetUniqueValueAt(false_operand_cond),
             analysis.GetValueDefinedAt(constant2));
 
-  EXPECT_EQ(analysis.values().size(), 9);
-  EXPECT_FALSE(analysis.ValueIsDefinedAt(inner_conditional));
-  EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
-  EXPECT_THAT(
-      HloValuesAt(inner_conditional),
-      UnorderedElementsAre(
-          analysis.GetValueDefinedAt(computation1->root_instruction()),
-          analysis.GetValueDefinedAt(computation2->root_instruction())));
-  EXPECT_THAT(
-      HloValuesAt(conditional),
-      UnorderedElementsAre(
-          analysis.GetValueDefinedAt(computation1->root_instruction()),
-          analysis.GetValueDefinedAt(computation2->root_instruction()),
-          analysis.GetValueDefinedAt(computation3->root_instruction())));
+  bool ssa_form = GetParam();
+  if (ssa_form) {
+    EXPECT_EQ(analysis.values().size(), 11);
+    EXPECT_TRUE(analysis.ValueIsDefinedAt(inner_conditional));
+    EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional));
+  } else {
+    EXPECT_EQ(analysis.values().size(), 9);
+    EXPECT_FALSE(analysis.ValueIsDefinedAt(inner_conditional));
+    EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
+    EXPECT_THAT(
+        HloValuesAt(inner_conditional),
+        UnorderedElementsAre(
+            analysis.GetValueDefinedAt(computation1->root_instruction()),
+            analysis.GetValueDefinedAt(computation2->root_instruction())));
+    EXPECT_THAT(
+        HloValuesAt(conditional),
+        UnorderedElementsAre(
+            analysis.GetValueDefinedAt(computation1->root_instruction()),
+            analysis.GetValueDefinedAt(computation2->root_instruction()),
+            analysis.GetValueDefinedAt(computation3->root_instruction())));
+  }
 }
 
 INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation,
index 1b24d8d..e89d94b 100644 (file)
@@ -66,6 +66,28 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a,
     }
   }
 
+  // If the common ancestor is a conditional instruction, even though the true
+  // and false computations are not really ordered per-se, we define the true
+  // computation to be ordered before the false one.
+  // This ensures that buffers can still be shared among the two computations
+  // as they will forcibly have disjoint liveness.
+  if (a_ancestor == b_ancestor &&
+      a_ancestor->opcode() == HloOpcode::kConditional) {
+    const HloComputation* true_computation = a_ancestor->true_computation();
+    const HloComputation* false_computation = a_ancestor->false_computation();
+    if (call_graph_->InstructionIsNestedIn(a, true_computation) &&
+        call_graph_->InstructionIsNestedIn(b, false_computation)) {
+      return true;
+    }
+    // If 'b' is the conditional ancestor, and 'a' is within the true or false
+    // computations, 'a' executes before 'b'.
+    if (b == a_ancestor &&
+        (call_graph_->InstructionIsNestedIn(a, true_computation) ||
+         call_graph_->InstructionIsNestedIn(a, false_computation))) {
+      return true;
+    }
+  }
+
   return ExecutesBeforeInSameComputation(a_ancestor, b_ancestor);
 }
 
@@ -118,7 +140,18 @@ bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const {
            b.defining_instruction()->while_condition()))) {
     return true;
   }
-
+  // If 'b' is a conditional phi and 'a' is in the true or false computation,
+  // then 'a' executes before 'b'.
+  if (b.is_phi() &&
+      b.defining_instruction()->opcode() == HloOpcode::kConditional &&
+      (call_graph_->InstructionIsNestedIn(
+           a.defining_instruction(),
+           b.defining_instruction()->true_computation()) ||
+       call_graph_->InstructionIsNestedIn(
+           a.defining_instruction(),
+           b.defining_instruction()->false_computation()))) {
+    return true;
+  }
   return ExecutesBefore(a.defining_instruction(), b.defining_instruction());
 }
 
@@ -212,18 +245,17 @@ bool HloOrdering::LiveRangeStrictlyBefore(
   VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString()
           << ", b = " << b.ToShortString() << ")";
   if (!IsDefinedBefore(a, b)) {
-    VLOG(4) << "a not defined before b";
+    VLOG(4) << a << " not defined before " << b;
     return false;
   }
-
   // All uses of 'a' must be before 'b' is defined.
   for (const HloUse& use : a.uses()) {
     if (!UseIsBeforeValueDefinition(use, b, dataflow)) {
-      VLOG(4) << "use of a (" << use << ") not before b is defined";
+      VLOG(4) << "use of " << a << " (" << use << ") not before " << b
+              << " is defined";
       return false;
     }
   }
-
   return true;
 }
 
index a989fce..441d790 100644 (file)
@@ -362,5 +362,66 @@ ENTRY while.v11 {
   ordering.ToString();  // Shouldn't crash.
 }
 
+TEST_F(HloOrderingTest, ConditionalInstructionOrdering) {
+  const char* module_str = R"(
+HloModule test_conditional_module
+
+true_branch {
+  param.1 = (s32[], s32[]) parameter(0)
+  get-tuple-element.1 = s32[] get-tuple-element(param.1), index=0
+  get-tuple-element.2 = s32[] get-tuple-element(param.1), index=1
+  add.1 = s32[] add(get-tuple-element.1, get-tuple-element.2)
+  ROOT tuple.1 = (s32[], s32[]) tuple(add.1, get-tuple-element.1)
+}
+
+false_branch {
+  param.2 = (s32[], s32[]) parameter(0)
+  get-tuple-element.3 = s32[] get-tuple-element(param.2), index=0
+  get-tuple-element.4 = s32[] get-tuple-element(param.2), index=1
+  add.2 = s32[] add(get-tuple-element.3, get-tuple-element.4)
+  ROOT tuple.2 = (s32[], s32[]) tuple(add.2, get-tuple-element.4)
+}
+
+ENTRY root {
+  param.3 = (pred[], (s32[], s32[])) parameter(0)
+  pred.1 = pred[] get-tuple-element(param.3), index=0
+  cond_arg.1 = (s32[], s32[]) get-tuple-element(param.3), index=1
+  conditional = (s32[], s32[]) conditional(pred.1, cond_arg.1, cond_arg.1), true_computation=true_branch, false_computation=false_branch
+  cond_res.1 = s32[] get-tuple-element(conditional), index=0
+  cond_res.2 = s32[] get-tuple-element(conditional), index=1
+  add.3 = s32[] add(cond_res.1, cond_res.2)
+  ROOT result = (s32[], s32[], s32[]) tuple(add.3, cond_res.1, cond_res.2)
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          tools::Parse(module_str));
+  TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
+                          HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
+  DependencyHloOrdering ordering(module.get());
+
+  // Even though the true and false branches has no ordering, since they do not
+  // interfere (as they are mutually exclusive), we define the true computation
+  // to be before the false one.
+  // Similarly, any instruction in the true or false branches are considered
+  // before the conditional instruction. The roots are effectively "at the same
+  // time" WRT the conditional, but they are Phi-ed anyway.
+  HloInstruction* add_1 = FindInstruction(module.get(), "add.1");
+  HloInstruction* add_2 = FindInstruction(module.get(), "add.2");
+  HloInstruction* add_3 = FindInstruction(module.get(), "add.3");
+  HloInstruction* conditional = FindInstruction(module.get(), "conditional");
+  EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1),
+                                       dataflow->GetValueDefinedAt(add_2)));
+  EXPECT_TRUE(
+      ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_2),
+                               dataflow->GetValueDefinedAt(conditional)));
+  EXPECT_TRUE(
+      ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1),
+                               dataflow->GetValueDefinedAt(conditional)));
+  EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1),
+                                       dataflow->GetValueDefinedAt(add_3)));
+  EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_2),
+                                       dataflow->GetValueDefinedAt(add_3)));
+}
+
 }  // namespace
 }  // namespace xla
index bc82167..b917dee 100644 (file)
@@ -571,5 +571,56 @@ XLA_TEST_F(ConditionalOpTest, ShapeMismatch) {
                                    "only parameter of true_computation"));
 }
 
+XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) {
+  Shape tuple_shape = ShapeUtil::MakeTupleShape({r0f32_, r0f32_});
+  Computation swapper;
+  {
+    ComputationBuilder builder(client_, TestName() + ".swapper");
+    auto param0 = builder.Parameter(0, tuple_shape, "sp0");
+    auto x = builder.GetTupleElement(param0, 0);
+    auto y = builder.GetTupleElement(param0, 1);
+    builder.Tuple({y, x});
+    swapper = builder.Build().ConsumeValueOrDie();
+  }
+  Computation forwarder;
+  {
+    ComputationBuilder builder(client_, TestName() + ".forwarder");
+    auto param0 = builder.Parameter(0, tuple_shape, "fp0");
+    auto x = builder.GetTupleElement(param0, 0);
+    auto y = builder.GetTupleElement(param0, 1);
+    builder.Tuple({x, y});
+    forwarder = builder.Build().ConsumeValueOrDie();
+  }
+  Computation main;
+  {
+    ComputationBuilder builder(client_, TestName() + ".main");
+    auto param0 = builder.Parameter(0, tuple_shape, "mp0");
+    auto x = builder.GetTupleElement(param0, 0);
+    auto y = builder.GetTupleElement(param0, 1);
+    auto lt_pred = builder.Lt(x, y);
+    auto res = builder.Conditional(lt_pred, param0, forwarder, param0, swapper);
+    auto ge_pred = builder.Ge(x, y);
+    builder.Conditional(ge_pred, res, swapper, res, forwarder);
+    main = builder.Build().ConsumeValueOrDie();
+  }
+
+  auto test_swap = [&](float a, float b) {
+    ComputationBuilder builder(client_, TestName());
+    auto x = builder.ConstantR0<float>(a);
+    auto y = builder.ConstantR0<float>(b);
+    auto tuple_operand = builder.Tuple({x, y});
+    builder.Call(main, {tuple_operand});
+
+    ComputeAndCompareTuple(
+        &builder,
+        *Literal::MakeTuple({Literal::CreateR0<float>(a).get(),
+                             Literal::CreateR0<float>(b).get()}),
+        {}, error_spec_);
+  };
+
+  test_swap(3.11f, 9.4f);
+  test_swap(11.24f, 5.55f);
+}
+
 }  // namespace
 }  // namespace xla