From: A. Unique TensorFlower Date: Wed, 4 Apr 2018 23:05:08 +0000 (-0700) Subject: Enable constant propagation across Enter nodes, but only if is_constant is true. X-Git-Tag: tflite-v0.1.7~39^2^2~15 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=7cee71e28e98bac613623feea19c4a51439e9a0a;p=platform%2Fupstream%2Ftensorflow.git Enable constant propagation across Enter nodes, but only if is_constant is true. Don't propagate constants with control dependencies through Merge nodes. PiperOrigin-RevId: 191663396 --- diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index dd522aa..d941a0b 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -773,7 +773,7 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { // the case of a merge node that propagate the first inputs that becomes // available, and therefore only requires a single constant input to be // foldable. - bool has_constant_input = false; + bool merge_has_constant_input = false; const bool is_merge = IsMerge(node); for (const auto& input : node.input()) { if (IsControlInput(input)) { @@ -784,21 +784,20 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { return false; } bool is_const = IsReallyConstant(*input_node); - if (!is_const && !is_merge) { - return false; - } - // Don't fold strings constants for now since this causes problems with - // checkpointing. - if (is_const && input_node->attr().at("dtype").type() == DT_STRING) { + if (is_const) { + // Don't fold strings constants for now since this causes problems with + // checkpointing. + if (input_node->attr().at("dtype").type() == DT_STRING) { + return false; + } + // Special case: If a Merge node has at least one constant input that + // does not depend on a control input, we can fold it. + merge_has_constant_input |= !HasControlInputs(*input_node); + } else if (!is_merge) { return false; } - has_constant_input |= is_const; - } - if (is_merge) { - return has_constant_input; } - - return true; + return !is_merge || merge_has_constant_input; } namespace { @@ -1714,9 +1713,11 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, } // Move constants past Enter. - // TODO(rmlarsen): Reenable when we fix the root cause of b/76008022 - if (opt_level_ == RewriterConfig::AGGRESSIVE && IsEnter(*node) && - node->input_size() > 0) { + if (IsEnter(*node) && node->input_size() > 0) { + if (node->attr().count("is_constant") == 0 || + !node->attr().at("is_constant").b()) { + continue; + } const string& node_name = node->name(); const NodeDef* input = node_map_->GetNode(node->input(0)); if (input != nullptr && IsReallyConstant(*input) && @@ -1745,7 +1746,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, node_map_->AddOutput(node_name, new_node->name()); for (NodeDef* consumer : consumers) { for (int i = 0; i < consumer->input_size(); ++i) { - if (consumer->input(i) == node_name) { + if (NodeName(consumer->input(i)) == node_name) { node_map_->UpdateInput(consumer->name(), node_name, new_node->name()); consumer->set_input(i, new_node->name()); diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 8d146637..71ee81d 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -1256,6 +1256,10 @@ TEST_F(ConstantFoldingTest, MergeNodes) { ops::Merge m1(scope.WithOpName("m1"), {x, const1, const2}); ops::Merge m2(scope.WithOpName("m2"), {const1, const3}); ops::Merge m3(scope.WithOpName("m3"), {x, y}); + // m4 is not foldable because the only constant input + // has a control input, so we cannot know if it will be + // triggered. + ops::Merge m4(scope.WithOpName("m4"), {x, const1}); ops::Identity out1(scope.WithOpName("out1"), m1.output); ops::Identity idx1(scope.WithOpName("idx1"), m1.value_index); @@ -1263,9 +1267,11 @@ TEST_F(ConstantFoldingTest, MergeNodes) { ops::Identity idx2(scope.WithOpName("idx2"), m2.value_index); ops::Identity out3(scope.WithOpName("out3"), m3.output); ops::Identity idx3(scope.WithOpName("idx3"), m3.value_index); + ops::Identity out4(scope.WithOpName("out4"), m4.output); + ops::Identity idx4(scope.WithOpName("idx4"), m4.value_index); GrapplerItem item; - item.fetch = {"out1", "idx1", "out2", "idx2", "out3", "idx3"}; + item.fetch = {"out1", "idx1", "out2", "idx2", "out3", "idx3", "out4", "idx4"}; TF_CHECK_OK(scope.ToGraphDef(&item.graph)); ConstantFolding optimizer(nullptr /* cpu_device */); @@ -1273,6 +1279,7 @@ TEST_F(ConstantFoldingTest, MergeNodes) { Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); + EXPECT_EQ(19, output.node_size()); int found_nodes = 0; for (const auto& node : output.node()) { if (node.name() == "out1") { @@ -1309,10 +1316,18 @@ TEST_F(ConstantFoldingTest, MergeNodes) { EXPECT_EQ(1, node.input_size()); EXPECT_EQ("m3:1", node.input(0)); ++found_nodes; + } else if (node.name() == "out4") { + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("m4", node.input(0)); + ++found_nodes; + } else if (node.name() == "idx4") { + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("m4:1", node.input(0)); + ++found_nodes; } } // Make sure the graph contains all the nodes we're expecting. - EXPECT_EQ(6, found_nodes); + EXPECT_EQ(8, found_nodes); std::vector fetch = {"out1", "idx1"}; auto tensors = EvaluateNodes(output, fetch); @@ -2320,6 +2335,10 @@ TEST_F(ConstantFoldingTest, Enter) { GrapplerItem item; AttrValue frame_name; frame_name.set_s("foo"); + AttrValue is_constant_true; + is_constant_true.set_b(true); + AttrValue is_constant_false; + is_constant_false.set_b(false); AttrValue type; type.set_type(DT_FLOAT); AttrValue value; @@ -2330,19 +2349,31 @@ TEST_F(ConstantFoldingTest, Enter) { GraphDef& graph = item.graph; AddNode("x", "Placeholder", {}, {{"T", type}}, &graph); AddNode("c1", "Const", {"^x"}, {{"value", value}, {"dtype", type}}, &graph); - AddNode("enter1", "Enter", {"x"}, {{"T", type}, {"frame_name", frame_name}}, + AddNode("enter1", "Enter", {"x"}, + {{"T", type}, + {"frame_name", frame_name}, + {"is_constant", is_constant_true}}, + &graph); + AddNode("enter2", "Enter", {"c1"}, + {{"T", type}, + {"frame_name", frame_name}, + {"is_constant", is_constant_true}}, &graph); - AddNode("enter2", "Enter", {"c1"}, {{"T", type}, {"frame_name", frame_name}}, + AddNode("enter3", "Enter", {"c1"}, + {{"T", type}, + {"frame_name", frame_name}, + {"is_constant", is_constant_false}}, &graph); AddNode("id1", "Identity", {"enter1"}, {{"T", type}}, &graph); AddNode("id2", "Identity", {"enter2"}, {{"T", type}}, &graph); AddNode("id3", "Identity", {"enter2"}, {{"T", type}}, &graph); + AddNode("id4", "Identity", {"enter3"}, {{"T", type}}, &graph); item.fetch.push_back("id1"); item.fetch.push_back("id2"); item.fetch.push_back("id3"); + item.fetch.push_back("id4"); - ConstantFolding optimizer(RewriterConfig::AGGRESSIVE, - nullptr /* cpu_device */); + ConstantFolding optimizer(nullptr /* cpu_device */); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -2351,7 +2382,7 @@ TEST_F(ConstantFoldingTest, Enter) { status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); - EXPECT_EQ(7, output.node_size()); + EXPECT_EQ(9, output.node_size()); for (const NodeDef& node : output.node()) { if (node.name() == "id1") { EXPECT_EQ("Identity", node.op()); @@ -2363,6 +2394,11 @@ TEST_F(ConstantFoldingTest, Enter) { EXPECT_EQ(1, node.input_size()); EXPECT_EQ("^enter2", node.input(0)); } + if (node.name() == "id4") { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("enter3", node.input(0)); + } } } diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 57b0b78..a936360 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -4844,6 +4844,29 @@ py_test( ) cuda_py_test( + name = "constant_folding_test", + size = "medium", + srcs = [ + "grappler/constant_folding_test.py", + ], + additional_deps = [ + ":client_testlib", + ":framework_for_generated_wrappers", + ":array_ops", + ":control_flow_ops", + ":dtypes", + ":functional_ops", + ":math_ops", + ":ops", + "//third_party/py/numpy", + "//tensorflow/core:protos_all_py", + ], + tags = [ + "grappler", + ], +) + +cuda_py_test( name = "layout_optimizer_test", size = "medium", srcs = [ diff --git a/tensorflow/python/grappler/constant_folding_test.py b/tensorflow/python/grappler/constant_folding_test.py new file mode 100644 index 0000000..ab1d0ed --- /dev/null +++ b/tensorflow/python/grappler/constant_folding_test.py @@ -0,0 +1,69 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Grappler Constant Folding.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.client import session +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class ConstantFoldingTest(test.TestCase): + + # See b/76008022. + def testScanInsideWhile(self): + + def loop_cond(idx_step, *unused_args): + return idx_step < 1 + + def loop_body(idx_step, y): + x = array_ops.zeros([10, 20, 30], dtype=dtypes.float32) + x = functional_ops.scan( + math_ops.add, + x, + initializer=array_ops.zeros([20, 30], dtype=dtypes.float32), + back_prop=False, + parallel_iterations=1) + + with ops.device('/cpu:0'): + y = array_ops.identity(x) + + return idx_step + 1, y + + if test.is_gpu_available(cuda_only=True): + init_y = array_ops.zeros([10, 20, 30], dtype=dtypes.float32) + _, y = control_flow_ops.while_loop( + loop_cond, + loop_body, + loop_vars=[0, init_y], + back_prop=False, + parallel_iterations=1) + with session.Session() as sess: + y_v = sess.run(y) + self.assertAllEqual(np.zeros([10, 20, 30]), y_v) + + +if __name__ == '__main__': + test.main()