From: A. Unique TensorFlower Date: Wed, 3 Jan 2018 02:19:21 +0000 (-0800) Subject: meta_graph export: Add support to strip default valued attributes. X-Git-Tag: v1.6.0-rc0~304^2~5^2~74 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=da1588ccf5c62eccac8013673359ac15b43eb394;p=platform%2Fupstream%2Ftensorflow.git meta_graph export: Add support to strip default valued attributes. Following APIs now accept an additional argument (`strip_default_attrs`) to enable/disable (default:disabled) stripping of default valued attributes in a NodeDef: o meta_graph: export_meta_graph, create_meta_graph. o saver: Saver.save, Saver.export_meta_graph. o builder: SavedModelBuilder.add_meta_graph, SavedModelBuilder.add_meta_graph_and_variables. o estimator: Estimator.export_savedmodel. Related changes: o Pywrap C++ AreAttrValuesEqual to compare two AttrValue instances. This allows for a single/canonical way of comparing AttrValues in C++/Python. o Add a utility method to meta_graph.py to get the node def by name in a graph def. o Update SavedModelBuilder documentation on relevance of strip_default_attrs flag. PiperOrigin-RevId: 180619001 --- diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 05ed8b3409..2395c7e717 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -1256,7 +1256,9 @@ class Estimator(BaseEstimator): assets_extra=None, as_text=False, checkpoint_path=None, - graph_rewrite_specs=(GraphRewriteSpec((tag_constants.SERVING,), ()),)): + graph_rewrite_specs=(GraphRewriteSpec((tag_constants.SERVING,), ()),), + strip_default_attrs=False): + # pylint: disable=line-too-long """Exports inference graph as a SavedModel into given dir. Args: @@ -1280,6 +1282,9 @@ class Estimator(BaseEstimator): produce a separate MetaGraphDef within the exported SavedModel, tagged and rewritten as specified. Defaults to a single entry using the default serving tag ("serve") and no rewriting. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). Returns: The string path to the exported directory. @@ -1287,6 +1292,7 @@ class Estimator(BaseEstimator): Raises: ValueError: if an unrecognized export_type is requested. """ + # pylint: enable=line-too-long if serving_input_fn is None: raise ValueError('serving_input_fn must be defined.') @@ -1366,7 +1372,8 @@ class Estimator(BaseEstimator): signature_def_map=signature_def_map, assets_collection=ops.get_collection( ops.GraphKeys.ASSET_FILEPATHS), - legacy_init_op=init_op) + legacy_init_op=init_op, + strip_default_attrs=strip_default_attrs) # pylint: disable=protected-access base_meta_graph_def = builder._saved_model.meta_graphs[0] diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py index 4b404a8e20..03ec66b98b 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py @@ -390,7 +390,9 @@ def make_export_strategy(serving_input_fn, default_output_alternative_key=None, assets_extra=None, as_text=False, - exports_to_keep=5): + exports_to_keep=5, + strip_default_attrs=False): + # pylint: disable=line-too-long """Create an ExportStrategy for use with Experiment. Args: @@ -411,10 +413,14 @@ def make_export_strategy(serving_input_fn, exports_to_keep: Number of exports to keep. Older exports will be garbage-collected. Defaults to 5. Set to None to disable garbage collection. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). Returns: An ExportStrategy that can be passed to the Experiment constructor. """ + # pylint: enable=line-too-long def export_fn(estimator, export_dir_base, checkpoint_path=None): """Exports the given Estimator as a SavedModel. @@ -443,7 +449,8 @@ def make_export_strategy(serving_input_fn, serving_input_fn, assets_extra=assets_extra, as_text=as_text, - checkpoint_path=checkpoint_path) + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs) else: export_result = estimator.export_savedmodel( export_dir_base, @@ -451,7 +458,8 @@ def make_export_strategy(serving_input_fn, default_output_alternative_key=default_output_alternative_key, assets_extra=assets_extra, as_text=as_text, - checkpoint_path=checkpoint_path) + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs) garbage_collect_exports(export_dir_base, exports_to_keep) return export_result @@ -464,7 +472,9 @@ def make_parsing_export_strategy(feature_columns, assets_extra=None, as_text=False, exports_to_keep=5, - target_core=False): + target_core=False, + strip_default_attrs=False): + # pylint: disable=line-too-long """Create an ExportStrategy for use with Experiment, using `FeatureColumn`s. Creates a SavedModel export that expects to be fed with a single string @@ -492,10 +502,14 @@ def make_parsing_export_strategy(feature_columns, target_core: If True, prepare an ExportStrategy for use with tensorflow.python.estimator.*. If False (default), prepare an ExportStrategy for use with tensorflow.contrib.learn.python.learn.*. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). Returns: An ExportStrategy that can be passed to the Experiment constructor. """ + # pylint: enable=line-too-long feature_spec = feature_column.create_feature_spec_for_parsing(feature_columns) if target_core: serving_input_fn = ( @@ -508,7 +522,8 @@ def make_parsing_export_strategy(feature_columns, default_output_alternative_key=default_output_alternative_key, assets_extra=assets_extra, as_text=as_text, - exports_to_keep=exports_to_keep) + exports_to_keep=exports_to_keep, + strip_default_attrs=strip_default_attrs) def _default_compare_fn(curr_best_eval_result, cand_eval_result): @@ -584,7 +599,9 @@ class BestModelSelector(object): def make_best_model_export_strategy(serving_input_fn, exports_to_keep=1, compare_fn=None, - default_output_alternative_key=None): + default_output_alternative_key=None, + strip_default_attrs=False): + # pylint: disable=line-too-long """Creates an custom ExportStrategy for use with tf.contrib.learn.Experiment. Args: @@ -596,14 +613,19 @@ def make_best_model_export_strategy(serving_input_fn, of evaluation result keyed by corresponding checkpoint path. default_output_alternative_key: the key for default serving signature for multi-headed inference graphs. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). Returns: An ExportStrategy that can be passed to the Experiment constructor. """ + # pylint: enable=line-too-long best_model_export_strategy = make_export_strategy( serving_input_fn, exports_to_keep=exports_to_keep, - default_output_alternative_key=default_output_alternative_key) + default_output_alternative_key=default_output_alternative_key, + strip_default_attrs=strip_default_attrs) best_model_selector = BestModelSelector(compare_fn) diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py index 628eb254c3..531d9c672b 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py @@ -55,7 +55,8 @@ class TestEstimator(core_estimator.Estimator): default_output_alternative_key=None, assets_extra=None, as_text=False, - checkpoint_path=None): + checkpoint_path=None, + strip_default_attrs=False): if not os.path.exists(export_dir): os.makedirs(export_dir) diff --git a/tensorflow/core/protobuf/meta_graph.proto b/tensorflow/core/protobuf/meta_graph.proto index 47ec2aa1ef..fd86c0da12 100644 --- a/tensorflow/core/protobuf/meta_graph.proto +++ b/tensorflow/core/protobuf/meta_graph.proto @@ -61,6 +61,10 @@ message MetaGraphDef { // graph. This will be populated by the framework, which will overwrite any // user supplied value. string tensorflow_git_version = 6; + + // A flag to denote whether default-valued attrs have been stripped from + // the nodes in this graph_def. + bool stripped_default_attrs = 7; } MetaInfoDef meta_info_def = 1; diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index d8e7df48c2..c037a9b122 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -91,10 +91,13 @@ limitations under the License. // 24. Deprecate lookup ops (v1) ops in favor of v2 (30may2017) // 25. Deprecate stack (v1) ops in favor of v2 (2017/6/15). // 25. Deprecate RandomPoisson (v1) ops in favor of v2 (2017/10/25). +// 26. Add a bool 'stripped_default_attrs' to MetaInfoDef indicating +// whether default-valued attrs have been stripped from the nodes in the +// GraphDef. (7dec2017) #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 24 +#define TF_GRAPH_DEF_VERSION 25 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index 3f1d63a543..1fd488e7b6 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -485,6 +485,7 @@ TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{ %unignore tensorflow; %unignore TF_Run; %unignore EqualGraphDefWrapper; +%unignore EqualAttrValueWrapper; // Include the wrapper for TF_PRunSetup from tf_session_helper.h. diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc index 2b83141faa..361dbc22b0 100644 --- a/tensorflow/python/client/tf_session_helper.cc +++ b/tensorflow/python/client/tf_session_helper.cc @@ -21,10 +21,13 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/equal_graph_def.h" #include "tensorflow/python/lib/core/ndarray_tensor.h" @@ -301,6 +304,27 @@ string EqualGraphDefWrapper(const string& actual, const string& expected) { return EqualGraphDef(actual_def, expected_def, &diff) ? "" : diff; } +string EqualAttrValueWrapper(const string& actual, const string& expected) { + AttrValue actual_attr_value; + if (!actual_attr_value.ParseFromString(actual)) { + return "actual is not a valid serialized AttrValue"; + } + + AttrValue expected_attr_value; + if (!expected_attr_value.ParseFromString(expected)) { + return "expected is not a valid serialized AttrValue"; + } + + string diff; + if (!AreAttrValuesEqual(actual_attr_value, expected_attr_value)) { + diff = strings::Printf( + "Actual AttrValue %s does not match Expected AttrValue %s.", + SummarizeAttrValue(actual_attr_value).c_str(), + SummarizeAttrValue(expected_attr_value).c_str()); + } + return diff; +} + // Return value set to 6 inlined elements so it fits in a 64-byte cache line. tensorflow::gtl::InlinedVector TF_GraphGetTensorShapeHelper( TF_Graph* graph, TF_Output output, TF_Status* out_status, diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h index 8f2499b9a0..29d5b28f40 100644 --- a/tensorflow/python/client/tf_session_helper.h +++ b/tensorflow/python/client/tf_session_helper.h @@ -97,6 +97,13 @@ void TF_Reset_wrapper(const TF_SessionOptions* opt, // for no difference. string EqualGraphDefWrapper(const string& actual, const string& expected); +// Convenience wrapper around AreAttrValuesEqual to make it easier to wrap. +// The actual and expected strings must correspond to a serialized binary +// representation of two AttrValue proto instances. +// Returns an explanation if a difference is found, or the empty string +// for no difference. +string EqualAttrValueWrapper(const string& actual, const string& expected); + // Gets shape from C API Graph object. // // If shape is known, returns shape vector where -1 means "unknown diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 1e3d6d5755..c72d37b442 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -461,7 +461,8 @@ class Estimator(object): self, export_dir_base, serving_input_receiver_fn, assets_extra=None, as_text=False, - checkpoint_path=None): + checkpoint_path=None, + strip_default_attrs=False): # pylint: disable=line-too-long """Exports inference graph as a SavedModel into given dir. @@ -503,6 +504,9 @@ class Estimator(object): as_text: whether to write the SavedModel proto in text format. checkpoint_path: The checkpoint path to export. If `None` (the default), the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). Returns: The string path to the exported directory. @@ -563,7 +567,8 @@ class Estimator(object): signature_def_map=signature_def_map, assets_collection=ops.get_collection( ops.GraphKeys.ASSET_FILEPATHS), - legacy_init_op=local_init_op) + legacy_init_op=local_init_op, + strip_default_attrs=strip_default_attrs) builder.save(as_text) # Add the extra assets diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index db64fbc9cc..58d0cb0018 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -40,6 +40,7 @@ from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.layers import layers from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops @@ -57,6 +58,7 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import loader +from tensorflow.python.saved_model import loader_impl from tensorflow.python.saved_model import tag_constants from tensorflow.python.summary import summary from tensorflow.python.summary import summary_iterator @@ -2050,6 +2052,65 @@ class EstimatorExportTest(test.TestCase): gfile.DeleteRecursively(tmpdir) + def test_export_savedmodel_proto_strip_default_attrs(self): + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_for_export_tests) + est.train(input_fn=dummy_input_fn, steps=1) + feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64), + 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)} + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dir_stripped = est.export_savedmodel( + export_dir_base, serving_input_receiver_fn, strip_default_attrs=True) + export_dir_not_stripped = est.export_savedmodel( + export_dir_base, serving_input_receiver_fn, strip_default_attrs=False) + + # Load the SavedModel from disk as-is to verify default attrs + # are stripped. Reimporting the SavedModel via the loader causes the + # default attrs to be populated in the NodeDefs. + + # pylint: disable=protected-access + saved_model_stripped_pb = loader_impl._parse_saved_model( + export_dir_stripped) + saved_model_not_stripped_pb = loader_impl._parse_saved_model( + export_dir_not_stripped) + self.assertIsNotNone(saved_model_stripped_pb) + self.assertIsNotNone(saved_model_not_stripped_pb) + # pylint: enable=protected-access + + meta_graph_def_stripped = [ + x for x in saved_model_stripped_pb.meta_graphs + if x.meta_info_def.tags == [tag_constants.SERVING]][0] + meta_graph_def_not_stripped = [ + x for x in saved_model_not_stripped_pb.meta_graphs + if x.meta_info_def.tags == [tag_constants.SERVING]][0] + + # "weight" node in graph is a "Variable" Op with 2 default valued attrs. + # o "container" : "". + # o "shared_name" : "". + + # saved_model_stripped_pb was exported with strip_default_attrs set to True. + # "weight" node shouldn't have attributes "container" and "shared_name". + node_def = test_util.get_node_def_from_graph( + 'weight', meta_graph_def_stripped.graph_def) + self.assertNotIn('container', node_def.attr) + self.assertNotIn('shared_name', node_def.attr) + + # saved_model_not_stripped_pb was exported with strip_default_attrs + # disabled. "weight" node should have attributes "container" and + # "shared_name". + node_def = test_util.get_node_def_from_graph( + 'weight', meta_graph_def_not_stripped.graph_def) + self.assertIn('container', node_def.attr) + self.assertIn('shared_name', node_def.attr) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + class EstimatorHookOrderingTest(test.TestCase): diff --git a/tensorflow/python/framework/meta_graph.py b/tensorflow/python/framework/meta_graph.py index c839d7a9a6..65032637d4 100644 --- a/tensorflow/python/framework/meta_graph.py +++ b/tensorflow/python/framework/meta_graph.py @@ -31,6 +31,7 @@ from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import op_def_pb2 from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.framework import graph_io from tensorflow.python.framework import importer @@ -442,6 +443,67 @@ def add_collection_def(meta_graph_def, key, graph=None, return +def _is_default_attr_value(op_def, attr_name, attr_value): + """Checks if given attribute matches the default value in the op def.""" + for attr_def in op_def.attr: + if attr_def.name == attr_name: + if not attr_def.HasField("default_value"): + return False + # pywrap_tensorflow.EqualAttrValueWrapper returns an empty string + # if both arguments represent an equivalent AttrValue instance. + return not pywrap_tensorflow.EqualAttrValueWrapper( + attr_value.SerializeToString(), + attr_def.default_value.SerializeToString()) + return False + + +def _strip_graph_default_valued_attrs(meta_graph_def): + """Strips default valued attributes for node defs in given MetaGraphDef. + + This method also sets `meta_info_def.stripped_default_attrs` in the given + `MetaGraphDef` proto to True. + + Args: + meta_graph_def: `MetaGraphDef` protocol buffer + + Returns: + None. + """ + # Map function op names to their function definitions. + op_name_to_function = {} + for function_def in meta_graph_def.graph_def.library.function: + op_name_to_function[function_def.signature.name] = function_def + + # Get all registered ops. + registered_ops = op_def_registry.get_registered_ops() + + def _strip_node_default_valued_attrs(node_def): + """Removes default valued attributes from a single node def.""" + if node_def.op in op_name_to_function or node_def.op not in registered_ops: + return + op_def = registered_ops[node_def.op] + + attrs_to_strip = set() + for attr_name, attr_value in node_def.attr.items(): + if _is_default_attr_value(op_def, attr_name, attr_value): + attrs_to_strip.add(attr_name) + + for attr in attrs_to_strip: + del node_def.attr[attr] + + # Process all NodeDef instances in graph_def. + for node_def in meta_graph_def.graph_def.node: + _strip_node_default_valued_attrs(node_def) + + # Process all NodeDef instances in graph_def.library.function. + for function_def in meta_graph_def.graph_def.library.function: + for function_node_def in function_def.node_def: + _strip_node_default_valued_attrs(function_node_def) + + # Tell consumers of this graph that default valued attrs have been stripped. + meta_graph_def.meta_info_def.stripped_default_attrs = True + + def create_meta_graph_def(meta_info_def=None, graph_def=None, saver_def=None, @@ -449,7 +511,9 @@ def create_meta_graph_def(meta_info_def=None, graph=None, export_scope=None, exclude_nodes=None, - clear_extraneous_savers=False): + clear_extraneous_savers=False, + strip_default_attrs=False): + # pylint: disable=line-too-long """Construct and returns a `MetaGraphDef` protocol buffer. Args: @@ -464,12 +528,17 @@ def create_meta_graph_def(meta_info_def=None, clear_extraneous_savers: Remove any preexisting SaverDefs from the SAVERS collection. Note this method does not alter the graph, so any extraneous Save/Restore ops should have been removed already, as needed. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + Returns: MetaGraphDef protocol buffer. Raises: TypeError: If the arguments are not of the correct proto buffer type. """ + # pylint: enable=line-too-long # Type check. if graph and not isinstance(graph, ops.Graph): raise TypeError("graph must be of type Graph, not %s", type(graph)) @@ -511,6 +580,10 @@ def create_meta_graph_def(meta_info_def=None, stripped_op_list_for_graph(meta_graph_def.graph_def)) # pylint: enable=g-explicit-length-test + # Strip default valued attributes in graph_def. + if strip_default_attrs: + _strip_graph_default_valued_attrs(meta_graph_def) + # Adds saver_def. if saver_def: meta_graph_def.saver_def.MergeFrom(saver_def) @@ -724,6 +797,7 @@ def export_scoped_meta_graph(filename=None, clear_devices=False, saver_def=None, clear_extraneous_savers=False, + strip_default_attrs=False, **kwargs): """Returns `MetaGraphDef` proto. Optionally writes it to filename. @@ -752,6 +826,8 @@ def export_scoped_meta_graph(filename=None, clear_extraneous_savers: Remove any Saver-related information from the graph (both Save/Restore ops and SaverDefs) that are not associated with the provided SaverDef. + strip_default_attrs: Set to true if default valued attributes must be + removed while exporting the GraphDef. **kwargs: Optional keyed arguments, including meta_info_def and collection_list. @@ -837,6 +913,7 @@ def export_scoped_meta_graph(filename=None, exclude_nodes=exclude_nodes, clear_extraneous_savers=clear_extraneous_savers, saver_def=saver_def, + strip_default_attrs=strip_default_attrs, **kwargs) if filename: @@ -881,3 +958,5 @@ def copy_scoped_meta_graph(from_scope, to_scope, graph=to_graph, import_scope=to_scope) return var_list + + diff --git a/tensorflow/python/framework/meta_graph_test.py b/tensorflow/python/framework/meta_graph_test.py index 4c22c913b8..ae8c9ea2a4 100644 --- a/tensorflow/python/framework/meta_graph_test.py +++ b/tensorflow/python/framework/meta_graph_test.py @@ -24,6 +24,7 @@ import random import shutil from tensorflow.core.framework import graph_pb2 +from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function @@ -154,6 +155,108 @@ class SimpleMetaGraphTest(test.TestCase): op_list = meta_graph.stripped_op_list_for_graph(graph) self.assertEqual(["Const"], [op.name for op in op_list.op]) + def testDefaultAttrStripping(self): + """Verifies that default attributes are stripped from a graph def.""" + + # Complex Op has 2 attributes with defaults: + # o "T" : float32. + # o "Tout" : complex64. + + # When inputs to the Complex Op are float32 instances, "T" maps to float32 + # and "Tout" maps to complex64. Since these attr values map to their + # defaults, they must be stripped unless stripping of default attrs is + # disabled. + with self.test_session(): + real_num = constant_op.constant(1.0, dtype=dtypes.float32, name="real") + imag_num = constant_op.constant(2.0, dtype=dtypes.float32, name="imag") + math_ops.complex(real_num, imag_num, name="complex") + + # strip_default_attrs is enabled. + meta_graph_def, _ = meta_graph.export_scoped_meta_graph( + graph_def=ops.get_default_graph().as_graph_def(), + strip_default_attrs=True) + node_def = test_util.get_node_def_from_graph("complex", + meta_graph_def.graph_def) + self.assertNotIn("T", node_def.attr) + self.assertNotIn("Tout", node_def.attr) + self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs) + + # strip_default_attrs is disabled. + meta_graph_def, _ = meta_graph.export_scoped_meta_graph( + graph_def=ops.get_default_graph().as_graph_def(), + strip_default_attrs=False) + node_def = test_util.get_node_def_from_graph("complex", + meta_graph_def.graph_def) + self.assertIn("T", node_def.attr) + self.assertIn("Tout", node_def.attr) + self.assertFalse(meta_graph_def.meta_info_def.stripped_default_attrs) + + # When inputs to the Complex Op are float64 instances, "T" maps to float64 + # and "Tout" maps to complex128. Since these attr values don't map to their + # defaults, they must not be stripped. + with self.test_session(graph=ops.Graph()): + real_num = constant_op.constant(1.0, dtype=dtypes.float64, name="real") + imag_num = constant_op.constant(2.0, dtype=dtypes.float64, name="imag") + math_ops.complex(real_num, imag_num, name="complex") + meta_graph_def, _ = meta_graph.export_scoped_meta_graph( + graph_def=ops.get_default_graph().as_graph_def(), + strip_default_attrs=True) + node_def = test_util.get_node_def_from_graph("complex", + meta_graph_def.graph_def) + self.assertEqual(node_def.attr["T"].type, dtypes.float64) + self.assertEqual(node_def.attr["Tout"].type, dtypes.complex128) + self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs) + + def testDefaultAttrStrippingNestedFunctions(self): + """Verifies that default attributes are stripped from function node defs.""" + with self.test_session(): + @function.Defun(dtypes.float32, dtypes.float32) + def f0(i, j): + return math_ops.complex(i, j, name="double_nested_complex") + + @function.Defun(dtypes.float32, dtypes.float32) + def f1(i, j): + return f0(i, j) + + _ = f1(constant_op.constant(1.0), constant_op.constant(2.0)) + meta_graph_def, _ = meta_graph.export_scoped_meta_graph( + graph_def=ops.get_default_graph().as_graph_def(), + strip_default_attrs=True) + + double_nested_complex_node_def = None + for function_def in meta_graph_def.graph_def.library.function: + for node_def in function_def.node_def: + if node_def.name == "double_nested_complex": + double_nested_complex_node_def = node_def + break + if double_nested_complex_node_def: + break + + self.assertIsNotNone(double_nested_complex_node_def) + self.assertNotIn("T", double_nested_complex_node_def.attr) + self.assertNotIn("Tout", double_nested_complex_node_def.attr) + self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs) + + def testDefaultAttrStrippingUnregisteredOps(self): + """Verifies that nodes with un-registered ops are not stripped.""" + graph_def = graph_pb2.GraphDef() + node = graph_def.node.add() + node.name = "node_with_unreg_op" + node.op = "unreg_op" + node.attr["attr_1"].i = 1 + + meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef() + meta_info_def.stripped_op_list.op.add() + + with self.test_session(): + meta_graph_def = meta_graph.create_meta_graph_def( + meta_info_def=meta_info_def, graph_def=graph_def, + strip_default_attrs=True) + node_def = test_util.get_node_def_from_graph("node_with_unreg_op", + meta_graph_def.graph_def) + self.assertEqual(node_def.attr["attr_1"].i, 1) + self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs) + class ScopedMetaGraphTest(test.TestCase): diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 7627fb3e69..5ac3053749 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -1369,3 +1369,21 @@ def create_local_cluster(num_workers, num_ps, protocol="grpc", ] return workers, ps_servers + + +def get_node_def_from_graph(node_name, graph_def): + """Returns the `NodeDef` instance for given node name in the graph def. + + This method explores only the NodeDefs in `graph_def.node`. + + Args: + node_name: Name of the NodeDef to search for. + graph_def: An instance of `GraphDef` proto. + + Returns: + the `NodeDef` instance whose name field matches the given node_name or None. + """ + for node_def in graph_def.node: + if node_def.name == node_name: + return node_def + return None diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py index f6aed118ca..4af717cca6 100644 --- a/tensorflow/python/framework/test_util_test.py +++ b/tensorflow/python/framework/test_util_test.py @@ -349,6 +349,13 @@ class TestUtilTest(test_util.TensorFlowTestCase): self.assertEqual(expected, self.evaluate(nested)) + def test_get_node_def_from_graph(self): + graph_def = graph_pb2.GraphDef() + node_foo = graph_def.node.add() + node_foo.name = "foo" + self.assertIs(test_util.get_node_def_from_graph("foo", graph_def), node_foo) + self.assertIsNone(test_util.get_node_def_from_graph("bar", graph_def)) + class GarbageCollectionTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/saved_model/README.md b/tensorflow/python/saved_model/README.md index 8c78013ffd..5eeaf73a43 100644 --- a/tensorflow/python/saved_model/README.md +++ b/tensorflow/python/saved_model/README.md @@ -117,6 +117,35 @@ with tf.Session(graph=tf.Graph()) as sess: builder.save() ~~~ +#### Stripping Default valued attributes +The SavedModelBuilder class allows users to control whether default-valued +attributes must be stripped from the NodeDefs while adding a meta graph to the +SavedModel bundle. Both `SavedModelBuilder.add_meta_graph_and_variables` and +`SavedModelBuilder.add_meta_graph` methods accept a Boolean flag +`strip_default_attrs` that controls this behavior. + +If `strip_default_attrs` is `False`, the exported MetaGraphDef will have the +default valued attributes in all it's NodeDef instances. This can break forward +compatibility with a sequence of events such as the following: + +* An existing Op (`Foo`) is updated to include a new attribute (`T`) with a + default (`bool`) at version 101. +* A model producer (such as a Trainer) binary picks up this change + (version 101) to the OpDef and re-exports an existing model that uses Op `Foo`. +* A model consumer (such as Tensorflow Serving) running an older binary + (version 100) doesn't have attribute `T` for Op `Foo`, but tries to import + this model. The model consumer doesn't recognize attribute `T` in a NodeDef + that uses Op `Foo` and therefore fails to load the model. + +By setting `strip_default_attrs` to `True`, the model producers can strip away +any default valued attributes in the NodeDefs. This helps ensure that newly +added attributes with defaults don't cause older model consumers to fail loading +models regenerated with newer training binaries. + +TIP: If you care about forward compatibility, then set `strip_default_attrs` +to `True` while using `SavedModelBuilder.add_meta_graph_and_variables` and +`SavedModelBuilder.add_meta_graph`. + ### Loader The SavedModel loader is implemented in C++ and Python. diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py index 16651ffebc..62ee53b816 100644 --- a/tensorflow/python/saved_model/builder_impl.py +++ b/tensorflow/python/saved_model/builder_impl.py @@ -239,7 +239,9 @@ class SavedModelBuilder(object): assets_collection=None, legacy_init_op=None, clear_devices=False, - main_op=None): + main_op=None, + strip_default_attrs=False): + # pylint: disable=line-too-long """Adds the current meta graph to the SavedModel. Creates a Saver in the current scope and uses the Saver to export the meta @@ -260,11 +262,15 @@ class SavedModelBuilder(object): main_op: Op or group of ops to execute when the graph is loaded. Note that when the main_op is specified it is run after the restore op at load-time. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). Raises: AssertionError: If the variables for the SavedModel have not been saved yet, or if the graph already contains one or more legacy init ops. """ + # pylint: enable=line-too-long if not self._has_saved_variables: raise AssertionError( "Graph state including variables and assets has not been saved yet. " @@ -299,7 +305,8 @@ class SavedModelBuilder(object): # there are edge cases where that option breaks the graph. Until that is # resolved, we just leave the option set to False for now. # TODO(soergel): Reinstate clear_extraneous_savers=True when possible. - meta_graph_def = saver.export_meta_graph(clear_devices=clear_devices) + meta_graph_def = saver.export_meta_graph( + clear_devices=clear_devices, strip_default_attrs=strip_default_attrs) # Tag the meta graph def and add it to the SavedModel. self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) @@ -311,7 +318,9 @@ class SavedModelBuilder(object): assets_collection=None, legacy_init_op=None, clear_devices=False, - main_op=None): + main_op=None, + strip_default_attrs=False): + # pylint: disable=line-too-long """Adds the current meta graph to the SavedModel and saves variables. Creates a Saver to save the variables from the provided session. Exports the @@ -334,7 +343,11 @@ class SavedModelBuilder(object): main_op: Op or group of ops to execute when the graph is loaded. Note that when the main_op is specified it is run after the restore op at load-time. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). """ + # pylint: enable=line-too-long if self._has_saved_variables: raise AssertionError("Graph state including variables and assets has " "already been saved. Please invoke " @@ -388,7 +401,8 @@ class SavedModelBuilder(object): # there are edge cases where that option breaks the graph. Until that is # resolved, we just leave the option set to False for now. # TODO(soergel): Reinstate clear_extraneous_savers=True when possible. - meta_graph_def = saver.export_meta_graph(clear_devices=clear_devices) + meta_graph_def = saver.export_meta_graph( + clear_devices=clear_devices, strip_default_attrs=strip_default_attrs) # Tag the meta graph def and add it to the SavedModel. self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index 92ca7dec6f..1ea619ff55 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -20,13 +20,17 @@ from __future__ import print_function import os +from tensorflow.core.framework import op_def_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.client import session from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.lib.io import file_io from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops @@ -36,6 +40,7 @@ from tensorflow.python.platform import test from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import constants from tensorflow.python.saved_model import loader +from tensorflow.python.saved_model import loader_impl from tensorflow.python.saved_model import main_op from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import tag_constants @@ -865,6 +870,132 @@ class SavedModelTest(test.TestCase): self.assertEqual( 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) + def testStripDefaultAttrs(self): + export_dir = os.path.join(test.get_temp_dir(), "test_strip_default_attrs") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + # Add a graph with two float32 variables and a Complex Op composing them + # with strip_default_attrs enabled. + with session.Session(graph=ops.Graph()) as sess: + real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real") + imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag") + math_ops.complex(real_num, imag_num, name="complex") + sess.run(variables.global_variables_initializer()) + builder.add_meta_graph_and_variables( + sess, ["foo"], strip_default_attrs=True) + + # Add a graph with the same float32 variables and a Complex Op composing + # them with strip_default_attrs disabled. + with session.Session(graph=ops.Graph()) as sess: + real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real") + imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag") + math_ops.complex(real_num, imag_num, name="complex") + sess.run(variables.global_variables_initializer()) + builder.add_meta_graph(["bar"], strip_default_attrs=False) + + # Save the SavedModel to disk in text format. + builder.save(as_text=True) + + # Loading graph "foo" via the loader must restore the defaults for the + # "Complex" node based on the "Complex" OpDef in the Op registry. + sess = session.Session(graph=ops.Graph()) + meta_graph_def = loader.load(sess, ["foo"], export_dir) + complex_node = test_util.get_node_def_from_graph("complex", + meta_graph_def.graph_def) + self.assertIn("T", complex_node.attr) + self.assertIn("Tout", complex_node.attr) + + # Load graph "foo" from disk as-is to verify default attrs are stripped. + # pylint: disable=protected-access + saved_model_pb = loader_impl._parse_saved_model(export_dir) + self.assertIsNotNone(saved_model_pb) + # pylint: enable=protected-access + + meta_graph_foo_def = None + meta_graph_bar_def = None + for meta_graph_def in saved_model_pb.meta_graphs: + if set(meta_graph_def.meta_info_def.tags) == set(["foo"]): + meta_graph_foo_def = meta_graph_def + elif set(meta_graph_def.meta_info_def.tags) == set(["bar"]): + meta_graph_bar_def = meta_graph_def + + self.assertIsNotNone(meta_graph_foo_def) + self.assertIsNotNone(meta_graph_bar_def) + + # "Complex" Op has 2 attributes with defaults: + # o "T" : float32. (input type) + # o "Tout" : complex64. (output type) + + # "Complex" Op in graph "foo" shouldn't have attributes "T" and "Tout". + # Graph "foo" was saved with strip_default_attrs set to True. + node_def = test_util.get_node_def_from_graph("complex", + meta_graph_foo_def.graph_def) + self.assertNotIn("T", node_def.attr) + self.assertNotIn("Tout", node_def.attr) + + # "Complex" Op in graph "bar" must have attributes "T" and "Tout". + # Graph "bar" was saved with strip_default_attrs set to False. + node_def = test_util.get_node_def_from_graph("complex", + meta_graph_bar_def.graph_def) + self.assertIn("T", node_def.attr) + self.assertIn("Tout", node_def.attr) + + def testStripDefaultAttrsInconsistentConsumerDefaults(self): + export_dir = os.path.join(test.get_temp_dir(), + "test_strip_default_attrs_no_consumer_defaults") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + # Add a graph with two float32 variables and a Complex Op composing them + # with strip_default_attrs enabled. This must remove the following + # defaults for the "Complex" Op: + # o "T" : float32. (input type) + # o "Tout" : complex64. (output type) + with session.Session(graph=ops.Graph()) as sess: + real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real") + imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag") + math_ops.complex(real_num, imag_num, name="complex") + sess.run(variables.global_variables_initializer()) + builder.add_meta_graph_and_variables( + sess, ["foo"], strip_default_attrs=True) + + # Save the SavedModel to disk in text format. + builder.save(as_text=True) + + # Update the Op registry to remove defaults for all attrs("T", "Tout") from + # the "Complex" OpDef. + complex_op_def = op_def_registry.get_registered_ops()["Complex"] + original_complex_op_def = op_def_pb2.OpDef() + original_complex_op_def.CopyFrom(complex_op_def) + for attr_def in complex_op_def.attr: + attr_def.ClearField("default_value") + + # Loading the SavedModel via the loader must fail because the SavedModel + # does not have any attr values for the "Complex" node and the current + # op registry does not have have any default values for the "Complex" op. + sess = session.Session(graph=ops.Graph()) + with self.assertRaisesRegexp( + ValueError, + "Expected one attr with name .*T(out)?.* in name: \"complex\".*"): + loader.load(sess, ["foo"], export_dir) + + # Update the Op registry to change the defaults for attr "Tout" + # (complex64 -> complex128). + complex_op_def.CopyFrom(original_complex_op_def) + for attr_def in complex_op_def.attr: + if attr_def.name == "Tout": + attr_def.default_value.type = types_pb2.DT_COMPLEX128 + + # Loading the SavedModel via the loader must set "Tout" attr_value for the + # "Complex" node according to the latest defaults (complex128). This is + # expected to fail the model import as there is no OpKernel registered to + # handle attrs "T" (float32) and "Tout" (complex128). + sess = session.Session(graph=ops.Graph()) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + ".*No OpKernel was registered to support Op \'Complex\' with these " + "attrs..*"): + loader.load(sess, ["foo"], export_dir) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index ba6301e785..2330229d56 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -1509,7 +1509,9 @@ class Saver(object): latest_filename=None, meta_graph_suffix="meta", write_meta_graph=True, - write_state=True): + write_state=True, + strip_default_attrs=False): + # pylint: disable=line-too-long """Saves variables. This method runs the ops added by the constructor for saving variables. @@ -1535,6 +1537,9 @@ class Saver(object): graph file. write_state: `Boolean` indicating whether or not to write the `CheckpointStateProto`. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). Returns: A string: path prefix used for the checkpoint files. If the saver is @@ -1548,6 +1553,7 @@ class Saver(object): collides with `save_path`. RuntimeError: If save and restore ops weren't built. """ + # pylint: enable=line-too-long if not self._is_built and context.in_graph_mode(): raise RuntimeError( "`build()` should be called before save if defer_build==True") @@ -1618,7 +1624,8 @@ class Saver(object): checkpoint_file, meta_graph_suffix=meta_graph_suffix) if context.in_graph_mode(): with sess.graph.as_default(): - self.export_meta_graph(meta_graph_filename) + self.export_meta_graph( + meta_graph_filename, strip_default_attrs=strip_default_attrs) if self._is_empty: return None @@ -1631,7 +1638,9 @@ class Saver(object): as_text=False, export_scope=None, clear_devices=False, - clear_extraneous_savers=False): + clear_extraneous_savers=False, + strip_default_attrs=False): + # pylint: disable=line-too-long """Writes `MetaGraphDef` to save_path/filename. Args: @@ -1644,10 +1653,14 @@ class Saver(object): clear_extraneous_savers: Remove any Saver-related information from the graph (both Save/Restore ops and SaverDefs) that are not associated with this Saver. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). Returns: A `MetaGraphDef` proto. """ + # pylint: enable=line-too-long return export_meta_graph( filename=filename, graph_def=ops.get_default_graph().as_graph_def(add_shapes=True), @@ -1656,7 +1669,8 @@ class Saver(object): as_text=as_text, export_scope=export_scope, clear_devices=clear_devices, - clear_extraneous_savers=clear_extraneous_savers) + clear_extraneous_savers=clear_extraneous_savers, + strip_default_attrs=strip_default_attrs) def restore(self, sess, save_path): """Restores previously saved variables. @@ -1859,7 +1873,9 @@ def export_meta_graph(filename=None, export_scope=None, clear_devices=False, clear_extraneous_savers=False, + strip_default_attrs=False, **kwargs): + # pylint: disable=line-too-long """Returns `MetaGraphDef` proto. Optionally writes it to filename. This function exports the graph, saver, and collection objects into @@ -1885,6 +1901,9 @@ def export_meta_graph(filename=None, clear_extraneous_savers: Remove any Saver-related information from the graph (both Save/Restore ops and SaverDefs) that are not associated with the provided SaverDef. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). **kwargs: Optional keyed arguments. Returns: @@ -1899,6 +1918,7 @@ def export_meta_graph(filename=None, execution is enabled. @end_compatibility """ + # pylint: enable=line-too-long if context.in_eager_mode(): raise RuntimeError("Exporting/importing meta graphs is not supported when " "eager execution is enabled. No graph exists when eager " @@ -1914,6 +1934,7 @@ def export_meta_graph(filename=None, export_scope=export_scope, clear_devices=clear_devices, clear_extraneous_savers=clear_extraneous_savers, + strip_default_attrs=strip_default_attrs, **kwargs) return meta_graph_def diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index 207e4a2842..0889ac2516 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -2065,6 +2065,42 @@ class MetaGraphTest(test.TestCase): self.assertEqual(o.summary, "") self.assertEqual(o.description, "") + def testStripDefaultValuedAttrs(self): + """Verifies that default valued attrs are stripped, unless disabled.""" + + # With strip_default_attrs enabled, attributes "T" (float32) and "Tout" + # (complex64) in the "Complex" op must be removed. + with self.test_session(): + real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real") + imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag") + math_ops.complex(real_num, imag_num, name="complex") + + save = saver_module.Saver({"real_num": real_num, "imag_num": imag_num}) + variables.global_variables_initializer() + + meta_graph_def = save.export_meta_graph(strip_default_attrs=True) + node_def = test_util.get_node_def_from_graph("complex", + meta_graph_def.graph_def) + self.assertNotIn("T", node_def.attr) + self.assertNotIn("Tout", node_def.attr) + + # With strip_default_attrs disabled, attributes "T" (float32) and "Tout" + # (complex64) in the "Complex" op must *not* be removed, even if they map + # to their defaults. + with self.test_session(graph=ops_lib.Graph()): + real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real") + imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag") + math_ops.complex(real_num, imag_num, name="complex") + + save = saver_module.Saver({"real_num": real_num, "imag_num": imag_num}) + variables.global_variables_initializer() + + meta_graph_def = save.export_meta_graph(strip_default_attrs=False) + node_def = test_util.get_node_def_from_graph("complex", + meta_graph_def.graph_def) + self.assertIn("T", node_def.attr) + self.assertIn("Tout", node_def.attr) + def testImportIntoNamescope(self): # Test that we can import a meta graph into a namescope. test_dir = self._get_test_dir("import_into_namescope") diff --git a/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-meta-info-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-meta-info-def.pbtxt index ebf49f434a..b0e9831154 100644 --- a/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-meta-info-def.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-meta-info-def.pbtxt @@ -18,6 +18,10 @@ tf_class { name: "META_GRAPH_VERSION_FIELD_NUMBER" mtype: "" } + member { + name: "STRIPPED_DEFAULT_ATTRS_FIELD_NUMBER" + mtype: "" + } member { name: "STRIPPED_OP_LIST_FIELD_NUMBER" mtype: "" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt index f5ed263f0e..ab697b1b95 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt @@ -29,7 +29,7 @@ tf_class { } member_method { name: "export_savedmodel" - argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], " } member_method { name: "get_variable_names" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt index 61a29942c5..b73f6433e2 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt @@ -29,7 +29,7 @@ tf_class { } member_method { name: "export_savedmodel" - argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], " } member_method { name: "get_variable_names" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt index 16e3b24615..24db86c92b 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt @@ -29,7 +29,7 @@ tf_class { } member_method { name: "export_savedmodel" - argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], " } member_method { name: "get_variable_names" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt index c6765ae277..47ee2ac51b 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt @@ -29,7 +29,7 @@ tf_class { } member_method { name: "export_savedmodel" - argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], " } member_method { name: "get_variable_names" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt index e3a820db46..fbfaa69a1b 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt @@ -29,7 +29,7 @@ tf_class { } member_method { name: "export_savedmodel" - argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], " } member_method { name: "get_variable_names" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt index a4c8cf6671..faf55cda86 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt @@ -29,7 +29,7 @@ tf_class { } member_method { name: "export_savedmodel" - argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], " } member_method { name: "get_variable_names" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt index 787952eced..d0bf043754 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt @@ -28,7 +28,7 @@ tf_class { } member_method { name: "export_savedmodel" - argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], " } member_method { name: "get_variable_names" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt index 99c03aa629..6aec1d3a51 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt @@ -29,7 +29,7 @@ tf_class { } member_method { name: "export_savedmodel" - argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], " } member_method { name: "get_variable_names" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt index e2ab96d5b4..9d8c7bb138 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt @@ -29,7 +29,7 @@ tf_class { } member_method { name: "export_savedmodel" - argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], " } member_method { name: "get_variable_names" diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.builder.-saved-model-builder.pbtxt b/tensorflow/tools/api/golden/tensorflow.saved_model.builder.-saved-model-builder.pbtxt index 56d76902fd..ca8e5884b1 100644 --- a/tensorflow/tools/api/golden/tensorflow.saved_model.builder.-saved-model-builder.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.saved_model.builder.-saved-model-builder.pbtxt @@ -8,11 +8,11 @@ tf_class { } member_method { name: "add_meta_graph" - argspec: "args=[\'self\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], " + argspec: "args=[\'self\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], " } member_method { name: "add_meta_graph_and_variables" - argspec: "args=[\'self\', \'sess\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], " + argspec: "args=[\'self\', \'sess\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], " } member_method { name: "save" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-saver.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-saver.pbtxt index 04c11712cd..2cda458f46 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-saver.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-saver.pbtxt @@ -20,7 +20,7 @@ tf_class { } member_method { name: "export_meta_graph" - argspec: "args=[\'self\', \'filename\', \'collection_list\', \'as_text\', \'export_scope\', \'clear_devices\', \'clear_extraneous_savers\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\', \'False\', \'False\'], " + argspec: "args=[\'self\', \'filename\', \'collection_list\', \'as_text\', \'export_scope\', \'clear_devices\', \'clear_extraneous_savers\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\', \'False\', \'False\', \'False\'], " } member_method { name: "from_proto" @@ -36,7 +36,7 @@ tf_class { } member_method { name: "save" - argspec: "args=[\'self\', \'sess\', \'save_path\', \'global_step\', \'latest_filename\', \'meta_graph_suffix\', \'write_meta_graph\', \'write_state\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'meta\', \'True\', \'True\'], " + argspec: "args=[\'self\', \'sess\', \'save_path\', \'global_step\', \'latest_filename\', \'meta_graph_suffix\', \'write_meta_graph\', \'write_state\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'meta\', \'True\', \'True\', \'False\'], " } member_method { name: "set_last_checkpoints" diff --git a/tensorflow/tools/api/golden/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.pbtxt index 3ffc640730..b2ef17b39e 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.pbtxt @@ -282,7 +282,7 @@ tf_module { } member_method { name: "export_meta_graph" - argspec: "args=[\'filename\', \'meta_info_def\', \'graph_def\', \'saver_def\', \'collection_list\', \'as_text\', \'graph\', \'export_scope\', \'clear_devices\', \'clear_extraneous_savers\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'False\', \'False\'], " + argspec: "args=[\'filename\', \'meta_info_def\', \'graph_def\', \'saver_def\', \'collection_list\', \'as_text\', \'graph\', \'export_scope\', \'clear_devices\', \'clear_extraneous_savers\', \'strip_default_attrs\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'False\', \'False\', \'False\'], " } member_method { name: "generate_checkpoint_state_proto"