From 1fc324c6701bc179ca73908731857e8a582437b5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 26 Feb 2018 10:24:08 -0800 Subject: [PATCH] Arithemtic optimization: Rewite Sub(0, y) => Neg(y) PiperOrigin-RevId: 187041872 --- .../core/grappler/optimizers/constant_folding.cc | 18 +++++++++++++++++- tensorflow/core/grappler/optimizers/constant_folding.h | 1 + .../core/grappler/optimizers/constant_folding_test.cc | 7 +++---- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 182e03f..10ca7dc 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -1434,6 +1434,17 @@ void ConstantFolding::ReplaceDivisionOfOnesByReciprocal(NodeDef* node, graph_modified_ = true; } +void ConstantFolding::ReplaceSubtractionFromZeroByNegation(NodeDef* node, + GraphDef* graph) { + node->set_op("Neg"); + node->mutable_input()->SwapElements(0, 1); + const string ctrl_dep = + AddControlDependency(node->input(1), graph, node_map_.get()); + node_map_->UpdateInput(node->name(), node->input(1), ctrl_dep); + node->set_input(1, ctrl_dep); + graph_modified_ = true; +} + Status ConstantFolding::ReplaceOperationWithConstant( double value, const TensorShapeProto& shape, NodeDef* node, GraphDef* graph) { @@ -1636,12 +1647,17 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape); if (y_matches_output_shape && ((is_mul && x_is_one) || (is_add && x_is_zero))) { - // TODO(rmlarsen): Handle subtraction 0 - y. // 1 * y = y or 0 + y = y. ReplaceOperationWithSnapshot(1, node, output); continue; } + if (y_matches_output_shape && (is_sub && x_is_zero)) { + // Replace 0 - y with Neg(y). + ReplaceSubtractionFromZeroByNegation(node, output); + continue; + } + // Replace 1 / y with Reciprocal op. if (y_matches_output_shape && is_any_div && x_is_one) { DataType type = node->attr().at("T").type(); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 232b2f9..2fd59c7 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -82,6 +82,7 @@ class ConstantFolding : public GraphOptimizer { GraphDef* graph); void ReplaceOperationWithSnapshot(int input_to_forward, NodeDef* node, GraphDef* graph); + void ReplaceSubtractionFromZeroByNegation(NodeDef* node, GraphDef* graph); Status ReplaceOperationWithConstant(double value, const TensorShapeProto& shape, NodeDef* node, GraphDef* graph); diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 219f3bd..c654019 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -286,10 +286,9 @@ TEST_F(ConstantFoldingTest, NeutralElement) { EXPECT_EQ("x", node.input(0)); EXPECT_EQ("^zeros", node.input(1)); } else if (name == "sub2") { - // We don't handle this case yet. - EXPECT_EQ("Sub", node.op()); - EXPECT_EQ("zeros", node.input(0)); - EXPECT_EQ("y", node.input(1)); + EXPECT_EQ("Neg", node.op()); + EXPECT_EQ("y", node.input(0)); + EXPECT_EQ("^zeros", node.input(1)); } const std::set square_zero_const{"mul1", "mul2", "mul5", "mul6", "matmul1", "matmul2"}; -- 2.7.4