[XLA]: Enhancement to the while loop simplifier HLO pass.
authorBixia Zheng <bixia@google.com>
Thu, 8 Mar 2018 23:10:36 +0000 (15:10 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 8 Mar 2018 23:14:52 +0000 (15:14 -0800)
If a while-loop tuple element is initialized as a constant and isn't changed by
the while-body, replace the use of the tuple element in while-condition and
while-body with the constant value. This enables the simplification of
while-loops that have 0/1 iteration and loop bound passed in through the
while-loop tuple.

Add test cases for while-loops with 0/1 iteration and loop bound passed in
through the while-loop tuple.

PiperOrigin-RevId: 188397087

tensorflow/compiler/xla/service/while_loop_simplifier.cc
tensorflow/compiler/xla/service/while_loop_simplifier_test.cc

index c9d77c9..1a93a88 100644 (file)
@@ -16,6 +16,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
 #include "tensorflow/compiler/xla/service/call_inliner.h"
 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
 #include "tensorflow/core/lib/gtl/optional.h"
 #include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/lib/strings/strcat.h"
@@ -605,6 +606,75 @@ static StatusOr<bool> TryRemoveWhileLoop(HloInstruction* while_op) {
   return false;
 }
 
+static StatusOr<bool> TryPropagateConstant(HloInstruction* while_op) {
+  auto while_init = while_op->operand(0);
+  if (while_init->opcode() != HloOpcode::kTuple) {
+    return false;
+  }
+
+  auto while_body = while_op->while_body();
+  auto while_body_root = while_body->root_instruction();
+  if (while_body_root->opcode() != HloOpcode::kTuple) {
+    return false;
+  }
+
+  auto while_body_param = while_body->parameter_instruction(0);
+  const HloInstruction::InstructionVector& root_operands =
+      while_body_root->operands();
+
+  // Find the loop invariant tuple elements with constant init value and
+  // build a map from the tuple element index to the constant value.
+  tensorflow::gtl::FlatMap<int, const HloInstruction*> index_to_constant;
+  for (int i = 0; i < root_operands.size(); i++) {
+    HloInstruction* instr = root_operands[i];
+    if (instr->opcode() == HloOpcode::kGetTupleElement &&
+        instr->tuple_index() == i && instr->operand(0) == while_body_param) {
+      auto tuple_element = while_init->operand(i);
+      if (tuple_element->IsConstant()) {
+        VLOG(3) << "Found loop invariant tuple element " << i << " "
+                << tuple_element->ToString();
+        index_to_constant[i] = tuple_element;
+      }
+    }
+  }
+
+  if (index_to_constant.empty()) {
+    return false;
+  }
+
+  // Replace the use of each constant tuple element in the loop_condition and
+  // loop_body with the corresponding constant value.
+  auto propagate_constant = [&](HloComputation* computation) -> StatusOr<bool> {
+    HloInstruction* param = computation->parameter_instruction(0);
+    bool changed = false;
+    for (auto instr : param->users()) {
+      // Since only a while-loop with a tuple result reaches here, we can safely
+      // assume that `param` is a tuple and the first operand of the
+      // GetTupleElement instruction is a use of `param`.
+      if (instr->opcode() == HloOpcode::kGetTupleElement) {
+        VLOG(3) << "tuple index " << instr->tuple_index() << " "
+                << instr->ToString();
+        auto iter = index_to_constant.find(instr->tuple_index());
+        if (iter != index_to_constant.end()) {
+          const HloInstruction* hlo_constant = (*iter).second;
+          VLOG(3) << "Replace use of " << instr->ToString() << " with "
+                  << hlo_constant->ToString();
+          TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(
+              computation->AddInstruction(hlo_constant->Clone())));
+          changed = true;
+        }
+      }
+    }
+    return changed;
+  };
+
+  TF_ASSIGN_OR_RETURN(bool changed_cond,
+                      propagate_constant(while_op->while_condition()));
+  TF_ASSIGN_OR_RETURN(bool changed_body, propagate_constant(while_body));
+
+  return changed_cond || changed_body;
+}
+
 StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) {
   XLA_VLOG_LINES(3,
                  "WhileLoopSimplifier::Run(), before:\n" + module->ToString());
@@ -635,7 +705,11 @@ StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) {
       continue;
     }
 
-    StatusOr<bool> result = TryRemoveWhileLoop(while_op);
+    StatusOr<bool> result = TryPropagateConstant(while_op);
+    TF_RETURN_IF_ERROR(result.status());
+    changed |= result.ValueOrDie();
+
+    result = TryRemoveWhileLoop(while_op);
     TF_RETURN_IF_ERROR(result.status());
     if (result.ValueOrDie()) {
       changed = true;
index cbea3e3..396f942 100644 (file)
@@ -30,6 +30,11 @@ class WhileLoopSimplifierTest : public HloVerifiedTestBase {
  protected:
   // Makes an HloModule that contains a loop with `num_iters` iteration.
   void MakeModuleWithSimpleLoop(int num_iters);
+
+  // Similar to MakeModuleWithSimpleLoop except that the loop bound is passed to
+  // the loop-condition through an element of a tuple which is the
+  // loop-condition parameter.
+  void MakeModuleWithSimpleLoopTupleElementLoopBound(int num_iters);
 };
 
 void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) {
@@ -66,6 +71,45 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) {
   ParseAndVerifyModule(hlo_string.c_str());
 }
 
+void WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound(
+    int num_iters) {
+  string hlo_string_template = R"(
+  HloModule SimpleLoopWithIndirectLoopBound
+  SimpleLoopWithIndirectLoopBound.body {
+    loop_var.1 = (s32[], s32[3]{0}, s32[]) parameter(0)
+    get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
+    constant.1 = s32[] constant(1)
+    add = s32[] add(get-tuple-element.1, constant.1)
+    get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
+    multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
+    limit = s32[] get-tuple-element(loop_var.1), index=2
+    ROOT tuple = (s32[], s32[3]{0}, s32[]) tuple(add, multiply, limit)
+  }
+  SimpleLoopWithIndirectLoopBound.condition {
+    loop_var.2 = (s32[], s32[3]{0}, s32[]) parameter(0)
+    get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
+    get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=2
+    ROOT less-than = pred[] less-than(get-tuple-element.3, get-tuple-element.4)
+  }
+  ENTRY SimpleLoopWithIndirectLoopBound {
+    constant.3 = s32[] constant(42)
+    constant.4 = s32[3]{0} constant({0, 1, 2})
+    constant.2 = s32[] constant({{LOOP_BOUND}})
+    tuple.1 = (s32[], s32[3]{0}, s32[]) tuple(constant.3, constant.4,
+      constant.2)
+    ROOT while = (s32[], s32[3]{0}, s32[]) while(tuple.1),
+      condition=SimpleLoopWithIndirectLoopBound.condition,
+      body=SimpleLoopWithIndirectLoopBound.body
+  }
+  )";
+
+  string hlo_string = tensorflow::str_util::StringReplace(
+      hlo_string_template, "{{LOOP_BOUND}}",
+      tensorflow::strings::StrCat(42 + num_iters),
+      /*replace_all=*/true);
+  ParseAndVerifyModule(hlo_string.c_str());
+}
+
 TEST_F(WhileLoopSimplifierTest, LoopWithZeroIterationSimiplified) {
   MakeModuleWithSimpleLoop(/*num_iters=*/0);
   HloModule* the_module = &module();
@@ -74,6 +118,15 @@ TEST_F(WhileLoopSimplifierTest, LoopWithZeroIterationSimiplified) {
               op::Tuple(op::Constant(), op::Constant()));
 }
 
+TEST_F(WhileLoopSimplifierTest,
+       LoopWithZeroIterationTupleElementLoopBoundSimplified) {
+  MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/0);
+  HloModule* the_module = &module();
+  ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie());
+  EXPECT_THAT(the_module->entry_computation()->root_instruction(),
+              op::Tuple(op::Constant(), op::Constant(), op::Constant()));
+}
+
 TEST_F(WhileLoopSimplifierTest, LoopWithOneIterationSimplified) {
   MakeModuleWithSimpleLoop(/*num_iters=*/1);
   HloModule* the_module = &module();
@@ -82,6 +135,15 @@ TEST_F(WhileLoopSimplifierTest, LoopWithOneIterationSimplified) {
               op::Tuple(op::Add(), op::Multiply()));
 }
 
+TEST_F(WhileLoopSimplifierTest,
+       LoopWithOneIterationTupleELementLoopBoundSimplified) {
+  MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/1);
+  HloModule* the_module = &module();
+  ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie());
+  EXPECT_THAT(the_module->entry_computation()->root_instruction(),
+              op::Tuple(op::Add(), op::Multiply(), op::Constant()));
+}
+
 TEST_F(WhileLoopSimplifierTest, LoopWithTwoIterationsNotSimplified) {
   MakeModuleWithSimpleLoop(/*num_iters=*/2);
   EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
@@ -364,7 +426,6 @@ TEST_F(WhileLoopSimplifierTest, LoopWithNonTupleBodyShapeNotSimplified) {
   HloModule BodyHasNonTupleRoot
   BodyHasNonTupleRoot.passthrough {
     ROOT param = (s32[], s32[]) parameter(0)
-    get-tuple-element = s32[] get-tuple-element((s32[], s32[]) param), index=1
   }
   BodyHasNonTupleRoot.always_true {
     param.1 = (s32[], s32[]) parameter(0)
@@ -382,5 +443,38 @@ TEST_F(WhileLoopSimplifierTest, LoopWithNonTupleBodyShapeNotSimplified) {
   EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
 }
 
+TEST_F(WhileLoopSimplifierTest,
+       LoopWithNonTupleBodyRootInstructionNotSimplified) {
+  const string hlo_string = R"(
+  HloModule SimpleLoop
+  SimpleLoop.body {
+    loop_var.1 = (s32[], s32[3]{0}) parameter(0)
+    get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
+    constant.1 = s32[] constant(1)
+    add = s32[] add(get-tuple-element.1, constant.1)
+    get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
+    multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
+    ROOT custom-call = (s32[], s32[3]{0}) custom-call(add, multiply),
+      custom_call_target="x"
+  }
+  SimpleLoop.condition {
+    loop_var.2 = (s32[], s32[3]{0}) parameter(0)
+    get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
+    constant.2 = s32[] constant(44)
+    ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
+  }
+  ENTRY SimpleLoop {
+    constant.3 = s32[] constant(42)
+    constant.4 = s32[3]{0} constant({0, 1, 2})
+    tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
+    ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition=
+      SimpleLoop.condition, body=SimpleLoop.body
+  }
+  )";
+
+  ParseAndVerifyModule(hlo_string.c_str());
+  EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
+}
+
 }  // namespace
 }  // namespace xla