Remove 'Print' in DebugStripper.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Sat, 7 Apr 2018 09:03:00 +0000 (02:03 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 7 Apr 2018 09:05:19 +0000 (02:05 -0700)
PiperOrigin-RevId: 191989327

tensorflow/core/grappler/op_types.cc
tensorflow/core/grappler/op_types.h
tensorflow/core/grappler/optimizers/debug_stripper.cc
tensorflow/core/grappler/optimizers/debug_stripper_test.cc

index a24d2db..1fb1711 100644 (file)
@@ -245,6 +245,8 @@ bool IsPolygamma(const NodeDef& node) { return node.op() == "Polygamma"; }
 
 bool IsPow(const NodeDef& node) { return node.op() == "Pow"; }
 
+bool IsPrint(const NodeDef& node) { return node.op() == "Print"; }
+
 bool IsProd(const NodeDef& node) { return node.op() == "Prod"; }
 
 bool IsReal(const NodeDef& node) { return node.op() == "Real"; }
index 8667f72..d516bae 100644 (file)
@@ -95,6 +95,7 @@ bool IsNoOp(const NodeDef& node);
 bool IsNotEqual(const NodeDef& node);
 bool IsPlaceholder(const NodeDef& node);
 bool IsPolygamma(const NodeDef& node);
+bool IsPrint(const NodeDef& node);
 bool IsProd(const NodeDef& node);
 bool IsPow(const NodeDef& node);
 bool IsReal(const NodeDef& node);
index 8bd1017..9701a03 100644 (file)
@@ -21,6 +21,7 @@ limitations under the License.
 #include "tensorflow/core/grappler/grappler_item.h"
 #include "tensorflow/core/grappler/op_types.h"
 #include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/platform/protobuf.h"
 
 namespace tensorflow {
 namespace grappler {
@@ -40,10 +41,22 @@ Status DebugStripper::Optimize(Cluster* cluster, const GrapplerItem& item,
           inp = AsControlDependency(inp);
         }
       }
-    } else if (IsCheckNumerics(node)) {
+    } else if (IsCheckNumerics(node) || IsPrint(node)) {
       // Replace with Identity op which will be pruned later.
       node.set_op("Identity");
-      node.mutable_attr()->erase("message");
+      // Only preserve T attribute.
+      protobuf::Map<string, AttrValue> new_attr;
+      if (node.attr().find("T") != node.attr().end()) {
+        new_attr.insert({"T", node.attr().at("T")});
+      }
+      node.mutable_attr()->swap(new_attr);
+      // As Identity op only takes one input, mark redundant inputs as control
+      // input.
+      for (size_t i = 1; i < node.input_size(); ++i) {
+        if (!IsControlInput(node.input(i))) {
+          *node.mutable_input(i) = AsControlDependency(node.input(i));
+        }
+      }
     }
   }
   return Status::OK();
index 3f11feb..96ceee7 100644 (file)
@@ -164,6 +164,42 @@ TEST_F(DebugStripperTest, StripCheckNumericsFromGraph) {
   test::ExpectTensorEqual<float>(expected[0], optimized[0]);
 }
 
+TEST_F(DebugStripperTest, StripPrintFromGraph) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
+                              ops::Placeholder::Shape({}));
+  Output print = ops::Print(s.WithOpName("Print"), x, {x});
+  GrapplerItem item;
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+  DebugStripper optimizer;
+  GraphDef output;
+  TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+  for (const NodeDef& node : output.node()) {
+    if (node.name() == "x") {
+      EXPECT_EQ("Placeholder", node.op());
+      EXPECT_EQ(0, node.input_size());
+    } else if (node.name() == "Print") {
+      EXPECT_EQ("Identity", node.op());
+      EXPECT_EQ(2, node.input_size());
+      EXPECT_EQ("x", node.input(0));
+      EXPECT_EQ("^x", node.input(1));
+      EXPECT_EQ(1, node.attr_size());
+    }
+  }
+
+  EXPECT_EQ(2, output.node_size());
+
+  Tensor x_t(DT_FLOAT, TensorShape({}));
+  x_t.flat<float>()(0) = 1.0f;
+  std::vector<Tensor> expected =
+      EvaluateNodes(item.graph, {"Print"}, {{"x", x_t}});
+  std::vector<Tensor> optimized =
+      EvaluateNodes(output, {"Print"}, {{"x", x_t}});
+  test::ExpectTensorEqual<float>(expected[0], optimized[0]);
+}
+
 }  // namespace
 }  // namespace grappler
 }  // namespace tensorflow