#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
namespace grappler {
inp = AsControlDependency(inp);
}
}
- } else if (IsCheckNumerics(node)) {
+ } else if (IsCheckNumerics(node) || IsPrint(node)) {
// Replace with Identity op which will be pruned later.
node.set_op("Identity");
- node.mutable_attr()->erase("message");
+ // Only preserve T attribute.
+ protobuf::Map<string, AttrValue> new_attr;
+ if (node.attr().find("T") != node.attr().end()) {
+ new_attr.insert({"T", node.attr().at("T")});
+ }
+ node.mutable_attr()->swap(new_attr);
+ // As Identity op only takes one input, mark redundant inputs as control
+ // input.
+ for (size_t i = 1; i < node.input_size(); ++i) {
+ if (!IsControlInput(node.input(i))) {
+ *node.mutable_input(i) = AsControlDependency(node.input(i));
+ }
+ }
}
}
return Status::OK();
test::ExpectTensorEqual<float>(expected[0], optimized[0]);
}
+TEST_F(DebugStripperTest, StripPrintFromGraph) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
+ ops::Placeholder::Shape({}));
+ Output print = ops::Print(s.WithOpName("Print"), x, {x});
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ DebugStripper optimizer;
+ GraphDef output;
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "x") {
+ EXPECT_EQ("Placeholder", node.op());
+ EXPECT_EQ(0, node.input_size());
+ } else if (node.name() == "Print") {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("^x", node.input(1));
+ EXPECT_EQ(1, node.attr_size());
+ }
+ }
+
+ EXPECT_EQ(2, output.node_size());
+
+ Tensor x_t(DT_FLOAT, TensorShape({}));
+ x_t.flat<float>()(0) = 1.0f;
+ std::vector<Tensor> expected =
+ EvaluateNodes(item.graph, {"Print"}, {{"x", x_t}});
+ std::vector<Tensor> optimized =
+ EvaluateNodes(output, {"Print"}, {{"x", x_t}});
+ test::ExpectTensorEqual<float>(expected[0], optimized[0]);
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow