Always delete old while loop after LICM
authorSanjoy Das <sanjoy@google.com>
Wed, 30 May 2018 23:29:25 +0000 (16:29 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 30 May 2018 23:32:05 +0000 (16:32 -0700)
Right now the old while loop can stick around if it had side effects, which is
incorrect.

PiperOrigin-RevId: 198639203

tensorflow/compiler/xla/service/BUILD
tensorflow/compiler/xla/service/while_util.cc
tensorflow/compiler/xla/service/while_util.h
tensorflow/compiler/xla/service/while_util_test.cc
tensorflow/compiler/xla/util.h

index 4d653a0..cd3d55e 100644 (file)
@@ -2920,6 +2920,7 @@ tf_cc_test(
     deps = [
         ":while_util",
         "//tensorflow/compiler/xla:test",
+        "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla/service:hlo_matchers",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/compiler/xla/tools/parser:hlo_parser",
index ed20b36..473eab2 100644 (file)
@@ -117,9 +117,13 @@ WhileUtil::MakeInstructionsLiveIn(
   HloInstruction* new_while = containing_computation->AddInstruction(
       HloInstruction::CreateWhile(new_while_shape, new_while_condition,
                                   new_while_body, new_while_init));
-  TF_RETURN_IF_ERROR(containing_computation->ReplaceInstruction(
-      while_instr, TupleUtil::ExtractPrefix(
-                       new_while, while_instr->shape().tuple_shapes_size())));
+
+  // We want to get rid of the old while instruction even if it has side
+  // effecting operations so we do a manual HloComputation::RemoveInstruction
+  // instead of relying on HloComputation::ReplaceInstruction.
+  TF_RETURN_IF_ERROR(while_instr->ReplaceAllUsesWith(TupleUtil::ExtractPrefix(
+      new_while, while_instr->shape().tuple_shapes_size())));
+  TF_RETURN_IF_ERROR(containing_computation->RemoveInstruction(while_instr));
 
   HloInstruction* while_body_param = new_while_body->parameter_instruction(0);
   std::vector<HloInstruction*> live_in_instructions;
index 322d27b..e67636d 100644 (file)
@@ -38,17 +38,21 @@ class WhileUtil {
   };
 
   // Replaces `while_instr` with a new while instruction that is equivalent to
-  // `while_instr`, except that it has all of the HLO instructions in
+  // `while_instr` except that it has all of the HLO instructions in
   // `instructions` as live-in, loop invariant values.  These new live in values
   // are represented as new elements appended to the parameter of the while
   // loop, which must be of tuple shape.  GetTupleElement instructions computing
   // each new live in value is returned in the `while_body_live_in_values`
   // vector.
   //
-  // Precondition: `while_instr` must have a tuple shaped state.
+  // Deletes `while_instr` after replacing it.
   //
-  // Every instruction in `instructions` must be contained in the computation
-  // that contains `while_instr`.
+  // Preconditions:
+  //
+  //  `while_instr` must have a tuple shaped state.
+  //
+  //   Every instruction in `instructions` must be contained in the computation
+  //   that contains `while_instr`.
   static StatusOr<MakeInstructionsLiveInResult> MakeInstructionsLiveIn(
       HloInstruction* while_instr,
       tensorflow::gtl::ArraySlice<HloInstruction*> instructions);
index 974bc54..bcc545c 100644 (file)
@@ -18,6 +18,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
 #include "tensorflow/compiler/xla/test.h"
 #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+#include "tensorflow/compiler/xla/util.h"
 
 namespace xla {
 namespace {
@@ -163,5 +164,47 @@ ENTRY main {
   ASSERT_EQ(gte_list.size(), 1);
   EXPECT_EQ((*gte_list.begin())->name(), "gte.0");
 }
+
+TEST(WhileUtilTest, AlwaysRemovePreviousWhileBody) {
+  const char* const hlo_string = R"(
+HloModule WhileWithSideEffects
+
+body {
+  param.b = (s32[], s32[]) parameter(0)
+  gte.0 = s32[] get-tuple-element(param.b), index=0
+  gte.1 = s32[] get-tuple-element(param.b), index=1
+  add = s32[] add(gte.0, gte.1)
+  ROOT tuple = (s32[], s32[]) tuple(gte.0, add)
+}
+
+cond {
+  param.c = (s32[], s32[]) parameter(0)
+  ROOT condition = pred[] infeed()
+}
+
+ENTRY main {
+  init = (s32[], s32[]) parameter(0)
+  to_make_live_in = f32[100] parameter(1)
+  ROOT while = (s32[], s32[]) while(init), condition=cond, body=body
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          tools::Parse(hlo_string));
+
+  HloComputation* main = module->GetComputationWithName("main");
+  HloInstruction* while_instr = main->root_instruction();
+  HloInstruction* to_make_live_in = main->parameter_instruction(1);
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      WhileUtil::MakeInstructionsLiveInResult make_live_in_result,
+      WhileUtil::MakeInstructionsLiveIn(while_instr,
+                                        /*instructions=*/{to_make_live_in}));
+
+  auto is_while = [](const HloInstruction* instr) {
+    return instr->opcode() == HloOpcode::kWhile;
+  };
+  EXPECT_EQ(c_count_if(main->instructions(), is_while), 1);
+}
 }  // namespace
 }  // namespace xla
index 7303640..b4f45cc 100644 (file)
@@ -526,6 +526,13 @@ typename std::decay<T>::type c_accumulate(const Sequence& sequence, T&& init,
                          std::forward<BinaryOp>(binary_op));
 }
 
+template <typename C, typename Pred>
+typename std::iterator_traits<
+    decltype(std::begin(std::declval<C>()))>::difference_type
+c_count_if(const C& c, Pred&& pred) {
+  return std::count_if(std::begin(c), std::end(c), std::forward<Pred>(pred));
+}
+
 template <typename C, typename Value>
 int64 FindIndex(const C& c, Value&& value) {
   auto it = c_find(c, std::forward<Value>(value));