bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); }
-// Returns whether `reshape` is an identity op. The tensor that `reshape`
-// reshapes is the `output_pos`-th output of node `input`.
-bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input,
- const int output_pos,
- const GraphProperties& graph_properties) {
- const std::vector<OpInfo::TensorProperties>& reshape_props =
- graph_properties.GetOutputProperties(reshape.name());
- const std::vector<OpInfo::TensorProperties>& input_props =
- graph_properties.GetOutputProperties(input.name());
- if (reshape_props.empty() || input_props.size() <= output_pos) {
- return false;
- }
-
- return ShapesSymbolicallyEqual(input_props[output_pos], reshape_props[0]);
-}
-
NodeDef* GetTailOfValuePreservingChain(
const NodeDef& node, const NodeMap& node_map,
const std::unordered_set<string>& nodes_to_preserve) {
}
};
+// Bypass redundant reshape nodes:
+//
+// Reshape Reshape <-+
+// ^ |
+// | |
+// Reshape becomes Reshape |
+// ^ |
+// | |
+// input input ---+
+class RemoveRedundantReshape : public ArithmeticOptimizerStage {
+ public:
+ explicit RemoveRedundantReshape(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("RemoveRedundantReshape", ctx, ctx_ext) {}
+ ~RemoveRedundantReshape() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsReshape(*node);
+ }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ NodeDef* input;
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
+
+ // 1. Bypass reshape followed by reshape.
+ if (IsReshape(*input) && !HasControlInputs(*input)) {
+ node->set_input(0, input->input(0));
+ ctx().node_map->UpdateInput(node->name(), input->name(), input->input(0));
+ *simplified_node_name = node->name();
+ AddToOptimizationQueue(node);
+ return Status::OK();
+ }
+
+ // 2. If the reshape is a no-op, forward its input to its consumers, unless
+ // it anchors a control dependency since we want to make sure that control
+ // dependency is triggered.
+ if (ReshapeIsIdentity(*node) && !HasControlInputs(*node)) {
+ *simplified_node_name = node->input(0);
+ return Status::OK();
+ }
+
+ return Status::OK();
+ }
+
+ private:
+ // Returns whether `reshape` is an identity op.
+ bool ReshapeIsIdentity(const NodeDef& reshape) {
+ OpInfo::TensorProperties reshape_props;
+ OpInfo::TensorProperties input_props;
+
+ if (!GetTensorProperties(reshape.name(), &reshape_props).ok() ||
+ !GetTensorProperties(reshape.input(0), &input_props).ok()) {
+ return false;
+ }
+
+ return ShapesSymbolicallyEqual(input_props.shape(), reshape_props.shape());
+ }
+};
+
} // namespace
class UniqueNodes {
string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
const NodeDef* node, SetVector<NodeDef*>* nodes_to_simplify) {
- if (node->op() == "Reshape") {
- // Reshape
- // ^
- // |
- // Reshape
- // ^
- // |
- // input
- //
- // becomes
- //
- // Reshape <-+
- // |
- // Reshape |
- // ^ |
- // | |
- // input ---+
- NodeDef* reshape = const_cast<NodeDef*>(node);
- int output_pos = 0;
- string input_node_name = ParseNodeName(reshape->input(0), &output_pos);
- const NodeDef* input = node_map_->GetNode(input_node_name);
- if (input->op() == "Reshape" && !HasControlInputs(*input)) {
- reshape->set_input(0, input->input(0));
- node_map_->UpdateInput(reshape->name(), input->name(), input->input(0));
- nodes_to_simplify->PushBack(reshape);
- return reshape->name();
- }
-
- // If the reshape is a no-op, forward its input to its consumers, unless it
- // anchors a control dependency since we want to make sure that control
- // dependency is triggered.
- if (ReshapeIsIdentity(*reshape, *input, output_pos, *graph_properties_) &&
- !HasControlInputs(*reshape)) {
- return reshape->input(0);
- }
- }
-
if (node->op() == "Transpose") {
// Reorder Cast and Transpose if beneficial.
//
pipeline.AddStage<RemoveRedundantBitcastStage>(ctx, ctx_ext);
if (options_.remove_redundant_cast)
pipeline.AddStage<RemoveRedundantCastStage>(ctx, ctx_ext);
+ if (options_.remove_redundant_reshape)
+ pipeline.AddStage<RemoveRedundantReshape>(ctx, ctx_ext);
if (options_.remove_negation)
pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);
if (options_.remove_logical_not)
options.remove_idempotent = false;
options.remove_redundant_bitcast = false;
options.remove_redundant_cast = false;
+ options.remove_redundant_reshape = false;
options.remove_negation = false;
options.remove_logical_not = false;
optimizer->options_ = options;
optimizer->options_.remove_redundant_cast = true;
}
+ void EnableOnlyRemoveRedundantReshape(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.remove_redundant_reshape = true;
+ }
+
void EnableOnlyRemoveNegation(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_negation = true;
test::ExpectTensorNear<complex64>(tensors_expected[0], tensors[0], 1e-6);
}
-TEST_F(ArithmeticOptimizerTest, IdentityReshape) {
+TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_IdentityReshape) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs =
ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, 28, 28}));
auto tensors_expected =
EvaluateNodes(item.graph, item.fetch, {{"Placeholder", x_t}});
EXPECT_EQ(1, tensors_expected.size());
- GraphDef output;
- TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph.Swap(&output);
- TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyRemoveRedundantReshape(&optimizer);
+ OptimizeTwiceAndPrune(&optimizer, &item, &output);
EXPECT_EQ(0, CountOpNodes(output, "Reshape"));
auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", x_t}});
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
-TEST_F(ArithmeticOptimizerTest, IdentityReshapeBetweenSymbolicShapes) {
+TEST_F(ArithmeticOptimizerTest,
+ RemoveRedundantReshape_IdentityReshapeBetweenSymbolicShapes) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs =
ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, -1, -1}));
Output reshape = ops::Reshape(s, inputs, target_shape);
Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
+ auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28}));
GrapplerItem item;
item.fetch = {"outputs"};
+ item.feed = {{"Placeholder", x_t}};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
- auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28}));
- auto tensors_expected =
- EvaluateNodes(item.graph, item.fetch, {{"Placeholder", x_t}});
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
EXPECT_EQ(1, tensors_expected.size());
- GraphDef output;
- TF_EXPECT_OK(ArithmeticOptimizer(RewriterConfig::AGGRESSIVE)
- .Optimize(nullptr, item, &output));
- item.graph.Swap(&output);
- TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+ GraphDef output;
+ // Assume valid feed shape in aggressive mode.
+ ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ EnableOnlyRemoveRedundantReshape(&optimizer);
+ OptimizeTwiceAndPrune(&optimizer, &item, &output);
EXPECT_EQ(0, CountOpNodes(output, "Reshape"));
- auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", x_t}});
+ auto tensors = EvaluateNodes(output, item.fetch, item.feed);
EXPECT_EQ(1, tensors.size());
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
-TEST_F(ArithmeticOptimizerTest, NotAssumeValidFeeds) {
+TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_NotAssumeValidFeeds) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs =
ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28}));
EXPECT_EQ(1, tensors_expected.size());
GraphDef output;
- TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
-
- item.graph.Swap(&output);
- TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+ ArithmeticOptimizer optimizer;
+ EnableOnlyRemoveRedundantReshape(&optimizer);
+ OptimizeTwiceAndPrune(&optimizer, &item, &output);
// The reshape is preserved because the shape of the placeholder can be
// different from the shape of the actual feed.
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
-TEST_F(ArithmeticOptimizerTest, AssumeValidFeedsInAggressiveMode) {
+TEST_F(ArithmeticOptimizerTest,
+ RemoveRedundantReshape_AssumeValidFeedsInAggressiveMode) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs =
ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28}));
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
EXPECT_EQ(1, tensors_expected.size());
- GraphDef output;
- TF_EXPECT_OK(ArithmeticOptimizer(RewriterConfig::AGGRESSIVE)
- .Optimize(nullptr, item, &output));
- item.graph.Swap(&output);
- TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+ GraphDef output;
+ ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ EnableOnlyRemoveRedundantReshape(&optimizer);
+ OptimizeTwiceAndPrune(&optimizer, &item, &output);
EXPECT_EQ(0, CountOpNodes(output, "Reshape"));
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
-TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) {
+TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_NotIdentityReshape) {
// Reshape from [-1,3,28,28] to [8,-1,28,28] is not identity, because it can
// be from [4,3,28,28] to [8,6,28,28].
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
item.feed = {{"Placeholder", x_t}};
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
EXPECT_EQ(1, tensors_expected.size());
- GraphDef output;
- TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph.Swap(&output);
- TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyRemoveRedundantReshape(&optimizer);
+ OptimizeTwiceAndPrune(&optimizer, &item, &output);
EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
-TEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) {
+TEST_F(ArithmeticOptimizerTest,
+ RemoveRedundantReshape_NotIdentityReshapeTooManyUnknownDimSizes) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs =
ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3}));
GrapplerItem item;
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
- GraphDef output;
- TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph.Swap(&output);
- TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyRemoveRedundantReshape(&optimizer);
+ OptimizeTwiceAndPrune(&optimizer, &item, &output);
EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
}
-TEST_F(ArithmeticOptimizerTest, CombineReshapes) {
+TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_CombineReshapes) {
// Converts an NCHW_VECT_C tensor to NHWC and then flattens it to 2D. The two
// reshapes should be combined.
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
item.feed = {{"nchw_vect_c", x_t}};
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
EXPECT_EQ(1, tensors_expected.size());
- GraphDef output;
- TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph.Swap(&output);
- TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyRemoveRedundantReshape(&optimizer);
+ OptimizeTwiceAndPrune(&optimizer, &item, &output);
EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
auto tensors = EvaluateNodes(output, item.fetch, item.feed);