Status AddDefaultAttrsToGraphDef(GraphDef* graph_def,
const OpRegistryInterface& op_registry,
int node_offset) {
+ return AddDefaultAttrsToGraphDef(graph_def, op_registry, node_offset, false);
+}
+
+Status AddDefaultAttrsToGraphDef(GraphDef* graph_def,
+ const OpRegistryInterface& op_registry,
+ int node_offset, bool skip_unknown_ops) {
if (node_offset > graph_def->node_size()) {
return errors::InvalidArgument(
"Tried to add default attrs to GraphDef "
for (int i = node_offset; i < graph_def->node_size(); ++i) {
NodeDef* node_def = graph_def->mutable_node(i);
const OpDef* op_def;
- TF_RETURN_IF_ERROR(op_registry.LookUpOpDef(node_def->op(), &op_def));
- AddDefaultsToNodeDef(*op_def, node_def);
+ Status s = op_registry.LookUpOpDef(node_def->op(), &op_def);
+ if (s.ok()) {
+ AddDefaultsToNodeDef(*op_def, node_def);
+ } else if (!skip_unknown_ops) {
+ return s;
+ }
}
return Status::OK();
const OpRegistryInterface& op_registry,
int node_offset);
+// Same as above, except for the fact that it skips nodes that aren't found in
+// op_registry if skip_unknown_ops is true.
+Status AddDefaultAttrsToGraphDef(GraphDef* graph_def,
+ const OpRegistryInterface& op_registry,
+ int node_offset, bool skip_unknown_ops);
+
// Remove attrs from 'graph_def' that have the default value according
// to 'producer_op_registry', but don't exist according to
// 'consumer_op_registry'. This can allow 'graph_def' to run on the
// The default values of attributes might have been stripped by the optimizer.
// Add them back.
return AddDefaultAttrsToGraphDef(output_graph_def, *graphptr->op_registry(),
- 0);
+ 0, true);
}
// Applies the same graph pruning logic to the graph as Session.Run in TF.
&new_item->graph,
FunctionLibraryDefinition(OpRegistry::Global(),
new_item->graph.library()),
- 0);
+ 0, true);
if (!attr_status.ok()) {
LOG(ERROR) << "Failed to instantiate default attribute values: "
<< attr_status.error_message();
ASSERT_TRUE(item != nullptr);
}
+TEST_F(GrapplerItemBuilderTest, GraphWithCustomOps) {
+ MetaGraphDef meta_graph;
+ // y = XTimesTwo(x)
+ constexpr char device[] = "/cpu:0";
+ *meta_graph.mutable_graph_def() = test::function::GDef(
+ {test::function::NDef("x", "Const", {}, {{"dtype", DT_FLOAT}}, device),
+ test::function::NDef("y", "CustomOp", {"x"}, {{"T", DT_FLOAT}}, device)},
+ {});
+
+ CollectionDef train_op;
+ train_op.mutable_node_list()->add_value("y");
+ (*meta_graph.mutable_collection_def())["train_op"] = train_op;
+
+ ItemConfig cfg;
+ cfg.inline_functions = false;
+
+ std::unique_ptr<GrapplerItem> item =
+ GrapplerItemFromMetaGraphDef("0", meta_graph, cfg);
+ ASSERT_TRUE(item != nullptr);
+}
+
TEST_F(GrapplerItemBuilderTest, FromGraphWithSignatureDef) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto x = ops::Const(s.WithOpName("x"), 0);