Use Identity instead of Snapshot when the graph does not contain ops that modify...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 11 May 2018 17:43:30 +0000 (10:43 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 11 May 2018 17:48:25 +0000 (10:48 -0700)
PiperOrigin-RevId: 196275133

tensorflow/core/grappler/op_types.cc
tensorflow/core/grappler/op_types.h
tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding.h
tensorflow/core/grappler/optimizers/constant_folding_test.cc

index e633ecf..07f826b 100644 (file)
@@ -408,6 +408,21 @@ bool IsPersistent(const NodeDef& node) {
   return IsConstant(node) || IsVariable(node);
 }
 
+bool MaybeHasRefInput(const NodeDef& node) {
+  const OpDef* op_def;
+  Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
+  if (!status.ok()) {
+    return true;
+  }
+  // Nodes such as Assign or AssignAdd modify one of their inputs.
+  for (const auto& input : op_def->input_arg()) {
+    if (input.is_ref()) {
+      return true;
+    }
+  }
+  return false;
+}
+
 bool IsFreeOfSideEffect(const NodeDef& node) {
   // Placeholders must be preserved to keep the graph feedable.
   if (IsPlaceholder(node)) {
index f6105d7..a5599eb 100644 (file)
@@ -166,6 +166,10 @@ bool IsPersistent(const NodeDef& node);
 
 bool IsFreeOfSideEffect(const NodeDef& node);
 
+// Returns true if the takes a tensor reference as input, or if looking up its
+// OpDef failed.
+bool MaybeHasRefInput(const NodeDef& node);
+
 bool ModifiesFrameInfo(const NodeDef& node);
 
 // Returns true if the op is known to write to one or more of its inputs.
index d5c583a..171d492 100644 (file)
@@ -1514,6 +1514,16 @@ void ConstantFolding::ReplaceOperationWithIdentity(
 void ConstantFolding::ReplaceOperationWithSnapshot(
     int input_to_forward, const GraphProperties& properties, NodeDef* node,
     GraphDef* graph) {
+  // If the graph contains no ops that mutate their inputs, we can
+  // use Identity insted of Snapshot.
+
+  // TODO(rmlarsen): Enable in regular mode after May 15, 2018.
+  if (opt_level_ == RewriterConfig::AGGRESSIVE &&
+      !graph_contains_assign_or_inplace_op_) {
+    ReplaceOperationWithIdentity(input_to_forward, properties, node, graph);
+    return;
+  }
+
   const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
   if (dtype == DT_INVALID) return;
 
@@ -2546,6 +2556,17 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
     cpu_device_ = owned_device_.get();
   }
 
+  graph_contains_assign_or_inplace_op_ = false;
+  // TODO(rmlarsen): Enable in regular mode after May 15, 2018.
+  if (opt_level_ == RewriterConfig::AGGRESSIVE) {
+    for (const NodeDef& node : item.graph.node()) {
+      if (ModifiesInputsInPlace(node) || MaybeHasRefInput(node)) {
+        graph_contains_assign_or_inplace_op_ = true;
+        break;
+      }
+    }
+  }
+
   has_fetch_ = !item.fetch.empty();
   GrapplerItem item_to_optimize = item;
   *optimized_graph = item.graph;
index 7aad3a6..f92f755 100644 (file)
@@ -126,6 +126,7 @@ class ConstantFolding : public GraphOptimizer {
   std::unordered_set<string> feed_nodes_;
   bool has_fetch_;
   bool graph_modified_;
+  bool graph_contains_assign_or_inplace_op_;
 };
 
 }  // end namespace grappler
index f018b21..0bf51c4 100644 (file)
@@ -33,77 +33,89 @@ class ConstantFoldingTest : public GrapplerTest {
  protected:
   template <DataType DTYPE>
   void SimpleNeutralElementTest() {
-    typedef typename EnumToDataType<DTYPE>::Type T;
-    tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-    Output x = ops::Placeholder(s.WithOpName("x"), DTYPE,
-                                ops::Placeholder::Shape(TensorShape({2, 2})));
-    Tensor zeros_t(DTYPE, TensorShape({2, 2}));
-    Tensor ones_t(DTYPE, TensorShape({2, 2}));
-    Tensor x_t(DTYPE, TensorShape({2, 2}));
-    for (int i = 0; i < 4; ++i) {
-      zeros_t.flat<T>()(i) = T(0);
-      ones_t.flat<T>()(i) = T(1);
-      x_t.flat<T>()(i) = T(i + 1);
-    }
-    Output zeros = ops::Const(s.WithOpName("zeros"), zeros_t);
-    Output ones = ops::Const(s.WithOpName("ones"), ones_t);
-    Output mul1;
-    Output mul2;
-    Output add1;
-    Output add2;
-    if (DTYPE == DT_BOOL) {
-      mul1 = ops::LogicalAnd(s.WithOpName("mul1"), x, zeros);
-      mul2 = ops::LogicalAnd(s.WithOpName("mul2"), x, ones);
-      add1 = ops::LogicalOr(s.WithOpName("add1"), x, zeros);
-      add2 = ops::LogicalOr(s.WithOpName("add2"), x, ones);
-    } else {
-      mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
-      mul2 = ops::Mul(s.WithOpName("mul2"), x, ones);
-      add1 = ops::Add(s.WithOpName("add1"), x, zeros);
-      add1 = ops::Add(s.WithOpName("add2"), x, ones);
-    }
-    GrapplerItem item;
-    TF_CHECK_OK(s.ToGraphDef(&item.graph));
-    item.fetch = {"mul1", "mul2", "add1", "add2"};
-    ConstantFolding optimizer(nullptr /* cpu_device */);
-    GraphDef output;
-    Status status = optimizer.Optimize(nullptr, item, &output);
-    TF_EXPECT_OK(status);
-
-    EXPECT_EQ(7, output.node_size());
-    for (int i = 0; i < output.node_size(); ++i) {
-      const NodeDef& node = output.node(i);
-      const string& name = node.name();
-      if (name == "mul1") {
-        EXPECT_EQ("Const", node.op());
-        EXPECT_EQ("^x", node.input(0));
-        EXPECT_EQ("^zeros", node.input(1));
-      } else if (name == "mul2") {
-        EXPECT_EQ("Snapshot", node.op());
-        EXPECT_EQ("x", node.input(0));
-        EXPECT_EQ("^ones", node.input(1));
-      } else if (name == "add1") {
-        EXPECT_EQ("Snapshot", node.op());
-        EXPECT_EQ("x", node.input(0));
-        EXPECT_EQ("^zeros", node.input(1));
-      } else if (name == "add2") {
-        if (DTYPE == DT_BOOL) {
+    for (bool use_snapshot : {false, true}) {
+      typedef typename EnumToDataType<DTYPE>::Type T;
+      tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+      Output x = ops::Placeholder(s.WithOpName("x"), DTYPE,
+                                  ops::Placeholder::Shape(TensorShape({2, 2})));
+      Output v = ops::Variable(s.WithOpName("v"), {2, 2}, DTYPE);
+      Tensor zeros_t(DTYPE, TensorShape({2, 2}));
+      Tensor ones_t(DTYPE, TensorShape({2, 2}));
+      Tensor x_t(DTYPE, TensorShape({2, 2}));
+      for (int i = 0; i < 4; ++i) {
+        zeros_t.flat<T>()(i) = T(0);
+        ones_t.flat<T>()(i) = T(1);
+        x_t.flat<T>()(i) = T(i + 1);
+      }
+      Output zeros = ops::Const(s.WithOpName("zeros"), zeros_t);
+      Output ones = ops::Const(s.WithOpName("ones"), ones_t);
+      Output mul1;
+      Output mul2;
+      Output add1;
+      Output add2;
+      if (DTYPE == DT_BOOL) {
+        mul1 = ops::LogicalAnd(s.WithOpName("mul1"), x, zeros);
+        mul2 = ops::LogicalAnd(s.WithOpName("mul2"), x, ones);
+        add1 = ops::LogicalOr(s.WithOpName("add1"), x, zeros);
+        add2 = ops::LogicalOr(s.WithOpName("add2"), x, ones);
+      } else {
+        mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
+        mul2 = ops::Mul(s.WithOpName("mul2"), x, ones);
+        add1 = ops::Add(s.WithOpName("add1"), x, zeros);
+        add1 = ops::Add(s.WithOpName("add2"), x, ones);
+      }
+      if (use_snapshot) {
+        // Add an op with ref input to prevent Snapshot from being
+        // turned into Identity.
+        ops::Assign(s.WithOpName("assign"), v, ones);
+      }
+      GrapplerItem item;
+      TF_CHECK_OK(s.ToGraphDef(&item.graph));
+      item.fetch = {"mul1", "mul2", "add1", "add2"};
+      ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
+                                nullptr /* cpu_device */);
+      GraphDef output;
+      Status status = optimizer.Optimize(nullptr, item, &output);
+      TF_EXPECT_OK(status);
+
+      EXPECT_EQ(7, output.node_size());
+      const string snapshot_or_identity =
+          use_snapshot ? "Snapshot" : "Identity";
+      for (int i = 0; i < output.node_size(); ++i) {
+        const NodeDef& node = output.node(i);
+        const string& name = node.name();
+        if (name == "mul1") {
           EXPECT_EQ("Const", node.op());
           EXPECT_EQ("^x", node.input(0));
+          EXPECT_EQ("^zeros", node.input(1));
+        } else if (name == "mul2") {
+          EXPECT_EQ(snapshot_or_identity, node.op());
+          EXPECT_EQ("x", node.input(0));
           EXPECT_EQ("^ones", node.input(1));
-        } else {
-          EXPECT_EQ("Add", node.op());
+        } else if (name == "add1") {
+          EXPECT_EQ(snapshot_or_identity, node.op());
           EXPECT_EQ("x", node.input(0));
-          EXPECT_EQ("ones", node.input(1));
+          EXPECT_EQ("^zeros", node.input(1));
+        } else if (name == "add2") {
+          if (DTYPE == DT_BOOL) {
+            EXPECT_EQ("Const", node.op());
+            EXPECT_EQ("^x", node.input(0));
+            EXPECT_EQ("^ones", node.input(1));
+          } else {
+            EXPECT_EQ("Add", node.op());
+            EXPECT_EQ("x", node.input(0));
+            EXPECT_EQ("ones", node.input(1));
+          }
         }
       }
-    }
-    auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
-    auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
-    EXPECT_EQ(4, tensors_expected.size());
-    EXPECT_EQ(4, tensors.size());
-    for (int i = 0; i < item.fetch.size(); ++i) {
-      test::ExpectTensorEqual<T>(tensors_expected[i], tensors[i]);
+      auto tensors_expected =
+          EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
+      auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
+      EXPECT_EQ(4, tensors_expected.size());
+      EXPECT_EQ(4, tensors.size());
+      for (int i = 0; i < item.fetch.size(); ++i) {
+        test::ExpectTensorEqual<T>(tensors_expected[i], tensors[i]);
+      }
     }
   }
 };
@@ -284,7 +296,8 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
     TF_CHECK_OK(s.ToGraphDef(&item.graph));
     item.fetch = {"stack", "matmul3", "matmul4"};
 
-    ConstantFolding optimizer(nullptr /* cpu_device */);
+    ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
+                              nullptr /* cpu_device */);
     GraphDef output;
     Status status = optimizer.Optimize(nullptr, item, &output);
     TF_EXPECT_OK(status);
@@ -309,11 +322,11 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
         EXPECT_EQ(ctrl_zeros_name, node.input(0));
         EXPECT_EQ("^y", node.input(1));
       } else if (name == "mul3") {
-        EXPECT_EQ("Snapshot", node.op());
+        EXPECT_EQ("Identity", node.op());
         EXPECT_EQ("x", node.input(0));
         EXPECT_EQ(ctrl_ones_name, node.input(1));
       } else if (name == "mul4") {
-        EXPECT_EQ("Snapshot", node.op());
+        EXPECT_EQ("Identity", node.op());
         EXPECT_EQ("y", node.input(0));
         EXPECT_EQ(ctrl_ones_name, node.input(1));
       } else if (name == "mul5") {
@@ -325,7 +338,7 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
         EXPECT_EQ("^zeros_1d", node.input(0));
         EXPECT_EQ("^y", node.input(1));
       } else if (name == "div1") {
-        EXPECT_EQ("Snapshot", node.op());
+        EXPECT_EQ("Identity", node.op());
         EXPECT_EQ("x", node.input(0));
         EXPECT_EQ(ctrl_ones_name, node.input(1));
       } else if (name == "div2") {
@@ -361,15 +374,15 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
         EXPECT_EQ(2, t.tensor_shape().dim(0).size());
         EXPECT_EQ(3, t.tensor_shape().dim(1).size());
       } else if (name == "add1") {
-        EXPECT_EQ("Snapshot", node.op());
+        EXPECT_EQ("Identity", node.op());
         EXPECT_EQ("x", node.input(0));
         EXPECT_EQ(ctrl_zeros_name, node.input(1));
       } else if (name == "add2") {
-        EXPECT_EQ("Snapshot", node.op());
+        EXPECT_EQ("Identity", node.op());
         EXPECT_EQ("y", node.input(0));
         EXPECT_EQ(ctrl_zeros_name, node.input(1));
       } else if (name == "bias_add1") {
-        EXPECT_EQ("Snapshot", node.op());
+        EXPECT_EQ("Identity", node.op());
         EXPECT_EQ("x", node.input(0));
         EXPECT_EQ("^zeros_1d", node.input(1));
       } else if (name == "bias_add2") {
@@ -378,7 +391,7 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
         EXPECT_EQ(zeros_name, node.input(0));
         EXPECT_EQ("bias", node.input(1));
       } else if (name == "sub1") {
-        EXPECT_EQ("Snapshot", node.op());
+        EXPECT_EQ("Identity", node.op());
         EXPECT_EQ("x", node.input(0));
         EXPECT_EQ(ctrl_zeros_name, node.input(1));
       } else if (name == "sub2") {