From ed5f003cc2c542c3c545369f71d4b57429da33fc Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Fri, 9 Feb 2018 14:25:28 -0800 Subject: [PATCH] Make import_graph_def add default attr values with the C API enabled. It turns out that the original Python code modifies the graph_def argument to add default attr values. I'm not sure if the behavior is covered by our API guarantees since it's not documented, but let's keep the behavior consistent for now. PiperOrigin-RevId: 185193037 --- tensorflow/python/framework/importer.py | 42 +++++++++++++++++------ tensorflow/python/saved_model/saved_model_test.py | 2 -- 2 files changed, 31 insertions(+), 13 deletions(-) diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py index c266443..cc8f239 100644 --- a/tensorflow/python/framework/importer.py +++ b/tensorflow/python/framework/importer.py @@ -150,7 +150,7 @@ def _MaybeDevice(device): yield -def _ProcessGraphDefParam(graph_def): +def _ProcessGraphDefParam(graph_def, op_dict): """Type-checks and possibly canonicalizes `graph_def`.""" if not isinstance(graph_def, graph_pb2.GraphDef): # `graph_def` could be a dynamically-created message, so try a duck-typed @@ -161,6 +161,22 @@ def _ProcessGraphDefParam(graph_def): graph_def.MergeFrom(old_graph_def) except TypeError: raise TypeError('graph_def must be a GraphDef proto.') + else: + # If we're using the graph_def provided by the caller, modify graph_def + # in-place to add attr defaults to the NodeDefs (this is visible to the + # caller). + # NOTE(skyewm): this is undocumented behavior that at least meta_graph.py + # depends on. It might make sense to move this to meta_graph.py and have + # import_graph_def not modify the graph_def argument (we'd have to make sure + # this doesn't break anything else.) + for node in graph_def.node: + if node.op not in op_dict: + # Assume unrecognized ops are functions for now. TF_ImportGraphDef will + # report an error if the op is actually missing. + continue + op_def = op_dict[node.op] + _SetDefaultAttrValues(node, op_def) + return graph_def @@ -369,6 +385,17 @@ def _GatherReturnElements(requested_return_elements, graph, results): return combined_return_elements +def _SetDefaultAttrValues(node_def, op_def): + """Set any default attr values in `node_def` that aren't present.""" + assert node_def.op == op_def.name + for attr_def in op_def.attr: + key = attr_def.name + if attr_def.HasField('default_value'): + value = node_def.attr[key] + if value is None or value.WhichOneof('value') is None: + node_def.attr[key].CopyFrom(attr_def.default_value) + + @tf_export('import_graph_def') @deprecated_args(None, 'Please file an issue at ' 'https://github.com/tensorflow/tensorflow/issues if you depend' @@ -420,12 +447,12 @@ def import_graph_def(graph_def, do not appear in `graph_def`, or `graph_def` is not well-formed (e.g. it refers to an unknown tensor). """ - graph_def = _ProcessGraphDefParam(graph_def) + op_dict = op_def_registry.get_registered_ops() + + graph_def = _ProcessGraphDefParam(graph_def, op_dict) input_map = _ProcessInputMapParam(input_map) return_elements = _ProcessReturnElementsParam(return_elements) - op_dict = op_def_registry.get_registered_ops() - if producer_op_list is not None: # TODO(skyewm): make a copy of graph_def so we're not mutating the argument? _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def) @@ -535,16 +562,9 @@ def import_graph_def(graph_def, # Check to see if this op's name matches a previously seen op if node.name in name_to_op: raise ValueError('Duplicate name \'%s\' in GraphDef.' % node.name) - # Set any default attr values that aren't present. if node.op not in op_dict: raise ValueError('No op named %s in defined operations.' % node.op) op_def = op_dict[node.op] - for attr_def in op_def.attr: - key = attr_def.name - if attr_def.HasField('default_value'): - value = node.attr[key] - if value is None or value.WhichOneof('value') is None: - node.attr[key].CopyFrom(attr_def.default_value) output_types = _OutputTypes(node, op_dict) name_to_op[node.name] = g.create_op( diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index d1f6bc2..d9d3168 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -873,8 +873,6 @@ class SavedModelTest(test.TestCase): 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) def testStripDefaultAttrs(self): - if ops._USE_C_API: return # TODO(skyewm): get this working - export_dir = self._get_export_dir("test_strip_default_attrs") builder = saved_model_builder.SavedModelBuilder(export_dir) -- 2.7.4