test::ExpectTensorEqual<float>(expected[0], optimized[0]);
}
+TEST_F(DebugStripperTest, StripCheckNumericsFromGraph) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
+ ops::Placeholder::Shape({}));
+ Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
+ ops::Placeholder::Shape({}));
+ auto check1 = ops::CheckNumerics(s.WithOpName("CheckNumerics1"), x, "foo");
+ auto check2 = ops::CheckNumerics(s.WithOpName("CheckNumerics2"), y, "foo");
+ Output add = ops::Add(s.WithOpName("z"), check1, check2);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ DebugStripper optimizer;
+ GraphDef output;
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ int count = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "x") {
+ count++;
+ EXPECT_EQ("Placeholder", node.op());
+ EXPECT_EQ(0, node.input_size());
+ } else if (node.name() == "y") {
+ count++;
+ EXPECT_EQ("Placeholder", node.op());
+ EXPECT_EQ(0, node.input_size());
+ } else if (node.name() == "CheckNumerics1") {
+ count++;
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ(1, node.attr_size());
+ } else if (node.name() == "CheckNumerics2") {
+ count++;
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("y", node.input(0));
+ EXPECT_EQ(1, node.attr_size());
+ } else if (node.name() == "z") {
+ count++;
+ EXPECT_EQ("Add", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("CheckNumerics1", node.input(0));
+ EXPECT_EQ("CheckNumerics2", node.input(1));
+ }
+ }
+ EXPECT_EQ(5, count);
+
+ Tensor x_t(DT_FLOAT, TensorShape({}));
+ Tensor y_t(DT_FLOAT, TensorShape({}));
+ x_t.flat<float>()(0) = 1.0f;
+ y_t.flat<float>()(0) = 0.5f;
+ std::vector<Tensor> expected =
+ EvaluateNodes(item.graph, {"z"}, {{"x", x_t}, {"y", y_t}});
+ std::vector<Tensor> optimized =
+ EvaluateNodes(output, {"z"}, {{"x", x_t}, {"y", y_t}});
+ test::ExpectTensorEqual<float>(expected[0], optimized[0]);
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow