[XLA] Add reproducer that shows perf issues in HloDataflowAnalysis::UpdateTupleValueS...
authorNick Desaulniers <ndesaulniers@google.com>
Wed, 14 Feb 2018 21:52:18 +0000 (13:52 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 14 Feb 2018 22:03:01 +0000 (14:03 -0800)
HloDataflowAnalysis::UpdateTupleValueSet starts to show up in profiles for while bodies that have many GetTupleElement nodes.

Use a set to keep track of which HloInstructions we need to propagate DFA for.

PiperOrigin-RevId: 185739365

tensorflow/compiler/xla/service/copy_insertion_test.cc
tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc

index 128ee72..153f062 100644 (file)
@@ -1724,8 +1724,58 @@ void BM_ParallelWhiles(int num_iters, int num_whiles) {
   }
 }
 
+std::unique_ptr<HloComputation> MakeBenchmarkWhileBody(
+    const int num_tuple_inputs) {
+  auto builder = HloComputation::Builder("benchmark_loop_body");
+  const Shape element_shape = ShapeUtil::MakeShape(F32, {});
+  std::vector<Shape> input_shape(num_tuple_inputs, element_shape);
+  const Shape loop_state_shape = ShapeUtil::MakeTupleShape(input_shape);
+  HloInstruction* param = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
+  std::vector<HloInstruction*> gte_nodes(num_tuple_inputs);
+  for (int i = 0; i < num_tuple_inputs; ++i) {
+    gte_nodes[i] = builder.AddInstruction(
+        HloInstruction::CreateGetTupleElement(element_shape, param, i));
+  }
+  builder.AddInstruction(HloInstruction::CreateTuple(gte_nodes));
+  return builder.Build();
+}
+
+void BM_ManyElementTuple(int num_iters, const int num_tuple_inputs) {
+  tensorflow::testing::StopTiming();
+  HloModuleConfig config;
+  config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
+  CopyInsertion copy_insertion;
+  const Shape element_shape = ShapeUtil::MakeShape(F32, {});
+  std::vector<HloInstruction*> tuple_params(num_tuple_inputs);
+  for (int i = 0; i < num_iters; ++i) {
+    auto builder = HloComputation::Builder("BM_ParallelWhiles");
+    HloModule module("BM_ManyElementTuple", VersionedComputationHandle(),
+                     config);
+    for (int j = 0; j < num_tuple_inputs; ++j) {
+      tuple_params[j] = builder.AddInstruction(
+          HloInstruction::CreateParameter(j, element_shape, ""));
+    }
+    HloInstruction* init =
+        builder.AddInstruction(HloInstruction::CreateTuple(tuple_params));
+    HloComputation* condition =
+        module.AddEmbeddedComputation(MakeTrivialCondition(init->shape()));
+    HloComputation* body =
+        module.AddEmbeddedComputation(MakeBenchmarkWhileBody(num_tuple_inputs));
+    HloInstruction* xla_while = builder.AddInstruction(
+        HloInstruction::CreateWhile(init->shape(), condition, body, init));
+    builder.AddInstruction(HloInstruction::CreateGetTupleElement(
+        ShapeUtil::MakeShape(F32, {}), xla_while, 0));
+    module.AddEntryComputation(builder.Build());
+    tensorflow::testing::StartTiming();
+    ASSERT_IS_OK(copy_insertion.Run(&module).status());
+    tensorflow::testing::StopTiming();
+  }
+}
+
 BENCHMARK(BM_SequentialWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096);
 BENCHMARK(BM_ParallelWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096);
+BENCHMARK(BM_ManyElementTuple)->Arg(1024)->Arg(12288);
 
 TEST_F(CopyInsertionTest, SimpleControlFlowTest) {
   const string& hlo_string = R"(
index d25fc5d..ccbbe8f 100644 (file)
@@ -585,16 +585,23 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet(
 
 void HloDataflowAnalysis::Propagate() {
   std::queue<HloInstruction*> worklist;
+  tensorflow::gtl::FlatSet<HloInstruction*> workset;
+  auto add_to_worklist = [&worklist, &workset](HloInstruction* instruction) {
+    if (workset.insert(instruction).second) {
+      worklist.push(instruction);
+    }
+  };
 
   for (HloComputation* computation : module_->computations()) {
     for (HloInstruction* instruction : computation->instructions()) {
-      worklist.push(instruction);
+      add_to_worklist(instruction);
     }
   }
 
   while (!worklist.empty()) {
     HloInstruction* instruction = worklist.front();
     worklist.pop();
+    workset.erase(workset.find(instruction));
 
     VLOG(3) << "Worklist top: " << instruction->name();
     VLOG(3) << ToString();
@@ -608,9 +615,10 @@ void HloDataflowAnalysis::Propagate() {
     VLOG(4) << "New value set for " << instruction->name() << ": "
             << GetInstructionValueSet(instruction);
 
-    // Instruction value was updated. Add users to work list.
+    // Instruction value was updated. Add users to work list if we haven't
+    // already.
     for (HloInstruction* user : instruction->users()) {
-      worklist.push(user);
+      add_to_worklist(user);
 
       // If user sequentially calls a computation, then the respective
       // parameter(s) of the computation need to be updated.
@@ -625,10 +633,10 @@ void HloDataflowAnalysis::Propagate() {
         // Note that the same instruction can be used in both operand 1 and
         // operand 2.
         if (user->operand(1) == instruction) {
-          worklist.push(user->true_computation()->parameter_instruction(0));
+          add_to_worklist(user->true_computation()->parameter_instruction(0));
         }
         if (user->operand(2) == instruction) {
-          worklist.push(user->false_computation()->parameter_instruction(0));
+          add_to_worklist(user->false_computation()->parameter_instruction(0));
         }
       } else {
         for (HloComputation* called_computation : user->called_computations()) {
@@ -636,7 +644,7 @@ void HloDataflowAnalysis::Propagate() {
               call_graph_->GetNode(called_computation);
           if (call_graph_node.context() == CallContext::kSequential) {
             for (int64 operand_number : user->OperandIndices(instruction)) {
-              worklist.push(
+              add_to_worklist(
                   called_computation->parameter_instruction(operand_number));
             }
           }
@@ -652,13 +660,13 @@ void HloDataflowAnalysis::Propagate() {
       for (const CallSite& callsite : call_graph_node.caller_callsites()) {
         if ((callsite.instruction()->opcode() == HloOpcode::kCall) ||
             (callsite.instruction()->opcode() == HloOpcode::kConditional)) {
-          worklist.push(callsite.instruction());
+          add_to_worklist(callsite.instruction());
         } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
           // Add the while itself, and the body and condition parameters.
-          worklist.push(callsite.instruction());
-          worklist.push(
+          add_to_worklist(callsite.instruction());
+          add_to_worklist(
               callsite.instruction()->while_body()->parameter_instruction(0));
-          worklist.push(
+          add_to_worklist(
               callsite.instruction()->while_condition()->parameter_instruction(
                   0));
         }