Implement strip CheckNumerics in DebugStripper.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 30 Mar 2018 23:16:13 +0000 (16:16 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 30 Mar 2018 23:18:22 +0000 (16:18 -0700)
PiperOrigin-RevId: 191131935

tensorflow/core/grappler/op_types.cc
tensorflow/core/grappler/op_types.h
tensorflow/core/grappler/optimizers/debug_stripper.cc
tensorflow/core/grappler/optimizers/debug_stripper_test.cc

index c31ac9b..e0ee49d 100644 (file)
@@ -68,6 +68,10 @@ bool IsBitcast(const NodeDef& node) { return node.op() == "Bitcast"; }
 
 bool IsCast(const NodeDef& node) { return node.op() == "Cast"; }
 
+bool IsCheckNumerics(const NodeDef& node) {
+  return node.op() == "CheckNumerics";
+}
+
 bool IsComplex(const NodeDef& node) { return node.op() == "Complex"; }
 
 bool IsComplexAbs(const NodeDef& node) { return node.op() == "ComplexAbs"; }
index 39affcb..aa6750d 100644 (file)
@@ -37,6 +37,7 @@ bool IsBiasAdd(const NodeDef& node);
 bool IsBiasAddGrad(const NodeDef& node);
 bool IsBitcast(const NodeDef& node);
 bool IsCast(const NodeDef& node);
+bool IsCheckNumerics(const NodeDef& node);
 bool IsComplex(const NodeDef& node);
 bool IsComplexAbs(const NodeDef& node);
 bool IsConj(const NodeDef& node);
index 0e058e3..8bd1017 100644 (file)
@@ -15,6 +15,7 @@ limitations under the License.
 
 #include "tensorflow/core/grappler/optimizers/debug_stripper.h"
 
+#include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/grappler/clusters/cluster.h"
 #include "tensorflow/core/grappler/grappler_item.h"
@@ -39,6 +40,10 @@ Status DebugStripper::Optimize(Cluster* cluster, const GrapplerItem& item,
           inp = AsControlDependency(inp);
         }
       }
+    } else if (IsCheckNumerics(node)) {
+      // Replace with Identity op which will be pruned later.
+      node.set_op("Identity");
+      node.mutable_attr()->erase("message");
     }
   }
   return Status::OK();
index c79c368..3f11feb 100644 (file)
@@ -105,6 +105,65 @@ TEST_F(DebugStripperTest, StripAssertFromGraph) {
   test::ExpectTensorEqual<float>(expected[0], optimized[0]);
 }
 
+TEST_F(DebugStripperTest, StripCheckNumericsFromGraph) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
+                              ops::Placeholder::Shape({}));
+  Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
+                              ops::Placeholder::Shape({}));
+  auto check1 = ops::CheckNumerics(s.WithOpName("CheckNumerics1"), x, "foo");
+  auto check2 = ops::CheckNumerics(s.WithOpName("CheckNumerics2"), y, "foo");
+  Output add = ops::Add(s.WithOpName("z"), check1, check2);
+  GrapplerItem item;
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+  DebugStripper optimizer;
+  GraphDef output;
+  TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+  int count = 0;
+  for (const NodeDef& node : output.node()) {
+    if (node.name() == "x") {
+      count++;
+      EXPECT_EQ("Placeholder", node.op());
+      EXPECT_EQ(0, node.input_size());
+    } else if (node.name() == "y") {
+      count++;
+      EXPECT_EQ("Placeholder", node.op());
+      EXPECT_EQ(0, node.input_size());
+    } else if (node.name() == "CheckNumerics1") {
+      count++;
+      EXPECT_EQ("Identity", node.op());
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("x", node.input(0));
+      EXPECT_EQ(1, node.attr_size());
+    } else if (node.name() == "CheckNumerics2") {
+      count++;
+      EXPECT_EQ("Identity", node.op());
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("y", node.input(0));
+      EXPECT_EQ(1, node.attr_size());
+    } else if (node.name() == "z") {
+      count++;
+      EXPECT_EQ("Add", node.op());
+      EXPECT_EQ(2, node.input_size());
+      EXPECT_EQ("CheckNumerics1", node.input(0));
+      EXPECT_EQ("CheckNumerics2", node.input(1));
+    }
+  }
+  EXPECT_EQ(5, count);
+
+  Tensor x_t(DT_FLOAT, TensorShape({}));
+  Tensor y_t(DT_FLOAT, TensorShape({}));
+  x_t.flat<float>()(0) = 1.0f;
+  y_t.flat<float>()(0) = 0.5f;
+  std::vector<Tensor> expected =
+      EvaluateNodes(item.graph, {"z"}, {{"x", x_t}, {"y", y_t}});
+  std::vector<Tensor> optimized =
+      EvaluateNodes(output, {"z"}, {{"x", x_t}, {"y", y_t}});
+  test::ExpectTensorEqual<float>(expected[0], optimized[0]);
+}
+
 }  // namespace
 }  // namespace grappler
 }  // namespace tensorflow