Implement partial constant folding for Concat.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 9 Mar 2018 18:22:16 +0000 (10:22 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 9 Mar 2018 18:26:07 +0000 (10:26 -0800)
PiperOrigin-RevId: 188501394

tensorflow/core/grappler/costs/graph_properties.cc
tensorflow/core/grappler/costs/graph_properties.h
tensorflow/core/grappler/costs/graph_properties_test.cc
tensorflow/core/grappler/op_types.cc
tensorflow/core/grappler/op_types.h
tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding.h
tensorflow/core/grappler/optimizers/constant_folding_test.cc

index 243ca91..817247e 100644 (file)
@@ -1182,5 +1182,12 @@ GraphProperties::GetOutputProperties(const string& node_name) const {
   return missing_properties_;
 }
 
+void GraphProperties::ClearInputProperties(const string& node_name) {
+  input_properties_.erase(node_name);
+}
+void GraphProperties::ClearOutputProperties(const string& node_name) {
+  output_properties_.erase(node_name);
+}
+
 }  // end namespace grappler
 }  // end namespace tensorflow
index 6fc53a7..5aa4962 100644 (file)
@@ -64,6 +64,8 @@ class GraphProperties {
       const string& node_name) const;
   const std::vector<OpInfo::TensorProperties>& GetOutputProperties(
       const string& node_name) const;
+  void ClearInputProperties(const string& node_name);
+  void ClearOutputProperties(const string& node_name);
 
   static void FillTensorPropertiesFromContext(
       const shape_inference::ShapeHandle&, const DataType&,
index 5012069..284d9d4 100644 (file)
@@ -113,6 +113,33 @@ TEST_F(GraphPropertiesTest, StaticProperties) {
   }
 }
 
+TEST_F(GraphPropertiesTest, ClearProperties) {
+  TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
+                                          cluster_->GetDeviceNames());
+  GrapplerItem item;
+  CHECK(fake_input.NextItem(&item));
+
+  GraphProperties properties(item);
+  Status s = properties.InferStatically(true);
+  TF_CHECK_OK(s);
+
+  for (const auto& node : item.graph.node()) {
+    if (node.op() == "RandomStandardNormal") {
+      EXPECT_EQ(1, properties.GetInputProperties(node.name()).size());
+      const auto props = properties.GetOutputProperties(node.name());
+      properties.ClearOutputProperties(node.name());
+      const auto cleared_props = properties.GetOutputProperties(node.name());
+      EXPECT_TRUE(cleared_props.empty());
+    } else if (node.op() == "AddN") {
+      const auto in_props = properties.GetInputProperties(node.name());
+      EXPECT_EQ(1, in_props.size());
+      properties.ClearInputProperties(node.name());
+      const auto cleared_props = properties.GetInputProperties(node.name());
+      EXPECT_TRUE(cleared_props.empty());
+    }
+  }
+}
+
 TEST_F(GraphPropertiesTest, DynamicProperties) {
   TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
                                           cluster_->GetDeviceNames());
index 8cf1402..ae71094 100644 (file)
@@ -72,6 +72,10 @@ bool IsComplex(const NodeDef& node) { return node.op() == "Complex"; }
 
 bool IsComplexAbs(const NodeDef& node) { return node.op() == "ComplexAbs"; }
 
+bool IsConcat(const NodeDef& node) {
+  return node.op() == "Concat" || node.op() == "ConcatV2";
+}
+
 bool IsConcatOffset(const NodeDef& node) { return node.op() == "ConcatOffset"; }
 
 bool IsConstant(const NodeDef& node) { return node.op() == "Const"; }
@@ -213,6 +217,8 @@ bool IsNextIteration(const NodeDef& node) {
   return op == "NextIteration" || op == "RefNextIteration";
 }
 
+bool IsPack(const NodeDef& node) { return node.op() == "Pack"; }
+
 bool IsPad(const NodeDef& node) {
   const auto& op = node.op();
   return op == "Pad" || op == "PadV2";
index a7c33ef..690275d 100644 (file)
@@ -40,6 +40,7 @@ bool IsCast(const NodeDef& node);
 bool IsComplex(const NodeDef& node);
 bool IsComplexAbs(const NodeDef& node);
 bool IsConj(const NodeDef& node);
+bool IsConcat(const NodeDef& node);
 bool IsConcatOffset(const NodeDef& node);
 bool IsConstant(const NodeDef& node);
 bool IsConv2D(const NodeDef& node);
@@ -85,6 +86,7 @@ bool IsMul(const NodeDef& node);
 bool IsMatMul(const NodeDef& node);
 bool IsNextIteration(const NodeDef& node);
 bool IsPad(const NodeDef& node);
+bool IsPack(const NodeDef& node);
 bool IsNoOp(const NodeDef& node);
 bool IsNotEqual(const NodeDef& node);
 bool IsPlaceholder(const NodeDef& node);
index 31dc1b7..4036ea3 100644 (file)
@@ -1510,7 +1510,7 @@ Status ConstantFolding::ReplaceOperationWithConstant(
 }
 
 Status ConstantFolding::SimplifyGraph(GraphDef* output,
-                                      const GraphProperties& properties,
+                                      GraphProperties* properties,
                                       bool use_shape_info) {
   const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
   for (int i = 0; i < output->node_size(); ++i) {
@@ -1520,7 +1520,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
     if (use_shape_info &&
         (IsShuffle(*node) || IsReverse(*node) || IsTranspose(*node))) {
       const auto& shape =
-          properties.GetInputProperties(node->name())[0].shape();
+          properties->GetInputProperties(node->name())[0].shape();
       // The node is replaceable iff
       // unknown_rank == false && (dim_size == 0 || all dims have size 1)
       bool replaceable = !shape.unknown_rank();
@@ -1649,7 +1649,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
       graph_modified_ = true;
       continue;
     }
-    if (use_shape_info && IsSimplifiableReshape(*node, properties)) {
+    if (use_shape_info && IsSimplifiableReshape(*node, *properties)) {
       DataType output_type = node->attr().at("T").type();
       node->set_op("Identity");
       node->clear_attr();
@@ -1667,8 +1667,8 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
     // Simplify arithmetic operations with ones or zeros.
     if (use_shape_info &&
         (is_mul || is_matmul || is_add || is_sub || is_any_div) &&
-        properties.HasInputProperties(node->name()) &&
-        properties.HasOutputProperties(node->name())) {
+        properties->HasInputProperties(node->name()) &&
+        properties->HasOutputProperties(node->name())) {
       const NodeDef* x = node_map_->GetNode(node->input(0));
       const NodeDef* y = node_map_->GetNode(node->input(1));
       if (x == nullptr || y == nullptr) {
@@ -1676,12 +1676,12 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
                                        node->DebugString());
       }
       const TensorShapeProto& output_shape =
-          properties.GetOutputProperties(node->name())[0].shape();
+          properties->GetOutputProperties(node->name())[0].shape();
 
       // Simplify element-wise multiplication by ones or addition/subtraction
       // of zeros.
       const TensorShapeProto& y_shape =
-          properties.GetInputProperties(node->name())[1].shape();
+          properties->GetInputProperties(node->name())[1].shape();
       const bool x_is_zero = IsZeros(*x);
       const bool x_is_one = IsOnes(*x);
       const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
@@ -1708,7 +1708,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
       }
 
       const TensorShapeProto& x_shape =
-          properties.GetInputProperties(node->name())[0].shape();
+          properties->GetInputProperties(node->name())[0].shape();
       const bool y_is_zero = IsZeros(*y);
       const bool y_is_one = IsOnes(*y);
       const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
@@ -1921,13 +1921,11 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
     // folding of ops when more than one but not all inputs are constant.
     // For AddN and AccumulateNV2, we may furthermore reorder inputs, since
     // addition is commutative.
-    // TODO(rmlarsen): Concat/Pack/ParallelConcat which are not commutative, so
-    // we have to preserve order and can only push consecutive runs of constant
-    // inputs into sub-nodes.
+    const int num_non_control_inputs = NumNonControlInputs(*node);
     if (IsAggregate(*node) && IsCommutative(*node) &&
-        NumNonControlInputs(*node) > 2) {
+        num_non_control_inputs > 2) {
       const int num_control_inputs =
-          node->input_size() - NumNonControlInputs(*node);
+          node->input_size() - num_non_control_inputs;
       std::vector<int> const_inputs;
       std::vector<int> nonconst_inputs;
       for (int i = 0; i < node->input_size(); ++i) {
@@ -1943,7 +1941,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
       }
       // Promote AccumulateNV2 with all constant inputs to AddN, since it is
       // a fake node that cannot be constant folded by itself.
-      if (const_inputs.size() == NumNonControlInputs(*node) &&
+      if (const_inputs.size() == num_non_control_inputs &&
           node->op() == "AccumulateNV2") {
         node->set_op("AddN");
         node->mutable_attr()->erase("shape");
@@ -1953,7 +1951,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
       const string new_node_name = OptimizedNodeName(
           *node, strings::StrCat("_partial_split_", const_inputs.size()));
       if (1 < const_inputs.size() &&
-          const_inputs.size() < NumNonControlInputs(*node) &&
+          const_inputs.size() < num_non_control_inputs &&
           !node_map_->NodeExists(new_node_name)) {
         NodeDef* added_node = output->add_node();
         *added_node = *node;
@@ -1987,8 +1985,121 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
                                               const_inputs.size() - 1);
         (*node->mutable_attr())["N"].set_i(node->input_size() -
                                            num_control_inputs);
+        properties->ClearInputProperties(node->name());
         (*added_node->mutable_attr())["N"].set_i(const_inputs.size());
         graph_modified_ = true;
+        continue;
+      }
+    }
+
+    // Partial constant folding for Concat which is not commutative, so
+    // we have to preserve order and can only push consecutive runs of constant
+    // inputs into sub-nodes.
+    if (IsConcat(*node) && num_non_control_inputs > 3) {
+      bool already_optimized = false;
+      const string optimized = strings::StrCat(node->name(), "_partial_split_");
+      for (const string& input : node->input()) {
+        if (input.rfind(optimized) != string::npos) {
+          already_optimized = true;
+          break;
+        }
+      }
+      if (already_optimized) {
+        continue;
+      }
+      int axis_arg = -1;
+      int begin = 0;
+      int end = num_non_control_inputs;
+      if (node->op() == "Concat") {
+        begin = 1;
+        axis_arg = 0;
+      } else if (node->op() == "ConcatV2") {
+        end = num_non_control_inputs - 1;
+        axis_arg = num_non_control_inputs - 1;
+      } else {
+        continue;
+      }
+
+      const NodeDef* axis_arg_node =
+          node_map_->GetNode(NodeName(node->input(axis_arg)));
+      if (axis_arg_node == nullptr || !IsReallyConstant(*axis_arg_node)) {
+        // We cannot constant fold Concat unless we know the axis.
+        // Skip node.
+        continue;
+      }
+
+      // We search for consecutive runs of constant inputs in the range
+      // [begin:end[ and push then down into child nodes.
+      std::vector<std::pair<int, int>> constant_input_runs;
+      int first = begin;
+      int last = begin;
+      while (last < end) {
+        while (first < end && !IsReallyConstant(*node_map_->GetNode(
+                                  NodeName(node->input(first))))) {
+          ++first;
+        }
+        // Invariant: node[first] is constant || first >= end.
+        last = first + 1;
+        while (last < end && IsReallyConstant(*node_map_->GetNode(
+                                 NodeName(node->input(last))))) {
+          ++last;
+        }
+        // Invariant: node[last] is not constant || last >= end
+        // Discard intervals shorter than 2 elements.
+        if (first < end && (last - first) > 1) {
+          constant_input_runs.emplace_back(first, last);
+        }
+        first = last;
+      }
+
+      std::set<int> inputs_to_delete;
+      for (auto interval : constant_input_runs) {
+        // Push the constant inputs in the interval to a child node than can be
+        // constant folded.
+        const string new_node_name = OptimizedNodeName(
+            *node, strings::StrCat("_partial_split_", interval.first));
+        if (node_map_->NodeExists(new_node_name)) {
+          break;
+        }
+        NodeDef* added_node = output->add_node();
+        *added_node = *node;
+        added_node->set_name(new_node_name);
+        node_map_->AddNode(added_node->name(), added_node);
+        added_node->clear_input();
+        for (int i = interval.first; i < interval.second; ++i) {
+          added_node->add_input(node->input(i));
+          node_map_->UpdateOutput(NodeName(node->input(i)), node->name(),
+                                  added_node->name());
+          if (i != interval.first) {
+            inputs_to_delete.insert(i);
+          }
+        }
+        added_node->add_input(node->input(axis_arg));
+        (*added_node->mutable_attr())["N"].set_i(interval.second -
+                                                 interval.first);
+        node_map_->AddOutput(NodeName(node->input(axis_arg)),
+                             added_node->name());
+
+        // Overwrite the first constant input with the result of the added
+        // child node.
+        node->set_input(interval.first, added_node->name());
+        node_map_->AddOutput(added_node->name(), node->name());
+      }
+      if (!constant_input_runs.empty()) {
+        graph_modified_ = true;
+        if (!inputs_to_delete.empty()) {
+          // Fix up the inputs to the original node.
+          std::vector<string> tmp(node->input().begin(), node->input().end());
+          node->clear_input();
+          for (int i = 0; i < tmp.size(); ++i) {
+            if (inputs_to_delete.find(i) == inputs_to_delete.end()) {
+              node->add_input(tmp[i]);
+            }
+          }
+          (*node->mutable_attr())["N"].set_i(node->input_size() - 1);
+          properties->ClearInputProperties(node->name());
+        }
+        continue;
       }
     }
   }
@@ -2030,7 +2141,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
 
   TF_RETURN_IF_ERROR(FoldGraph(output));
   node_map_.reset(new NodeMap(output));
-  TF_RETURN_IF_ERROR(SimplifyGraph(output, properties, can_use_shape_info));
+  TF_RETURN_IF_ERROR(SimplifyGraph(output, &properties, can_use_shape_info));
 
   return Status::OK();
 }
index 2fd59c7..13ecfcd 100644 (file)
@@ -92,7 +92,7 @@ class ConstantFolding : public GraphOptimizer {
   bool IsSimplifiableReduction(const NodeDef& node) const;
   bool IsSimplifiableReshape(const NodeDef& node,
                              const GraphProperties& properties) const;
-  Status SimplifyGraph(GraphDef* output, const GraphProperties& properties,
+  Status SimplifyGraph(GraphDef* output, GraphProperties* properties,
                        bool use_shape_info);
 
   Status RunOptimizationPass(Cluster* cluster, const GrapplerItem& item,
index 4b97708..9214695 100644 (file)
@@ -188,20 +188,19 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
     Output sub1 = ops::Sub(s.WithOpName("sub1"), x, zeros);
     Output sub2 = ops::Sub(s.WithOpName("sub2"), zeros, y);
     Output concat =
-        ops::Concat(s.WithOpName("concat"),
-                    {mul1, mul2, mul3, mul4, mul5, mul6, div1, div2, matmul1,
-                     matmul2, add1, add2, bias_add1, bias_add2, sub1, sub2},
-                    0);
+        ops::Stack(s.WithOpName("stack"),
+                   {mul1, mul2, mul3, mul4, mul5, mul6, div1, div2, matmul1,
+                    matmul2, add1, add2, bias_add1, bias_add2, sub1, sub2});
     GrapplerItem item;
     TF_CHECK_OK(s.ToGraphDef(&item.graph));
-    item.fetch = {"concat", "matmul3", "matmul4"};
+    item.fetch = {"stack", "matmul3", "matmul4"};
 
     ConstantFolding optimizer(nullptr /* cpu_device */);
     GraphDef output;
     Status status = optimizer.Optimize(nullptr, item, &output);
     TF_EXPECT_OK(status);
 
-    EXPECT_EQ(28, output.node_size());
+    EXPECT_EQ(27, output.node_size());
     for (int i = 0; i < output.node_size(); ++i) {
       const NodeDef& node = output.node(i);
       const string& name = node.name();
@@ -1626,19 +1625,19 @@ TEST_F(ConstantFoldingTest, PartialFolding_AssociativeAndCommutative) {
     Output acc4 = fun(s.WithOpName("acc4"), {c1, y, c2});
     Output acc5 = fun(s.WithOpName("acc5"), {x, c1, c2});
     Output acc6 = fun(s.WithOpName("acc6"), {x, c1, y, c2});
-    Output concat = ops::Concat(s.WithOpName("concat"),
-                                {acc0, acc1, acc2, acc3, acc4, acc5, acc6}, 0);
+    Output stack = ops::Stack(s.WithOpName("stack"),
+                              {acc0, acc1, acc2, acc3, acc4, acc5, acc6});
 
     GrapplerItem item;
     TF_CHECK_OK(s.ToGraphDef(&item.graph));
-    item.fetch = {"concat"};
+    item.fetch = {"stack"};
 
     ConstantFolding optimizer(nullptr /* cpu_device */);
     GraphDef output;
     Status status = optimizer.Optimize(nullptr, item, &output);
     TF_EXPECT_OK(status);
 
-    EXPECT_EQ(17, output.node_size());
+    EXPECT_EQ(16, output.node_size());
     for (const NodeDef& node : output.node()) {
       if (node.name() == "acc0") {
         EXPECT_EQ("Const", node.op());
@@ -1696,7 +1695,86 @@ TEST_F(ConstantFoldingTest, PartialFolding_AssociativeAndCommutative) {
   }
 }
 
-TEST_F(ConstantFoldingTest, IdenticalN) {
+TEST_F(ConstantFoldingTest, PartialFolding_Concat) {
+  Scope s = Scope::NewRootScope();
+  Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
+                              ops::Placeholder::Shape(TensorShape({2, 2})));
+  Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
+                              ops::Placeholder::Shape(TensorShape({2, 2})));
+  Output z = ops::Placeholder(s.WithOpName("z"), DT_FLOAT,
+                              ops::Placeholder::Shape(TensorShape({2, 2})));
+  Output axis = ops::Const(s.WithOpName("axis"), 0, {});
+  Output c1 = ops::Const(s.WithOpName("c1"), 1.0f, {2, 2});
+  Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2, 2});
+  Output concat0 = ops::Concat(s.WithOpName("concat0"), {c1, c2, c1}, axis);
+  Output concat1 = ops::Concat(s.WithOpName("concat1"), {x, y, z}, axis);
+  Output concat2 = ops::Concat(s.WithOpName("concat2"), {c1, x, y}, axis);
+  Output concat3 = ops::Concat(s.WithOpName("concat3"), {c1, c2, z}, axis);
+  Output concat4 = ops::Concat(s.WithOpName("concat4"), {c1, y, c2}, axis);
+  Output concat5 = ops::Concat(s.WithOpName("concat5"), {x, c1, c2}, axis);
+  Output concat6 = ops::Concat(s.WithOpName("concat6"), {x, c1, y, c2}, axis);
+  Output concat7 = ops::Concat(s.WithOpName("concat7"), {x, y, c1, c2}, axis);
+  Output concat8 = ops::Concat(s.WithOpName("concat8"), {x, c1, c2, y}, axis);
+  Output concat9 = ops::Concat(s.WithOpName("concat9"), {c1, c2, x, y}, axis);
+
+  GrapplerItem item;
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  item.fetch = {"concat0", "concat1", "concat2", "concat3", "concat4",
+                "concat5", "concat6", "concat7", "concat8", "concat9"};
+
+  ConstantFolding optimizer(nullptr /* cpu_device */);
+  GraphDef output;
+  Status status = optimizer.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
+
+  EXPECT_EQ(21, output.node_size());
+  for (int i = 0; i < output.node_size(); ++i) {
+    const NodeDef& node = output.node(i);
+    if (node.name() == "concat0") {
+      EXPECT_EQ("Const", node.op());
+    } else if (node.name() == "concat3") {
+      EXPECT_EQ(3, node.input_size());
+      EXPECT_EQ("ConstantFolding/concat3_partial_split_0", node.input(0));
+      EXPECT_EQ("z", node.input(1));
+      EXPECT_EQ("axis", node.input(2));
+    } else if (node.name() == "concat5") {
+      EXPECT_EQ(3, node.input_size());
+      EXPECT_EQ("x", node.input(0));
+      EXPECT_EQ("ConstantFolding/concat5_partial_split_1", node.input(1));
+      EXPECT_EQ("axis", node.input(2));
+    } else if (node.name() == "concat7") {
+      EXPECT_EQ(4, node.input_size());
+      EXPECT_EQ("x", node.input(0));
+      EXPECT_EQ("y", node.input(1));
+      EXPECT_EQ("ConstantFolding/concat7_partial_split_2", node.input(2));
+      EXPECT_EQ("axis", node.input(3));
+    } else if (node.name() == "concat8") {
+      EXPECT_EQ(4, node.input_size());
+      EXPECT_EQ("x", node.input(0));
+      EXPECT_EQ("ConstantFolding/concat8_partial_split_1", node.input(1));
+      EXPECT_EQ("y", node.input(2));
+      EXPECT_EQ("axis", node.input(3));
+    } else if (node.name() == "concat9") {
+      EXPECT_EQ(4, node.input_size());
+      EXPECT_EQ("ConstantFolding/concat9_partial_split_0", node.input(0));
+      EXPECT_EQ("x", node.input(1));
+      EXPECT_EQ("y", node.input(2));
+      EXPECT_EQ("axis", node.input(3));
+    } else if (StringPiece(node.name()).starts_with("ConstantFolding/")) {
+      EXPECT_EQ("Const", node.op());
+    } else {
+      EXPECT_EQ(item.graph.node(i).DebugString(), node.DebugString());
+    }
+  }
+
+  auto tensors_expected = EvaluateNodes(item.graph, {"concat0"});
+  auto tensors = EvaluateNodes(output, {"concat0"});
+  EXPECT_EQ(1, tensors_expected.size());
+  EXPECT_EQ(1, tensors.size());
+  test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+}
+
+TEST_F(ConstantFoldingTest, PartialFolding_IdentityN) {
   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
   Output x = ops::Placeholder(scope.WithOpName("x"), DT_FLOAT,
                               ops::Placeholder::Shape(TensorShape({})));