#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/strings/numbers.h"
// folding of ops when more than one but not all inputs are constant.
// For AddN and AccumulateNV2, we may furthermore reorder inputs, since
// addition is commutative.
- // TODO(rmlarsen): Concat/Pack/ParallelConcat which are not commutative, so
- // we have to preserve order and can only push consecutive runs of constant
- // inputs into sub-nodes.
+ const int num_non_control_inputs = NumNonControlInputs(*node);
if (IsAggregate(*node) && IsCommutative(*node) &&
- NumNonControlInputs(*node) > 2) {
+ num_non_control_inputs > 2) {
const int num_control_inputs =
- node->input_size() - NumNonControlInputs(*node);
+ node->input_size() - num_non_control_inputs;
std::vector<int> const_inputs;
std::vector<int> nonconst_inputs;
for (int i = 0; i < node->input_size(); ++i) {
}
// Promote AccumulateNV2 with all constant inputs to AddN, since it is
// a fake node that cannot be constant folded by itself.
- if (const_inputs.size() == NumNonControlInputs(*node) &&
+ if (const_inputs.size() == num_non_control_inputs &&
node->op() == "AccumulateNV2") {
node->set_op("AddN");
node->mutable_attr()->erase("shape");
const string new_node_name = OptimizedNodeName(
*node, strings::StrCat("_partial_split_", const_inputs.size()));
if (1 < const_inputs.size() &&
- const_inputs.size() < NumNonControlInputs(*node) &&
+ const_inputs.size() < num_non_control_inputs &&
!node_map_->NodeExists(new_node_name)) {
NodeDef* added_node = output->add_node();
*added_node = *node;
const_inputs.size() - 1);
(*node->mutable_attr())["N"].set_i(node->input_size() -
num_control_inputs);
+ properties->ClearInputProperties(node->name());
(*added_node->mutable_attr())["N"].set_i(const_inputs.size());
graph_modified_ = true;
+ continue;
+ }
+ }
+
+ // Partial constant folding for Concat which is not commutative, so
+ // we have to preserve order and can only push consecutive runs of constant
+ // inputs into sub-nodes.
+ if (IsConcat(*node) && num_non_control_inputs > 3 &&
+ node->name().rfind("_partial_split_") == string::npos) {
+ int axis_arg = -1;
+ int begin = 0;
+ int end = num_non_control_inputs;
+ if (node->op() == "Concat") {
+ begin = 1;
+ axis_arg = 0;
+ } else if (node->op() == "ConcatV2") {
+ end = num_non_control_inputs - 1;
+ axis_arg = num_non_control_inputs - 1;
+ } else {
+ continue;
+ }
+
+ const NodeDef* axis_arg_node =
+ node_map_->GetNode(NodeName(node->input(axis_arg)));
+ if (axis_arg_node == nullptr || !IsReallyConstant(*axis_arg_node)) {
+ // We cannot constant fold Concat unless we the axis argument is
+ // constant. Skip node.
+ continue;
+ }
+
+ // We search for consecutive runs of constant inputs in the range
+ // [begin:end[ and push then down into child nodes.
+ std::vector<std::pair<int, int>> constant_input_runs;
+ int first = begin;
+ int last = begin;
+ while (last < end) {
+ while (first < end && !IsReallyConstant(*node_map_->GetNode(
+ NodeName(node->input(first))))) {
+ ++first;
+ }
+ // Invariant: node[first] is constant || first >= end.
+ last = first + 1;
+ while (last < end && IsReallyConstant(*node_map_->GetNode(
+ NodeName(node->input(last))))) {
+ ++last;
+ }
+ // Invariant: node[last] is not constant || last >= end
+ // Discard intervals shorter than 2 elements.
+ if (first < end && (last - first) > 1) {
+ constant_input_runs.emplace_back(first, last);
+ }
+ first = last;
+ }
+
+ // Skip if all inputs are constant, and let constant folding take over.
+ if (constant_input_runs.size() == 1 &&
+ constant_input_runs[0].first == begin &&
+ constant_input_runs[0].second == end) {
+ continue;
+ }
+ std::set<int> inputs_to_delete;
+ for (auto interval : constant_input_runs) {
+ // Push the constant inputs in the interval to a child node than can be
+ // constant folded.
+ const string new_node_name = OptimizedNodeName(
+ *node, strings::StrCat("_partial_split_", interval.first));
+ if (node_map_->NodeExists(new_node_name)) {
+ break;
+ }
+ NodeDef* added_node = output->add_node();
+ *added_node = *node;
+ added_node->set_name(new_node_name);
+ node_map_->AddNode(added_node->name(), added_node);
+ added_node->clear_input();
+ for (int i = interval.first; i < interval.second; ++i) {
+ added_node->add_input(node->input(i));
+ node_map_->UpdateOutput(NodeName(node->input(i)), node->name(),
+ added_node->name());
+ if (i != interval.first) {
+ inputs_to_delete.insert(i);
+ }
+ }
+ added_node->add_input(node->input(axis_arg));
+ (*added_node->mutable_attr())["N"].set_i(interval.second -
+ interval.first);
+ node_map_->AddOutput(NodeName(node->input(axis_arg)),
+ added_node->name());
+
+ // Overwrite the first constant input with the result of the added
+ // child node.
+ node->set_input(interval.first, added_node->name());
+ node_map_->AddOutput(added_node->name(), node->name());
+ }
+ if (!constant_input_runs.empty()) {
+ graph_modified_ = true;
+ if (!inputs_to_delete.empty()) {
+ // Fix up the inputs to the original node.
+ std::vector<string> tmp(node->input().begin(), node->input().end());
+ node->clear_input();
+ for (int i = 0; i < tmp.size(); ++i) {
+ if (inputs_to_delete.find(i) == inputs_to_delete.end()) {
+ node->add_input(tmp[i]);
+ }
+ }
+ (*node->mutable_attr())["N"].set_i(node->input_size() - 1);
+ properties->ClearInputProperties(node->name());
+ }
+ continue;
}
}
}
Output sub1 = ops::Sub(s.WithOpName("sub1"), x, zeros);
Output sub2 = ops::Sub(s.WithOpName("sub2"), zeros, y);
Output concat =
- ops::Concat(s.WithOpName("concat"),
- {mul1, mul2, mul3, mul4, mul5, mul6, div1, div2, matmul1,
- matmul2, add1, add2, bias_add1, bias_add2, sub1, sub2},
- 0);
+ ops::Stack(s.WithOpName("stack"),
+ {mul1, mul2, mul3, mul4, mul5, mul6, div1, div2, matmul1,
+ matmul2, add1, add2, bias_add1, bias_add2, sub1, sub2});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
- item.fetch = {"concat", "matmul3", "matmul4"};
+ item.fetch = {"stack", "matmul3", "matmul4"};
ConstantFolding optimizer(nullptr /* cpu_device */);
GraphDef output;
const string ones_name = strings::StrCat("ones", suffix);
const string ctrl_zeros_name = strings::StrCat("^zeros", suffix);
const string ctrl_ones_name = strings::StrCat("^ones", suffix);
- EXPECT_EQ(28, output.node_size());
+ EXPECT_EQ(27, output.node_size());
for (int i = 0; i < output.node_size(); ++i) {
const NodeDef& node = output.node(i);
const string& name = node.name();
Output acc4 = fun(s.WithOpName("acc4"), {c1, y, c2});
Output acc5 = fun(s.WithOpName("acc5"), {x, c1, c2});
Output acc6 = fun(s.WithOpName("acc6"), {x, c1, y, c2});
- Output concat = ops::Concat(s.WithOpName("concat"),
- {acc0, acc1, acc2, acc3, acc4, acc5, acc6}, 0);
+ Output stack = ops::Stack(s.WithOpName("stack"),
+ {acc0, acc1, acc2, acc3, acc4, acc5, acc6});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
- item.fetch = {"concat"};
+ item.fetch = {"stack"};
ConstantFolding optimizer(nullptr /* cpu_device */);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
- EXPECT_EQ(17, output.node_size());
+ EXPECT_EQ(16, output.node_size());
for (const NodeDef& node : output.node()) {
if (node.name() == "acc0") {
EXPECT_EQ("Const", node.op());
}
}
-TEST_F(ConstantFoldingTest, IdenticalN) {
+TEST_F(ConstantFoldingTest, PartialFolding_Concat) {
+ Scope s = Scope::NewRootScope();
+ Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
+ ops::Placeholder::Shape(TensorShape({2, 2})));
+ Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
+ ops::Placeholder::Shape(TensorShape({2, 2})));
+ Output z = ops::Placeholder(s.WithOpName("z"), DT_FLOAT,
+ ops::Placeholder::Shape(TensorShape({2, 2})));
+ Output axis = ops::Const(s.WithOpName("axis"), 0, {});
+ Output c1 = ops::Const(s.WithOpName("c1"), 1.0f, {2, 2});
+ Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2, 2});
+ Output concat0 = ops::Concat(s.WithOpName("concat0"), {c1, c2, c1}, axis);
+ Output concat1 = ops::Concat(s.WithOpName("concat1"), {x, y, z}, axis);
+ Output concat2 = ops::Concat(s.WithOpName("concat2"), {c1, x, y}, axis);
+ Output concat3 = ops::Concat(s.WithOpName("concat3"), {c1, c2, z}, axis);
+ Output concat4 = ops::Concat(s.WithOpName("concat4"), {c1, y, c2}, axis);
+ Output concat5 = ops::Concat(s.WithOpName("concat5"), {x, c1, c2}, axis);
+ Output concat6 = ops::Concat(s.WithOpName("concat6"), {x, c1, y, c2}, axis);
+ Output concat7 = ops::Concat(s.WithOpName("concat7"), {x, y, c1, c2}, axis);
+ Output concat8 = ops::Concat(s.WithOpName("concat8"), {x, c1, c2, y}, axis);
+ Output concat9 = ops::Concat(s.WithOpName("concat9"), {c1, c2, x, y}, axis);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch = {"concat0", "concat1", "concat2", "concat3", "concat4",
+ "concat5", "concat6", "concat7", "concat8", "concat9"};
+
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ // Run the optimizer twice to make sure the rewrite is idempotent.
+ item.graph.Swap(&output);
+ status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ EXPECT_EQ(21, output.node_size());
+ for (int i = 0; i < output.node_size(); ++i) {
+ const NodeDef& node = output.node(i);
+ if (node.name() == "concat0") {
+ EXPECT_EQ("Const", node.op());
+ } else if (node.name() == "concat3") {
+ EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ("ConstantFolding/concat3_partial_split_0", node.input(0));
+ EXPECT_EQ("z", node.input(1));
+ EXPECT_EQ("axis", node.input(2));
+ } else if (node.name() == "concat5") {
+ EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("ConstantFolding/concat5_partial_split_1", node.input(1));
+ EXPECT_EQ("axis", node.input(2));
+ } else if (node.name() == "concat7") {
+ EXPECT_EQ(4, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("y", node.input(1));
+ EXPECT_EQ("ConstantFolding/concat7_partial_split_2", node.input(2));
+ EXPECT_EQ("axis", node.input(3));
+ } else if (node.name() == "concat8") {
+ EXPECT_EQ(4, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("ConstantFolding/concat8_partial_split_1", node.input(1));
+ EXPECT_EQ("y", node.input(2));
+ EXPECT_EQ("axis", node.input(3));
+ } else if (node.name() == "concat9") {
+ EXPECT_EQ(4, node.input_size());
+ EXPECT_EQ("ConstantFolding/concat9_partial_split_0", node.input(0));
+ EXPECT_EQ("x", node.input(1));
+ EXPECT_EQ("y", node.input(2));
+ EXPECT_EQ("axis", node.input(3));
+ } else if (StringPiece(node.name()).starts_with("ConstantFolding/")) {
+ EXPECT_EQ("Const", node.op());
+ } else {
+ EXPECT_EQ(item.graph.node(i).DebugString(), node.DebugString());
+ }
+ }
+
+ auto tensors_expected = EvaluateNodes(item.graph, {"concat0"});
+ auto tensors = EvaluateNodes(output, {"concat0"});
+ EXPECT_EQ(1, tensors_expected.size());
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+}
+
+TEST_F(ConstantFoldingTest, PartialFolding_IdentityN) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
Output x = ops::Placeholder(scope.WithOpName("x"), DT_FLOAT,
ops::Placeholder::Shape(TensorShape({})));