From 994fef1b4e702fb6cf178cff8a30cd75794c6451 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 7 Apr 2018 02:03:00 -0700 Subject: [PATCH] Remove 'Print' in DebugStripper. PiperOrigin-RevId: 191989327 --- tensorflow/core/grappler/op_types.cc | 2 ++ tensorflow/core/grappler/op_types.h | 1 + .../core/grappler/optimizers/debug_stripper.cc | 17 ++++++++-- .../grappler/optimizers/debug_stripper_test.cc | 36 ++++++++++++++++++++++ 4 files changed, 54 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index a24d2db..1fb1711 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -245,6 +245,8 @@ bool IsPolygamma(const NodeDef& node) { return node.op() == "Polygamma"; } bool IsPow(const NodeDef& node) { return node.op() == "Pow"; } +bool IsPrint(const NodeDef& node) { return node.op() == "Print"; } + bool IsProd(const NodeDef& node) { return node.op() == "Prod"; } bool IsReal(const NodeDef& node) { return node.op() == "Real"; } diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 8667f72..d516bae 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -95,6 +95,7 @@ bool IsNoOp(const NodeDef& node); bool IsNotEqual(const NodeDef& node); bool IsPlaceholder(const NodeDef& node); bool IsPolygamma(const NodeDef& node); +bool IsPrint(const NodeDef& node); bool IsProd(const NodeDef& node); bool IsPow(const NodeDef& node); bool IsReal(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/debug_stripper.cc b/tensorflow/core/grappler/optimizers/debug_stripper.cc index 8bd1017..9701a03 100644 --- a/tensorflow/core/grappler/optimizers/debug_stripper.cc +++ b/tensorflow/core/grappler/optimizers/debug_stripper.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/platform/protobuf.h" namespace tensorflow { namespace grappler { @@ -40,10 +41,22 @@ Status DebugStripper::Optimize(Cluster* cluster, const GrapplerItem& item, inp = AsControlDependency(inp); } } - } else if (IsCheckNumerics(node)) { + } else if (IsCheckNumerics(node) || IsPrint(node)) { // Replace with Identity op which will be pruned later. node.set_op("Identity"); - node.mutable_attr()->erase("message"); + // Only preserve T attribute. + protobuf::Map new_attr; + if (node.attr().find("T") != node.attr().end()) { + new_attr.insert({"T", node.attr().at("T")}); + } + node.mutable_attr()->swap(new_attr); + // As Identity op only takes one input, mark redundant inputs as control + // input. + for (size_t i = 1; i < node.input_size(); ++i) { + if (!IsControlInput(node.input(i))) { + *node.mutable_input(i) = AsControlDependency(node.input(i)); + } + } } } return Status::OK(); diff --git a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc index 3f11feb..96ceee7 100644 --- a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc +++ b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc @@ -164,6 +164,42 @@ TEST_F(DebugStripperTest, StripCheckNumericsFromGraph) { test::ExpectTensorEqual(expected[0], optimized[0]); } +TEST_F(DebugStripperTest, StripPrintFromGraph) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT, + ops::Placeholder::Shape({})); + Output print = ops::Print(s.WithOpName("Print"), x, {x}); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + DebugStripper optimizer; + GraphDef output; + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + for (const NodeDef& node : output.node()) { + if (node.name() == "x") { + EXPECT_EQ("Placeholder", node.op()); + EXPECT_EQ(0, node.input_size()); + } else if (node.name() == "Print") { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("^x", node.input(1)); + EXPECT_EQ(1, node.attr_size()); + } + } + + EXPECT_EQ(2, output.node_size()); + + Tensor x_t(DT_FLOAT, TensorShape({})); + x_t.flat()(0) = 1.0f; + std::vector expected = + EvaluateNodes(item.graph, {"Print"}, {{"x", x_t}}); + std::vector optimized = + EvaluateNodes(output, {"Print"}, {{"x", x_t}}); + test::ExpectTensorEqual(expected[0], optimized[0]); +} + } // namespace } // namespace grappler } // namespace tensorflow -- 2.7.4