From fdd06c559d4dcf9b8c7a4f3bc54540a5aa99d083 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Mon, 21 May 2018 11:12:55 -0700 Subject: [PATCH] Optimize multiplications by constants in more cases. PiperOrigin-RevId: 197422256 --- tensorflow/core/grappler/op_types.cc | 2 + tensorflow/core/grappler/op_types.h | 1 + tensorflow/core/grappler/optimizers/BUILD | 1 + .../core/grappler/optimizers/constant_folding.cc | 128 +++++++++++++++++++++ .../core/grappler/optimizers/constant_folding.h | 4 + .../grappler/optimizers/constant_folding_test.cc | 71 ++++++++++++ .../core/grappler/optimizers/symbolic_shapes.cc | 19 +++ .../core/grappler/optimizers/symbolic_shapes.h | 3 + .../grappler/optimizers/symbolic_shapes_test.cc | 20 ++++ 9 files changed, 249 insertions(+) diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 07f826b..9258194 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -106,6 +106,8 @@ bool IsConv2DBackpropInput(const NodeDef& node) { return node.op() == "Conv2DBackpropInput"; } +bool IsConv3D(const NodeDef& node) { return node.op() == "Conv3D"; } + bool IsDepthwiseConv2dNative(const NodeDef& node) { return node.op() == "DepthwiseConv2dNative"; } diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index a5599eb..9d91ba1 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -48,6 +48,7 @@ bool IsConstant(const NodeDef& node); bool IsConv2D(const NodeDef& node); bool IsConv2DBackpropFilter(const NodeDef& node); bool IsConv2DBackpropInput(const NodeDef& node); +bool IsConv3D(const NodeDef& node); bool IsDepthwiseConv2dNative(const NodeDef& node); bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node); bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 56c23db..104a042 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -96,6 +96,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":graph_optimizer", + ":symbolic_shapes", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 38e4a02..2d13c78 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/cleanup.h" @@ -2146,6 +2147,11 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph, return Status::OK(); } + if (MulConvPushDown(node, *properties)) { + graph_modified_ = true; + return Status::OK(); + } + if (PartialConstPropThroughIdentityN(node)) { graph_modified_ = true; return Status::OK(); @@ -2321,6 +2327,7 @@ bool ConstantFolding::ReduceDivToReciprocalMul(GraphDef* optimized_graph, node_map_->UpdateOutput(node->name(), const_input, reciprocal_node->name()); return true; } + return false; } @@ -2411,6 +2418,127 @@ bool ConstantFolding::ConstantPushDown(NodeDef* node) { return false; } +bool ConstantFolding::MulConvPushDown(NodeDef* node, + const GraphProperties& properties) { + // Push down multiplication on ConvND. + // * ConvND + // / \ / \ + // ConvND C2 -- > X * + // / \ / \ + // X C1 C1 C2 + // + // where C1 and C2 are constants and X is non-constant. + if (IsMul(*node) && NumNonControlInputs(*node) == 2) { + NodeDef* mul_left_child = node_map_->GetNode(node->input(0)); + NodeDef* mul_right_child = node_map_->GetNode(node->input(1)); + // One child must be constant, and the second must be Conv op. + const bool left_child_is_constant = IsReallyConstant(*mul_left_child); + const bool right_child_is_constant = IsReallyConstant(*mul_right_child); + if (!left_child_is_constant && !right_child_is_constant) { + return false; + } + NodeDef* conv_node = + left_child_is_constant ? mul_right_child : mul_left_child; + if (!IsConv2D(*conv_node) && !IsConv3D(*conv_node)) { + return false; + } + if (node->device() != mul_left_child->device() || + node->device() != mul_right_child->device()) { + return false; + } + + // Make sure that it is safe to change the value of the convolution + // output. + if (conv_node->input_size() < 2 || + NumNonControlOutputs(*conv_node, *node_map_) > 1 || + nodes_to_preserve_.find(conv_node->name()) != + nodes_to_preserve_.end()) { + return false; + } + + // Identify the nodes to swap. + NodeDef* conv_left_child = node_map_->GetNode(conv_node->input(0)); + NodeDef* conv_right_child = node_map_->GetNode(conv_node->input(1)); + const bool conv_left_is_constant = IsReallyConstant(*conv_left_child); + const bool conv_right_is_constant = IsReallyConstant(*conv_right_child); + if (!conv_left_is_constant && !conv_right_is_constant) { + // At least one of the convolution inputs should be constant. + return false; + } + if (conv_left_is_constant && conv_right_is_constant) { + // Leverage regular constant folding to handle this. + return false; + } + const auto& mul_props = properties.GetOutputProperties(node->name()); + const auto& conv_props = properties.GetOutputProperties(conv_node->name()); + if (mul_props.empty() || conv_props.empty()) { + return false; + } + const auto& mul_shape = mul_props[0].shape(); + const auto& conv_shape = conv_props[0].shape(); + if (!ShapesSymbolicallyEqual(mul_shape, conv_shape)) { + return false; + } + + const auto& input_props = properties.GetInputProperties(conv_node->name()); + if (input_props.size() < 2) { + return false; + } + const auto& filter_shape = input_props[1].shape(); + + NodeDef* const_node = + left_child_is_constant ? mul_left_child : mul_right_child; + const auto& const_props = + properties.GetOutputProperties(const_node->name()); + if (const_props.empty()) { + return false; + } + const auto& const_shape = const_props[0].shape(); + + TensorShapeProto new_filter_shape; + if (!ShapeAfterBroadcast(filter_shape, const_shape, &new_filter_shape)) { + return false; + } + if (!ShapesSymbolicallyEqual(filter_shape, new_filter_shape)) { + return false; + } + + string mul_new_name = + AddPrefixToNodeName("merged_input", conv_node->name()); + if (node_map_->NodeExists(mul_new_name)) { + return false; + } + // Make sure we don't introduce loops in the graph by removing control + // dependencies from the conv2d node to c2. + NodeDef* conv_const_node = + conv_left_is_constant ? conv_left_child : conv_right_child; + if (MaybeRemoveControlInput(conv_node->name(), const_node, graph_, + node_map_.get())) { + // Add a control dep from c1 to c2 to ensure c2 is in the right frame + *const_node->add_input() = AsControlDependency(*conv_const_node); + } + + conv_node->set_name(node->name()); + node->set_name(mul_new_name); + if (conv_left_is_constant) { + node_map_->UpdateInput(conv_node->name(), node->input(0), mul_new_name); + conv_node->set_input(0, mul_new_name); + } else { + node_map_->UpdateInput(conv_node->name(), node->input(1), mul_new_name); + conv_node->set_input(1, mul_new_name); + } + if (left_child_is_constant) { + node->set_input(1, conv_const_node->name()); + } else { + node->set_input(0, conv_const_node->name()); + } + node_map_->AddNode(mul_new_name, node); + + return true; + } + return false; +} + bool ConstantFolding::PartialConstPropThroughIdentityN(NodeDef* node) { // Partial constant propagation through IdentityN. if (IsIdentityN(*node) && NumNonControlInputs(*node) > 0) { diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 6c99120..3efea4f 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -122,6 +122,10 @@ class ConstantFolding : public GraphOptimizer { // the transformation applied successfully. bool ConstantPushDown(NodeDef* node); + // Aggregate constants present around a conv operator. Returns true if the + // transformation was applied successfully. + bool MulConvPushDown(NodeDef* node, const GraphProperties& properties); + // Strength reduces floating point division by a constant Div(x, const) to // multiplication by the reciprocal Mul(x, Reciprocal(const)). bool ReduceDivToReciprocalMul(GraphDef* optimized_graph, NodeDef* node); diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index ce38e0f..2a3758a 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -240,6 +240,77 @@ TEST_F(ConstantFoldingTest, AddTree) { } } +TEST_F(ConstantFoldingTest, ConvPushDownTest) { + // Tests if the following rewrite is performed: + // + // * Conv2D + // / \ / \ + // c Conv2D --> x (c * filter) + // / \ + // x filter + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + int input_depth = 3; + int filter_count = 5; + int filter_size = 2; + TensorShape filter_shape( + {filter_size, filter_size, input_depth, filter_count}); + Tensor filter_values(DT_FLOAT, filter_shape); + for (int i = 0; i < filter_values.NumElements(); ++i) { + filter_values.flat()(i) = std::sqrt(static_cast(i)); + } + Output filter = + ops::Const(s.WithOpName("filter"), Input::Initializer(filter_values)); + + int batch_size = 4; + int input_dim = 10; + TensorShape input_shape({batch_size, input_dim, input_dim, input_depth}); + Output input = ops::Placeholder(s.WithOpName("x"), DT_FLOAT, + ops::Placeholder::Shape(input_shape)); + + Output conv = + ops::Conv2D(s.WithOpName("conv"), input, filter, {1, 1, 1, 1}, "VALID"); + Output c = ops::Const(s.WithOpName("c"), 3.0f, {1}); + Output mul = ops::Mul(s.WithOpName("mul"), c, conv); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + ConstantFolding fold(nullptr); + GraphDef output; + Status status = fold.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + std::cout << output.DebugString() << std::endl; + + EXPECT_EQ(5, output.node_size()); + int found = 0; + for (const auto& node : output.node()) { + if (node.name() == "mul") { + found++; + EXPECT_EQ("Conv2D", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("conv/merged_input", node.input(1)); + } else if (node.name() == "conv/merged_input") { + found++; + EXPECT_EQ("Const", node.op()); + EXPECT_EQ(0, node.input_size()); + } + } + EXPECT_EQ(2, found); + + // Check that const folded multiplication node has the expected value. + std::vector fetch = {"mul"}; + Tensor value(DT_FLOAT, input_shape); + for (int i = 0; i < value.NumElements(); ++i) { + value.flat()(i) = i; + } + auto actual = EvaluateNodes(output, fetch, {{"x", value}}); + auto expected = EvaluateNodes(item.graph, fetch, {{"x", value}}); + test::ExpectTensorEqual(expected[0], actual[0]); +} + TEST_F(ConstantFoldingTest, NeutralElement) { int kConst = 0; int kLike = 1; diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes.cc b/tensorflow/core/grappler/optimizers/symbolic_shapes.cc index 32e86f8..155843a 100644 --- a/tensorflow/core/grappler/optimizers/symbolic_shapes.cc +++ b/tensorflow/core/grappler/optimizers/symbolic_shapes.cc @@ -106,6 +106,25 @@ bool ShapesBroadcastable(const OpInfo::TensorProperties& left, return ShapesBroadcastable(left.shape(), right.shape()); } +bool ShapeAfterBroadcast(const TensorShapeProto& left, + const TensorShapeProto& right, + TensorShapeProto* output_shape) { + if (!ShapeIsSymbolicallyDefined(left) || !ShapeIsSymbolicallyDefined(right)) { + return false; + } + BCast bcast(ShapeDims(left), ShapeDims(right), + /*fewer_dims_optimization*/ false); + if (!bcast.IsValid()) { + return false; + } + output_shape->set_unknown_rank(false); + output_shape->clear_dim(); + for (const auto& dim : bcast.output_shape()) { + output_shape->add_dim()->set_size(dim); + } + return true; +} + bool CompareSymbolicallyShapedTensorSizes(const TensorShapeProto& left, const TensorShapeProto& right) { // if one of the ranks is unknown, it's impossible to compare tensor sizes diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes.h b/tensorflow/core/grappler/optimizers/symbolic_shapes.h index 38d7fbf..ace7bd1 100644 --- a/tensorflow/core/grappler/optimizers/symbolic_shapes.h +++ b/tensorflow/core/grappler/optimizers/symbolic_shapes.h @@ -53,6 +53,9 @@ bool ShapesBroadcastable(const TensorShapeProto& left, const TensorShapeProto& right); bool ShapesBroadcastable(const OpInfo::TensorProperties& left, const OpInfo::TensorProperties& right); +bool ShapeAfterBroadcast(const TensorShapeProto& left, + const TensorShapeProto& right, + TensorShapeProto* output_shape); // Return true if can prove, that tensor of size 'left' is smaller than tensor // of size 'right'. Return false if it's larger or equal, or it's impossible to diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc b/tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc index 5720fbd..7ce995d 100644 --- a/tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc +++ b/tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc @@ -74,6 +74,26 @@ TEST_F(SymbolicShapesTest, ShapesBroadcastable) { EXPECT_TRUE(ShapesBroadcastable(MakeShape({-2, 1}), MakeShape({1, -2}))); EXPECT_TRUE(ShapesBroadcastable(MakeShape({-2, 1}), MakeShape({1, -3}))); EXPECT_TRUE(ShapesBroadcastable(MakeShape({-3}), MakeShape({-2, -3}))); + + TensorShapeProto output_shape; + EXPECT_TRUE( + ShapeAfterBroadcast(MakeShape({1, 2}), MakeShape({1, 2}), &output_shape)); + EXPECT_TRUE(ShapesSymbolicallyEqual(MakeShape({1, 2}), output_shape)); + EXPECT_TRUE(ShapeAfterBroadcast(MakeShape({-2, 2}), MakeShape({-2, 2}), + &output_shape)); + EXPECT_TRUE(ShapesSymbolicallyEqual(MakeShape({-2, 2}), output_shape)); + EXPECT_TRUE(ShapeAfterBroadcast(MakeShape({-2, 32}), MakeShape({-2, 1}), + &output_shape)); + EXPECT_TRUE(ShapesSymbolicallyEqual(MakeShape({-2, 32}), output_shape)); + EXPECT_TRUE(ShapeAfterBroadcast(MakeShape({-2, 1}), MakeShape({1, -2}), + &output_shape)); + EXPECT_TRUE(ShapesSymbolicallyEqual(MakeShape({-2, -2}), output_shape)); + EXPECT_TRUE(ShapeAfterBroadcast(MakeShape({-2, 1}), MakeShape({1, -3}), + &output_shape)); + EXPECT_TRUE(ShapesSymbolicallyEqual(MakeShape({-2, -3}), output_shape)); + EXPECT_TRUE( + ShapeAfterBroadcast(MakeShape({-3}), MakeShape({-2, -3}), &output_shape)); + EXPECT_TRUE(ShapesSymbolicallyEqual(MakeShape({-2, -3}), output_shape)); } TEST_F(SymbolicShapesTest, CompareSymbolicallyShapedTensorSizes) { -- 2.7.4