meta_graph export: Add support to strip default valued attributes.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 3 Jan 2018 02:19:21 +0000 (18:19 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 3 Jan 2018 02:28:12 +0000 (18:28 -0800)
Following APIs now accept an additional argument (`strip_default_attrs`) to
enable/disable (default:disabled) stripping of default valued attributes in a NodeDef:
  o meta_graph: export_meta_graph, create_meta_graph.
  o saver: Saver.save, Saver.export_meta_graph.
  o builder: SavedModelBuilder.add_meta_graph,
             SavedModelBuilder.add_meta_graph_and_variables.
  o estimator: Estimator.export_savedmodel.

Related changes:
  o Pywrap C++ AreAttrValuesEqual to compare two AttrValue instances.
    This allows for a single/canonical way of comparing AttrValues in C++/Python.
  o Add a utility method to meta_graph.py to get the node def by name in a graph def.
  o Update SavedModelBuilder documentation on relevance of strip_default_attrs flag.

PiperOrigin-RevId: 180619001

32 files changed:
tensorflow/contrib/learn/python/learn/estimators/estimator.py
tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py
tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py
tensorflow/core/protobuf/meta_graph.proto
tensorflow/core/public/version.h
tensorflow/python/client/tf_session.i
tensorflow/python/client/tf_session_helper.cc
tensorflow/python/client/tf_session_helper.h
tensorflow/python/estimator/estimator.py
tensorflow/python/estimator/estimator_test.py
tensorflow/python/framework/meta_graph.py
tensorflow/python/framework/meta_graph_test.py
tensorflow/python/framework/test_util.py
tensorflow/python/framework/test_util_test.py
tensorflow/python/saved_model/README.md
tensorflow/python/saved_model/builder_impl.py
tensorflow/python/saved_model/saved_model_test.py
tensorflow/python/training/saver.py
tensorflow/python/training/saver_test.py
tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-meta-info-def.pbtxt
tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt
tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt
tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt
tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt
tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt
tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt
tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt
tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt
tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt
tensorflow/tools/api/golden/tensorflow.saved_model.builder.-saved-model-builder.pbtxt
tensorflow/tools/api/golden/tensorflow.train.-saver.pbtxt
tensorflow/tools/api/golden/tensorflow.train.pbtxt

index 05ed8b3409e68ae54e5ef89b3a1592a6f285565b..2395c7e7172c5b63db4900524a64c11d78b079b5 100644 (file)
@@ -1256,7 +1256,9 @@ class Estimator(BaseEstimator):
       assets_extra=None,
       as_text=False,
       checkpoint_path=None,
-      graph_rewrite_specs=(GraphRewriteSpec((tag_constants.SERVING,), ()),)):
+      graph_rewrite_specs=(GraphRewriteSpec((tag_constants.SERVING,), ()),),
+      strip_default_attrs=False):
+    # pylint: disable=line-too-long
     """Exports inference graph as a SavedModel into given dir.
 
     Args:
@@ -1280,6 +1282,9 @@ class Estimator(BaseEstimator):
         produce a separate MetaGraphDef within the exported SavedModel, tagged
         and rewritten as specified.  Defaults to a single entry using the
         default serving tag ("serve") and no rewriting.
+      strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+        removed from the NodeDefs. For a detailed guide, see
+        [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
 
     Returns:
       The string path to the exported directory.
@@ -1287,6 +1292,7 @@ class Estimator(BaseEstimator):
     Raises:
       ValueError: if an unrecognized export_type is requested.
     """
+    # pylint: enable=line-too-long
     if serving_input_fn is None:
       raise ValueError('serving_input_fn must be defined.')
 
@@ -1366,7 +1372,8 @@ class Estimator(BaseEstimator):
             signature_def_map=signature_def_map,
             assets_collection=ops.get_collection(
                 ops.GraphKeys.ASSET_FILEPATHS),
-            legacy_init_op=init_op)
+            legacy_init_op=init_op,
+            strip_default_attrs=strip_default_attrs)
 
     # pylint: disable=protected-access
     base_meta_graph_def = builder._saved_model.meta_graphs[0]
index 4b404a8e20e33a17a0d5f857e4220f90c7bc799f..03ec66b98b598f454be47cc5186c161dd207036e 100644 (file)
@@ -390,7 +390,9 @@ def make_export_strategy(serving_input_fn,
                          default_output_alternative_key=None,
                          assets_extra=None,
                          as_text=False,
-                         exports_to_keep=5):
+                         exports_to_keep=5,
+                         strip_default_attrs=False):
+  # pylint: disable=line-too-long
   """Create an ExportStrategy for use with Experiment.
 
   Args:
@@ -411,10 +413,14 @@ def make_export_strategy(serving_input_fn,
     exports_to_keep: Number of exports to keep.  Older exports will be
       garbage-collected.  Defaults to 5.  Set to None to disable garbage
       collection.
+    strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+      removed from the NodeDefs. For a detailed guide, see
+      [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
 
   Returns:
     An ExportStrategy that can be passed to the Experiment constructor.
   """
+  # pylint: enable=line-too-long
 
   def export_fn(estimator, export_dir_base, checkpoint_path=None):
     """Exports the given Estimator as a SavedModel.
@@ -443,7 +449,8 @@ def make_export_strategy(serving_input_fn,
           serving_input_fn,
           assets_extra=assets_extra,
           as_text=as_text,
-          checkpoint_path=checkpoint_path)
+          checkpoint_path=checkpoint_path,
+          strip_default_attrs=strip_default_attrs)
     else:
       export_result = estimator.export_savedmodel(
           export_dir_base,
@@ -451,7 +458,8 @@ def make_export_strategy(serving_input_fn,
           default_output_alternative_key=default_output_alternative_key,
           assets_extra=assets_extra,
           as_text=as_text,
-          checkpoint_path=checkpoint_path)
+          checkpoint_path=checkpoint_path,
+          strip_default_attrs=strip_default_attrs)
 
     garbage_collect_exports(export_dir_base, exports_to_keep)
     return export_result
@@ -464,7 +472,9 @@ def make_parsing_export_strategy(feature_columns,
                                  assets_extra=None,
                                  as_text=False,
                                  exports_to_keep=5,
-                                 target_core=False):
+                                 target_core=False,
+                                 strip_default_attrs=False):
+  # pylint: disable=line-too-long
   """Create an ExportStrategy for use with Experiment, using `FeatureColumn`s.
 
   Creates a SavedModel export that expects to be fed with a single string
@@ -492,10 +502,14 @@ def make_parsing_export_strategy(feature_columns,
     target_core: If True, prepare an ExportStrategy for use with
       tensorflow.python.estimator.*.  If False (default), prepare an
       ExportStrategy for use with tensorflow.contrib.learn.python.learn.*.
+    strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+      removed from the NodeDefs. For a detailed guide, see
+      [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
 
   Returns:
     An ExportStrategy that can be passed to the Experiment constructor.
   """
+  # pylint: enable=line-too-long
   feature_spec = feature_column.create_feature_spec_for_parsing(feature_columns)
   if target_core:
     serving_input_fn = (
@@ -508,7 +522,8 @@ def make_parsing_export_strategy(feature_columns,
       default_output_alternative_key=default_output_alternative_key,
       assets_extra=assets_extra,
       as_text=as_text,
-      exports_to_keep=exports_to_keep)
+      exports_to_keep=exports_to_keep,
+      strip_default_attrs=strip_default_attrs)
 
 
 def _default_compare_fn(curr_best_eval_result, cand_eval_result):
@@ -584,7 +599,9 @@ class BestModelSelector(object):
 def make_best_model_export_strategy(serving_input_fn,
                                     exports_to_keep=1,
                                     compare_fn=None,
-                                    default_output_alternative_key=None):
+                                    default_output_alternative_key=None,
+                                    strip_default_attrs=False):
+  # pylint: disable=line-too-long
   """Creates an custom ExportStrategy for use with tf.contrib.learn.Experiment.
 
   Args:
@@ -596,14 +613,19 @@ def make_best_model_export_strategy(serving_input_fn,
         of evaluation result keyed by corresponding checkpoint path.
     default_output_alternative_key: the key for default serving signature for
         multi-headed inference graphs.
+    strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+      removed from the NodeDefs. For a detailed guide, see
+      [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
 
   Returns:
     An ExportStrategy that can be passed to the Experiment constructor.
   """
+  # pylint: enable=line-too-long
   best_model_export_strategy = make_export_strategy(
       serving_input_fn,
       exports_to_keep=exports_to_keep,
-      default_output_alternative_key=default_output_alternative_key)
+      default_output_alternative_key=default_output_alternative_key,
+      strip_default_attrs=strip_default_attrs)
 
   best_model_selector = BestModelSelector(compare_fn)
 
index 628eb254c3b1129648c453dc47f0c0919891de6f..531d9c672bff6b244365c8b47f532f9ced333fc5 100644 (file)
@@ -55,7 +55,8 @@ class TestEstimator(core_estimator.Estimator):
                         default_output_alternative_key=None,
                         assets_extra=None,
                         as_text=False,
-                        checkpoint_path=None):
+                        checkpoint_path=None,
+                        strip_default_attrs=False):
 
     if not os.path.exists(export_dir):
       os.makedirs(export_dir)
index 47ec2aa1efeb11135b95b3b2c4342b77f0a9866b..fd86c0da12b26cf5ed8a7846d159dd6feb4ddc4e 100644 (file)
@@ -61,6 +61,10 @@ message MetaGraphDef {
     // graph. This will be populated by the framework, which will overwrite any
     // user supplied value.
     string tensorflow_git_version = 6;
+
+    // A flag to denote whether default-valued attrs have been stripped from
+    // the nodes in this graph_def.
+    bool stripped_default_attrs = 7;
   }
   MetaInfoDef meta_info_def = 1;
 
index d8e7df48c2d9f7023f15ffab7a62ccafbce4458d..c037a9b122cff43abb27810703e3bf6fe51486d6 100644 (file)
@@ -91,10 +91,13 @@ limitations under the License.
 // 24. Deprecate lookup ops (v1) ops in favor of v2 (30may2017)
 // 25. Deprecate stack (v1) ops in favor of v2 (2017/6/15).
 // 25. Deprecate RandomPoisson (v1) ops in favor of v2 (2017/10/25).
+// 26. Add a bool 'stripped_default_attrs' to MetaInfoDef indicating
+//     whether default-valued attrs have been stripped from the nodes in the
+//     GraphDef. (7dec2017)
 
 #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
-#define TF_GRAPH_DEF_VERSION 24
+#define TF_GRAPH_DEF_VERSION 25
 
 // Checkpoint compatibility versions (the versions field in SavedSliceMeta).
 //
index 3f1d63a54375879370dd93ab1ebaae6d5f765e09..1fd488e7b6388f7953a279dca8f93ab57a85f63d 100644 (file)
@@ -485,6 +485,7 @@ TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{
 %unignore tensorflow;
 %unignore TF_Run;
 %unignore EqualGraphDefWrapper;
+%unignore EqualAttrValueWrapper;
 
 // Include the wrapper for TF_PRunSetup from tf_session_helper.h.
 
index 2b83141faa0302eddb875b059ab69fc523a29c4c..361dbc22b097a9bc82f656d7416b88c4a3a1ec2d 100644 (file)
@@ -21,10 +21,13 @@ limitations under the License.
 #include "tensorflow/c/c_api_internal.h"
 #include "tensorflow/c/tf_status_helper.h"
 #include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/attr_value_util.h"
 #include "tensorflow/core/framework/log_memory.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/graph/tensor_id.h"
 #include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/util/equal_graph_def.h"
 #include "tensorflow/python/lib/core/ndarray_tensor.h"
@@ -301,6 +304,27 @@ string EqualGraphDefWrapper(const string& actual, const string& expected) {
   return EqualGraphDef(actual_def, expected_def, &diff) ? "" : diff;
 }
 
+string EqualAttrValueWrapper(const string& actual, const string& expected) {
+  AttrValue actual_attr_value;
+  if (!actual_attr_value.ParseFromString(actual)) {
+    return "actual is not a valid serialized AttrValue";
+  }
+
+  AttrValue expected_attr_value;
+  if (!expected_attr_value.ParseFromString(expected)) {
+    return "expected is not a valid serialized AttrValue";
+  }
+
+  string diff;
+  if (!AreAttrValuesEqual(actual_attr_value, expected_attr_value)) {
+    diff = strings::Printf(
+        "Actual AttrValue %s does not match Expected AttrValue %s.",
+        SummarizeAttrValue(actual_attr_value).c_str(),
+        SummarizeAttrValue(expected_attr_value).c_str());
+  }
+  return diff;
+}
+
 // Return value set to 6 inlined elements so it fits in a 64-byte cache line.
 tensorflow::gtl::InlinedVector<int64_t, 6> TF_GraphGetTensorShapeHelper(
     TF_Graph* graph, TF_Output output, TF_Status* out_status,
index 8f2499b9a05b5f133ddab0a4d2e0315154ed53ed..29d5b28f40a7c07c199eec8c8cd85de626f6b068 100644 (file)
@@ -97,6 +97,13 @@ void TF_Reset_wrapper(const TF_SessionOptions* opt,
 // for no difference.
 string EqualGraphDefWrapper(const string& actual, const string& expected);
 
+// Convenience wrapper around AreAttrValuesEqual to make it easier to wrap.
+// The actual and expected strings must correspond to a serialized binary
+// representation of two AttrValue proto instances.
+// Returns an explanation if a difference is found, or the empty string
+// for no difference.
+string EqualAttrValueWrapper(const string& actual, const string& expected);
+
 // Gets shape from C API Graph object.
 //
 // If shape is known, returns shape vector where -1 means "unknown
index 1e3d6d5755fa4ea31ad4957dd7124343deaaee57..c72d37b44208fc40a0c9f476a1b1c06db1bd8084 100644 (file)
@@ -461,7 +461,8 @@ class Estimator(object):
       self, export_dir_base, serving_input_receiver_fn,
       assets_extra=None,
       as_text=False,
-      checkpoint_path=None):
+      checkpoint_path=None,
+      strip_default_attrs=False):
     # pylint: disable=line-too-long
     """Exports inference graph as a SavedModel into given dir.
 
@@ -503,6 +504,9 @@ class Estimator(object):
       as_text: whether to write the SavedModel proto in text format.
       checkpoint_path: The checkpoint path to export.  If `None` (the default),
         the most recent checkpoint found within the model directory is chosen.
+      strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+        removed from the NodeDefs. For a detailed guide, see
+        [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
 
     Returns:
       The string path to the exported directory.
@@ -563,7 +567,8 @@ class Estimator(object):
             signature_def_map=signature_def_map,
             assets_collection=ops.get_collection(
                 ops.GraphKeys.ASSET_FILEPATHS),
-            legacy_init_op=local_init_op)
+            legacy_init_op=local_init_op,
+            strip_default_attrs=strip_default_attrs)
         builder.save(as_text)
 
       # Add the extra assets
index db64fbc9ccc3a212e7dfa1ad4d82e3138e3a3d56..58d0cb0018ea618fe872da38a92f08b37fd4a0db 100644 (file)
@@ -40,6 +40,7 @@ from tensorflow.python.estimator.inputs import numpy_io
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
 from tensorflow.python.layers import layers
 from tensorflow.python.lib.io import file_io
 from tensorflow.python.ops import array_ops
@@ -57,6 +58,7 @@ from tensorflow.python.platform import gfile
 from tensorflow.python.platform import test
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.saved_model import loader
+from tensorflow.python.saved_model import loader_impl
 from tensorflow.python.saved_model import tag_constants
 from tensorflow.python.summary import summary
 from tensorflow.python.summary import summary_iterator
@@ -2050,6 +2052,65 @@ class EstimatorExportTest(test.TestCase):
 
     gfile.DeleteRecursively(tmpdir)
 
+  def test_export_savedmodel_proto_strip_default_attrs(self):
+    tmpdir = tempfile.mkdtemp()
+    est = estimator.Estimator(model_fn=_model_fn_for_export_tests)
+    est.train(input_fn=dummy_input_fn, steps=1)
+    feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64),
+                    'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)}
+    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
+        feature_spec)
+
+    # Perform the export.
+    export_dir_base = os.path.join(
+        compat.as_bytes(tmpdir), compat.as_bytes('export'))
+    export_dir_stripped = est.export_savedmodel(
+        export_dir_base, serving_input_receiver_fn, strip_default_attrs=True)
+    export_dir_not_stripped = est.export_savedmodel(
+        export_dir_base, serving_input_receiver_fn, strip_default_attrs=False)
+
+    # Load the SavedModel from disk as-is to verify default attrs
+    # are stripped. Reimporting the SavedModel via the loader causes the
+    # default attrs to be populated in the NodeDefs.
+
+    # pylint: disable=protected-access
+    saved_model_stripped_pb = loader_impl._parse_saved_model(
+        export_dir_stripped)
+    saved_model_not_stripped_pb = loader_impl._parse_saved_model(
+        export_dir_not_stripped)
+    self.assertIsNotNone(saved_model_stripped_pb)
+    self.assertIsNotNone(saved_model_not_stripped_pb)
+    # pylint: enable=protected-access
+
+    meta_graph_def_stripped = [
+        x for x in saved_model_stripped_pb.meta_graphs
+        if x.meta_info_def.tags == [tag_constants.SERVING]][0]
+    meta_graph_def_not_stripped = [
+        x for x in saved_model_not_stripped_pb.meta_graphs
+        if x.meta_info_def.tags == [tag_constants.SERVING]][0]
+
+    # "weight" node in graph is a "Variable" Op with 2 default valued attrs.
+    #   o "container"    : "".
+    #   o "shared_name"  : "".
+
+    # saved_model_stripped_pb was exported with strip_default_attrs set to True.
+    # "weight" node shouldn't have attributes "container" and "shared_name".
+    node_def = test_util.get_node_def_from_graph(
+        'weight', meta_graph_def_stripped.graph_def)
+    self.assertNotIn('container', node_def.attr)
+    self.assertNotIn('shared_name', node_def.attr)
+
+    # saved_model_not_stripped_pb was exported with strip_default_attrs
+    # disabled. "weight" node should have attributes "container" and
+    # "shared_name".
+    node_def = test_util.get_node_def_from_graph(
+        'weight', meta_graph_def_not_stripped.graph_def)
+    self.assertIn('container', node_def.attr)
+    self.assertIn('shared_name', node_def.attr)
+
+    # Clean up.
+    gfile.DeleteRecursively(tmpdir)
+
 
 class EstimatorHookOrderingTest(test.TestCase):
 
index c839d7a9a693a4e1201c558173662fd24b5036dd..65032637d4fdea91c2ec96d63499b5122b21f2f1 100644 (file)
@@ -31,6 +31,7 @@ from tensorflow.core.framework import graph_pb2
 from tensorflow.core.framework import op_def_pb2
 from tensorflow.core.protobuf import meta_graph_pb2
 from tensorflow.core.protobuf import saver_pb2
+from tensorflow.python import pywrap_tensorflow
 from tensorflow.python.eager import context
 from tensorflow.python.framework import graph_io
 from tensorflow.python.framework import importer
@@ -442,6 +443,67 @@ def add_collection_def(meta_graph_def, key, graph=None,
     return
 
 
+def _is_default_attr_value(op_def, attr_name, attr_value):
+  """Checks if given attribute matches the default value in the op def."""
+  for attr_def in op_def.attr:
+    if attr_def.name == attr_name:
+      if not attr_def.HasField("default_value"):
+        return False
+      # pywrap_tensorflow.EqualAttrValueWrapper returns an empty string
+      # if both arguments represent an equivalent AttrValue instance.
+      return not pywrap_tensorflow.EqualAttrValueWrapper(
+          attr_value.SerializeToString(),
+          attr_def.default_value.SerializeToString())
+  return False
+
+
+def _strip_graph_default_valued_attrs(meta_graph_def):
+  """Strips default valued attributes for node defs in given MetaGraphDef.
+
+  This method also sets `meta_info_def.stripped_default_attrs` in the given
+  `MetaGraphDef` proto to True.
+
+  Args:
+    meta_graph_def: `MetaGraphDef` protocol buffer
+
+  Returns:
+    None.
+  """
+  # Map function op names to their function definitions.
+  op_name_to_function = {}
+  for function_def in meta_graph_def.graph_def.library.function:
+    op_name_to_function[function_def.signature.name] = function_def
+
+  # Get all registered ops.
+  registered_ops = op_def_registry.get_registered_ops()
+
+  def _strip_node_default_valued_attrs(node_def):
+    """Removes default valued attributes from a single node def."""
+    if node_def.op in op_name_to_function or node_def.op not in registered_ops:
+      return
+    op_def = registered_ops[node_def.op]
+
+    attrs_to_strip = set()
+    for attr_name, attr_value in node_def.attr.items():
+      if _is_default_attr_value(op_def, attr_name, attr_value):
+        attrs_to_strip.add(attr_name)
+
+    for attr in attrs_to_strip:
+      del node_def.attr[attr]
+
+  # Process all NodeDef instances in graph_def.
+  for node_def in meta_graph_def.graph_def.node:
+    _strip_node_default_valued_attrs(node_def)
+
+  # Process all NodeDef instances in graph_def.library.function.
+  for function_def in meta_graph_def.graph_def.library.function:
+    for function_node_def in function_def.node_def:
+      _strip_node_default_valued_attrs(function_node_def)
+
+  # Tell consumers of this graph that default valued attrs have been stripped.
+  meta_graph_def.meta_info_def.stripped_default_attrs = True
+
+
 def create_meta_graph_def(meta_info_def=None,
                           graph_def=None,
                           saver_def=None,
@@ -449,7 +511,9 @@ def create_meta_graph_def(meta_info_def=None,
                           graph=None,
                           export_scope=None,
                           exclude_nodes=None,
-                          clear_extraneous_savers=False):
+                          clear_extraneous_savers=False,
+                          strip_default_attrs=False):
+  # pylint: disable=line-too-long
   """Construct and returns a `MetaGraphDef` protocol buffer.
 
   Args:
@@ -464,12 +528,17 @@ def create_meta_graph_def(meta_info_def=None,
     clear_extraneous_savers: Remove any preexisting SaverDefs from the SAVERS
         collection.  Note this method does not alter the graph, so any
         extraneous Save/Restore ops should have been removed already, as needed.
+    strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+        removed from the NodeDefs. For a detailed guide, see
+        [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
+
   Returns:
     MetaGraphDef protocol buffer.
 
   Raises:
     TypeError: If the arguments are not of the correct proto buffer type.
   """
+  # pylint: enable=line-too-long
   # Type check.
   if graph and not isinstance(graph, ops.Graph):
     raise TypeError("graph must be of type Graph, not %s", type(graph))
@@ -511,6 +580,10 @@ def create_meta_graph_def(meta_info_def=None,
         stripped_op_list_for_graph(meta_graph_def.graph_def))
   # pylint: enable=g-explicit-length-test
 
+  # Strip default valued attributes in graph_def.
+  if strip_default_attrs:
+    _strip_graph_default_valued_attrs(meta_graph_def)
+
   # Adds saver_def.
   if saver_def:
     meta_graph_def.saver_def.MergeFrom(saver_def)
@@ -724,6 +797,7 @@ def export_scoped_meta_graph(filename=None,
                              clear_devices=False,
                              saver_def=None,
                              clear_extraneous_savers=False,
+                             strip_default_attrs=False,
                              **kwargs):
   """Returns `MetaGraphDef` proto. Optionally writes it to filename.
 
@@ -752,6 +826,8 @@ def export_scoped_meta_graph(filename=None,
     clear_extraneous_savers: Remove any Saver-related information from the
         graph (both Save/Restore ops and SaverDefs) that are not associated
         with the provided SaverDef.
+    strip_default_attrs: Set to true if default valued attributes must be
+        removed while exporting the GraphDef.
     **kwargs: Optional keyed arguments, including meta_info_def and
         collection_list.
 
@@ -837,6 +913,7 @@ def export_scoped_meta_graph(filename=None,
       exclude_nodes=exclude_nodes,
       clear_extraneous_savers=clear_extraneous_savers,
       saver_def=saver_def,
+      strip_default_attrs=strip_default_attrs,
       **kwargs)
 
   if filename:
@@ -881,3 +958,5 @@ def copy_scoped_meta_graph(from_scope, to_scope,
                                       graph=to_graph,
                                       import_scope=to_scope)
   return var_list
+
+
index 4c22c913b850685bd6e50b03b5fbb09a01441b68..ae8c9ea2a4d928271921515c88b86138a23a7d51 100644 (file)
@@ -24,6 +24,7 @@ import random
 import shutil
 
 from tensorflow.core.framework import graph_pb2
+from tensorflow.core.protobuf import meta_graph_pb2
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import function
@@ -154,6 +155,108 @@ class SimpleMetaGraphTest(test.TestCase):
     op_list = meta_graph.stripped_op_list_for_graph(graph)
     self.assertEqual(["Const"], [op.name for op in op_list.op])
 
+  def testDefaultAttrStripping(self):
+    """Verifies that default attributes are stripped from a graph def."""
+
+    # Complex Op has 2 attributes with defaults:
+    #   o "T"    : float32.
+    #   o "Tout" : complex64.
+
+    # When inputs to the Complex Op are float32 instances, "T" maps to float32
+    # and "Tout" maps to complex64. Since these attr values map to their
+    # defaults, they must be stripped unless stripping of default attrs is
+    # disabled.
+    with self.test_session():
+      real_num = constant_op.constant(1.0, dtype=dtypes.float32, name="real")
+      imag_num = constant_op.constant(2.0, dtype=dtypes.float32, name="imag")
+      math_ops.complex(real_num, imag_num, name="complex")
+
+      # strip_default_attrs is enabled.
+      meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
+          graph_def=ops.get_default_graph().as_graph_def(),
+          strip_default_attrs=True)
+      node_def = test_util.get_node_def_from_graph("complex",
+                                                   meta_graph_def.graph_def)
+      self.assertNotIn("T", node_def.attr)
+      self.assertNotIn("Tout", node_def.attr)
+      self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
+
+      # strip_default_attrs is disabled.
+      meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
+          graph_def=ops.get_default_graph().as_graph_def(),
+          strip_default_attrs=False)
+      node_def = test_util.get_node_def_from_graph("complex",
+                                                   meta_graph_def.graph_def)
+      self.assertIn("T", node_def.attr)
+      self.assertIn("Tout", node_def.attr)
+      self.assertFalse(meta_graph_def.meta_info_def.stripped_default_attrs)
+
+    # When inputs to the Complex Op are float64 instances, "T" maps to float64
+    # and "Tout" maps to complex128. Since these attr values don't map to their
+    # defaults, they must not be stripped.
+    with self.test_session(graph=ops.Graph()):
+      real_num = constant_op.constant(1.0, dtype=dtypes.float64, name="real")
+      imag_num = constant_op.constant(2.0, dtype=dtypes.float64, name="imag")
+      math_ops.complex(real_num, imag_num, name="complex")
+      meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
+          graph_def=ops.get_default_graph().as_graph_def(),
+          strip_default_attrs=True)
+      node_def = test_util.get_node_def_from_graph("complex",
+                                                   meta_graph_def.graph_def)
+      self.assertEqual(node_def.attr["T"].type, dtypes.float64)
+      self.assertEqual(node_def.attr["Tout"].type, dtypes.complex128)
+      self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
+
+  def testDefaultAttrStrippingNestedFunctions(self):
+    """Verifies that default attributes are stripped from function node defs."""
+    with self.test_session():
+      @function.Defun(dtypes.float32, dtypes.float32)
+      def f0(i, j):
+        return math_ops.complex(i, j, name="double_nested_complex")
+
+      @function.Defun(dtypes.float32, dtypes.float32)
+      def f1(i, j):
+        return f0(i, j)
+
+      _ = f1(constant_op.constant(1.0), constant_op.constant(2.0))
+      meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
+          graph_def=ops.get_default_graph().as_graph_def(),
+          strip_default_attrs=True)
+
+      double_nested_complex_node_def = None
+      for function_def in meta_graph_def.graph_def.library.function:
+        for node_def in function_def.node_def:
+          if node_def.name == "double_nested_complex":
+            double_nested_complex_node_def = node_def
+            break
+        if double_nested_complex_node_def:
+          break
+
+      self.assertIsNotNone(double_nested_complex_node_def)
+      self.assertNotIn("T", double_nested_complex_node_def.attr)
+      self.assertNotIn("Tout", double_nested_complex_node_def.attr)
+      self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
+
+  def testDefaultAttrStrippingUnregisteredOps(self):
+    """Verifies that nodes with un-registered ops are not stripped."""
+    graph_def = graph_pb2.GraphDef()
+    node = graph_def.node.add()
+    node.name = "node_with_unreg_op"
+    node.op = "unreg_op"
+    node.attr["attr_1"].i = 1
+
+    meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef()
+    meta_info_def.stripped_op_list.op.add()
+
+    with self.test_session():
+      meta_graph_def = meta_graph.create_meta_graph_def(
+          meta_info_def=meta_info_def, graph_def=graph_def,
+          strip_default_attrs=True)
+      node_def = test_util.get_node_def_from_graph("node_with_unreg_op",
+                                                   meta_graph_def.graph_def)
+      self.assertEqual(node_def.attr["attr_1"].i, 1)
+      self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
+
 
 class ScopedMetaGraphTest(test.TestCase):
 
index 7627fb3e69d5e8c71363c6f2dff24a069b139f42..5ac30537499166f24f29ba0db3cc3feada695aa4 100644 (file)
@@ -1369,3 +1369,21 @@ def create_local_cluster(num_workers, num_ps, protocol="grpc",
   ]
 
   return workers, ps_servers
+
+
+def get_node_def_from_graph(node_name, graph_def):
+  """Returns the `NodeDef` instance for given node name in the graph def.
+
+  This method explores only the NodeDefs in `graph_def.node`.
+
+  Args:
+    node_name: Name of the NodeDef to search for.
+    graph_def: An instance of `GraphDef` proto.
+
+  Returns:
+    the `NodeDef` instance whose name field matches the given node_name or None.
+  """
+  for node_def in graph_def.node:
+    if node_def.name == node_name:
+      return node_def
+  return None
index f6aed118ca478daf1f1926ec9d9653015194cab3..4af717cca6547ad452a86b82e2a0b88bbae13bd4 100644 (file)
@@ -349,6 +349,13 @@ class TestUtilTest(test_util.TensorFlowTestCase):
 
     self.assertEqual(expected, self.evaluate(nested))
 
+  def test_get_node_def_from_graph(self):
+    graph_def = graph_pb2.GraphDef()
+    node_foo = graph_def.node.add()
+    node_foo.name = "foo"
+    self.assertIs(test_util.get_node_def_from_graph("foo", graph_def), node_foo)
+    self.assertIsNone(test_util.get_node_def_from_graph("bar", graph_def))
+
 
 class GarbageCollectionTest(test_util.TensorFlowTestCase):
 
index 8c78013ffd25feda7315657bfe070f8243959ae1..5eeaf73a4370b0558a2c11d17a3546171b886a69 100644 (file)
@@ -117,6 +117,35 @@ with tf.Session(graph=tf.Graph()) as sess:
 builder.save()
 ~~~
 
+#### Stripping Default valued attributes
+The SavedModelBuilder class allows users to control whether default-valued
+attributes must be stripped from the NodeDefs while adding a meta graph to the
+SavedModel bundle. Both `SavedModelBuilder.add_meta_graph_and_variables` and
+`SavedModelBuilder.add_meta_graph` methods accept a Boolean flag
+`strip_default_attrs` that controls this behavior.
+
+If `strip_default_attrs` is `False`, the exported MetaGraphDef will have the
+default valued attributes in all it's NodeDef instances. This can break forward
+compatibility with a sequence of events such as the following:
+
+* An existing Op (`Foo`) is updated to include a new attribute (`T`) with a
+  default (`bool`) at version 101.
+* A model producer (such as a Trainer) binary picks up this change
+  (version 101) to the OpDef and re-exports an existing model that uses Op `Foo`.
+* A model consumer (such as Tensorflow Serving) running an older binary
+  (version 100) doesn't have attribute `T` for Op `Foo`, but tries to import
+  this model. The model consumer doesn't recognize attribute `T` in a NodeDef
+  that uses Op `Foo` and therefore fails to load the model.
+
+By setting `strip_default_attrs` to `True`, the model producers can strip away
+any default valued attributes in the NodeDefs. This helps ensure that newly
+added attributes with defaults don't cause older model consumers to fail loading
+models regenerated with newer training binaries.
+
+TIP: If you care about forward compatibility, then set `strip_default_attrs`
+to `True` while using `SavedModelBuilder.add_meta_graph_and_variables` and
+`SavedModelBuilder.add_meta_graph`.
+
 ### Loader
 The SavedModel loader is implemented in C++ and Python.
 
index 16651ffebc5f5911d7c270425f599036a8e80e0c..62ee53b816c2a38327fa116d2924446e6bf24a1e 100644 (file)
@@ -239,7 +239,9 @@ class SavedModelBuilder(object):
                      assets_collection=None,
                      legacy_init_op=None,
                      clear_devices=False,
-                     main_op=None):
+                     main_op=None,
+                     strip_default_attrs=False):
+    # pylint: disable=line-too-long
     """Adds the current meta graph to the SavedModel.
 
     Creates a Saver in the current scope and uses the Saver to export the meta
@@ -260,11 +262,15 @@ class SavedModelBuilder(object):
       main_op: Op or group of ops to execute when the graph is loaded. Note
           that when the main_op is specified it is run after the restore op at
           load-time.
+      strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+        removed from the NodeDefs. For a detailed guide, see
+        [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
 
     Raises:
       AssertionError: If the variables for the SavedModel have not been saved
           yet, or if the graph already contains one or more legacy init ops.
     """
+    # pylint: enable=line-too-long
     if not self._has_saved_variables:
       raise AssertionError(
           "Graph state including variables and assets has not been saved yet. "
@@ -299,7 +305,8 @@ class SavedModelBuilder(object):
     # there are edge cases where that option breaks the graph.  Until that is
     # resolved, we just leave the option set to False for now.
     # TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
-    meta_graph_def = saver.export_meta_graph(clear_devices=clear_devices)
+    meta_graph_def = saver.export_meta_graph(
+        clear_devices=clear_devices, strip_default_attrs=strip_default_attrs)
 
     # Tag the meta graph def and add it to the SavedModel.
     self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
@@ -311,7 +318,9 @@ class SavedModelBuilder(object):
                                    assets_collection=None,
                                    legacy_init_op=None,
                                    clear_devices=False,
-                                   main_op=None):
+                                   main_op=None,
+                                   strip_default_attrs=False):
+    # pylint: disable=line-too-long
     """Adds the current meta graph to the SavedModel and saves variables.
 
     Creates a Saver to save the variables from the provided session. Exports the
@@ -334,7 +343,11 @@ class SavedModelBuilder(object):
       main_op: Op or group of ops to execute when the graph is loaded. Note
           that when the main_op is specified it is run after the restore op at
           load-time.
+      strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+        removed from the NodeDefs. For a detailed guide, see
+        [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
     """
+    # pylint: enable=line-too-long
     if self._has_saved_variables:
       raise AssertionError("Graph state including variables and assets has "
                            "already been saved. Please invoke "
@@ -388,7 +401,8 @@ class SavedModelBuilder(object):
     # there are edge cases where that option breaks the graph.  Until that is
     # resolved, we just leave the option set to False for now.
     # TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
-    meta_graph_def = saver.export_meta_graph(clear_devices=clear_devices)
+    meta_graph_def = saver.export_meta_graph(
+        clear_devices=clear_devices, strip_default_attrs=strip_default_attrs)
 
     # Tag the meta graph def and add it to the SavedModel.
     self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
index 92ca7dec6f63b50b33dde9909b4738676fb8c783..1ea619ff55dea00f8ee09024ab45dcd324a2ddce 100644 (file)
@@ -20,13 +20,17 @@ from __future__ import print_function
 
 import os
 
+from tensorflow.core.framework import op_def_pb2
 from tensorflow.core.framework import types_pb2
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.core.protobuf import meta_graph_pb2
 from tensorflow.python.client import session
 from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
+from tensorflow.python.framework import op_def_registry
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
 from tensorflow.python.lib.io import file_io
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
@@ -36,6 +40,7 @@ from tensorflow.python.platform import test
 from tensorflow.python.saved_model import builder as saved_model_builder
 from tensorflow.python.saved_model import constants
 from tensorflow.python.saved_model import loader
+from tensorflow.python.saved_model import loader_impl
 from tensorflow.python.saved_model import main_op
 from tensorflow.python.saved_model import signature_def_utils
 from tensorflow.python.saved_model import tag_constants
@@ -865,6 +870,132 @@ class SavedModelTest(test.TestCase):
       self.assertEqual(
           42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
 
+  def testStripDefaultAttrs(self):
+    export_dir = os.path.join(test.get_temp_dir(), "test_strip_default_attrs")
+    builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+    # Add a graph with two float32 variables and a Complex Op composing them
+    # with strip_default_attrs enabled.
+    with session.Session(graph=ops.Graph()) as sess:
+      real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
+      imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
+      math_ops.complex(real_num, imag_num, name="complex")
+      sess.run(variables.global_variables_initializer())
+      builder.add_meta_graph_and_variables(
+          sess, ["foo"], strip_default_attrs=True)
+
+    # Add a graph with the same float32 variables and a Complex Op composing
+    # them with strip_default_attrs disabled.
+    with session.Session(graph=ops.Graph()) as sess:
+      real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
+      imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
+      math_ops.complex(real_num, imag_num, name="complex")
+      sess.run(variables.global_variables_initializer())
+      builder.add_meta_graph(["bar"], strip_default_attrs=False)
+
+    # Save the SavedModel to disk in text format.
+    builder.save(as_text=True)
+
+    # Loading graph "foo" via the loader must restore the defaults for the
+    # "Complex" node based on the "Complex" OpDef in the Op registry.
+    sess = session.Session(graph=ops.Graph())
+    meta_graph_def = loader.load(sess, ["foo"], export_dir)
+    complex_node = test_util.get_node_def_from_graph("complex",
+                                                     meta_graph_def.graph_def)
+    self.assertIn("T", complex_node.attr)
+    self.assertIn("Tout", complex_node.attr)
+
+    # Load graph "foo" from disk as-is to verify default attrs are stripped.
+    # pylint: disable=protected-access
+    saved_model_pb = loader_impl._parse_saved_model(export_dir)
+    self.assertIsNotNone(saved_model_pb)
+    # pylint: enable=protected-access
+
+    meta_graph_foo_def = None
+    meta_graph_bar_def = None
+    for meta_graph_def in saved_model_pb.meta_graphs:
+      if set(meta_graph_def.meta_info_def.tags) == set(["foo"]):
+        meta_graph_foo_def = meta_graph_def
+      elif set(meta_graph_def.meta_info_def.tags) == set(["bar"]):
+        meta_graph_bar_def = meta_graph_def
+
+    self.assertIsNotNone(meta_graph_foo_def)
+    self.assertIsNotNone(meta_graph_bar_def)
+
+    # "Complex" Op has 2 attributes with defaults:
+    #   o "T"    : float32.   (input type)
+    #   o "Tout" : complex64. (output type)
+
+    # "Complex" Op in graph "foo" shouldn't have attributes "T" and "Tout".
+    # Graph "foo" was saved with strip_default_attrs set to True.
+    node_def = test_util.get_node_def_from_graph("complex",
+                                                 meta_graph_foo_def.graph_def)
+    self.assertNotIn("T", node_def.attr)
+    self.assertNotIn("Tout", node_def.attr)
+
+    # "Complex" Op in graph "bar" must have attributes "T" and "Tout".
+    # Graph "bar" was saved with strip_default_attrs set to False.
+    node_def = test_util.get_node_def_from_graph("complex",
+                                                 meta_graph_bar_def.graph_def)
+    self.assertIn("T", node_def.attr)
+    self.assertIn("Tout", node_def.attr)
+
+  def testStripDefaultAttrsInconsistentConsumerDefaults(self):
+    export_dir = os.path.join(test.get_temp_dir(),
+                              "test_strip_default_attrs_no_consumer_defaults")
+    builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+    # Add a graph with two float32 variables and a Complex Op composing them
+    # with strip_default_attrs enabled. This must remove the following
+    # defaults for the "Complex" Op:
+    #   o "T"    : float32.   (input type)
+    #   o "Tout" : complex64. (output type)
+    with session.Session(graph=ops.Graph()) as sess:
+      real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
+      imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
+      math_ops.complex(real_num, imag_num, name="complex")
+      sess.run(variables.global_variables_initializer())
+      builder.add_meta_graph_and_variables(
+          sess, ["foo"], strip_default_attrs=True)
+
+    # Save the SavedModel to disk in text format.
+    builder.save(as_text=True)
+
+    # Update the Op registry to remove defaults for all attrs("T", "Tout") from
+    # the "Complex" OpDef.
+    complex_op_def = op_def_registry.get_registered_ops()["Complex"]
+    original_complex_op_def = op_def_pb2.OpDef()
+    original_complex_op_def.CopyFrom(complex_op_def)
+    for attr_def in complex_op_def.attr:
+      attr_def.ClearField("default_value")
+
+    # Loading the SavedModel via the loader must fail because the SavedModel
+    # does not have any attr values for the "Complex" node and the current
+    # op registry does not have have any default values for the "Complex" op.
+    sess = session.Session(graph=ops.Graph())
+    with self.assertRaisesRegexp(
+        ValueError,
+        "Expected one attr with name .*T(out)?.* in name: \"complex\".*"):
+      loader.load(sess, ["foo"], export_dir)
+
+    # Update the Op registry to change the defaults for attr "Tout"
+    # (complex64 -> complex128).
+    complex_op_def.CopyFrom(original_complex_op_def)
+    for attr_def in complex_op_def.attr:
+      if attr_def.name == "Tout":
+        attr_def.default_value.type = types_pb2.DT_COMPLEX128
+
+    # Loading the SavedModel via the loader must set "Tout" attr_value for the
+    # "Complex" node according to the latest defaults (complex128). This is
+    # expected to fail the model import as there is no OpKernel registered to
+    # handle attrs "T" (float32) and "Tout" (complex128).
+    sess = session.Session(graph=ops.Graph())
+    with self.assertRaisesRegexp(
+        errors.InvalidArgumentError,
+        ".*No OpKernel was registered to support Op \'Complex\' with these "
+        "attrs..*"):
+      loader.load(sess, ["foo"], export_dir)
+
 
 if __name__ == "__main__":
   test.main()
index ba6301e785947c8347ef23b81491e684bee62974..2330229d56cc80b837619ec5ddc8e7b87b4ef8bf 100644 (file)
@@ -1509,7 +1509,9 @@ class Saver(object):
            latest_filename=None,
            meta_graph_suffix="meta",
            write_meta_graph=True,
-           write_state=True):
+           write_state=True,
+           strip_default_attrs=False):
+    # pylint: disable=line-too-long
     """Saves variables.
 
     This method runs the ops added by the constructor for saving variables.
@@ -1535,6 +1537,9 @@ class Saver(object):
         graph file.
       write_state: `Boolean` indicating whether or not to write the
         `CheckpointStateProto`.
+      strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+        removed from the NodeDefs. For a detailed guide, see
+        [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
 
     Returns:
       A string: path prefix used for the checkpoint files.  If the saver is
@@ -1548,6 +1553,7 @@ class Saver(object):
         collides with `save_path`.
       RuntimeError: If save and restore ops weren't built.
     """
+    # pylint: enable=line-too-long
     if not self._is_built and context.in_graph_mode():
       raise RuntimeError(
           "`build()` should be called before save if defer_build==True")
@@ -1618,7 +1624,8 @@ class Saver(object):
           checkpoint_file, meta_graph_suffix=meta_graph_suffix)
       if context.in_graph_mode():
         with sess.graph.as_default():
-          self.export_meta_graph(meta_graph_filename)
+          self.export_meta_graph(
+              meta_graph_filename, strip_default_attrs=strip_default_attrs)
 
     if self._is_empty:
       return None
@@ -1631,7 +1638,9 @@ class Saver(object):
                         as_text=False,
                         export_scope=None,
                         clear_devices=False,
-                        clear_extraneous_savers=False):
+                        clear_extraneous_savers=False,
+                        strip_default_attrs=False):
+    # pylint: disable=line-too-long
     """Writes `MetaGraphDef` to save_path/filename.
 
     Args:
@@ -1644,10 +1653,14 @@ class Saver(object):
       clear_extraneous_savers: Remove any Saver-related information from the
         graph (both Save/Restore ops and SaverDefs) that are not associated
         with this Saver.
+      strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+        removed from the NodeDefs. For a detailed guide, see
+        [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
 
     Returns:
       A `MetaGraphDef` proto.
     """
+    # pylint: enable=line-too-long
     return export_meta_graph(
         filename=filename,
         graph_def=ops.get_default_graph().as_graph_def(add_shapes=True),
@@ -1656,7 +1669,8 @@ class Saver(object):
         as_text=as_text,
         export_scope=export_scope,
         clear_devices=clear_devices,
-        clear_extraneous_savers=clear_extraneous_savers)
+        clear_extraneous_savers=clear_extraneous_savers,
+        strip_default_attrs=strip_default_attrs)
 
   def restore(self, sess, save_path):
     """Restores previously saved variables.
@@ -1859,7 +1873,9 @@ def export_meta_graph(filename=None,
                       export_scope=None,
                       clear_devices=False,
                       clear_extraneous_savers=False,
+                      strip_default_attrs=False,
                       **kwargs):
+  # pylint: disable=line-too-long
   """Returns `MetaGraphDef` proto. Optionally writes it to filename.
 
   This function exports the graph, saver, and collection objects into
@@ -1885,6 +1901,9 @@ def export_meta_graph(filename=None,
     clear_extraneous_savers: Remove any Saver-related information from the
         graph (both Save/Restore ops and SaverDefs) that are not associated
         with the provided SaverDef.
+    strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+      removed from the NodeDefs. For a detailed guide, see
+      [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
     **kwargs: Optional keyed arguments.
 
   Returns:
@@ -1899,6 +1918,7 @@ def export_meta_graph(filename=None,
   execution is enabled.
   @end_compatibility
   """
+  # pylint: enable=line-too-long
   if context.in_eager_mode():
     raise RuntimeError("Exporting/importing meta graphs is not supported when "
                        "eager execution is enabled. No graph exists when eager "
@@ -1914,6 +1934,7 @@ def export_meta_graph(filename=None,
       export_scope=export_scope,
       clear_devices=clear_devices,
       clear_extraneous_savers=clear_extraneous_savers,
+      strip_default_attrs=strip_default_attrs,
       **kwargs)
   return meta_graph_def
 
index 207e4a28426f95af4d5947964cf9133be10bc0fa..0889ac251631a6d07eff0221feb3009576a970dd 100644 (file)
@@ -2065,6 +2065,42 @@ class MetaGraphTest(test.TestCase):
         self.assertEqual(o.summary, "")
         self.assertEqual(o.description, "")
 
+  def testStripDefaultValuedAttrs(self):
+    """Verifies that default valued attrs are stripped, unless disabled."""
+
+    # With strip_default_attrs enabled, attributes "T" (float32) and "Tout"
+    # (complex64) in the "Complex" op must be removed.
+    with self.test_session():
+      real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
+      imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
+      math_ops.complex(real_num, imag_num, name="complex")
+
+      save = saver_module.Saver({"real_num": real_num, "imag_num": imag_num})
+      variables.global_variables_initializer()
+
+      meta_graph_def = save.export_meta_graph(strip_default_attrs=True)
+      node_def = test_util.get_node_def_from_graph("complex",
+                                                   meta_graph_def.graph_def)
+      self.assertNotIn("T", node_def.attr)
+      self.assertNotIn("Tout", node_def.attr)
+
+    # With strip_default_attrs disabled, attributes "T" (float32) and "Tout"
+    # (complex64) in the "Complex" op must *not* be removed, even if they map
+    # to their defaults.
+    with self.test_session(graph=ops_lib.Graph()):
+      real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
+      imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
+      math_ops.complex(real_num, imag_num, name="complex")
+
+      save = saver_module.Saver({"real_num": real_num, "imag_num": imag_num})
+      variables.global_variables_initializer()
+
+      meta_graph_def = save.export_meta_graph(strip_default_attrs=False)
+      node_def = test_util.get_node_def_from_graph("complex",
+                                                   meta_graph_def.graph_def)
+      self.assertIn("T", node_def.attr)
+      self.assertIn("Tout", node_def.attr)
+
   def testImportIntoNamescope(self):
     # Test that we can import a meta graph into a namescope.
     test_dir = self._get_test_dir("import_into_namescope")
index ebf49f434ae468311a07374cdca1140336983a81..b0e983115499c5b5b79459affc931600ad16256b 100644 (file)
@@ -18,6 +18,10 @@ tf_class {
     name: "META_GRAPH_VERSION_FIELD_NUMBER"
     mtype: "<type \'int\'>"
   }
+  member {
+    name: "STRIPPED_DEFAULT_ATTRS_FIELD_NUMBER"
+    mtype: "<type \'int\'>"
+  }
   member {
     name: "STRIPPED_OP_LIST_FIELD_NUMBER"
     mtype: "<type \'int\'>"
index f5ed263f0e20d6fdf7f23a3a2ab06029084d20e4..ab697b1b95b15e3ac7974e7092f1d5934b088bb6 100644 (file)
@@ -29,7 +29,7 @@ tf_class {
   }
   member_method {
     name: "export_savedmodel"
-    argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+    argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
   }
   member_method {
     name: "get_variable_names"
index 61a29942c577a056e94dfe661fa5fec952b4f634..b73f6433e226f6b570b68c6a419c53d5c808d9d6 100644 (file)
@@ -29,7 +29,7 @@ tf_class {
   }
   member_method {
     name: "export_savedmodel"
-    argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+    argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
   }
   member_method {
     name: "get_variable_names"
index 16e3b246156792418109981cc85ce0b07854a62c..24db86c92be66f4ad609211e9b23bba1a63a047e 100644 (file)
@@ -29,7 +29,7 @@ tf_class {
   }
   member_method {
     name: "export_savedmodel"
-    argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+    argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
   }
   member_method {
     name: "get_variable_names"
index c6765ae277983eee54d0d998d6ad85c065460653..47ee2ac51b42cb0b2f03b1a6cf525d2793d5d7e8 100644 (file)
@@ -29,7 +29,7 @@ tf_class {
   }
   member_method {
     name: "export_savedmodel"
-    argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+    argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
   }
   member_method {
     name: "get_variable_names"
index e3a820db46e085d0aa61f76e2ffd6e32abbfd855..fbfaa69a1b9ecc10986f59dd8a50dff9d298e8d5 100644 (file)
@@ -29,7 +29,7 @@ tf_class {
   }
   member_method {
     name: "export_savedmodel"
-    argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+    argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
   }
   member_method {
     name: "get_variable_names"
index a4c8cf667179ba9863251469195cb75f1a60560e..faf55cda86471a9ab3c8ed4ff581ce0ee65d8bf3 100644 (file)
@@ -29,7 +29,7 @@ tf_class {
   }
   member_method {
     name: "export_savedmodel"
-    argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+    argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
   }
   member_method {
     name: "get_variable_names"
index 787952eced27532cbd8596e9aacb3ce5abd7fade..d0bf043754b60240c507fe34b21b0599b94b69e2 100644 (file)
@@ -28,7 +28,7 @@ tf_class {
   }
   member_method {
     name: "export_savedmodel"
-    argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+    argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
   }
   member_method {
     name: "get_variable_names"
index 99c03aa6297f4726970b83ad1f88924d320c5e33..6aec1d3a513cd103d00476b62c8f0fc2d4fcd766 100644 (file)
@@ -29,7 +29,7 @@ tf_class {
   }
   member_method {
     name: "export_savedmodel"
-    argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+    argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
   }
   member_method {
     name: "get_variable_names"
index e2ab96d5b46d9cdebc558e756ca26158fddb3f26..9d8c7bb1381efe123d30c0b30f1b8c868d919cf0 100644 (file)
@@ -29,7 +29,7 @@ tf_class {
   }
   member_method {
     name: "export_savedmodel"
-    argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+    argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
   }
   member_method {
     name: "get_variable_names"
index 56d76902fd0fe72ced6c0267295d9a9dc822a745..ca8e5884b18110d4293225e595c030e9629b5663 100644 (file)
@@ -8,11 +8,11 @@ tf_class {
   }
   member_method {
     name: "add_meta_graph"
-    argspec: "args=[\'self\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
+    argspec: "args=[\'self\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
   }
   member_method {
     name: "add_meta_graph_and_variables"
-    argspec: "args=[\'self\', \'sess\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
+    argspec: "args=[\'self\', \'sess\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
   }
   member_method {
     name: "save"
index 04c11712cd4c200bb2c04342e66924abf59c5f73..2cda458f468b2d748b43954b14b670df7145243f 100644 (file)
@@ -20,7 +20,7 @@ tf_class {
   }
   member_method {
     name: "export_meta_graph"
-    argspec: "args=[\'self\', \'filename\', \'collection_list\', \'as_text\', \'export_scope\', \'clear_devices\', \'clear_extraneous_savers\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\', \'False\', \'False\'], "
+    argspec: "args=[\'self\', \'filename\', \'collection_list\', \'as_text\', \'export_scope\', \'clear_devices\', \'clear_extraneous_savers\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\', \'False\', \'False\', \'False\'], "
   }
   member_method {
     name: "from_proto"
@@ -36,7 +36,7 @@ tf_class {
   }
   member_method {
     name: "save"
-    argspec: "args=[\'self\', \'sess\', \'save_path\', \'global_step\', \'latest_filename\', \'meta_graph_suffix\', \'write_meta_graph\', \'write_state\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'meta\', \'True\', \'True\'], "
+    argspec: "args=[\'self\', \'sess\', \'save_path\', \'global_step\', \'latest_filename\', \'meta_graph_suffix\', \'write_meta_graph\', \'write_state\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'meta\', \'True\', \'True\', \'False\'], "
   }
   member_method {
     name: "set_last_checkpoints"
index 3ffc6407306b4e44ec23052187b6f9376bba833c..b2ef17b39e71380dce4c65df4d36a3f76e198c04 100644 (file)
@@ -282,7 +282,7 @@ tf_module {
   }
   member_method {
     name: "export_meta_graph"
-    argspec: "args=[\'filename\', \'meta_info_def\', \'graph_def\', \'saver_def\', \'collection_list\', \'as_text\', \'graph\', \'export_scope\', \'clear_devices\', \'clear_extraneous_savers\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'False\', \'False\'], "
+    argspec: "args=[\'filename\', \'meta_info_def\', \'graph_def\', \'saver_def\', \'collection_list\', \'as_text\', \'graph\', \'export_scope\', \'clear_devices\', \'clear_extraneous_savers\', \'strip_default_attrs\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'False\', \'False\', \'False\'], "
   }
   member_method {
     name: "generate_checkpoint_state_proto"