)
cc_library(
+ name = "hlo_liveness_analysis",
+ srcs = ["hlo_liveness_analysis.cc"],
+ hdrs = ["hlo_liveness_analysis.h"],
+ deps = [
+ ":call_graph",
+ ":hlo",
+ ":hlo_value",
+ "//tensorflow/compiler/xla:shape_tree",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_cc_test(
+ name = "hlo_liveness_analysis_test",
+ srcs = ["hlo_liveness_analysis_test.cc"],
+ deps = [
+ ":hlo",
+ ":hlo_liveness_analysis",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/compiler/xla/tools/parser:hlo_parser",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+cc_library(
name = "hlo_buffer",
srcs = ["hlo_buffer.cc"],
hdrs = ["hlo_buffer.h"],
)
cc_library(
+ name = "hlo_module_dce",
+ srcs = ["hlo_module_dce.cc"],
+ hdrs = ["hlo_module_dce.h"],
+ deps = [
+ ":hlo",
+ ":hlo_dce",
+ ":hlo_liveness_analysis",
+ ":hlo_pass",
+ "//tensorflow/compiler/xla:status",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
name = "hlo_verifier",
srcs = ["hlo_verifier.cc"],
hdrs = ["hlo_verifier.h"],
)
tf_cc_test(
+ name = "hlo_module_dce_test",
+ srcs = ["hlo_module_dce_test.cc"],
+ deps = [
+ ":hlo",
+ ":hlo_module_dce",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/compiler/xla/tools/parser:hlo_parser",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+tf_cc_test(
name = "layout_assignment_test",
srcs = ["layout_assignment_test.cc"],
deps = [
return distribution_;
}
-bool HloInstruction::HasSideEffect() const {
+bool HloInstruction::HasSideEffectNoRecurse() const {
switch (opcode_) {
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kTrace:
case HloOpcode::kHostCompute:
return true;
- default: {
- // Check if any of the called computations has a side effect.
- for (const auto& computation : called_computations()) {
- if (computation->HasSideEffect()) {
- return true;
- }
- }
+ default:
return false;
+ }
+}
+
+bool HloInstruction::HasSideEffect() const {
+ if (HasSideEffectNoRecurse()) {
+ return true;
+ }
+ // Check if any of the called computations has a side effect.
+ for (const auto& computation : called_computations()) {
+ if (computation->HasSideEffect()) {
+ return true;
}
}
+ return false;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCall(
// Returns the opcode for this instruction.
HloOpcode opcode() const { return opcode_; }
+ // Returns true if this instruction has a side effect, irrespective of whether
+ // any called computations may contain an instruction with side effects.
+ bool HasSideEffectNoRecurse() const;
+
// Returns true if this instruction has a side effect. An instruction has a
// side effect if it uses certain opcodes or calls a computation with a side
// effect.
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h"
+
+#include <deque>
+
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/call_graph.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+using Worklist = std::deque<const HloInstruction*>;
+using Workset = std::unordered_set<const HloInstruction*>;
+
+namespace {
+
+void AddToWorklist(const HloInstruction* instruction, Worklist* worklist,
+ Workset* workset) {
+ if (workset->count(instruction) == 0) {
+ worklist->push_back(instruction);
+ workset->insert(instruction);
+ VLOG(3) << "ADD instruction: " << instruction->name();
+ }
+}
+
+using VisitorFunction = std::function<void(const ShapeIndex& /*index*/)>;
+
+void ForEachLiveIndex(const ShapeTree<bool>& index_tree,
+ const VisitorFunction& func) {
+ index_tree.ForEachElement([&](const ShapeIndex& shape_index, bool live) {
+ if (live) {
+ func(shape_index);
+ }
+ });
+}
+
+// Marks 'instruction' output live at 'shape_index'.
+// Adds to 'worklist' iff:
+// *) 'instruction' is not already on worklist.
+// *) 'shape_index' has not yet been visited.
+void MarkLiveAtIndex(const HloInstruction* instruction,
+ const ShapeIndex& shape_index,
+ HloLivenessAnalysis::HloIndexMap* live_index_map,
+ Worklist* worklist, Workset* workset) {
+ auto it = live_index_map->find(instruction);
+ if (it == live_index_map->end()) {
+ auto it_added = live_index_map->emplace(
+ std::piecewise_construct, std::forward_as_tuple(instruction),
+ std::forward_as_tuple(instruction->shape(), /*init_value=*/false));
+ it = it_added.first;
+ }
+ if (it->second.element(shape_index) == false) {
+ AddToWorklist(instruction, worklist, workset);
+ *it->second.mutable_element(shape_index) = true;
+ VLOG(3) << "MARK instruction: " << instruction->name()
+ << " shape_index: " << shape_index.ToString();
+ }
+}
+
+// Marks 'instruction' live at all shape indices in its output.
+void MarkLiveAtAllIndices(const HloInstruction* instruction,
+ HloLivenessAnalysis::HloIndexMap* live_index_map,
+ Worklist* worklist, Workset* workset) {
+ bool add_to_worklist = false;
+ auto it = live_index_map->find(instruction);
+ if (it == live_index_map->end()) {
+ live_index_map->emplace(
+ std::piecewise_construct, std::forward_as_tuple(instruction),
+ std::forward_as_tuple(instruction->shape(), /*init_value=*/true));
+ add_to_worklist = true;
+ } else {
+ ShapeUtil::ForEachSubshape(
+ instruction->shape(),
+ [&](const Shape& sub_shape, const ShapeIndex& shape_index) {
+ if (it->second.element(shape_index) == false) {
+ add_to_worklist = true;
+ *it->second.mutable_element(shape_index) = true;
+ VLOG(3) << "MARK instruction: " << instruction->name()
+ << " shape_index: " << shape_index.ToString();
+ }
+ });
+ }
+ if (add_to_worklist) {
+ AddToWorklist(instruction, worklist, workset);
+ }
+}
+
+// Propagates liveness through Tuple instructions.
+// *) For each tuple operand:
+// *) For tuple output shape index associated with operand:
+// *) Propgate live shape indices to tuple operand at the associated
+// shape index in the operands output, and add to worklist.
+void PropagateLivenessThroughTuple(
+ const HloInstruction* instruction,
+ HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist,
+ Workset* workset) {
+ CHECK_EQ(instruction->opcode(), HloOpcode::kTuple);
+ for (int64 operand_index = 0; operand_index < instruction->operand_count();
+ ++operand_index) {
+ const ShapeTree<bool>& index_tree = FindOrDie(*live_index_map, instruction);
+ ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) {
+ if (shape_index.empty() || shape_index[0] != operand_index) {
+ return;
+ }
+ // Mark top-level index of operand at 'operand_index'.
+ MarkLiveAtIndex(instruction->operand(operand_index), {}, live_index_map,
+ worklist, workset);
+ // Mark sub-shape index of operand at 'operand_index'.
+ ShapeIndex operand_shape_index;
+ for (int i = 1; i < shape_index.size(); ++i) {
+ operand_shape_index.push_back(shape_index[i]);
+ }
+ MarkLiveAtIndex(instruction->operand(operand_index), operand_shape_index,
+ live_index_map, worklist, workset);
+ });
+ }
+}
+
+// Propagates liveness through GetTupleElement instructions.
+// *) For each live index in GetTupleElement output, mark output of GTE operand
+// at associated shape index in its output, and add to worklist.
+void PropagateLivenessThroughGTE(
+ const HloInstruction* instruction,
+ HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist,
+ Workset* workset) {
+ CHECK_EQ(instruction->opcode(), HloOpcode::kGetTupleElement);
+ // Mark operand top-level index.
+ MarkLiveAtIndex(instruction->operand(0), {}, live_index_map, worklist,
+ workset);
+ const ShapeTree<bool>& index_tree = FindOrDie(*live_index_map, instruction);
+ // Propagate live shape indices along GTE -> Tuple edge.
+ ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) {
+ ShapeIndex operand_shape_index(shape_index);
+ operand_shape_index.push_front(instruction->tuple_index());
+ MarkLiveAtIndex(instruction->operand(0), operand_shape_index,
+ live_index_map, worklist, workset);
+ });
+}
+
+// Propagates liveness through While instructions.
+// *) For each live index in While output, mark shape index of while.body.root
+// and while.operand (adding each to worklist).
+// *) Mark while.cond.root and add to worklist.
+void PropagateLivenessThroughWhile(
+ const HloInstruction* instruction,
+ HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist,
+ Workset* workset) {
+ CHECK_EQ(instruction->opcode(), HloOpcode::kWhile);
+ const ShapeTree<bool>& index_tree = FindOrDie(*live_index_map, instruction);
+
+ ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) {
+ // Propagate liveness to while body computation root instruction.
+ MarkLiveAtIndex(instruction->while_body()->root_instruction(), shape_index,
+ live_index_map, worklist, workset);
+ // Propagate liveness to tuple-shaped operand.
+ MarkLiveAtIndex(instruction->operand(0), shape_index, live_index_map,
+ worklist, workset);
+ });
+
+ // Propagate liveness to while condition computation root instruction.
+ MarkLiveAtIndex(instruction->while_condition()->root_instruction(), {},
+ live_index_map, worklist, workset);
+}
+
+// Propagates liveness out of Parameter instructions to callers and aliasing
+// positions. This can occur if liveness propagates to a parameter in the
+// while.condition computation, requiring liveness to propagate out to caller
+// callsite while (and while.body.root).
+void PropagateLivenessToParameterCallers(
+ const HloInstruction* instruction,
+ HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist,
+ Workset* workset, CallGraph* call_graph) {
+ CHECK_EQ(instruction->opcode(), HloOpcode::kParameter);
+ const CallGraphNode& call_graph_node =
+ call_graph->GetNode(instruction->parent());
+ if (call_graph_node.context() == CallContext::kSequential) {
+ for (const CallSite& callsite : call_graph_node.caller_callsites()) {
+ if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
+ auto* xla_while = callsite.instruction();
+ const ShapeTree<bool>& index_tree =
+ FindOrDie(*live_index_map, instruction);
+ ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) {
+ // Propagate liveness to while result{shape_index}
+ MarkLiveAtIndex(xla_while, shape_index, live_index_map, worklist,
+ workset);
+ // Propagate liveness to while body root{shape_index}.
+ MarkLiveAtIndex(xla_while->while_body()->root_instruction(),
+ shape_index, live_index_map, worklist, workset);
+ // Propagate liveness to operand(0){shape_index}.
+ MarkLiveAtIndex(xla_while->operand(0), shape_index, live_index_map,
+ worklist, workset);
+ });
+ }
+ }
+ }
+}
+
+} // namespace
+
+HloLivenessAnalysis::HloLivenessAnalysis(const HloModule& module)
+ : module_(module), call_graph_(CallGraph::Build(&module)) {}
+
+// Runs liveness analysis on 'module_'.
+// Initializes worklist with entry root instruction (and any instruction with
+// side-effects), marking all of their output shape indices live.
+// Visits elements on worklist, propagating liveness from an instructions
+// live output shape indices to its called computations and operands.
+void HloLivenessAnalysis::RunAnalysis() {
+ Worklist worklist;
+ Workset workset;
+ // Add entry compuation root instruction.
+ MarkLiveAtAllIndices(module_.entry_computation()->root_instruction(),
+ &live_index_map_, &worklist, &workset);
+ for (auto* computation : module_.computations()) {
+ for (auto* instruction : computation->instructions()) {
+ if (instruction->HasSideEffectNoRecurse()) {
+ // Add instructions with side effects.
+ MarkLiveAtAllIndices(instruction, &live_index_map_, &worklist,
+ &workset);
+ }
+ }
+ }
+
+ while (!worklist.empty()) {
+ const HloInstruction* instruction = worklist.front();
+ worklist.pop_front();
+ workset.erase(workset.find(instruction));
+ VLOG(1) << "VISIT instruction: " << instruction->name();
+
+ if (instruction->opcode() == HloOpcode::kTuple) {
+ PropagateLivenessThroughTuple(instruction, &live_index_map_, &worklist,
+ &workset);
+ } else if (instruction->opcode() == HloOpcode::kGetTupleElement) {
+ PropagateLivenessThroughGTE(instruction, &live_index_map_, &worklist,
+ &workset);
+ } else if (instruction->opcode() == HloOpcode::kWhile &&
+ ShapeUtil::IsTuple(instruction->shape())) {
+ PropagateLivenessThroughWhile(instruction, &live_index_map_, &worklist,
+ &workset);
+ } else if (instruction->opcode() == HloOpcode::kParameter &&
+ ShapeUtil::IsTuple(instruction->shape())) {
+ PropagateLivenessToParameterCallers(instruction, &live_index_map_,
+ &worklist, &workset,
+ call_graph_.get());
+ } else {
+ // Propagate liveness to called computations.
+ for (auto* called_computation : instruction->called_computations()) {
+ MarkLiveAtAllIndices(called_computation->root_instruction(),
+ &live_index_map_, &worklist, &workset);
+ }
+ // Propagate liveness to operands.
+ for (HloInstruction* operand : instruction->operands()) {
+ MarkLiveAtAllIndices(operand, &live_index_map_, &worklist, &workset);
+ }
+ }
+ }
+}
+
+bool HloLivenessAnalysis::IsLive(const HloInstruction* instruction,
+ const ShapeIndex& shape_index) const {
+ if (ContainsKey(live_index_map_, instruction)) {
+ return FindOrDie(live_index_map_, instruction).element(shape_index);
+ }
+ return false;
+}
+
+/* static */
+StatusOr<std::unique_ptr<HloLivenessAnalysis>> HloLivenessAnalysis::Run(
+ const HloModule& module) {
+ VLOG(1) << "HloLivenessAnalysis::Run on module " << module.name();
+ XLA_VLOG_LINES(2, module.ToString());
+
+ auto liveness_analysis = WrapUnique(new HloLivenessAnalysis(module));
+
+ liveness_analysis->RunAnalysis();
+
+ return std::move(liveness_analysis);
+}
+
+} // namespace xla
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_
+
+#include <unordered_map>
+
+#include "tensorflow/compiler/xla/service/call_graph.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_value.h"
+#include "tensorflow/compiler/xla/shape_tree.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/compiler/xla/statusor.h"
+
+namespace xla {
+
+// Analysis which identifies all live {HloInstruction, ShapeIndex} pairs in
+// an HLO module.
+//
+// HloLivenessAnalysis marks the shape index of each live output of each
+// instruction in the module, by propagating live shape index information
+// from an instruction to its called computations and operands.
+class HloLivenessAnalysis {
+ public:
+ // Maps from an HloInstruction to its live/dead output shape indices.
+ using HloIndexMap =
+ std::unordered_map<const HloInstruction*, ShapeTree<bool>>;
+
+ // Runs liveness analysis on 'module'. Returns HloLivenessAnalysis object
+ // which exports liveness for each {HloInstruction, ShapeIndex} in 'module'.
+ static StatusOr<std::unique_ptr<HloLivenessAnalysis>> Run(
+ const HloModule& module);
+
+ // Returns true if output of 'instruction' at 'shape_index' is live.
+ // Returns false otherwise.
+ bool IsLive(const HloInstruction* instruction,
+ const ShapeIndex& shape_index) const;
+
+ private:
+ HloLivenessAnalysis(const HloModule& module);
+
+ void RunAnalysis();
+
+ const HloModule& module_;
+ std::unique_ptr<CallGraph> call_graph_;
+ HloIndexMap live_index_map_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h"
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class HloLivenessAnalysisTest : public HloTestBase {
+ protected:
+ HloLivenessAnalysisTest() {}
+
+ // Run liveness analysis on the member module. For convenience returns a
+ // reference to the generated analysis stored in analysis_.
+ const HloLivenessAnalysis& RunLiveness(HloModule* module) {
+ liveness_ = HloLivenessAnalysis::Run(*module).ConsumeValueOrDie();
+ return *liveness_;
+ }
+
+ HloInstruction* GetInstruction(HloModule* module, const string& name) {
+ HloInstruction* to_return = nullptr;
+ for (auto* comp : module->computations()) {
+ for (auto* inst : comp->instructions()) {
+ if (inst->name() == name) {
+ to_return = inst;
+ break;
+ }
+ }
+ }
+ return CHECK_NOTNULL(to_return);
+ }
+
+ std::unique_ptr<HloLivenessAnalysis> liveness_;
+};
+
+// Test that add instruction at entry root is live at all output shape indices.
+TEST_F(HloLivenessAnalysisTest, AddAtEntryRoot) {
+ auto module = tools::Parse(R"(
+ HloModule SimpleModule
+ ENTRY SimpleComputation {
+ constant.1 = s32[] constant(0)
+ constant.2 = s32[] constant(1)
+ ROOT add = s32[] add(constant.1, constant.2)
+ })")
+ .ValueOrDie();
+ const HloLivenessAnalysis& liveness = RunLiveness(module.get());
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {}));
+}
+
+// Test that a dead add instruction is marked as dead by analysis.
+TEST_F(HloLivenessAnalysisTest, DeadAdd) {
+ auto module = tools::Parse(R"(
+ HloModule SimpleModule
+ ENTRY SimpleComputation {
+ constant.1 = s32[] constant(0)
+ constant.2 = s32[] constant(1)
+ add.1 = s32[] add(constant.1, constant.2)
+ ROOT add.2 = s32[] add(constant.1, constant.2)
+ })")
+ .ValueOrDie();
+ const HloLivenessAnalysis& liveness = RunLiveness(module.get());
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.2"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {}));
+ EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "add.1"), {}));
+}
+
+// Test that all output shape indices of entry root tuple (and defining
+// instruction in its output) are marked live.
+TEST_F(HloLivenessAnalysisTest, TupleAtEntryRoot) {
+ auto module = tools::Parse(R"(
+ HloModule SimpleModule
+ ENTRY SimpleComputation {
+ constant.1 = s32[] constant(0)
+ constant.2 = s32[] constant(1)
+ ROOT tuple.1 = (s32[], s32[]) tuple(constant.1, constant.2)
+ })")
+ .ValueOrDie();
+ const HloLivenessAnalysis& liveness = RunLiveness(module.get());
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {}));
+}
+
+// Tests that all outputs of nested tuple and entry root (and defining
+// instruction values appearing in its output) are marked live.
+TEST_F(HloLivenessAnalysisTest, NestedTupleAtEntryRoot) {
+ auto module = tools::Parse(R"(
+ HloModule SimpleModule
+ ENTRY SimpleComputation {
+ constant.1 = s32[] constant(1)
+ constant.2 = s32[] constant(2)
+ constant.3 = s32[] constant(3)
+ tuple.1 = (s32[], s32[]) tuple(constant.2, constant.3)
+ ROOT tuple.2 = (s32[], s32[]) tuple(constant.1, tuple.1)
+ })")
+ .ValueOrDie();
+ const HloLivenessAnalysis& liveness = RunLiveness(module.get());
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 0}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 1}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {}));
+}
+
+// Tests that GTE at entry root of Tuple instruction only propgates liveness
+// to the live elements in tuple.
+TEST_F(HloLivenessAnalysisTest, GteOfTuple) {
+ auto module = tools::Parse(R"(
+ HloModule SimpleModule
+ ENTRY SimpleComputation {
+ constant.1 = s32[] constant(0)
+ constant.2 = s32[] constant(1)
+ tuple.1 = (s32[], s32[]) tuple(constant.1, constant.2)
+ ROOT get-tuple-element.1 = s32[] get-tuple-element(tuple.1), index=0
+ })")
+ .ValueOrDie();
+ const HloLivenessAnalysis& liveness = RunLiveness(module.get());
+ EXPECT_TRUE(
+ liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.1"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0}));
+ EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {}));
+ EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {}));
+}
+
+// Tests that GTE at entry root of nested Tuple instruction only propgates
+// liveness to the live elements in tuple.
+TEST_F(HloLivenessAnalysisTest, GteOfNestedTuple) {
+ auto module = tools::Parse(R"(
+ HloModule SimpleModule
+ ENTRY SimpleComputation {
+ constant.1 = s32[] constant(0)
+ constant.2 = s32[] constant(1)
+ constant.3 = s32[] constant(2)
+ tuple.1 = (s32[], s32[]) tuple(constant.2, constant.3)
+ tuple.2 = (s32[], s32[]) tuple(constant.1, tuple.1)
+ ROOT get-tuple-element.1 = (s32[], s32[]) get-tuple-element(tuple.2), index=1
+ })")
+ .ValueOrDie();
+ const HloLivenessAnalysis& liveness = RunLiveness(module.get());
+ EXPECT_TRUE(
+ liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.1"), {}));
+ EXPECT_TRUE(liveness.IsLive(
+ GetInstruction(module.get(), "get-tuple-element.1"), {0}));
+ EXPECT_TRUE(liveness.IsLive(
+ GetInstruction(module.get(), "get-tuple-element.1"), {1}));
+
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {}));
+ EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 0}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 1}));
+
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1}));
+
+ EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {}));
+}
+
+// Tests that GTE of GTE (at entry root) of nested Tuple instruction only
+// propgates liveness to the live elements in tuple.
+TEST_F(HloLivenessAnalysisTest, GteOfGteOfNestedTuple) {
+ auto module = tools::Parse(R"(
+ HloModule SimpleModule
+ ENTRY SimpleComputation {
+ constant.1 = s32[] constant(0)
+ constant.2 = s32[] constant(1)
+ constant.3 = s32[] constant(2)
+ tuple.1 = (s32[], s32[]) tuple(constant.2, constant.3)
+ tuple.2 = (s32[], s32[]) tuple(constant.1, tuple.1)
+ get-tuple-element.1 = (s32[], s32[]) get-tuple-element(tuple.2), index=1
+ ROOT get-tuple-element.2 = s32[] get-tuple-element(get-tuple-element.1), index=0
+ })")
+ .ValueOrDie();
+ const HloLivenessAnalysis& liveness = RunLiveness(module.get());
+ EXPECT_TRUE(
+ liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.2"), {}));
+
+ EXPECT_TRUE(
+ liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.1"), {}));
+ EXPECT_TRUE(liveness.IsLive(
+ GetInstruction(module.get(), "get-tuple-element.1"), {0}));
+ EXPECT_FALSE(liveness.IsLive(
+ GetInstruction(module.get(), "get-tuple-element.1"), {1}));
+
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {}));
+ EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 0}));
+ EXPECT_FALSE(
+ liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 1}));
+
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0}));
+ EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1}));
+
+ EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {}));
+ EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {}));
+}
+
+// Test that live/dead while tuple elements are marked live/dead correctly.
+TEST_F(HloLivenessAnalysisTest, WhileWithDeadTupleElement) {
+ auto module = tools::Parse(R"(
+ HloModule SimpleLoop
+ SimpleLoop.body {
+ loop_var.1 = (s32[], s32[3]{0}) parameter(0)
+ get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
+ constant.1 = s32[] constant(1)
+ add.0 = s32[] add(get-tuple-element.1, constant.1)
+ get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
+ multiply.0 = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
+ ROOT tuple.0 = (s32[], s32[3]{0}) tuple(add.0, multiply.0)
+ }
+ SimpleLoop.condition {
+ loop_var.2 = (s32[], s32[3]{0}) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
+ constant.2 = s32[] constant(5)
+ ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
+ }
+ ENTRY SimpleLoop {
+ constant.3 = s32[] constant(0)
+ constant.4 = s32[3]{0} constant({0, 1, 2})
+ tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
+ while.0 = (s32[], s32[3]{0}) while(tuple.1), condition=
+ SimpleLoop.condition, body=SimpleLoop.body
+ ROOT get-tuple-element.4 = s32[] get-tuple-element(while.0), index=0
+ })")
+ .ValueOrDie();
+ const HloLivenessAnalysis& liveness = RunLiveness(module.get());
+ EXPECT_TRUE(
+ liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.4"), {}));
+
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {0}));
+ EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {1}));
+
+ // While operand.
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0}));
+ EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {}));
+
+ // While body.
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {0}));
+ EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {1}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.0"), {}));
+ EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "multiply.0"), {}));
+}
+
+// Tests that a tuple element live in while.cond computation, propagates
+// liveness to while.body.root/while.result/while.operand (where it is unused).
+TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) {
+ auto module = tools::Parse(R"(
+ HloModule SimpleLoop
+ SimpleLoop.body {
+ loop_var.1 = (s32[], s32[3]{0}) parameter(0)
+ get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
+ constant.1 = s32[] constant(1)
+ add.0 = s32[] add(get-tuple-element.1, constant.1)
+ get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
+ multiply.0 = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
+ ROOT tuple.0 = (s32[], s32[3]{0}) tuple(add.0, multiply.0)
+ }
+ SimpleLoop.condition {
+ loop_var.2 = (s32[], s32[3]{0}) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
+ get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=1
+ add.1 = s32[] add(get-tuple-element.3, get-tuple-element.4)
+ constant.2 = s32[] constant(5)
+ ROOT less-than = pred[] less-than(add.1, constant.2)
+ }
+ ENTRY SimpleLoop {
+ constant.3 = s32[] constant(0)
+ constant.4 = s32[3]{0} constant({0, 1, 2})
+ tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
+ while.0 = (s32[], s32[3]{0}) while(tuple.1), condition=
+ SimpleLoop.condition, body=SimpleLoop.body
+ ROOT get-tuple-element.5 = s32[] get-tuple-element(while.0), index=0
+ })")
+ .ValueOrDie();
+ const HloLivenessAnalysis& liveness = RunLiveness(module.get());
+ EXPECT_TRUE(
+ liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.5"), {}));
+
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {0}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {1}));
+
+ // While operand.
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.4"), {}));
+
+ // While body.
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {0}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {1}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.0"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "multiply.0"), {}));
+}
+
+// Tests that a use of while.result{0} propagates liveness to
+// while.body.param{1} to while.body.root{1}, and then to while.body.param{2}.
+TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) {
+ auto module = tools::Parse(R"(
+ HloModule SimpleLoop
+ SimpleLoop.body {
+ loop_var.1 = (s32[], s32[], s32[]) parameter(0)
+ get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
+ get-tuple-element.2 = s32[] get-tuple-element(loop_var.1), index=1
+ add.1 = s32[] add(get-tuple-element.1, get-tuple-element.2)
+ get-tuple-element.3 = s32[] get-tuple-element(loop_var.1), index=2
+ multiply.1 = s32[] multiply(get-tuple-element.3, get-tuple-element.3)
+ ROOT tuple.1 = (s32[], s32[], s32[]) tuple(add.1, get-tuple-element.3, multiply.1)
+ }
+ SimpleLoop.condition {
+ loop_var.2 = (s32[], s32[], s32[]) parameter(0)
+ get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=0
+ constant.1 = s32[] constant(5)
+ ROOT less-than = pred[] less-than(get-tuple-element.4, constant.1)
+ }
+ ENTRY SimpleLoop {
+ constant.2 = s32[] constant(0)
+ constant.3 = s32[] constant(1)
+ constant.4 = s32[] constant(2)
+ tuple.2 = (s32[], s32[], s32[]) tuple(constant.2, constant.3, constant.4)
+ while.1 = (s32[], s32[], s32[]) while(tuple.2), condition=
+ SimpleLoop.condition, body=SimpleLoop.body
+ ROOT get-tuple-element.5 = s32[] get-tuple-element(while.1), index=0
+ })")
+ .ValueOrDie();
+
+ const HloLivenessAnalysis& liveness = RunLiveness(module.get());
+ EXPECT_TRUE(
+ liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.5"), {}));
+
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {0}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {1}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {2}));
+ // While operand.
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {2}));
+ // While body root.
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {2}));
+ // While body param.
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {0}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {1}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {2}));
+}
+
+} // namespace
+} // namespace xla
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_module_dce.h"
+
+#include <deque>
+#include <unordered_set>
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_dce.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+namespace {
+
+bool HasSendRecv(HloComputation* computation) {
+ for (auto* instruction : computation->instructions()) {
+ if (instruction->opcode() == HloOpcode::kSend ||
+ instruction->opcode() == HloOpcode::kSendDone ||
+ instruction->opcode() == HloOpcode::kRecv ||
+ instruction->opcode() == HloOpcode::kRecvDone) {
+ return true;
+ }
+ for (auto* sub_computation : instruction->called_computations()) {
+ if (HasSendRecv(sub_computation)) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) {
+ bool changed = false;
+ for (auto* computation : module->computations()) {
+ for (auto* instruction : computation->instructions()) {
+ if (instruction->opcode() != HloOpcode::kWhile) {
+ continue;
+ }
+
+ const auto* xla_while = instruction;
+ auto* while_body_comp = xla_while->while_body();
+ auto* while_body_param = while_body_comp->parameter_instruction(0);
+ auto* while_body_root = while_body_comp->root_instruction();
+
+ if (!ShapeUtil::IsTuple(xla_while->shape()) ||
+ while_body_root->opcode() != HloOpcode::kTuple ||
+ HasSendRecv(while_body_comp)) {
+ // Only run DCE on tuple-shaped while loops where body root is Tuple,
+ // with no send/recv instructions.
+ VLOG(1) << "WhileDCE SKIP while: " << xla_while->ToString();
+ continue;
+ }
+
+ // Remove dead tuple elements.
+ const int64 tuple_element_count =
+ ShapeUtil::TupleElementCount(xla_while->shape());
+ for (int64 i = 0; i < tuple_element_count; ++i) {
+ if (liveness->IsLive(xla_while, {i})) {
+ continue;
+ }
+ VLOG(1) << "WhileDCE Dead while tuple element."
+ << " while: " << xla_while->name() << " tuple_index: " << i;
+ // Transform while.body computation to make tuple element at
+ // 'shape_index' as simple pass-through parameter (which candidate
+ // be removed later by simplification pass).
+ HloInstruction* pass_thru_gte = while_body_comp->AddInstruction(
+ HloInstruction::CreateGetTupleElement(
+ while_body_param->shape().tuple_shapes(i), while_body_param,
+ i));
+ // Replace while.body.root Tuple operand at 'tuple_index' with
+ // 'pass_thru_gte', making prior operand a dead root (to be cleaned
+ // up with a subsequent DCE pass).
+ TF_RETURN_IF_ERROR(
+ while_body_root->ReplaceOperandWith(i, pass_thru_gte));
+ changed = true;
+ }
+ }
+ }
+ return changed;
+}
+
+} // namespace
+
+StatusOr<bool> HloModuleDCE::Run(HloModule* module) {
+ VLOG(2) << "Before HloModuleDCE:";
+ XLA_VLOG_LINES(3, module->ToString());
+
+ std::unique_ptr<HloLivenessAnalysis> liveness;
+ TF_ASSIGN_OR_RETURN(liveness, HloLivenessAnalysis::Run(*module));
+
+ // Sweep through while instructions, transforming dead while tuple element
+ // computations to pass through tuple values (creating dead roots in while
+ // body computation in the process).
+ TF_ASSIGN_OR_RETURN(bool hlo_module_dce_changed,
+ RunWhileDCE(module, liveness.get()));
+
+ // Run HloDCE to clean up any dead code created during HloModuleDCE.
+ HloDCE hlo_dce;
+ TF_ASSIGN_OR_RETURN(bool hlo_dce_changed, hlo_dce.Run(module));
+
+ VLOG(2) << "After HloModuleDCE:";
+ XLA_VLOG_LINES(3, module->ToString());
+
+ return hlo_module_dce_changed | hlo_dce_changed;
+}
+
+} // namespace xla
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_DCE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_DCE_H_
+
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+#include "tensorflow/compiler/xla/statusor.h"
+
+namespace xla {
+
+// HLO pass which removes dead code from computations in the module using
+// HloModule-scoped analysis (HloLivenessAnalysis).
+//
+// Sweeps through live instructions which cross computation boundaries (kWhile),
+// and removes code at dead shape indices.
+//
+class HloModuleDCE : public HloPassInterface {
+ public:
+ ~HloModuleDCE() override {}
+ tensorflow::StringPiece name() const override { return "hlo-module-dce"; }
+
+ // Run the pass on the given module. Returns whether the module was changed
+ // (instructions were removed).
+ StatusOr<bool> Run(HloModule* module) override;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_DCE_H_
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_module_dce.h"
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class HloModuleDceTest : public HloTestBase {
+ protected:
+ HloModuleDceTest() {}
+
+ // Returns whether the given instruction exists in the given computation.
+ bool HasInstruction(const HloComputation& computation,
+ const HloInstruction* instruction) {
+ return std::find(computation.instructions().begin(),
+ computation.instructions().end(),
+ instruction) != computation.instructions().end();
+ }
+
+ // Returns whether the while instruction with name 'while_name' in
+ // 'computation' passes through its tuple element at 'tuple_index' from
+ // parameter to root instruction.
+ bool WhileBodyHasPassThroughTupleElement(const HloComputation* computation,
+ const string& while_name,
+ const int64 tuple_index) {
+ for (auto* instruction : computation->instructions()) {
+ if (instruction->opcode() == HloOpcode::kWhile &&
+ instruction->name() == while_name) {
+ auto* while_body_comp = instruction->while_body();
+ auto* while_body_param = while_body_comp->parameter_instruction(0);
+ auto* while_body_root = while_body_comp->root_instruction();
+ if (while_body_root->opcode() != HloOpcode::kTuple) {
+ return false;
+ }
+ auto* operand = while_body_root->operand(tuple_index);
+ if (operand->opcode() == HloOpcode::kGetTupleElement &&
+ operand->tuple_index() == tuple_index &&
+ operand->operand(0) == while_body_param) {
+ return true;
+ }
+ return false;
+ }
+ }
+ return false;
+ }
+};
+
+// Tests that a while with all outputs live is unmodified.
+TEST_F(HloModuleDceTest, WhileWithLiveOutputs) {
+ auto module = tools::Parse(R"(
+ HloModule SimpleLoop
+ SimpleLoop.body {
+ loop_var.1 = (s32[], s32[3]{0}) parameter(0)
+ get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
+ constant.1 = s32[] constant(1)
+ add = s32[] add(get-tuple-element.1, constant.1)
+ get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
+ multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
+ ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
+ }
+ SimpleLoop.condition {
+ loop_var.2 = (s32[], s32[3]{0}) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
+ constant.2 = s32[] constant(5)
+ ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
+ }
+ ENTRY SimpleLoop {
+ constant.3 = s32[] constant(0)
+ constant.4 = s32[3]{0} constant({0, 1, 2})
+ tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
+ ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition=
+ SimpleLoop.condition, body=SimpleLoop.body
+ })")
+ .ValueOrDie();
+
+ HloModuleDCE dce;
+ EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while", 0));
+ EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while", 1));
+}
+
+// Tests a while loop with one unused output (which is used in the while loop
+// body by an instruction with side-effects: rng) is unmodified.
+TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) {
+ auto module = tools::Parse(R"(
+ HloModule SimpleLoop
+ SimpleLoop.body {
+ loop_var.1 = (s32[], f32[]) parameter(0)
+ get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
+ constant.1 = s32[] constant(1)
+ add = s32[] add(get-tuple-element.1, constant.1)
+ get-tuple-element.2 = f32[] get-tuple-element(loop_var.1), index=1
+ constant.2 = f32[] constant(1.0)
+ rng = f32[] rng(constant.2, get-tuple-element.2), distribution=rng_uniform
+ add.1 = s32[] add(get-tuple-element.2, constant.2)
+ ROOT tuple = (s32[], f32[]) tuple(add, add.1)
+ }
+ SimpleLoop.condition {
+ loop_var.2 = (s32[], f32[]) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
+ constant.3 = s32[] constant(5)
+ ROOT less-than = pred[] less-than(get-tuple-element.3, constant.3)
+ }
+ ENTRY SimpleLoop {
+ constant.4 = s32[] constant(0)
+ constant.5 = f32[] constant(0.0)
+ tuple.1 = (s32[], f32[]) tuple(constant.4, constant.5)
+ while = (s32[], f32[]) while(tuple.1), condition=
+ SimpleLoop.condition, body=SimpleLoop.body
+ ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0
+ })")
+ .ValueOrDie();
+
+ HloModuleDCE dce;
+ EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while", 0));
+ EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while", 1));
+}
+
+// Tests that a while loop with one dead tuple element at {1} has its while
+// loop body modified to make that tuple element pass-through the while body.
+TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) {
+ auto module = tools::Parse(R"(
+ HloModule SimpleLoop
+ SimpleLoop.body {
+ loop_var.1 = (s32[], s32[3]{0}) parameter(0)
+ get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
+ constant.1 = s32[] constant(1)
+ add = s32[] add(get-tuple-element.1, constant.1)
+ get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
+ multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
+ ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
+ }
+ SimpleLoop.condition {
+ loop_var.2 = (s32[], s32[3]{0}) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
+ constant.2 = s32[] constant(5)
+ ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
+ }
+ ENTRY SimpleLoop {
+ constant.3 = s32[] constant(0)
+ constant.4 = s32[3]{0} constant({0, 1, 2})
+ tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
+ while = (s32[], s32[3]{0}) while(tuple.1), condition=
+ SimpleLoop.condition, body=SimpleLoop.body
+ ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0
+ })")
+ .ValueOrDie();
+
+ HloModuleDCE dce;
+ // While tuple element {1} should not be pass-through before ModuleDCE.
+ EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while", 1));
+ EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while", 0));
+ // While tuple element {1} should now be pass-through after ModuleDCE.
+ EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while", 1));
+}
+
+// Tests that a tuple element {1} used by condition computation (which appears
+// dead in while.body{1} and at while.result{1}) propgates liveness of this
+// tuple element to while.body{1} and at while.result{1}.
+TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) {
+ auto module = tools::Parse(R"(
+ HloModule SimpleLoop
+ SimpleLoop.body {
+ loop_var.1 = (s32[], s32[]) parameter(0)
+ get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
+ constant.1 = s32[] constant(1)
+ add = s32[] add(get-tuple-element.1, constant.1)
+ get-tuple-element.2 = s32[] get-tuple-element(loop_var.1), index=1
+ multiply = s32[] multiply(get-tuple-element.2, get-tuple-element.2)
+ ROOT tuple = (s32[], s32[]) tuple(add, multiply)
+ }
+ SimpleLoop.condition {
+ loop_var.2 = (s32[], s32[]) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1
+ constant.2 = s32[] constant(5)
+ ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
+ }
+ ENTRY SimpleLoop {
+ constant.3 = s32[] constant(0)
+ constant.4 = s32[] constant(0)
+ tuple.1 = (s32[], s32[]) tuple(constant.3, constant.4)
+ while = (s32[], s32[]) while(tuple.1), condition=
+ SimpleLoop.condition, body=SimpleLoop.body
+ ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0
+ })")
+ .ValueOrDie();
+
+ HloModuleDCE dce;
+ // While tuple element {1} should not be pass-through before ModuleDCE.
+ EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while", 1));
+ EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while", 0));
+ // While tuple element {1} still be pass-through after ModuleDCE.
+ EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while", 1));
+}
+
+// Tests that HloModuleDCE can remove a dead tuple element at index {1} between
+// two dependent while loops.
+TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) {
+ auto module = tools::Parse(R"(
+ HloModule SimpleLoop
+ SimpleLoop.body0 {
+ loop_var.1 = (s32[], s32[3]{0}) parameter(0)
+ get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
+ constant.1 = s32[] constant(1)
+ add = s32[] add(get-tuple-element.1, constant.1)
+ get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
+ multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
+ ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
+ }
+ SimpleLoop.condition0 {
+ loop_var.2 = (s32[], s32[3]{0}) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
+ constant.2 = s32[] constant(5)
+ ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
+ }
+ SimpleLoop.body1 {
+ loop_var.3 = (s32[], s32[3]{0}) parameter(0)
+ get-tuple-element.4 = s32[] get-tuple-element(loop_var.3), index=0
+ constant.3 = s32[] constant(1)
+ add.1 = s32[] add(get-tuple-element.4, constant.3)
+ get-tuple-element.5 = s32[3]{0} get-tuple-element(loop_var.3), index=1
+ multiply.1 = s32[3]{0} multiply(get-tuple-element.5, get-tuple-element.5)
+ ROOT tuple.1 = (s32[], s32[3]{0}) tuple(add.1, multiply.1)
+ }
+ SimpleLoop.condition1 {
+ loop_var.4 = (s32[], s32[3]{0}) parameter(0)
+ get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0
+ constant.4 = s32[] constant(5)
+ ROOT less-than.1 = pred[] less-than(get-tuple-element.6, constant.4)
+ }
+ ENTRY SimpleLoop {
+ constant.5 = s32[] constant(0)
+ constant.6 = s32[3]{0} constant({0, 1, 2})
+ tuple.2 = (s32[], s32[3]{0}) tuple(constant.5, constant.6)
+ while.1 = (s32[], s32[3]{0}) while(tuple.2), condition=
+ SimpleLoop.condition0, body=SimpleLoop.body0
+ get-tuple-element.7 = s32[] get-tuple-element(while.1), index=0
+ tuple.3 = (s32[], s32[3]{0}) tuple(get-tuple-element.7, constant.6)
+ while.2 = (s32[], s32[3]{0}) while(tuple.3), condition=
+ SimpleLoop.condition1, body=SimpleLoop.body1
+ ROOT get-tuple-element.8 = s32[] get-tuple-element(while.2), index=0
+ })")
+ .ValueOrDie();
+
+ HloModuleDCE dce;
+ // Before HloModuleDCE while.1 and while.2 should not have pass-thru elements.
+ EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while.1", 1));
+ EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while.2", 1));
+ EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
+ // After HloModuleDCE while.1 and while.2 should have pass-thru elements,
+ // after being modified to pass through unused tuple element {1}.
+ EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while.1", 0));
+ EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while.1", 1));
+ EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while.2", 0));
+ EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while.2", 1));
+}
+
+// Tests that HloModuleDCE can remove a dead tuple element at while.1{0} and
+// while.2{1}, between two dependent while loops.
+TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) {
+ auto module = tools::Parse(R"(
+ HloModule SimpleLoop
+ SimpleLoop.body0 {
+ loop_var.1 = (s32[3]{0}, s32[]) parameter(0)
+ get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=1
+ constant.1 = s32[] constant(1)
+ add = s32[] add(get-tuple-element.1, constant.1)
+ get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=0
+ multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
+ ROOT tuple = (s32[3]{0}, s32[]) tuple(multiply, add)
+ }
+ SimpleLoop.condition0 {
+ loop_var.2 = (s32[3]{0}, s32[]) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1
+ constant.2 = s32[] constant(5)
+ ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
+ }
+ SimpleLoop.body1 {
+ loop_var.3 = (s32[], s32[3]{0}) parameter(0)
+ get-tuple-element.4 = s32[] get-tuple-element(loop_var.3), index=0
+ constant.3 = s32[] constant(1)
+ add.1 = s32[] add(get-tuple-element.4, constant.3)
+ get-tuple-element.5 = s32[3]{0} get-tuple-element(loop_var.3), index=1
+ multiply.1 = s32[3]{0} multiply(get-tuple-element.5, get-tuple-element.5)
+ ROOT tuple.1 = (s32[], s32[3]{0}) tuple(add.1, multiply.1)
+ }
+ SimpleLoop.condition1 {
+ loop_var.4 = (s32[], s32[3]{0}) parameter(0)
+ get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0
+ constant.4 = s32[] constant(5)
+ ROOT less-than.1 = pred[] less-than(get-tuple-element.6, constant.4)
+ }
+ ENTRY SimpleLoop {
+ constant.5 = s32[] constant(0)
+ constant.6 = s32[3]{0} constant({0, 1, 2})
+ tuple.2 = (s32[3]{0}, s32[]) tuple(constant.6, constant.5)
+ while.1 = (s32[3]{0}, s32[]) while(tuple.2), condition=
+ SimpleLoop.condition0, body=SimpleLoop.body0
+ get-tuple-element.7 = s32[] get-tuple-element(while.1), index=1
+ tuple.3 = (s32[], s32[3]{0}) tuple(get-tuple-element.7, constant.6)
+ while.2 = (s32[], s32[3]{0}) while(tuple.3), condition=
+ SimpleLoop.condition1, body=SimpleLoop.body1
+ ROOT get-tuple-element.8 = s32[] get-tuple-element(while.2), index=0
+ })")
+ .ValueOrDie();
+
+ HloModuleDCE dce;
+ // Before HloModuleDCE while.1{0} and while.2{1} should not be pass-thru.
+ EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while.1", 0));
+ EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while.2", 1));
+ EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
+ // After HloModuleDCE while.1{0} and while.2{1} not be pass-thru elements.
+ EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while.1", 1));
+ EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while.1", 0));
+ EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while.2", 0));
+ EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while.2", 1));
+}
+
+} // namespace
+} // namespace xla