Don't let the grappler item builder fail if the graph contains unknown custom
authorBenoit Steiner <bsteiner@google.com>
Mon, 12 Mar 2018 16:16:08 +0000 (09:16 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 12 Mar 2018 16:21:01 +0000 (09:21 -0700)
ops.

PiperOrigin-RevId: 188730560

tensorflow/core/framework/graph_def_util.cc
tensorflow/core/framework/graph_def_util.h
tensorflow/core/grappler/grappler_item_builder.cc
tensorflow/core/grappler/grappler_item_builder_test.cc

index 1f67053..896cb3c 100644 (file)
@@ -53,6 +53,12 @@ Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def) {
 Status AddDefaultAttrsToGraphDef(GraphDef* graph_def,
                                  const OpRegistryInterface& op_registry,
                                  int node_offset) {
+  return AddDefaultAttrsToGraphDef(graph_def, op_registry, node_offset, false);
+}
+
+Status AddDefaultAttrsToGraphDef(GraphDef* graph_def,
+                                 const OpRegistryInterface& op_registry,
+                                 int node_offset, bool skip_unknown_ops) {
   if (node_offset > graph_def->node_size()) {
     return errors::InvalidArgument(
         "Tried to add default attrs to GraphDef "
@@ -63,8 +69,12 @@ Status AddDefaultAttrsToGraphDef(GraphDef* graph_def,
   for (int i = node_offset; i < graph_def->node_size(); ++i) {
     NodeDef* node_def = graph_def->mutable_node(i);
     const OpDef* op_def;
-    TF_RETURN_IF_ERROR(op_registry.LookUpOpDef(node_def->op(), &op_def));
-    AddDefaultsToNodeDef(*op_def, node_def);
+    Status s = op_registry.LookUpOpDef(node_def->op(), &op_def);
+    if (s.ok()) {
+      AddDefaultsToNodeDef(*op_def, node_def);
+    } else if (!skip_unknown_ops) {
+      return s;
+    }
   }
 
   return Status::OK();
index 0c6542a..525e84a 100644 (file)
@@ -56,6 +56,12 @@ Status AddDefaultAttrsToGraphDef(GraphDef* graph_def,
                                  const OpRegistryInterface& op_registry,
                                  int node_offset);
 
+// Same as above, except for the fact that it skips nodes that aren't found in
+// op_registry if skip_unknown_ops is true.
+Status AddDefaultAttrsToGraphDef(GraphDef* graph_def,
+                                 const OpRegistryInterface& op_registry,
+                                 int node_offset, bool skip_unknown_ops);
+
 // Remove attrs from 'graph_def' that have the default value according
 // to 'producer_op_registry', but don't exist according to
 // 'consumer_op_registry'. This can allow 'graph_def' to run on the
index 33ad426..04c7dae 100644 (file)
@@ -138,7 +138,7 @@ Status OptimizeGraph(const GraphDef& graph_def_arg, GraphDef* output_graph_def,
   // The default values of attributes might have been stripped by the optimizer.
   // Add them back.
   return AddDefaultAttrsToGraphDef(output_graph_def, *graphptr->op_registry(),
-                                   0);
+                                   0, true);
 }
 
 // Applies the same graph pruning logic to the graph as Session.Run in TF.
@@ -514,7 +514,7 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
       &new_item->graph,
       FunctionLibraryDefinition(OpRegistry::Global(),
                                 new_item->graph.library()),
-      0);
+      0, true);
   if (!attr_status.ok()) {
     LOG(ERROR) << "Failed to instantiate default attribute values: "
                << attr_status.error_message();
index 78cbff6..ada9092 100644 (file)
@@ -280,6 +280,27 @@ TEST_F(GrapplerItemBuilderTest, GraphWithFunctions) {
   ASSERT_TRUE(item != nullptr);
 }
 
+TEST_F(GrapplerItemBuilderTest, GraphWithCustomOps) {
+  MetaGraphDef meta_graph;
+  // y = XTimesTwo(x)
+  constexpr char device[] = "/cpu:0";
+  *meta_graph.mutable_graph_def() = test::function::GDef(
+      {test::function::NDef("x", "Const", {}, {{"dtype", DT_FLOAT}}, device),
+       test::function::NDef("y", "CustomOp", {"x"}, {{"T", DT_FLOAT}}, device)},
+      {});
+
+  CollectionDef train_op;
+  train_op.mutable_node_list()->add_value("y");
+  (*meta_graph.mutable_collection_def())["train_op"] = train_op;
+
+  ItemConfig cfg;
+  cfg.inline_functions = false;
+
+  std::unique_ptr<GrapplerItem> item =
+      GrapplerItemFromMetaGraphDef("0", meta_graph, cfg);
+  ASSERT_TRUE(item != nullptr);
+}
+
 TEST_F(GrapplerItemBuilderTest, FromGraphWithSignatureDef) {
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
   auto x = ops::Const(s.WithOpName("x"), 0);