// the case of a merge node that propagate the first inputs that becomes
// available, and therefore only requires a single constant input to be
// foldable.
- bool has_constant_input = false;
+ bool merge_has_constant_input = false;
const bool is_merge = IsMerge(node);
for (const auto& input : node.input()) {
if (IsControlInput(input)) {
return false;
}
bool is_const = IsReallyConstant(*input_node);
- if (!is_const && !is_merge) {
- return false;
- }
- // Don't fold strings constants for now since this causes problems with
- // checkpointing.
- if (is_const && input_node->attr().at("dtype").type() == DT_STRING) {
+ if (is_const) {
+ // Don't fold strings constants for now since this causes problems with
+ // checkpointing.
+ if (input_node->attr().at("dtype").type() == DT_STRING) {
+ return false;
+ }
+ // Special case: If a Merge node has at least one constant input that
+ // does not depend on a control input, we can fold it.
+ merge_has_constant_input |= !HasControlInputs(*input_node);
+ } else if (!is_merge) {
return false;
}
- has_constant_input |= is_const;
- }
- if (is_merge) {
- return has_constant_input;
}
-
- return true;
+ return !is_merge || merge_has_constant_input;
}
namespace {
}
// Move constants past Enter.
- // TODO(rmlarsen): Reenable when we fix the root cause of b/76008022
- if (opt_level_ == RewriterConfig::AGGRESSIVE && IsEnter(*node) &&
- node->input_size() > 0) {
+ if (IsEnter(*node) && node->input_size() > 0) {
+ if (node->attr().count("is_constant") == 0 ||
+ !node->attr().at("is_constant").b()) {
+ continue;
+ }
const string& node_name = node->name();
const NodeDef* input = node_map_->GetNode(node->input(0));
if (input != nullptr && IsReallyConstant(*input) &&
node_map_->AddOutput(node_name, new_node->name());
for (NodeDef* consumer : consumers) {
for (int i = 0; i < consumer->input_size(); ++i) {
- if (consumer->input(i) == node_name) {
+ if (NodeName(consumer->input(i)) == node_name) {
node_map_->UpdateInput(consumer->name(), node_name,
new_node->name());
consumer->set_input(i, new_node->name());
ops::Merge m1(scope.WithOpName("m1"), {x, const1, const2});
ops::Merge m2(scope.WithOpName("m2"), {const1, const3});
ops::Merge m3(scope.WithOpName("m3"), {x, y});
+ // m4 is not foldable because the only constant input
+ // has a control input, so we cannot know if it will be
+ // triggered.
+ ops::Merge m4(scope.WithOpName("m4"), {x, const1});
ops::Identity out1(scope.WithOpName("out1"), m1.output);
ops::Identity idx1(scope.WithOpName("idx1"), m1.value_index);
ops::Identity idx2(scope.WithOpName("idx2"), m2.value_index);
ops::Identity out3(scope.WithOpName("out3"), m3.output);
ops::Identity idx3(scope.WithOpName("idx3"), m3.value_index);
+ ops::Identity out4(scope.WithOpName("out4"), m4.output);
+ ops::Identity idx4(scope.WithOpName("idx4"), m4.value_index);
GrapplerItem item;
- item.fetch = {"out1", "idx1", "out2", "idx2", "out3", "idx3"};
+ item.fetch = {"out1", "idx1", "out2", "idx2", "out3", "idx3", "out4", "idx4"};
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
ConstantFolding optimizer(nullptr /* cpu_device */);
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
+ EXPECT_EQ(19, output.node_size());
int found_nodes = 0;
for (const auto& node : output.node()) {
if (node.name() == "out1") {
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("m3:1", node.input(0));
++found_nodes;
+ } else if (node.name() == "out4") {
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("m4", node.input(0));
+ ++found_nodes;
+ } else if (node.name() == "idx4") {
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("m4:1", node.input(0));
+ ++found_nodes;
}
}
// Make sure the graph contains all the nodes we're expecting.
- EXPECT_EQ(6, found_nodes);
+ EXPECT_EQ(8, found_nodes);
std::vector<string> fetch = {"out1", "idx1"};
auto tensors = EvaluateNodes(output, fetch);
GrapplerItem item;
AttrValue frame_name;
frame_name.set_s("foo");
+ AttrValue is_constant_true;
+ is_constant_true.set_b(true);
+ AttrValue is_constant_false;
+ is_constant_false.set_b(false);
AttrValue type;
type.set_type(DT_FLOAT);
AttrValue value;
GraphDef& graph = item.graph;
AddNode("x", "Placeholder", {}, {{"T", type}}, &graph);
AddNode("c1", "Const", {"^x"}, {{"value", value}, {"dtype", type}}, &graph);
- AddNode("enter1", "Enter", {"x"}, {{"T", type}, {"frame_name", frame_name}},
+ AddNode("enter1", "Enter", {"x"},
+ {{"T", type},
+ {"frame_name", frame_name},
+ {"is_constant", is_constant_true}},
+ &graph);
+ AddNode("enter2", "Enter", {"c1"},
+ {{"T", type},
+ {"frame_name", frame_name},
+ {"is_constant", is_constant_true}},
&graph);
- AddNode("enter2", "Enter", {"c1"}, {{"T", type}, {"frame_name", frame_name}},
+ AddNode("enter3", "Enter", {"c1"},
+ {{"T", type},
+ {"frame_name", frame_name},
+ {"is_constant", is_constant_false}},
&graph);
AddNode("id1", "Identity", {"enter1"}, {{"T", type}}, &graph);
AddNode("id2", "Identity", {"enter2"}, {{"T", type}}, &graph);
AddNode("id3", "Identity", {"enter2"}, {{"T", type}}, &graph);
+ AddNode("id4", "Identity", {"enter3"}, {{"T", type}}, &graph);
item.fetch.push_back("id1");
item.fetch.push_back("id2");
item.fetch.push_back("id3");
+ item.fetch.push_back("id4");
- ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
- nullptr /* cpu_device */);
+ ConstantFolding optimizer(nullptr /* cpu_device */);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
- EXPECT_EQ(7, output.node_size());
+ EXPECT_EQ(9, output.node_size());
for (const NodeDef& node : output.node()) {
if (node.name() == "id1") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("^enter2", node.input(0));
}
+ if (node.name() == "id4") {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("enter3", node.input(0));
+ }
}
}
--- /dev/null
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Grappler Constant Folding."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class ConstantFoldingTest(test.TestCase):
+
+ # See b/76008022.
+ def testScanInsideWhile(self):
+
+ def loop_cond(idx_step, *unused_args):
+ return idx_step < 1
+
+ def loop_body(idx_step, y):
+ x = array_ops.zeros([10, 20, 30], dtype=dtypes.float32)
+ x = functional_ops.scan(
+ math_ops.add,
+ x,
+ initializer=array_ops.zeros([20, 30], dtype=dtypes.float32),
+ back_prop=False,
+ parallel_iterations=1)
+
+ with ops.device('/cpu:0'):
+ y = array_ops.identity(x)
+
+ return idx_step + 1, y
+
+ if test.is_gpu_available(cuda_only=True):
+ init_y = array_ops.zeros([10, 20, 30], dtype=dtypes.float32)
+ _, y = control_flow_ops.while_loop(
+ loop_cond,
+ loop_body,
+ loop_vars=[0, init_y],
+ back_prop=False,
+ parallel_iterations=1)
+ with session.Session() as sess:
+ y_v = sess.run(y)
+ self.assertAllEqual(np.zeros([10, 20, 30]), y_v)
+
+
+if __name__ == '__main__':
+ test.main()