Enable constant propagation across Enter nodes, but only if is_constant is true.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 4 Apr 2018 23:05:08 +0000 (16:05 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 4 Apr 2018 23:07:48 +0000 (16:07 -0700)
Don't propagate constants with control dependencies through Merge nodes.

PiperOrigin-RevId: 191663396

tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding_test.cc
tensorflow/python/BUILD
tensorflow/python/grappler/constant_folding_test.py [new file with mode: 0644]

index dd522aa..d941a0b 100644 (file)
@@ -773,7 +773,7 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
   // the case of a merge node that propagate the first inputs that becomes
   // available, and therefore only requires a single constant input to be
   // foldable.
-  bool has_constant_input = false;
+  bool merge_has_constant_input = false;
   const bool is_merge = IsMerge(node);
   for (const auto& input : node.input()) {
     if (IsControlInput(input)) {
@@ -784,21 +784,20 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
       return false;
     }
     bool is_const = IsReallyConstant(*input_node);
-    if (!is_const && !is_merge) {
-      return false;
-    }
-    // Don't fold strings constants for now since this causes problems with
-    // checkpointing.
-    if (is_const && input_node->attr().at("dtype").type() == DT_STRING) {
+    if (is_const) {
+      // Don't fold strings constants for now since this causes problems with
+      // checkpointing.
+      if (input_node->attr().at("dtype").type() == DT_STRING) {
+        return false;
+      }
+      // Special case: If a Merge node has at least one constant input that
+      // does not depend on a control input, we can fold it.
+      merge_has_constant_input |= !HasControlInputs(*input_node);
+    } else if (!is_merge) {
       return false;
     }
-    has_constant_input |= is_const;
-  }
-  if (is_merge) {
-    return has_constant_input;
   }
-
-  return true;
+  return !is_merge || merge_has_constant_input;
 }
 
 namespace {
@@ -1714,9 +1713,11 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
     }
 
     // Move constants past Enter.
-    // TODO(rmlarsen): Reenable when we fix the root cause of b/76008022
-    if (opt_level_ == RewriterConfig::AGGRESSIVE && IsEnter(*node) &&
-        node->input_size() > 0) {
+    if (IsEnter(*node) && node->input_size() > 0) {
+      if (node->attr().count("is_constant") == 0 ||
+          !node->attr().at("is_constant").b()) {
+        continue;
+      }
       const string& node_name = node->name();
       const NodeDef* input = node_map_->GetNode(node->input(0));
       if (input != nullptr && IsReallyConstant(*input) &&
@@ -1745,7 +1746,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
           node_map_->AddOutput(node_name, new_node->name());
           for (NodeDef* consumer : consumers) {
             for (int i = 0; i < consumer->input_size(); ++i) {
-              if (consumer->input(i) == node_name) {
+              if (NodeName(consumer->input(i)) == node_name) {
                 node_map_->UpdateInput(consumer->name(), node_name,
                                        new_node->name());
                 consumer->set_input(i, new_node->name());
index 8d14663..71ee81d 100644 (file)
@@ -1256,6 +1256,10 @@ TEST_F(ConstantFoldingTest, MergeNodes) {
   ops::Merge m1(scope.WithOpName("m1"), {x, const1, const2});
   ops::Merge m2(scope.WithOpName("m2"), {const1, const3});
   ops::Merge m3(scope.WithOpName("m3"), {x, y});
+  // m4 is not foldable because the only constant input
+  // has a control input, so we cannot know if it will be
+  // triggered.
+  ops::Merge m4(scope.WithOpName("m4"), {x, const1});
 
   ops::Identity out1(scope.WithOpName("out1"), m1.output);
   ops::Identity idx1(scope.WithOpName("idx1"), m1.value_index);
@@ -1263,9 +1267,11 @@ TEST_F(ConstantFoldingTest, MergeNodes) {
   ops::Identity idx2(scope.WithOpName("idx2"), m2.value_index);
   ops::Identity out3(scope.WithOpName("out3"), m3.output);
   ops::Identity idx3(scope.WithOpName("idx3"), m3.value_index);
+  ops::Identity out4(scope.WithOpName("out4"), m4.output);
+  ops::Identity idx4(scope.WithOpName("idx4"), m4.value_index);
 
   GrapplerItem item;
-  item.fetch = {"out1", "idx1", "out2", "idx2", "out3", "idx3"};
+  item.fetch = {"out1", "idx1", "out2", "idx2", "out3", "idx3", "out4", "idx4"};
   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
 
   ConstantFolding optimizer(nullptr /* cpu_device */);
@@ -1273,6 +1279,7 @@ TEST_F(ConstantFoldingTest, MergeNodes) {
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
 
+  EXPECT_EQ(19, output.node_size());
   int found_nodes = 0;
   for (const auto& node : output.node()) {
     if (node.name() == "out1") {
@@ -1309,10 +1316,18 @@ TEST_F(ConstantFoldingTest, MergeNodes) {
       EXPECT_EQ(1, node.input_size());
       EXPECT_EQ("m3:1", node.input(0));
       ++found_nodes;
+    } else if (node.name() == "out4") {
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("m4", node.input(0));
+      ++found_nodes;
+    } else if (node.name() == "idx4") {
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("m4:1", node.input(0));
+      ++found_nodes;
     }
   }
   // Make sure the graph contains all the nodes we're expecting.
-  EXPECT_EQ(6, found_nodes);
+  EXPECT_EQ(8, found_nodes);
 
   std::vector<string> fetch = {"out1", "idx1"};
   auto tensors = EvaluateNodes(output, fetch);
@@ -2320,6 +2335,10 @@ TEST_F(ConstantFoldingTest, Enter) {
   GrapplerItem item;
   AttrValue frame_name;
   frame_name.set_s("foo");
+  AttrValue is_constant_true;
+  is_constant_true.set_b(true);
+  AttrValue is_constant_false;
+  is_constant_false.set_b(false);
   AttrValue type;
   type.set_type(DT_FLOAT);
   AttrValue value;
@@ -2330,19 +2349,31 @@ TEST_F(ConstantFoldingTest, Enter) {
   GraphDef& graph = item.graph;
   AddNode("x", "Placeholder", {}, {{"T", type}}, &graph);
   AddNode("c1", "Const", {"^x"}, {{"value", value}, {"dtype", type}}, &graph);
-  AddNode("enter1", "Enter", {"x"}, {{"T", type}, {"frame_name", frame_name}},
+  AddNode("enter1", "Enter", {"x"},
+          {{"T", type},
+           {"frame_name", frame_name},
+           {"is_constant", is_constant_true}},
+          &graph);
+  AddNode("enter2", "Enter", {"c1"},
+          {{"T", type},
+           {"frame_name", frame_name},
+           {"is_constant", is_constant_true}},
           &graph);
-  AddNode("enter2", "Enter", {"c1"}, {{"T", type}, {"frame_name", frame_name}},
+  AddNode("enter3", "Enter", {"c1"},
+          {{"T", type},
+           {"frame_name", frame_name},
+           {"is_constant", is_constant_false}},
           &graph);
   AddNode("id1", "Identity", {"enter1"}, {{"T", type}}, &graph);
   AddNode("id2", "Identity", {"enter2"}, {{"T", type}}, &graph);
   AddNode("id3", "Identity", {"enter2"}, {{"T", type}}, &graph);
+  AddNode("id4", "Identity", {"enter3"}, {{"T", type}}, &graph);
   item.fetch.push_back("id1");
   item.fetch.push_back("id2");
   item.fetch.push_back("id3");
+  item.fetch.push_back("id4");
 
-  ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
-                            nullptr /* cpu_device */);
+  ConstantFolding optimizer(nullptr /* cpu_device */);
   GraphDef output;
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
@@ -2351,7 +2382,7 @@ TEST_F(ConstantFoldingTest, Enter) {
   status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
 
-  EXPECT_EQ(7, output.node_size());
+  EXPECT_EQ(9, output.node_size());
   for (const NodeDef& node : output.node()) {
     if (node.name() == "id1") {
       EXPECT_EQ("Identity", node.op());
@@ -2363,6 +2394,11 @@ TEST_F(ConstantFoldingTest, Enter) {
       EXPECT_EQ(1, node.input_size());
       EXPECT_EQ("^enter2", node.input(0));
     }
+    if (node.name() == "id4") {
+      EXPECT_EQ("Identity", node.op());
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("enter3", node.input(0));
+    }
   }
 }
 
index 57b0b78..a936360 100644 (file)
@@ -4844,6 +4844,29 @@ py_test(
 )
 
 cuda_py_test(
+    name = "constant_folding_test",
+    size = "medium",
+    srcs = [
+        "grappler/constant_folding_test.py",
+    ],
+    additional_deps = [
+        ":client_testlib",
+        ":framework_for_generated_wrappers",
+        ":array_ops",
+        ":control_flow_ops",
+        ":dtypes",
+        ":functional_ops",
+        ":math_ops",
+        ":ops",
+        "//third_party/py/numpy",
+        "//tensorflow/core:protos_all_py",
+    ],
+    tags = [
+        "grappler",
+    ],
+)
+
+cuda_py_test(
     name = "layout_optimizer_test",
     size = "medium",
     srcs = [
diff --git a/tensorflow/python/grappler/constant_folding_test.py b/tensorflow/python/grappler/constant_folding_test.py
new file mode 100644 (file)
index 0000000..ab1d0ed
--- /dev/null
@@ -0,0 +1,69 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Grappler Constant Folding."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class ConstantFoldingTest(test.TestCase):
+
+  # See b/76008022.
+  def testScanInsideWhile(self):
+
+    def loop_cond(idx_step, *unused_args):
+      return idx_step < 1
+
+    def loop_body(idx_step, y):
+      x = array_ops.zeros([10, 20, 30], dtype=dtypes.float32)
+      x = functional_ops.scan(
+          math_ops.add,
+          x,
+          initializer=array_ops.zeros([20, 30], dtype=dtypes.float32),
+          back_prop=False,
+          parallel_iterations=1)
+
+      with ops.device('/cpu:0'):
+        y = array_ops.identity(x)
+
+        return idx_step + 1, y
+
+    if test.is_gpu_available(cuda_only=True):
+      init_y = array_ops.zeros([10, 20, 30], dtype=dtypes.float32)
+      _, y = control_flow_ops.while_loop(
+          loop_cond,
+          loop_body,
+          loop_vars=[0, init_y],
+          back_prop=False,
+          parallel_iterations=1)
+      with session.Session() as sess:
+        y_v = sess.run(y)
+        self.assertAllEqual(np.zeros([10, 20, 30]), y_v)
+
+
+if __name__ == '__main__':
+  test.main()