"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
return Status::OK();
}
+Status RemoveDeadBranches(const std::unordered_set<string>& nodes_to_preserve,
+ GraphDef* optimized_graph) {
+ std::unordered_set<const NodeDef*> dead_nodes;
+ std::unordered_map<NodeDef*, std::set<int>> dead_merge_inputs;
+ // TODO(bsteiner): also rewrite switches as identity. For now we just record
+ // them
+ std::unordered_set<GraphView::OutputPort, GraphView::HashPort>
+ identity_switches;
+
+ GraphView view(optimized_graph);
+ for (const NodeDef& node : optimized_graph->node()) {
+ if (!IsSwitch(node)) {
+ continue;
+ }
+ if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) {
+ continue;
+ }
+ GraphView::InputPort ctrl_port(&node, 1);
+ GraphView::OutputPort ctrl_node = view.GetRegularFanin(ctrl_port);
+ if (!IsConstant(*ctrl_node.node)) {
+ continue;
+ }
+ Tensor selector;
+ CHECK(selector.FromProto(ctrl_node.node->attr().at("value").tensor()));
+ const int dead_fanout = selector.scalar<bool>()() ? 0 : 1;
+ GraphView::OutputPort dead(const_cast<NodeDef*>(&node), dead_fanout);
+ identity_switches.insert(dead);
+
+ SetVector<GraphView::InputPort, GraphView::HashPort> zombie_inputs;
+ for (const GraphView::InputPort& port : view.GetFanout(dead)) {
+ if (dead_nodes.find(port.node) == dead_nodes.end()) {
+ zombie_inputs.PushBack(port);
+ }
+ }
+ // If we encounter a single node that must be preserved in the fanout of the
+ // switch node we need to preserve the entire switch fanout: we therefore
+ // work on a local copy that only gets committed to the master copy once the
+ // whole fanout has been explored.
+ std::unordered_set<const NodeDef*> local_dead_nodes = dead_nodes;
+ std::unordered_map<NodeDef*, std::set<int>> local_dead_merge_inputs =
+ dead_merge_inputs;
+ bool found_node_to_preserve = false;
+ while (!found_node_to_preserve && !zombie_inputs.Empty()) {
+ GraphView::InputPort dead = zombie_inputs.PopBack();
+ if (nodes_to_preserve.find(dead.node->name()) !=
+ nodes_to_preserve.end()) {
+ found_node_to_preserve = true;
+ break;
+ }
+
+ if (local_dead_nodes.find(dead.node) != local_dead_nodes.end()) {
+ continue;
+ }
+
+ if (IsMerge(*dead.node)) {
+ const int fanout = dead.node->attr().at("N").i();
+ if (fanout > 2) {
+ // This never happens in practice, so we'll just skip these to
+ // simplify the code for now.
+ found_node_to_preserve = true;
+ break;
+ }
+ GraphView::OutputPort value_index(dead.node, 1);
+ const std::unordered_set<GraphView::InputPort, GraphView::HashPort>&
+ index_fanout = view.GetFanout(value_index);
+ if (!index_fanout.empty()) {
+ // The 2nd output (that indicates which input is propagated) is
+ // connected. This never happens in practice, so we'll just skip this
+ // case to simplify the code for now.
+ found_node_to_preserve = true;
+ break;
+ }
+
+ bool fully_dead = false;
+ if (dead.port_id < 0) {
+ // If the control dependency never gets triggered the merge will also
+ // never get triggered.
+ local_dead_nodes.insert(dead.node);
+ fully_dead = true;
+ } else {
+ local_dead_merge_inputs[dead.node].insert(dead.port_id);
+ if (local_dead_merge_inputs[dead.node].size() ==
+ dead.node->attr().at("N").i()) {
+ fully_dead = true;
+ }
+ if (fully_dead) {
+ local_dead_nodes.insert(dead.node);
+ for (const GraphView::InputPort& port :
+ view.GetFanouts(*dead.node, true)) {
+ zombie_inputs.PushBack(port);
+ }
+ }
+ }
+ } else {
+ if (local_dead_nodes.insert(dead.node).second) {
+ for (const GraphView::InputPort& dead_fanout :
+ view.GetFanouts(*dead.node, true)) {
+ zombie_inputs.PushBack(dead_fanout);
+ }
+ }
+ }
+ }
+ if (!found_node_to_preserve) {
+ std::swap(dead_nodes, local_dead_nodes);
+ std::swap(dead_merge_inputs, local_dead_merge_inputs);
+ }
+ }
+
+ int last = optimized_graph->node_size() - 1;
+ for (int i = optimized_graph->node_size() - 1; i >= 0; --i) {
+ NodeDef* node = optimized_graph->mutable_node(i);
+ if (dead_nodes.find(node) != dead_nodes.end()) {
+ optimized_graph->mutable_node()->SwapElements(i, last);
+ last--;
+ }
+ }
+ optimized_graph->mutable_node()->DeleteSubrange(last + 1, dead_nodes.size());
+
+ for (const auto& itr : dead_merge_inputs) {
+ NodeDef* dead_node = itr.first;
+ if (dead_nodes.find(dead_node) != dead_nodes.end()) {
+ // The node has been pruned since all its inputs are dead.
+ continue;
+ }
+ const std::set<int>& dead_inputs = itr.second;
+ for (int index : dead_inputs) {
+ dead_node->mutable_input()->DeleteSubrange(index, 1);
+ }
+ dead_node->set_op("Identity");
+ dead_node->mutable_attr()->erase("N");
+ }
+ return Status::OK();
+}
+
} // namespace
Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
if (options_.enable_stack_push_removal) {
TF_RETURN_IF_ERROR(RemoveStackOps(item.NodesToPreserve(), optimized_graph));
}
+ if (opt_level_ == RewriterConfig::AGGRESSIVE &&
+ options_.enable_dead_branch_removal) {
+ TF_RETURN_IF_ERROR(
+ RemoveDeadBranches(item.NodesToPreserve(), optimized_graph));
+ }
return Status::OK();
}
struct LoopOptimizerOptions {
bool enable_loop_invariant_node_motion = false;
bool enable_stack_push_removal = true;
+ bool enable_dead_branch_removal = true;
static LoopOptimizerOptions Default(RewriterConfig::Toggle opt_level) {
LoopOptimizerOptions options;
}
}
+TEST_F(LoopOptimizerTest, RemoveDeadBranches) {
+ Scope scope = Scope::NewRootScope();
+ Output v_in = ops::Variable(scope.WithOpName("v_in"), {3}, DT_FLOAT);
+
+ Output ctrl1 = ops::Const(scope.WithOpName("ctrl1"), false, TensorShape({}));
+ ops::Switch s1(scope.WithOpName("switch1"), v_in, ctrl1);
+ Output square1 = ops::Square(scope.WithOpName("square1"), s1.output_false);
+ Output sqrt1 = ops::Sqrt(scope.WithOpName("sqrt1"), s1.output_true);
+
+ Output ctrl2 = ops::Const(scope.WithOpName("ctrl2"), true, TensorShape({}));
+ ops::Switch s2(scope.WithOpName("switch2"), v_in, ctrl2);
+ Output square2 = ops::Square(scope.WithOpName("square2"), s2.output_false);
+ Output sqrt2 = ops::Sqrt(scope.WithOpName("sqrt2"), s2.output_true);
+
+ Output ctrl3 = ops::Const(scope.WithOpName("ctrl3"), false, TensorShape({}));
+ ops::Switch s3(scope.WithOpName("switch3"), v_in, ctrl3);
+ Output square3 = ops::Square(scope.WithOpName("square3"), s3.output_false);
+ Output sqrt3 = ops::Sqrt(scope.WithOpName("sqrt3"), s3.output_true);
+
+ Output ctrl4 = ops::Const(scope.WithOpName("ctrl4"), false, TensorShape({}));
+ ops::Switch s4(scope.WithOpName("switch4"), v_in, ctrl4);
+ Output square4 = ops::Square(scope.WithOpName("square4"), s4.output_false);
+ Output sqrt4 = ops::Sqrt(scope.WithOpName("sqrt4"), s4.output_true);
+
+ ops::Merge m1(scope.WithOpName("m1"), {square1, sqrt1});
+ ops::Merge m2(scope.WithOpName("m2"), {v_in, square1});
+ ops::Merge m3(scope.WithOpName("m3"), {v_in, sqrt1});
+ ops::Merge m4(scope.WithOpName("m4"), {square1, sqrt2});
+ ops::Merge m5(scope.WithOpName("m5"), {square2, sqrt1});
+ ops::Merge m6(scope.WithOpName("m6").WithControlDependencies(sqrt2),
+ {v_in, square1});
+ ops::Merge m7(scope.WithOpName("m7").WithControlDependencies(sqrt1),
+ {v_in, square1});
+
+ ops::Switch s5(scope.WithOpName("switch5"), v_in, ctrl1);
+ Output id1 = ops::Identity(scope.WithOpName("id1"), s5.output_false);
+ Output id2 = ops::Identity(scope.WithOpName("id2"), s5.output_true);
+ ops::Merge m8(scope.WithOpName("m8"), {id1, id2});
+
+ ops::Switch s6(scope.WithOpName("switch6"), v_in, ctrl1);
+ Output id3 = ops::Identity(scope.WithOpName("id3"), s6.output_false);
+ Output id4 = ops::Identity(scope.WithOpName("id4"), s6.output_true);
+ ops::Merge m9(scope.WithOpName("m9"), {id3, id4});
+
+ GrapplerItem item;
+ item.fetch.push_back("m8");
+ item.fetch.push_back("id4");
+
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+ LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_CHECK_OK(status);
+
+ for (const NodeDef& node : output.node()) {
+ // These nodes should have been pruned
+ EXPECT_NE("Square1", node.name());
+ EXPECT_NE("Sqrt2", node.name());
+ EXPECT_NE("m5", node.name());
+ EXPECT_NE("m7", node.name());
+
+ if (node.name() == "m1") {
+ // sqrt1 is dead
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("square1", node.input(0));
+ } else if (node.name() == "m2") {
+ // both inputs are alive
+ EXPECT_EQ("Merge", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("v_in", node.input(0));
+ EXPECT_EQ("square1", node.input(1));
+ } else if (node.name() == "m3") {
+ // sqrt1 is dead
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("v_in", node.input(0));
+ } else if (node.name() == "m4") {
+ // both inputs are alive
+ EXPECT_EQ("Merge", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("square1", node.input(0));
+ EXPECT_EQ("sqrt2", node.input(1));
+ } else if (node.name() == "m6") {
+ // both inputs are alive and the control dependency can get triggered
+ EXPECT_EQ("Merge", node.op());
+ EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ("v_in", node.input(0));
+ EXPECT_EQ("square1", node.input(1));
+ EXPECT_EQ("^sqrt2", node.input(2));
+ } else if (node.name() == "m8") {
+ // The node is to be preserved because of a fetch
+ EXPECT_EQ("Merge", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("id1", node.input(0));
+ EXPECT_EQ("id2", node.input(1));
+ } else if (node.name() == "m9") {
+ // The node is to be preserved because of a fetch
+ EXPECT_EQ("Merge", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("id3", node.input(0));
+ EXPECT_EQ("id4", node.input(1));
+ }
+ }
+}
+
} // namespace grappler
} // namespace tensorflow
// A vector with a set. The set stores the same elements as the vector, and
// quickly answers whether a value is in the vector. Duplicated elements are not
// allowed for now.
-template <class T>
+template <class T, class Hash = std::hash<T>>
class SetVector {
public:
// Returns false if value already existed in the set, true otherwise.
void Reserve(int64 size) { vector_.reserve(size); }
private:
- std::unordered_set<T> set_;
+ std::unordered_set<T, Hash> set_;
std::vector<T> vector_;
};