#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
return false;
}
+static StatusOr<bool> TryPropagateConstant(HloInstruction* while_op) {
+ auto while_init = while_op->operand(0);
+ if (while_init->opcode() != HloOpcode::kTuple) {
+ return false;
+ }
+
+ auto while_body = while_op->while_body();
+ auto while_body_root = while_body->root_instruction();
+ if (while_body_root->opcode() != HloOpcode::kTuple) {
+ return false;
+ }
+
+ auto while_body_param = while_body->parameter_instruction(0);
+ const HloInstruction::InstructionVector& root_operands =
+ while_body_root->operands();
+
+ // Find the loop invariant tuple elements with constant init value and
+ // build a map from the tuple element index to the constant value.
+ tensorflow::gtl::FlatMap<int, const HloInstruction*> index_to_constant;
+ for (int i = 0; i < root_operands.size(); i++) {
+ HloInstruction* instr = root_operands[i];
+ if (instr->opcode() == HloOpcode::kGetTupleElement &&
+ instr->tuple_index() == i && instr->operand(0) == while_body_param) {
+ auto tuple_element = while_init->operand(i);
+ if (tuple_element->IsConstant()) {
+ VLOG(3) << "Found loop invariant tuple element " << i << " "
+ << tuple_element->ToString();
+ index_to_constant[i] = tuple_element;
+ }
+ }
+ }
+
+ if (index_to_constant.empty()) {
+ return false;
+ }
+
+ // Replace the use of each constant tuple element in the loop_condition and
+ // loop_body with the corresponding constant value.
+ auto propagate_constant = [&](HloComputation* computation) -> StatusOr<bool> {
+ HloInstruction* param = computation->parameter_instruction(0);
+ bool changed = false;
+ for (auto instr : param->users()) {
+ // Since only a while-loop with a tuple result reaches here, we can safely
+ // assume that `param` is a tuple and the first operand of the
+ // GetTupleElement instruction is a use of `param`.
+ if (instr->opcode() == HloOpcode::kGetTupleElement) {
+ VLOG(3) << "tuple index " << instr->tuple_index() << " "
+ << instr->ToString();
+ auto iter = index_to_constant.find(instr->tuple_index());
+ if (iter != index_to_constant.end()) {
+ const HloInstruction* hlo_constant = (*iter).second;
+ VLOG(3) << "Replace use of " << instr->ToString() << " with "
+ << hlo_constant->ToString();
+ TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(
+ computation->AddInstruction(hlo_constant->Clone())));
+ changed = true;
+ }
+ }
+ }
+ return changed;
+ };
+
+ TF_ASSIGN_OR_RETURN(bool changed_cond,
+ propagate_constant(while_op->while_condition()));
+ TF_ASSIGN_OR_RETURN(bool changed_body, propagate_constant(while_body));
+
+ return changed_cond || changed_body;
+}
+
StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) {
XLA_VLOG_LINES(3,
"WhileLoopSimplifier::Run(), before:\n" + module->ToString());
continue;
}
- StatusOr<bool> result = TryRemoveWhileLoop(while_op);
+ StatusOr<bool> result = TryPropagateConstant(while_op);
+ TF_RETURN_IF_ERROR(result.status());
+ changed |= result.ValueOrDie();
+
+ result = TryRemoveWhileLoop(while_op);
TF_RETURN_IF_ERROR(result.status());
if (result.ValueOrDie()) {
changed = true;
protected:
// Makes an HloModule that contains a loop with `num_iters` iteration.
void MakeModuleWithSimpleLoop(int num_iters);
+
+ // Similar to MakeModuleWithSimpleLoop except that the loop bound is passed to
+ // the loop-condition through an element of a tuple which is the
+ // loop-condition parameter.
+ void MakeModuleWithSimpleLoopTupleElementLoopBound(int num_iters);
};
void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) {
ParseAndVerifyModule(hlo_string.c_str());
}
+void WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound(
+ int num_iters) {
+ string hlo_string_template = R"(
+ HloModule SimpleLoopWithIndirectLoopBound
+ SimpleLoopWithIndirectLoopBound.body {
+ loop_var.1 = (s32[], s32[3]{0}, s32[]) parameter(0)
+ get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
+ constant.1 = s32[] constant(1)
+ add = s32[] add(get-tuple-element.1, constant.1)
+ get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
+ multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
+ limit = s32[] get-tuple-element(loop_var.1), index=2
+ ROOT tuple = (s32[], s32[3]{0}, s32[]) tuple(add, multiply, limit)
+ }
+ SimpleLoopWithIndirectLoopBound.condition {
+ loop_var.2 = (s32[], s32[3]{0}, s32[]) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
+ get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=2
+ ROOT less-than = pred[] less-than(get-tuple-element.3, get-tuple-element.4)
+ }
+ ENTRY SimpleLoopWithIndirectLoopBound {
+ constant.3 = s32[] constant(42)
+ constant.4 = s32[3]{0} constant({0, 1, 2})
+ constant.2 = s32[] constant({{LOOP_BOUND}})
+ tuple.1 = (s32[], s32[3]{0}, s32[]) tuple(constant.3, constant.4,
+ constant.2)
+ ROOT while = (s32[], s32[3]{0}, s32[]) while(tuple.1),
+ condition=SimpleLoopWithIndirectLoopBound.condition,
+ body=SimpleLoopWithIndirectLoopBound.body
+ }
+ )";
+
+ string hlo_string = tensorflow::str_util::StringReplace(
+ hlo_string_template, "{{LOOP_BOUND}}",
+ tensorflow::strings::StrCat(42 + num_iters),
+ /*replace_all=*/true);
+ ParseAndVerifyModule(hlo_string.c_str());
+}
+
TEST_F(WhileLoopSimplifierTest, LoopWithZeroIterationSimiplified) {
MakeModuleWithSimpleLoop(/*num_iters=*/0);
HloModule* the_module = &module();
op::Tuple(op::Constant(), op::Constant()));
}
+TEST_F(WhileLoopSimplifierTest,
+ LoopWithZeroIterationTupleElementLoopBoundSimplified) {
+ MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/0);
+ HloModule* the_module = &module();
+ ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie());
+ EXPECT_THAT(the_module->entry_computation()->root_instruction(),
+ op::Tuple(op::Constant(), op::Constant(), op::Constant()));
+}
+
TEST_F(WhileLoopSimplifierTest, LoopWithOneIterationSimplified) {
MakeModuleWithSimpleLoop(/*num_iters=*/1);
HloModule* the_module = &module();
op::Tuple(op::Add(), op::Multiply()));
}
+TEST_F(WhileLoopSimplifierTest,
+ LoopWithOneIterationTupleELementLoopBoundSimplified) {
+ MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/1);
+ HloModule* the_module = &module();
+ ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie());
+ EXPECT_THAT(the_module->entry_computation()->root_instruction(),
+ op::Tuple(op::Add(), op::Multiply(), op::Constant()));
+}
+
TEST_F(WhileLoopSimplifierTest, LoopWithTwoIterationsNotSimplified) {
MakeModuleWithSimpleLoop(/*num_iters=*/2);
EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
HloModule BodyHasNonTupleRoot
BodyHasNonTupleRoot.passthrough {
ROOT param = (s32[], s32[]) parameter(0)
- get-tuple-element = s32[] get-tuple-element((s32[], s32[]) param), index=1
}
BodyHasNonTupleRoot.always_true {
param.1 = (s32[], s32[]) parameter(0)
EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
}
+TEST_F(WhileLoopSimplifierTest,
+ LoopWithNonTupleBodyRootInstructionNotSimplified) {
+ const string hlo_string = R"(
+ HloModule SimpleLoop
+ SimpleLoop.body {
+ loop_var.1 = (s32[], s32[3]{0}) parameter(0)
+ get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
+ constant.1 = s32[] constant(1)
+ add = s32[] add(get-tuple-element.1, constant.1)
+ get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
+ multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
+ ROOT custom-call = (s32[], s32[3]{0}) custom-call(add, multiply),
+ custom_call_target="x"
+ }
+ SimpleLoop.condition {
+ loop_var.2 = (s32[], s32[3]{0}) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
+ constant.2 = s32[] constant(44)
+ ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
+ }
+ ENTRY SimpleLoop {
+ constant.3 = s32[] constant(42)
+ constant.4 = s32[3]{0} constant({0, 1, 2})
+ tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
+ ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition=
+ SimpleLoop.condition, body=SimpleLoop.body
+ }
+ )";
+
+ ParseAndVerifyModule(hlo_string.c_str());
+ EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
+}
+
} // namespace
} // namespace xla