Implement partial constant folding for Concat.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 14 Mar 2018 18:14:23 +0000 (11:14 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 14 Mar 2018 18:18:10 +0000 (11:18 -0700)
PiperOrigin-RevId: 189055561

tensorflow/core/grappler/costs/graph_properties.cc
tensorflow/core/grappler/costs/graph_properties.h
tensorflow/core/grappler/costs/graph_properties_test.cc
tensorflow/core/grappler/op_types.cc
tensorflow/core/grappler/op_types.h
tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding_test.cc

index 243ca91..817247e 100644 (file)
@@ -1182,5 +1182,12 @@ GraphProperties::GetOutputProperties(const string& node_name) const {
   return missing_properties_;
 }
 
+void GraphProperties::ClearInputProperties(const string& node_name) {
+  input_properties_.erase(node_name);
+}
+void GraphProperties::ClearOutputProperties(const string& node_name) {
+  output_properties_.erase(node_name);
+}
+
 }  // end namespace grappler
 }  // end namespace tensorflow
index 6fc53a7..5aa4962 100644 (file)
@@ -64,6 +64,8 @@ class GraphProperties {
       const string& node_name) const;
   const std::vector<OpInfo::TensorProperties>& GetOutputProperties(
       const string& node_name) const;
+  void ClearInputProperties(const string& node_name);
+  void ClearOutputProperties(const string& node_name);
 
   static void FillTensorPropertiesFromContext(
       const shape_inference::ShapeHandle&, const DataType&,
index 5012069..284d9d4 100644 (file)
@@ -113,6 +113,33 @@ TEST_F(GraphPropertiesTest, StaticProperties) {
   }
 }
 
+TEST_F(GraphPropertiesTest, ClearProperties) {
+  TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
+                                          cluster_->GetDeviceNames());
+  GrapplerItem item;
+  CHECK(fake_input.NextItem(&item));
+
+  GraphProperties properties(item);
+  Status s = properties.InferStatically(true);
+  TF_CHECK_OK(s);
+
+  for (const auto& node : item.graph.node()) {
+    if (node.op() == "RandomStandardNormal") {
+      EXPECT_EQ(1, properties.GetInputProperties(node.name()).size());
+      const auto props = properties.GetOutputProperties(node.name());
+      properties.ClearOutputProperties(node.name());
+      const auto cleared_props = properties.GetOutputProperties(node.name());
+      EXPECT_TRUE(cleared_props.empty());
+    } else if (node.op() == "AddN") {
+      const auto in_props = properties.GetInputProperties(node.name());
+      EXPECT_EQ(1, in_props.size());
+      properties.ClearInputProperties(node.name());
+      const auto cleared_props = properties.GetInputProperties(node.name());
+      EXPECT_TRUE(cleared_props.empty());
+    }
+  }
+}
+
 TEST_F(GraphPropertiesTest, DynamicProperties) {
   TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
                                           cluster_->GetDeviceNames());
index 53c177b..9c9600d 100644 (file)
@@ -72,6 +72,10 @@ bool IsComplex(const NodeDef& node) { return node.op() == "Complex"; }
 
 bool IsComplexAbs(const NodeDef& node) { return node.op() == "ComplexAbs"; }
 
+bool IsConcat(const NodeDef& node) {
+  return node.op() == "Concat" || node.op() == "ConcatV2";
+}
+
 bool IsConcatOffset(const NodeDef& node) { return node.op() == "ConcatOffset"; }
 
 bool IsConstant(const NodeDef& node) { return node.op() == "Const"; }
index cd5b464..41ba8bb 100644 (file)
@@ -88,6 +88,7 @@ bool IsMatMul(const NodeDef& node);
 bool IsNextIteration(const NodeDef& node);
 bool IsPack(const NodeDef& node);
 bool IsPad(const NodeDef& node);
+bool IsPack(const NodeDef& node);
 bool IsNoOp(const NodeDef& node);
 bool IsNotEqual(const NodeDef& node);
 bool IsPlaceholder(const NodeDef& node);
index 821b22a..77ccd4f 100644 (file)
@@ -31,6 +31,7 @@ limitations under the License.
 #include "tensorflow/core/grappler/grappler_item.h"
 #include "tensorflow/core/grappler/op_types.h"
 #include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
 #include "tensorflow/core/lib/gtl/cleanup.h"
 #include "tensorflow/core/lib/gtl/inlined_vector.h"
 #include "tensorflow/core/lib/strings/numbers.h"
@@ -2093,13 +2094,11 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
     // folding of ops when more than one but not all inputs are constant.
     // For AddN and AccumulateNV2, we may furthermore reorder inputs, since
     // addition is commutative.
-    // TODO(rmlarsen): Concat/Pack/ParallelConcat which are not commutative, so
-    // we have to preserve order and can only push consecutive runs of constant
-    // inputs into sub-nodes.
+    const int num_non_control_inputs = NumNonControlInputs(*node);
     if (IsAggregate(*node) && IsCommutative(*node) &&
-        NumNonControlInputs(*node) > 2) {
+        num_non_control_inputs > 2) {
       const int num_control_inputs =
-          node->input_size() - NumNonControlInputs(*node);
+          node->input_size() - num_non_control_inputs;
       std::vector<int> const_inputs;
       std::vector<int> nonconst_inputs;
       for (int i = 0; i < node->input_size(); ++i) {
@@ -2115,7 +2114,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
       }
       // Promote AccumulateNV2 with all constant inputs to AddN, since it is
       // a fake node that cannot be constant folded by itself.
-      if (const_inputs.size() == NumNonControlInputs(*node) &&
+      if (const_inputs.size() == num_non_control_inputs &&
           node->op() == "AccumulateNV2") {
         node->set_op("AddN");
         node->mutable_attr()->erase("shape");
@@ -2125,7 +2124,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
       const string new_node_name = OptimizedNodeName(
           *node, strings::StrCat("_partial_split_", const_inputs.size()));
       if (1 < const_inputs.size() &&
-          const_inputs.size() < NumNonControlInputs(*node) &&
+          const_inputs.size() < num_non_control_inputs &&
           !node_map_->NodeExists(new_node_name)) {
         NodeDef* added_node = output->add_node();
         *added_node = *node;
@@ -2159,8 +2158,117 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
                                               const_inputs.size() - 1);
         (*node->mutable_attr())["N"].set_i(node->input_size() -
                                            num_control_inputs);
+        properties->ClearInputProperties(node->name());
         (*added_node->mutable_attr())["N"].set_i(const_inputs.size());
         graph_modified_ = true;
+        continue;
+      }
+    }
+
+    // Partial constant folding for Concat which is not commutative, so
+    // we have to preserve order and can only push consecutive runs of constant
+    // inputs into sub-nodes.
+    if (IsConcat(*node) && num_non_control_inputs > 3 &&
+        node->name().rfind("_partial_split_") == string::npos) {
+      int axis_arg = -1;
+      int begin = 0;
+      int end = num_non_control_inputs;
+      if (node->op() == "Concat") {
+        begin = 1;
+        axis_arg = 0;
+      } else if (node->op() == "ConcatV2") {
+        end = num_non_control_inputs - 1;
+        axis_arg = num_non_control_inputs - 1;
+      } else {
+        continue;
+      }
+
+      const NodeDef* axis_arg_node =
+          node_map_->GetNode(NodeName(node->input(axis_arg)));
+      if (axis_arg_node == nullptr || !IsReallyConstant(*axis_arg_node)) {
+        // We cannot constant fold Concat unless we the axis argument is
+        // constant. Skip node.
+        continue;
+      }
+
+      // We search for consecutive runs of constant inputs in the range
+      // [begin:end[ and push then down into child nodes.
+      std::vector<std::pair<int, int>> constant_input_runs;
+      int first = begin;
+      int last = begin;
+      while (last < end) {
+        while (first < end && !IsReallyConstant(*node_map_->GetNode(
+                                  NodeName(node->input(first))))) {
+          ++first;
+        }
+        // Invariant: node[first] is constant || first >= end.
+        last = first + 1;
+        while (last < end && IsReallyConstant(*node_map_->GetNode(
+                                 NodeName(node->input(last))))) {
+          ++last;
+        }
+        // Invariant: node[last] is not constant || last >= end
+        // Discard intervals shorter than 2 elements.
+        if (first < end && (last - first) > 1) {
+          constant_input_runs.emplace_back(first, last);
+        }
+        first = last;
+      }
+
+      // Skip if all inputs are constant, and let constant folding take over.
+      if (constant_input_runs.size() == 1 &&
+          constant_input_runs[0].first == begin &&
+          constant_input_runs[0].second == end) {
+        continue;
+      }
+      std::set<int> inputs_to_delete;
+      for (auto interval : constant_input_runs) {
+        // Push the constant inputs in the interval to a child node than can be
+        // constant folded.
+        const string new_node_name = OptimizedNodeName(
+            *node, strings::StrCat("_partial_split_", interval.first));
+        if (node_map_->NodeExists(new_node_name)) {
+          break;
+        }
+        NodeDef* added_node = output->add_node();
+        *added_node = *node;
+        added_node->set_name(new_node_name);
+        node_map_->AddNode(added_node->name(), added_node);
+        added_node->clear_input();
+        for (int i = interval.first; i < interval.second; ++i) {
+          added_node->add_input(node->input(i));
+          node_map_->UpdateOutput(NodeName(node->input(i)), node->name(),
+                                  added_node->name());
+          if (i != interval.first) {
+            inputs_to_delete.insert(i);
+          }
+        }
+        added_node->add_input(node->input(axis_arg));
+        (*added_node->mutable_attr())["N"].set_i(interval.second -
+                                                 interval.first);
+        node_map_->AddOutput(NodeName(node->input(axis_arg)),
+                             added_node->name());
+
+        // Overwrite the first constant input with the result of the added
+        // child node.
+        node->set_input(interval.first, added_node->name());
+        node_map_->AddOutput(added_node->name(), node->name());
+      }
+      if (!constant_input_runs.empty()) {
+        graph_modified_ = true;
+        if (!inputs_to_delete.empty()) {
+          // Fix up the inputs to the original node.
+          std::vector<string> tmp(node->input().begin(), node->input().end());
+          node->clear_input();
+          for (int i = 0; i < tmp.size(); ++i) {
+            if (inputs_to_delete.find(i) == inputs_to_delete.end()) {
+              node->add_input(tmp[i]);
+            }
+          }
+          (*node->mutable_attr())["N"].set_i(node->input_size() - 1);
+          properties->ClearInputProperties(node->name());
+        }
+        continue;
       }
     }
   }
index cf151d4..9050ccb 100644 (file)
@@ -199,13 +199,12 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
     Output sub1 = ops::Sub(s.WithOpName("sub1"), x, zeros);
     Output sub2 = ops::Sub(s.WithOpName("sub2"), zeros, y);
     Output concat =
-        ops::Concat(s.WithOpName("concat"),
-                    {mul1, mul2, mul3, mul4, mul5, mul6, div1, div2, matmul1,
-                     matmul2, add1, add2, bias_add1, bias_add2, sub1, sub2},
-                    0);
+        ops::Stack(s.WithOpName("stack"),
+                   {mul1, mul2, mul3, mul4, mul5, mul6, div1, div2, matmul1,
+                    matmul2, add1, add2, bias_add1, bias_add2, sub1, sub2});
     GrapplerItem item;
     TF_CHECK_OK(s.ToGraphDef(&item.graph));
-    item.fetch = {"concat", "matmul3", "matmul4"};
+    item.fetch = {"stack", "matmul3", "matmul4"};
 
     ConstantFolding optimizer(nullptr /* cpu_device */);
     GraphDef output;
@@ -219,7 +218,7 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
     const string ones_name = strings::StrCat("ones", suffix);
     const string ctrl_zeros_name = strings::StrCat("^zeros", suffix);
     const string ctrl_ones_name = strings::StrCat("^ones", suffix);
-    EXPECT_EQ(28, output.node_size());
+    EXPECT_EQ(27, output.node_size());
     for (int i = 0; i < output.node_size(); ++i) {
       const NodeDef& node = output.node(i);
       const string& name = node.name();
@@ -1825,19 +1824,19 @@ TEST_F(ConstantFoldingTest, PartialFolding_AssociativeAndCommutative) {
     Output acc4 = fun(s.WithOpName("acc4"), {c1, y, c2});
     Output acc5 = fun(s.WithOpName("acc5"), {x, c1, c2});
     Output acc6 = fun(s.WithOpName("acc6"), {x, c1, y, c2});
-    Output concat = ops::Concat(s.WithOpName("concat"),
-                                {acc0, acc1, acc2, acc3, acc4, acc5, acc6}, 0);
+    Output stack = ops::Stack(s.WithOpName("stack"),
+                              {acc0, acc1, acc2, acc3, acc4, acc5, acc6});
 
     GrapplerItem item;
     TF_CHECK_OK(s.ToGraphDef(&item.graph));
-    item.fetch = {"concat"};
+    item.fetch = {"stack"};
 
     ConstantFolding optimizer(nullptr /* cpu_device */);
     GraphDef output;
     Status status = optimizer.Optimize(nullptr, item, &output);
     TF_EXPECT_OK(status);
 
-    EXPECT_EQ(17, output.node_size());
+    EXPECT_EQ(16, output.node_size());
     for (const NodeDef& node : output.node()) {
       if (node.name() == "acc0") {
         EXPECT_EQ("Const", node.op());
@@ -1895,7 +1894,90 @@ TEST_F(ConstantFoldingTest, PartialFolding_AssociativeAndCommutative) {
   }
 }
 
-TEST_F(ConstantFoldingTest, IdenticalN) {
+TEST_F(ConstantFoldingTest, PartialFolding_Concat) {
+  Scope s = Scope::NewRootScope();
+  Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
+                              ops::Placeholder::Shape(TensorShape({2, 2})));
+  Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
+                              ops::Placeholder::Shape(TensorShape({2, 2})));
+  Output z = ops::Placeholder(s.WithOpName("z"), DT_FLOAT,
+                              ops::Placeholder::Shape(TensorShape({2, 2})));
+  Output axis = ops::Const(s.WithOpName("axis"), 0, {});
+  Output c1 = ops::Const(s.WithOpName("c1"), 1.0f, {2, 2});
+  Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2, 2});
+  Output concat0 = ops::Concat(s.WithOpName("concat0"), {c1, c2, c1}, axis);
+  Output concat1 = ops::Concat(s.WithOpName("concat1"), {x, y, z}, axis);
+  Output concat2 = ops::Concat(s.WithOpName("concat2"), {c1, x, y}, axis);
+  Output concat3 = ops::Concat(s.WithOpName("concat3"), {c1, c2, z}, axis);
+  Output concat4 = ops::Concat(s.WithOpName("concat4"), {c1, y, c2}, axis);
+  Output concat5 = ops::Concat(s.WithOpName("concat5"), {x, c1, c2}, axis);
+  Output concat6 = ops::Concat(s.WithOpName("concat6"), {x, c1, y, c2}, axis);
+  Output concat7 = ops::Concat(s.WithOpName("concat7"), {x, y, c1, c2}, axis);
+  Output concat8 = ops::Concat(s.WithOpName("concat8"), {x, c1, c2, y}, axis);
+  Output concat9 = ops::Concat(s.WithOpName("concat9"), {c1, c2, x, y}, axis);
+
+  GrapplerItem item;
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  item.fetch = {"concat0", "concat1", "concat2", "concat3", "concat4",
+                "concat5", "concat6", "concat7", "concat8", "concat9"};
+
+  ConstantFolding optimizer(nullptr /* cpu_device */);
+  GraphDef output;
+  Status status = optimizer.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
+  // Run the optimizer twice to make sure the rewrite is idempotent.
+  item.graph.Swap(&output);
+  status = optimizer.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
+
+  EXPECT_EQ(21, output.node_size());
+  for (int i = 0; i < output.node_size(); ++i) {
+    const NodeDef& node = output.node(i);
+    if (node.name() == "concat0") {
+      EXPECT_EQ("Const", node.op());
+    } else if (node.name() == "concat3") {
+      EXPECT_EQ(3, node.input_size());
+      EXPECT_EQ("ConstantFolding/concat3_partial_split_0", node.input(0));
+      EXPECT_EQ("z", node.input(1));
+      EXPECT_EQ("axis", node.input(2));
+    } else if (node.name() == "concat5") {
+      EXPECT_EQ(3, node.input_size());
+      EXPECT_EQ("x", node.input(0));
+      EXPECT_EQ("ConstantFolding/concat5_partial_split_1", node.input(1));
+      EXPECT_EQ("axis", node.input(2));
+    } else if (node.name() == "concat7") {
+      EXPECT_EQ(4, node.input_size());
+      EXPECT_EQ("x", node.input(0));
+      EXPECT_EQ("y", node.input(1));
+      EXPECT_EQ("ConstantFolding/concat7_partial_split_2", node.input(2));
+      EXPECT_EQ("axis", node.input(3));
+    } else if (node.name() == "concat8") {
+      EXPECT_EQ(4, node.input_size());
+      EXPECT_EQ("x", node.input(0));
+      EXPECT_EQ("ConstantFolding/concat8_partial_split_1", node.input(1));
+      EXPECT_EQ("y", node.input(2));
+      EXPECT_EQ("axis", node.input(3));
+    } else if (node.name() == "concat9") {
+      EXPECT_EQ(4, node.input_size());
+      EXPECT_EQ("ConstantFolding/concat9_partial_split_0", node.input(0));
+      EXPECT_EQ("x", node.input(1));
+      EXPECT_EQ("y", node.input(2));
+      EXPECT_EQ("axis", node.input(3));
+    } else if (StringPiece(node.name()).starts_with("ConstantFolding/")) {
+      EXPECT_EQ("Const", node.op());
+    } else {
+      EXPECT_EQ(item.graph.node(i).DebugString(), node.DebugString());
+    }
+  }
+
+  auto tensors_expected = EvaluateNodes(item.graph, {"concat0"});
+  auto tensors = EvaluateNodes(output, {"concat0"});
+  EXPECT_EQ(1, tensors_expected.size());
+  EXPECT_EQ(1, tensors.size());
+  test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+}
+
+TEST_F(ConstantFoldingTest, PartialFolding_IdentityN) {
   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
   Output x = ops::Placeholder(scope.WithOpName("x"), DT_FLOAT,
                               ops::Placeholder::Shape(TensorShape({})));