deps = [
":while_util",
"//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
HloInstruction* new_while = containing_computation->AddInstruction(
HloInstruction::CreateWhile(new_while_shape, new_while_condition,
new_while_body, new_while_init));
- TF_RETURN_IF_ERROR(containing_computation->ReplaceInstruction(
- while_instr, TupleUtil::ExtractPrefix(
- new_while, while_instr->shape().tuple_shapes_size())));
+
+ // We want to get rid of the old while instruction even if it has side
+ // effecting operations so we do a manual HloComputation::RemoveInstruction
+ // instead of relying on HloComputation::ReplaceInstruction.
+ TF_RETURN_IF_ERROR(while_instr->ReplaceAllUsesWith(TupleUtil::ExtractPrefix(
+ new_while, while_instr->shape().tuple_shapes_size())));
+ TF_RETURN_IF_ERROR(containing_computation->RemoveInstruction(while_instr));
HloInstruction* while_body_param = new_while_body->parameter_instruction(0);
std::vector<HloInstruction*> live_in_instructions;
};
// Replaces `while_instr` with a new while instruction that is equivalent to
- // `while_instr`, except that it has all of the HLO instructions in
+ // `while_instr` except that it has all of the HLO instructions in
// `instructions` as live-in, loop invariant values. These new live in values
// are represented as new elements appended to the parameter of the while
// loop, which must be of tuple shape. GetTupleElement instructions computing
// each new live in value is returned in the `while_body_live_in_values`
// vector.
//
- // Precondition: `while_instr` must have a tuple shaped state.
+ // Deletes `while_instr` after replacing it.
//
- // Every instruction in `instructions` must be contained in the computation
- // that contains `while_instr`.
+ // Preconditions:
+ //
+ // `while_instr` must have a tuple shaped state.
+ //
+ // Every instruction in `instructions` must be contained in the computation
+ // that contains `while_instr`.
static StatusOr<MakeInstructionsLiveInResult> MakeInstructionsLiveIn(
HloInstruction* while_instr,
tensorflow::gtl::ArraySlice<HloInstruction*> instructions);
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+#include "tensorflow/compiler/xla/util.h"
namespace xla {
namespace {
ASSERT_EQ(gte_list.size(), 1);
EXPECT_EQ((*gte_list.begin())->name(), "gte.0");
}
+
+TEST(WhileUtilTest, AlwaysRemovePreviousWhileBody) {
+ const char* const hlo_string = R"(
+HloModule WhileWithSideEffects
+
+body {
+ param.b = (s32[], s32[]) parameter(0)
+ gte.0 = s32[] get-tuple-element(param.b), index=0
+ gte.1 = s32[] get-tuple-element(param.b), index=1
+ add = s32[] add(gte.0, gte.1)
+ ROOT tuple = (s32[], s32[]) tuple(gte.0, add)
+}
+
+cond {
+ param.c = (s32[], s32[]) parameter(0)
+ ROOT condition = pred[] infeed()
+}
+
+ENTRY main {
+ init = (s32[], s32[]) parameter(0)
+ to_make_live_in = f32[100] parameter(1)
+ ROOT while = (s32[], s32[]) while(init), condition=cond, body=body
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(hlo_string));
+
+ HloComputation* main = module->GetComputationWithName("main");
+ HloInstruction* while_instr = main->root_instruction();
+ HloInstruction* to_make_live_in = main->parameter_instruction(1);
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ WhileUtil::MakeInstructionsLiveInResult make_live_in_result,
+ WhileUtil::MakeInstructionsLiveIn(while_instr,
+ /*instructions=*/{to_make_live_in}));
+
+ auto is_while = [](const HloInstruction* instr) {
+ return instr->opcode() == HloOpcode::kWhile;
+ };
+ EXPECT_EQ(c_count_if(main->instructions(), is_while), 1);
+}
} // namespace
} // namespace xla
std::forward<BinaryOp>(binary_op));
}
+template <typename C, typename Pred>
+typename std::iterator_traits<
+ decltype(std::begin(std::declval<C>()))>::difference_type
+c_count_if(const C& c, Pred&& pred) {
+ return std::count_if(std::begin(c), std::end(c), std::forward<Pred>(pred));
+}
+
template <typename C, typename Value>
int64 FindIndex(const C& c, Value&& value) {
auto it = c_find(c, std::forward<Value>(value));