Arithemtic optimization: Rewite Sub(0, y) => Neg(y)
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 26 Feb 2018 18:24:08 +0000 (10:24 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 26 Feb 2018 18:28:36 +0000 (10:28 -0800)
PiperOrigin-RevId: 187041872

tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding.h
tensorflow/core/grappler/optimizers/constant_folding_test.cc

index 182e03f04e205f4426db716b1ac29fe18c8acc7e..10ca7dcce0e920542da5a5b1e14c0b3c50d420c9 100644 (file)
@@ -1434,6 +1434,17 @@ void ConstantFolding::ReplaceDivisionOfOnesByReciprocal(NodeDef* node,
   graph_modified_ = true;
 }
 
+void ConstantFolding::ReplaceSubtractionFromZeroByNegation(NodeDef* node,
+                                                           GraphDef* graph) {
+  node->set_op("Neg");
+  node->mutable_input()->SwapElements(0, 1);
+  const string ctrl_dep =
+      AddControlDependency(node->input(1), graph, node_map_.get());
+  node_map_->UpdateInput(node->name(), node->input(1), ctrl_dep);
+  node->set_input(1, ctrl_dep);
+  graph_modified_ = true;
+}
+
 Status ConstantFolding::ReplaceOperationWithConstant(
     double value, const TensorShapeProto& shape, NodeDef* node,
     GraphDef* graph) {
@@ -1636,12 +1647,17 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
       const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
       if (y_matches_output_shape &&
           ((is_mul && x_is_one) || (is_add && x_is_zero))) {
-        // TODO(rmlarsen): Handle subtraction 0 - y.
         // 1 * y = y or 0 + y = y.
         ReplaceOperationWithSnapshot(1, node, output);
         continue;
       }
 
+      if (y_matches_output_shape && (is_sub && x_is_zero)) {
+        // Replace 0 - y with Neg(y).
+        ReplaceSubtractionFromZeroByNegation(node, output);
+        continue;
+      }
+
       // Replace 1 / y with Reciprocal op.
       if (y_matches_output_shape && is_any_div && x_is_one) {
         DataType type = node->attr().at("T").type();
index 232b2f9fa05d86877e681c8eff3606725cf6fdb9..2fd59c7f9ccf3f94e683d7ec41a5848b9eec4a8f 100644 (file)
@@ -82,6 +82,7 @@ class ConstantFolding : public GraphOptimizer {
                                     GraphDef* graph);
   void ReplaceOperationWithSnapshot(int input_to_forward, NodeDef* node,
                                     GraphDef* graph);
+  void ReplaceSubtractionFromZeroByNegation(NodeDef* node, GraphDef* graph);
   Status ReplaceOperationWithConstant(double value,
                                       const TensorShapeProto& shape,
                                       NodeDef* node, GraphDef* graph);
index 219f3bd5ec2a1c15078972bdea69a7642bb4af46..c6540192d7f85098f64ba42c0d4bf27dafc762ab 100644 (file)
@@ -286,10 +286,9 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
         EXPECT_EQ("x", node.input(0));
         EXPECT_EQ("^zeros", node.input(1));
       } else if (name == "sub2") {
-        // We don't handle this case yet.
-        EXPECT_EQ("Sub", node.op());
-        EXPECT_EQ("zeros", node.input(0));
-        EXPECT_EQ("y", node.input(1));
+        EXPECT_EQ("Neg", node.op());
+        EXPECT_EQ("y", node.input(0));
+        EXPECT_EQ("^zeros", node.input(1));
       }
       const std::set<string> square_zero_const{"mul1", "mul2",    "mul5",
                                                "mul6", "matmul1", "matmul2"};