Optimize multiplications by constants in more cases.
authorBenoit Steiner <bsteiner@google.com>
Mon, 21 May 2018 18:12:55 +0000 (11:12 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 21 May 2018 18:17:30 +0000 (11:17 -0700)
PiperOrigin-RevId: 197422256

tensorflow/core/grappler/op_types.cc
tensorflow/core/grappler/op_types.h
tensorflow/core/grappler/optimizers/BUILD
tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding.h
tensorflow/core/grappler/optimizers/constant_folding_test.cc
tensorflow/core/grappler/optimizers/symbolic_shapes.cc
tensorflow/core/grappler/optimizers/symbolic_shapes.h
tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc

index 07f826b..9258194 100644 (file)
@@ -106,6 +106,8 @@ bool IsConv2DBackpropInput(const NodeDef& node) {
   return node.op() == "Conv2DBackpropInput";
 }
 
+bool IsConv3D(const NodeDef& node) { return node.op() == "Conv3D"; }
+
 bool IsDepthwiseConv2dNative(const NodeDef& node) {
   return node.op() == "DepthwiseConv2dNative";
 }
index a5599eb..9d91ba1 100644 (file)
@@ -48,6 +48,7 @@ bool IsConstant(const NodeDef& node);
 bool IsConv2D(const NodeDef& node);
 bool IsConv2DBackpropFilter(const NodeDef& node);
 bool IsConv2DBackpropInput(const NodeDef& node);
+bool IsConv3D(const NodeDef& node);
 bool IsDepthwiseConv2dNative(const NodeDef& node);
 bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node);
 bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node);
index 56c23db..104a042 100644 (file)
@@ -96,6 +96,7 @@ cc_library(
     visibility = ["//visibility:public"],
     deps = [
         ":graph_optimizer",
+        ":symbolic_shapes",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
index 38e4a02..2d13c78 100644 (file)
@@ -31,6 +31,7 @@ limitations under the License.
 #include "tensorflow/core/grappler/costs/graph_properties.h"
 #include "tensorflow/core/grappler/grappler_item.h"
 #include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
 #include "tensorflow/core/grappler/utils.h"
 #include "tensorflow/core/lib/core/stringpiece.h"
 #include "tensorflow/core/lib/gtl/cleanup.h"
@@ -2146,6 +2147,11 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
     return Status::OK();
   }
 
+  if (MulConvPushDown(node, *properties)) {
+    graph_modified_ = true;
+    return Status::OK();
+  }
+
   if (PartialConstPropThroughIdentityN(node)) {
     graph_modified_ = true;
     return Status::OK();
@@ -2321,6 +2327,7 @@ bool ConstantFolding::ReduceDivToReciprocalMul(GraphDef* optimized_graph,
     node_map_->UpdateOutput(node->name(), const_input, reciprocal_node->name());
     return true;
   }
+
   return false;
 }
 
@@ -2411,6 +2418,127 @@ bool ConstantFolding::ConstantPushDown(NodeDef* node) {
   return false;
 }
 
+bool ConstantFolding::MulConvPushDown(NodeDef* node,
+                                      const GraphProperties& properties) {
+  // Push down multiplication on ConvND.
+  //                       *                  ConvND
+  //                     /   \                /    \
+  //                 ConvND  C2    -- >      X      *
+  //                  / \                          / \
+  //                 X  C1                       C1  C2
+  //
+  // where C1 and C2 are constants and X is non-constant.
+  if (IsMul(*node) && NumNonControlInputs(*node) == 2) {
+    NodeDef* mul_left_child = node_map_->GetNode(node->input(0));
+    NodeDef* mul_right_child = node_map_->GetNode(node->input(1));
+    // One child must be constant, and the second must be Conv op.
+    const bool left_child_is_constant = IsReallyConstant(*mul_left_child);
+    const bool right_child_is_constant = IsReallyConstant(*mul_right_child);
+    if (!left_child_is_constant && !right_child_is_constant) {
+      return false;
+    }
+    NodeDef* conv_node =
+        left_child_is_constant ? mul_right_child : mul_left_child;
+    if (!IsConv2D(*conv_node) && !IsConv3D(*conv_node)) {
+      return false;
+    }
+    if (node->device() != mul_left_child->device() ||
+        node->device() != mul_right_child->device()) {
+      return false;
+    }
+
+    // Make sure that it is safe to change the value of the convolution
+    // output.
+    if (conv_node->input_size() < 2 ||
+        NumNonControlOutputs(*conv_node, *node_map_) > 1 ||
+        nodes_to_preserve_.find(conv_node->name()) !=
+            nodes_to_preserve_.end()) {
+      return false;
+    }
+
+    // Identify the nodes to swap.
+    NodeDef* conv_left_child = node_map_->GetNode(conv_node->input(0));
+    NodeDef* conv_right_child = node_map_->GetNode(conv_node->input(1));
+    const bool conv_left_is_constant = IsReallyConstant(*conv_left_child);
+    const bool conv_right_is_constant = IsReallyConstant(*conv_right_child);
+    if (!conv_left_is_constant && !conv_right_is_constant) {
+      // At least one of the convolution inputs should be constant.
+      return false;
+    }
+    if (conv_left_is_constant && conv_right_is_constant) {
+      // Leverage regular constant folding to handle this.
+      return false;
+    }
+    const auto& mul_props = properties.GetOutputProperties(node->name());
+    const auto& conv_props = properties.GetOutputProperties(conv_node->name());
+    if (mul_props.empty() || conv_props.empty()) {
+      return false;
+    }
+    const auto& mul_shape = mul_props[0].shape();
+    const auto& conv_shape = conv_props[0].shape();
+    if (!ShapesSymbolicallyEqual(mul_shape, conv_shape)) {
+      return false;
+    }
+
+    const auto& input_props = properties.GetInputProperties(conv_node->name());
+    if (input_props.size() < 2) {
+      return false;
+    }
+    const auto& filter_shape = input_props[1].shape();
+
+    NodeDef* const_node =
+        left_child_is_constant ? mul_left_child : mul_right_child;
+    const auto& const_props =
+        properties.GetOutputProperties(const_node->name());
+    if (const_props.empty()) {
+      return false;
+    }
+    const auto& const_shape = const_props[0].shape();
+
+    TensorShapeProto new_filter_shape;
+    if (!ShapeAfterBroadcast(filter_shape, const_shape, &new_filter_shape)) {
+      return false;
+    }
+    if (!ShapesSymbolicallyEqual(filter_shape, new_filter_shape)) {
+      return false;
+    }
+
+    string mul_new_name =
+        AddPrefixToNodeName("merged_input", conv_node->name());
+    if (node_map_->NodeExists(mul_new_name)) {
+      return false;
+    }
+    // Make sure we don't introduce loops in the graph by removing control
+    // dependencies from the conv2d node to c2.
+    NodeDef* conv_const_node =
+        conv_left_is_constant ? conv_left_child : conv_right_child;
+    if (MaybeRemoveControlInput(conv_node->name(), const_node, graph_,
+                                node_map_.get())) {
+      // Add a control dep from c1 to c2 to ensure c2 is in the right frame
+      *const_node->add_input() = AsControlDependency(*conv_const_node);
+    }
+
+    conv_node->set_name(node->name());
+    node->set_name(mul_new_name);
+    if (conv_left_is_constant) {
+      node_map_->UpdateInput(conv_node->name(), node->input(0), mul_new_name);
+      conv_node->set_input(0, mul_new_name);
+    } else {
+      node_map_->UpdateInput(conv_node->name(), node->input(1), mul_new_name);
+      conv_node->set_input(1, mul_new_name);
+    }
+    if (left_child_is_constant) {
+      node->set_input(1, conv_const_node->name());
+    } else {
+      node->set_input(0, conv_const_node->name());
+    }
+    node_map_->AddNode(mul_new_name, node);
+
+    return true;
+  }
+  return false;
+}
+
 bool ConstantFolding::PartialConstPropThroughIdentityN(NodeDef* node) {
   // Partial constant propagation through IdentityN.
   if (IsIdentityN(*node) && NumNonControlInputs(*node) > 0) {
index 6c99120..3efea4f 100644 (file)
@@ -122,6 +122,10 @@ class ConstantFolding : public GraphOptimizer {
   // the transformation applied successfully.
   bool ConstantPushDown(NodeDef* node);
 
+  // Aggregate constants present around a conv operator. Returns true if the
+  // transformation was applied successfully.
+  bool MulConvPushDown(NodeDef* node, const GraphProperties& properties);
+
   // Strength reduces floating point division by a constant Div(x, const) to
   // multiplication by the reciprocal Mul(x, Reciprocal(const)).
   bool ReduceDivToReciprocalMul(GraphDef* optimized_graph, NodeDef* node);
index ce38e0f..2a3758a 100644 (file)
@@ -240,6 +240,77 @@ TEST_F(ConstantFoldingTest, AddTree) {
   }
 }
 
+TEST_F(ConstantFoldingTest, ConvPushDownTest) {
+  // Tests if the following rewrite is performed:
+  //
+  //         *                       Conv2D
+  //        / \                       / \
+  //       c  Conv2D        -->      x  (c * filter)
+  //           / \
+  //          x  filter
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+  int input_depth = 3;
+  int filter_count = 5;
+  int filter_size = 2;
+  TensorShape filter_shape(
+      {filter_size, filter_size, input_depth, filter_count});
+  Tensor filter_values(DT_FLOAT, filter_shape);
+  for (int i = 0; i < filter_values.NumElements(); ++i) {
+    filter_values.flat<float>()(i) = std::sqrt(static_cast<float>(i));
+  }
+  Output filter =
+      ops::Const(s.WithOpName("filter"), Input::Initializer(filter_values));
+
+  int batch_size = 4;
+  int input_dim = 10;
+  TensorShape input_shape({batch_size, input_dim, input_dim, input_depth});
+  Output input = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
+                                  ops::Placeholder::Shape(input_shape));
+
+  Output conv =
+      ops::Conv2D(s.WithOpName("conv"), input, filter, {1, 1, 1, 1}, "VALID");
+  Output c = ops::Const(s.WithOpName("c"), 3.0f, {1});
+  Output mul = ops::Mul(s.WithOpName("mul"), c, conv);
+
+  GrapplerItem item;
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+  ConstantFolding fold(nullptr);
+  GraphDef output;
+  Status status = fold.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
+
+  std::cout << output.DebugString() << std::endl;
+
+  EXPECT_EQ(5, output.node_size());
+  int found = 0;
+  for (const auto& node : output.node()) {
+    if (node.name() == "mul") {
+      found++;
+      EXPECT_EQ("Conv2D", node.op());
+      EXPECT_EQ(2, node.input_size());
+      EXPECT_EQ("x", node.input(0));
+      EXPECT_EQ("conv/merged_input", node.input(1));
+    } else if (node.name() == "conv/merged_input") {
+      found++;
+      EXPECT_EQ("Const", node.op());
+      EXPECT_EQ(0, node.input_size());
+    }
+  }
+  EXPECT_EQ(2, found);
+
+  // Check that const folded multiplication node has the expected value.
+  std::vector<string> fetch = {"mul"};
+  Tensor value(DT_FLOAT, input_shape);
+  for (int i = 0; i < value.NumElements(); ++i) {
+    value.flat<float>()(i) = i;
+  }
+  auto actual = EvaluateNodes(output, fetch, {{"x", value}});
+  auto expected = EvaluateNodes(item.graph, fetch, {{"x", value}});
+  test::ExpectTensorEqual<float>(expected[0], actual[0]);
+}
+
 TEST_F(ConstantFoldingTest, NeutralElement) {
   int kConst = 0;
   int kLike = 1;
index 32e86f8..155843a 100644 (file)
@@ -106,6 +106,25 @@ bool ShapesBroadcastable(const OpInfo::TensorProperties& left,
   return ShapesBroadcastable(left.shape(), right.shape());
 }
 
+bool ShapeAfterBroadcast(const TensorShapeProto& left,
+                         const TensorShapeProto& right,
+                         TensorShapeProto* output_shape) {
+  if (!ShapeIsSymbolicallyDefined(left) || !ShapeIsSymbolicallyDefined(right)) {
+    return false;
+  }
+  BCast bcast(ShapeDims(left), ShapeDims(right),
+              /*fewer_dims_optimization*/ false);
+  if (!bcast.IsValid()) {
+    return false;
+  }
+  output_shape->set_unknown_rank(false);
+  output_shape->clear_dim();
+  for (const auto& dim : bcast.output_shape()) {
+    output_shape->add_dim()->set_size(dim);
+  }
+  return true;
+}
+
 bool CompareSymbolicallyShapedTensorSizes(const TensorShapeProto& left,
                                           const TensorShapeProto& right) {
   // if one of the ranks is unknown, it's impossible to compare tensor sizes
index 38d7fbf..ace7bd1 100644 (file)
@@ -53,6 +53,9 @@ bool ShapesBroadcastable(const TensorShapeProto& left,
                          const TensorShapeProto& right);
 bool ShapesBroadcastable(const OpInfo::TensorProperties& left,
                          const OpInfo::TensorProperties& right);
+bool ShapeAfterBroadcast(const TensorShapeProto& left,
+                         const TensorShapeProto& right,
+                         TensorShapeProto* output_shape);
 
 // Return true if can prove, that tensor of size 'left' is smaller than tensor
 // of size 'right'. Return false if it's larger or equal, or it's impossible to
index 5720fbd..7ce995d 100644 (file)
@@ -74,6 +74,26 @@ TEST_F(SymbolicShapesTest, ShapesBroadcastable) {
   EXPECT_TRUE(ShapesBroadcastable(MakeShape({-2, 1}), MakeShape({1, -2})));
   EXPECT_TRUE(ShapesBroadcastable(MakeShape({-2, 1}), MakeShape({1, -3})));
   EXPECT_TRUE(ShapesBroadcastable(MakeShape({-3}), MakeShape({-2, -3})));
+
+  TensorShapeProto output_shape;
+  EXPECT_TRUE(
+      ShapeAfterBroadcast(MakeShape({1, 2}), MakeShape({1, 2}), &output_shape));
+  EXPECT_TRUE(ShapesSymbolicallyEqual(MakeShape({1, 2}), output_shape));
+  EXPECT_TRUE(ShapeAfterBroadcast(MakeShape({-2, 2}), MakeShape({-2, 2}),
+                                  &output_shape));
+  EXPECT_TRUE(ShapesSymbolicallyEqual(MakeShape({-2, 2}), output_shape));
+  EXPECT_TRUE(ShapeAfterBroadcast(MakeShape({-2, 32}), MakeShape({-2, 1}),
+                                  &output_shape));
+  EXPECT_TRUE(ShapesSymbolicallyEqual(MakeShape({-2, 32}), output_shape));
+  EXPECT_TRUE(ShapeAfterBroadcast(MakeShape({-2, 1}), MakeShape({1, -2}),
+                                  &output_shape));
+  EXPECT_TRUE(ShapesSymbolicallyEqual(MakeShape({-2, -2}), output_shape));
+  EXPECT_TRUE(ShapeAfterBroadcast(MakeShape({-2, 1}), MakeShape({1, -3}),
+                                  &output_shape));
+  EXPECT_TRUE(ShapesSymbolicallyEqual(MakeShape({-2, -3}), output_shape));
+  EXPECT_TRUE(
+      ShapeAfterBroadcast(MakeShape({-3}), MakeShape({-2, -3}), &output_shape));
+  EXPECT_TRUE(ShapesSymbolicallyEqual(MakeShape({-2, -3}), output_shape));
 }
 
 TEST_F(SymbolicShapesTest, CompareSymbolicallyShapedTensorSizes) {