return node.op() == "Conv2DBackpropInput";
}
+bool IsConv3D(const NodeDef& node) { return node.op() == "Conv3D"; }
+
bool IsDepthwiseConv2dNative(const NodeDef& node) {
return node.op() == "DepthwiseConv2dNative";
}
bool IsConv2D(const NodeDef& node);
bool IsConv2DBackpropFilter(const NodeDef& node);
bool IsConv2DBackpropInput(const NodeDef& node);
+bool IsConv3D(const NodeDef& node);
bool IsDepthwiseConv2dNative(const NodeDef& node);
bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node);
bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node);
visibility = ["//visibility:public"],
deps = [
":graph_optimizer",
+ ":symbolic_shapes",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
return Status::OK();
}
+ if (MulConvPushDown(node, *properties)) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
if (PartialConstPropThroughIdentityN(node)) {
graph_modified_ = true;
return Status::OK();
node_map_->UpdateOutput(node->name(), const_input, reciprocal_node->name());
return true;
}
+
return false;
}
return false;
}
+bool ConstantFolding::MulConvPushDown(NodeDef* node,
+ const GraphProperties& properties) {
+ // Push down multiplication on ConvND.
+ // * ConvND
+ // / \ / \
+ // ConvND C2 -- > X *
+ // / \ / \
+ // X C1 C1 C2
+ //
+ // where C1 and C2 are constants and X is non-constant.
+ if (IsMul(*node) && NumNonControlInputs(*node) == 2) {
+ NodeDef* mul_left_child = node_map_->GetNode(node->input(0));
+ NodeDef* mul_right_child = node_map_->GetNode(node->input(1));
+ // One child must be constant, and the second must be Conv op.
+ const bool left_child_is_constant = IsReallyConstant(*mul_left_child);
+ const bool right_child_is_constant = IsReallyConstant(*mul_right_child);
+ if (!left_child_is_constant && !right_child_is_constant) {
+ return false;
+ }
+ NodeDef* conv_node =
+ left_child_is_constant ? mul_right_child : mul_left_child;
+ if (!IsConv2D(*conv_node) && !IsConv3D(*conv_node)) {
+ return false;
+ }
+ if (node->device() != mul_left_child->device() ||
+ node->device() != mul_right_child->device()) {
+ return false;
+ }
+
+ // Make sure that it is safe to change the value of the convolution
+ // output.
+ if (conv_node->input_size() < 2 ||
+ NumNonControlOutputs(*conv_node, *node_map_) > 1 ||
+ nodes_to_preserve_.find(conv_node->name()) !=
+ nodes_to_preserve_.end()) {
+ return false;
+ }
+
+ // Identify the nodes to swap.
+ NodeDef* conv_left_child = node_map_->GetNode(conv_node->input(0));
+ NodeDef* conv_right_child = node_map_->GetNode(conv_node->input(1));
+ const bool conv_left_is_constant = IsReallyConstant(*conv_left_child);
+ const bool conv_right_is_constant = IsReallyConstant(*conv_right_child);
+ if (!conv_left_is_constant && !conv_right_is_constant) {
+ // At least one of the convolution inputs should be constant.
+ return false;
+ }
+ if (conv_left_is_constant && conv_right_is_constant) {
+ // Leverage regular constant folding to handle this.
+ return false;
+ }
+ const auto& mul_props = properties.GetOutputProperties(node->name());
+ const auto& conv_props = properties.GetOutputProperties(conv_node->name());
+ if (mul_props.empty() || conv_props.empty()) {
+ return false;
+ }
+ const auto& mul_shape = mul_props[0].shape();
+ const auto& conv_shape = conv_props[0].shape();
+ if (!ShapesSymbolicallyEqual(mul_shape, conv_shape)) {
+ return false;
+ }
+
+ const auto& input_props = properties.GetInputProperties(conv_node->name());
+ if (input_props.size() < 2) {
+ return false;
+ }
+ const auto& filter_shape = input_props[1].shape();
+
+ NodeDef* const_node =
+ left_child_is_constant ? mul_left_child : mul_right_child;
+ const auto& const_props =
+ properties.GetOutputProperties(const_node->name());
+ if (const_props.empty()) {
+ return false;
+ }
+ const auto& const_shape = const_props[0].shape();
+
+ TensorShapeProto new_filter_shape;
+ if (!ShapeAfterBroadcast(filter_shape, const_shape, &new_filter_shape)) {
+ return false;
+ }
+ if (!ShapesSymbolicallyEqual(filter_shape, new_filter_shape)) {
+ return false;
+ }
+
+ string mul_new_name =
+ AddPrefixToNodeName("merged_input", conv_node->name());
+ if (node_map_->NodeExists(mul_new_name)) {
+ return false;
+ }
+ // Make sure we don't introduce loops in the graph by removing control
+ // dependencies from the conv2d node to c2.
+ NodeDef* conv_const_node =
+ conv_left_is_constant ? conv_left_child : conv_right_child;
+ if (MaybeRemoveControlInput(conv_node->name(), const_node, graph_,
+ node_map_.get())) {
+ // Add a control dep from c1 to c2 to ensure c2 is in the right frame
+ *const_node->add_input() = AsControlDependency(*conv_const_node);
+ }
+
+ conv_node->set_name(node->name());
+ node->set_name(mul_new_name);
+ if (conv_left_is_constant) {
+ node_map_->UpdateInput(conv_node->name(), node->input(0), mul_new_name);
+ conv_node->set_input(0, mul_new_name);
+ } else {
+ node_map_->UpdateInput(conv_node->name(), node->input(1), mul_new_name);
+ conv_node->set_input(1, mul_new_name);
+ }
+ if (left_child_is_constant) {
+ node->set_input(1, conv_const_node->name());
+ } else {
+ node->set_input(0, conv_const_node->name());
+ }
+ node_map_->AddNode(mul_new_name, node);
+
+ return true;
+ }
+ return false;
+}
+
bool ConstantFolding::PartialConstPropThroughIdentityN(NodeDef* node) {
// Partial constant propagation through IdentityN.
if (IsIdentityN(*node) && NumNonControlInputs(*node) > 0) {
// the transformation applied successfully.
bool ConstantPushDown(NodeDef* node);
+ // Aggregate constants present around a conv operator. Returns true if the
+ // transformation was applied successfully.
+ bool MulConvPushDown(NodeDef* node, const GraphProperties& properties);
+
// Strength reduces floating point division by a constant Div(x, const) to
// multiplication by the reciprocal Mul(x, Reciprocal(const)).
bool ReduceDivToReciprocalMul(GraphDef* optimized_graph, NodeDef* node);
}
}
+TEST_F(ConstantFoldingTest, ConvPushDownTest) {
+ // Tests if the following rewrite is performed:
+ //
+ // * Conv2D
+ // / \ / \
+ // c Conv2D --> x (c * filter)
+ // / \
+ // x filter
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ int input_depth = 3;
+ int filter_count = 5;
+ int filter_size = 2;
+ TensorShape filter_shape(
+ {filter_size, filter_size, input_depth, filter_count});
+ Tensor filter_values(DT_FLOAT, filter_shape);
+ for (int i = 0; i < filter_values.NumElements(); ++i) {
+ filter_values.flat<float>()(i) = std::sqrt(static_cast<float>(i));
+ }
+ Output filter =
+ ops::Const(s.WithOpName("filter"), Input::Initializer(filter_values));
+
+ int batch_size = 4;
+ int input_dim = 10;
+ TensorShape input_shape({batch_size, input_dim, input_dim, input_depth});
+ Output input = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
+ ops::Placeholder::Shape(input_shape));
+
+ Output conv =
+ ops::Conv2D(s.WithOpName("conv"), input, filter, {1, 1, 1, 1}, "VALID");
+ Output c = ops::Const(s.WithOpName("c"), 3.0f, {1});
+ Output mul = ops::Mul(s.WithOpName("mul"), c, conv);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ ConstantFolding fold(nullptr);
+ GraphDef output;
+ Status status = fold.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ std::cout << output.DebugString() << std::endl;
+
+ EXPECT_EQ(5, output.node_size());
+ int found = 0;
+ for (const auto& node : output.node()) {
+ if (node.name() == "mul") {
+ found++;
+ EXPECT_EQ("Conv2D", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("conv/merged_input", node.input(1));
+ } else if (node.name() == "conv/merged_input") {
+ found++;
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ(0, node.input_size());
+ }
+ }
+ EXPECT_EQ(2, found);
+
+ // Check that const folded multiplication node has the expected value.
+ std::vector<string> fetch = {"mul"};
+ Tensor value(DT_FLOAT, input_shape);
+ for (int i = 0; i < value.NumElements(); ++i) {
+ value.flat<float>()(i) = i;
+ }
+ auto actual = EvaluateNodes(output, fetch, {{"x", value}});
+ auto expected = EvaluateNodes(item.graph, fetch, {{"x", value}});
+ test::ExpectTensorEqual<float>(expected[0], actual[0]);
+}
+
TEST_F(ConstantFoldingTest, NeutralElement) {
int kConst = 0;
int kLike = 1;
return ShapesBroadcastable(left.shape(), right.shape());
}
+bool ShapeAfterBroadcast(const TensorShapeProto& left,
+ const TensorShapeProto& right,
+ TensorShapeProto* output_shape) {
+ if (!ShapeIsSymbolicallyDefined(left) || !ShapeIsSymbolicallyDefined(right)) {
+ return false;
+ }
+ BCast bcast(ShapeDims(left), ShapeDims(right),
+ /*fewer_dims_optimization*/ false);
+ if (!bcast.IsValid()) {
+ return false;
+ }
+ output_shape->set_unknown_rank(false);
+ output_shape->clear_dim();
+ for (const auto& dim : bcast.output_shape()) {
+ output_shape->add_dim()->set_size(dim);
+ }
+ return true;
+}
+
bool CompareSymbolicallyShapedTensorSizes(const TensorShapeProto& left,
const TensorShapeProto& right) {
// if one of the ranks is unknown, it's impossible to compare tensor sizes
const TensorShapeProto& right);
bool ShapesBroadcastable(const OpInfo::TensorProperties& left,
const OpInfo::TensorProperties& right);
+bool ShapeAfterBroadcast(const TensorShapeProto& left,
+ const TensorShapeProto& right,
+ TensorShapeProto* output_shape);
// Return true if can prove, that tensor of size 'left' is smaller than tensor
// of size 'right'. Return false if it's larger or equal, or it's impossible to
EXPECT_TRUE(ShapesBroadcastable(MakeShape({-2, 1}), MakeShape({1, -2})));
EXPECT_TRUE(ShapesBroadcastable(MakeShape({-2, 1}), MakeShape({1, -3})));
EXPECT_TRUE(ShapesBroadcastable(MakeShape({-3}), MakeShape({-2, -3})));
+
+ TensorShapeProto output_shape;
+ EXPECT_TRUE(
+ ShapeAfterBroadcast(MakeShape({1, 2}), MakeShape({1, 2}), &output_shape));
+ EXPECT_TRUE(ShapesSymbolicallyEqual(MakeShape({1, 2}), output_shape));
+ EXPECT_TRUE(ShapeAfterBroadcast(MakeShape({-2, 2}), MakeShape({-2, 2}),
+ &output_shape));
+ EXPECT_TRUE(ShapesSymbolicallyEqual(MakeShape({-2, 2}), output_shape));
+ EXPECT_TRUE(ShapeAfterBroadcast(MakeShape({-2, 32}), MakeShape({-2, 1}),
+ &output_shape));
+ EXPECT_TRUE(ShapesSymbolicallyEqual(MakeShape({-2, 32}), output_shape));
+ EXPECT_TRUE(ShapeAfterBroadcast(MakeShape({-2, 1}), MakeShape({1, -2}),
+ &output_shape));
+ EXPECT_TRUE(ShapesSymbolicallyEqual(MakeShape({-2, -2}), output_shape));
+ EXPECT_TRUE(ShapeAfterBroadcast(MakeShape({-2, 1}), MakeShape({1, -3}),
+ &output_shape));
+ EXPECT_TRUE(ShapesSymbolicallyEqual(MakeShape({-2, -3}), output_shape));
+ EXPECT_TRUE(
+ ShapeAfterBroadcast(MakeShape({-3}), MakeShape({-2, -3}), &output_shape));
+ EXPECT_TRUE(ShapesSymbolicallyEqual(MakeShape({-2, -3}), output_shape));
}
TEST_F(SymbolicShapesTest, CompareSymbolicallyShapedTensorSizes) {