Move warm_starting_util from third_party/tensorflow/python/estimator to third_party...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 9 Mar 2018 20:39:41 +0000 (12:39 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 9 Mar 2018 20:43:57 +0000 (12:43 -0800)
PiperOrigin-RevId: 188522820

14 files changed:
tensorflow/python/BUILD
tensorflow/python/estimator/BUILD
tensorflow/python/estimator/canned/dnn_linear_combined_test.py
tensorflow/python/estimator/canned/dnn_testing_utils.py
tensorflow/python/estimator/canned/linear_testing_utils.py
tensorflow/python/estimator/estimator.py
tensorflow/python/estimator/estimator_lib.py
tensorflow/python/training/training.py
tensorflow/python/training/warm_starting_util.py [moved from tensorflow/python/estimator/warm_starting_util.py with 67% similarity]
tensorflow/python/training/warm_starting_util_test.py [moved from tensorflow/python/estimator/warm_starting_util_test.py with 94% similarity]
tensorflow/tools/api/golden/tensorflow.estimator.-vocab-info.pbtxt
tensorflow/tools/api/golden/tensorflow.estimator.-warm-start-settings.pbtxt
tensorflow/tools/api/golden/tensorflow.train.-vocab-info.pbtxt [new file with mode: 0644]
tensorflow/tools/api/golden/tensorflow.train.pbtxt

index 3b050a8..ccc1f4c 100644 (file)
@@ -4002,6 +4002,25 @@ py_test(
 )
 
 py_test(
+    name = "warm_starting_util_test",
+    size = "small",
+    srcs = ["training/warm_starting_util_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":array_ops",
+        ":client_testlib",
+        ":dtypes",
+        ":framework_ops",
+        ":init_ops",
+        ":training",
+        ":variable_scope",
+        ":variables",
+        "//tensorflow/python/feature_column",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
     name = "monitored_session_test",
     size = "medium",
     srcs = ["training/monitored_session_test.py"],
index c519fd5..e3a6708 100644 (file)
@@ -37,7 +37,6 @@ py_library(
         ":parsing_utils",
         ":run_config",
         ":training",
-        ":warm_starting_util",
         "//tensorflow/python:util",
     ],
 )
@@ -278,12 +277,12 @@ py_library(
     srcs = ["canned/dnn_testing_utils.py"],
     srcs_version = "PY2AND3",
     deps = [
+        ":estimator",
         ":head",
         ":metric_keys",
         ":model_fn",
         ":numpy_io",
         ":prediction_keys",
-        ":warm_starting_util",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:check_ops",
@@ -427,7 +426,6 @@ py_library(
         ":model_fn",
         ":run_config",
         ":util",
-        ":warm_starting_util",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python:client",
         "//tensorflow/python:control_flow_ops",
@@ -868,39 +866,3 @@ py_test(
         "//tensorflow/python:training",
     ],
 )
-
-py_library(
-    name = "warm_starting_util",
-    srcs = ["warm_starting_util.py"],
-    srcs_version = "PY2AND3",
-    deps = [
-        "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework_ops",
-        "//tensorflow/python:platform",
-        "//tensorflow/python:state_ops",
-        "//tensorflow/python:training",
-        "//tensorflow/python:variable_scope",
-        "//tensorflow/python:variables",
-        "//tensorflow/python/feature_column",
-    ],
-)
-
-py_test(
-    name = "warm_starting_util_test",
-    size = "small",
-    srcs = ["warm_starting_util_test.py"],
-    srcs_version = "PY2AND3",
-    deps = [
-        ":warm_starting_util",
-        "//tensorflow/python:array_ops",
-        "//tensorflow/python:client_testlib",
-        "//tensorflow/python:dtypes",
-        "//tensorflow/python:framework_ops",
-        "//tensorflow/python:init_ops",
-        "//tensorflow/python:training",
-        "//tensorflow/python:variable_scope",
-        "//tensorflow/python:variables",
-        "//tensorflow/python/feature_column",
-        "//third_party/py/numpy",
-    ],
-)
index 84675bf..d275695 100644 (file)
@@ -26,7 +26,7 @@ import six
 
 from tensorflow.core.example import example_pb2
 from tensorflow.core.example import feature_pb2
-from tensorflow.python.estimator import warm_starting_util
+from tensorflow.python.estimator import estimator
 from tensorflow.python.estimator.canned import dnn_linear_combined
 from tensorflow.python.estimator.canned import dnn_testing_utils
 from tensorflow.python.estimator.canned import linear_testing_utils
@@ -866,7 +866,7 @@ class DNNLinearCombinedWarmStartingTest(test.TestCase):
                 learning_rate=0.0),
             # The provided regular expression will only warm-start the deep
             # portion of the model.
-            warm_start_from=warm_starting_util.WarmStartSettings(
+            warm_start_from=estimator.WarmStartSettings(
                 ckpt_to_initialize_from=dnn_lc_classifier.model_dir,
                 vars_to_warm_start='.*(dnn).*')))
 
index 7065759..9a7d088 100644 (file)
@@ -27,8 +27,8 @@ import six
 
 from tensorflow.core.framework import summary_pb2
 from tensorflow.python.client import session as tf_session
+from tensorflow.python.estimator import estimator
 from tensorflow.python.estimator import model_fn
-from tensorflow.python.estimator import warm_starting_util
 from tensorflow.python.estimator.canned import head as head_lib
 from tensorflow.python.estimator.canned import metric_keys
 from tensorflow.python.estimator.canned import prediction_keys
@@ -828,7 +828,7 @@ class BaseDNNWarmStartingTest(object):
         optimizer=gradient_descent.GradientDescentOptimizer(learning_rate=0.0),
         # The provided regular expression will only warm-start the city
         # embedding, not the kernels and biases of the hidden weights.
-        warm_start_from=warm_starting_util.WarmStartSettings(
+        warm_start_from=estimator.WarmStartSettings(
             ckpt_to_initialize_from=dnn_classifier.model_dir,
             vars_to_warm_start='.*(city).*'))
 
@@ -892,7 +892,7 @@ class BaseDNNWarmStartingTest(object):
         dimension=2)
     # We can create our VocabInfo object from the new and old occupation
     # FeatureColumn's.
-    occupation_vocab_info = warm_starting_util.VocabInfo(
+    occupation_vocab_info = estimator.VocabInfo(
         new_vocab=new_occupation.categorical_column.vocabulary_file,
         new_vocab_size=new_occupation.categorical_column.vocabulary_size,
         num_oov_buckets=new_occupation.categorical_column.num_oov_buckets,
@@ -907,7 +907,7 @@ class BaseDNNWarmStartingTest(object):
         feature_columns=[occupation],
         n_classes=4,
         optimizer=gradient_descent.GradientDescentOptimizer(learning_rate=0.0),
-        warm_start_from=warm_starting_util.WarmStartSettings(
+        warm_start_from=estimator.WarmStartSettings(
             ckpt_to_initialize_from=dnn_classifier.model_dir,
             var_name_to_vocab_info={
                 OCCUPATION_EMBEDDING_NAME: occupation_vocab_info
@@ -978,7 +978,7 @@ class BaseDNNWarmStartingTest(object):
         optimizer=gradient_descent.GradientDescentOptimizer(learning_rate=0.0),
         # The 'city' variable correspond to the 'locality' variable in the
         # previous model.
-        warm_start_from=warm_starting_util.WarmStartSettings(
+        warm_start_from=estimator.WarmStartSettings(
             ckpt_to_initialize_from=dnn_classifier.model_dir,
             var_name_to_prev_var_name={
                 CITY_EMBEDDING_NAME:
index 3e9183c..8e506a7 100644 (file)
@@ -31,7 +31,6 @@ from tensorflow.core.example import feature_pb2
 from tensorflow.python.client import session as tf_session
 from tensorflow.python.estimator import estimator
 from tensorflow.python.estimator import run_config
-from tensorflow.python.estimator import warm_starting_util
 from tensorflow.python.estimator.canned import linear
 from tensorflow.python.estimator.canned import metric_keys
 from tensorflow.python.estimator.export import export
@@ -1968,7 +1967,7 @@ class BaseLinearWarmStartingTest(object):
         optimizer=gradient_descent.GradientDescentOptimizer(learning_rate=0.0),
         # The provided regular expression will only warm-start the age variable
         # and not the bias.
-        warm_start_from=warm_starting_util.WarmStartSettings(
+        warm_start_from=estimator.WarmStartSettings(
             ckpt_to_initialize_from=linear_classifier.model_dir,
             vars_to_warm_start='.*(age).*'))
 
@@ -2016,7 +2015,7 @@ class BaseLinearWarmStartingTest(object):
         vocabulary_size=len(new_vocab_list))
     # We can create our VocabInfo object from the new and old occupation
     # FeatureColumn's.
-    occupation_vocab_info = warm_starting_util.VocabInfo(
+    occupation_vocab_info = estimator.VocabInfo(
         new_vocab=new_occupation.vocabulary_file,
         new_vocab_size=new_occupation.vocabulary_size,
         num_oov_buckets=new_occupation.num_oov_buckets,
@@ -2030,7 +2029,7 @@ class BaseLinearWarmStartingTest(object):
         feature_columns=[occupation],
         n_classes=4,
         optimizer=gradient_descent.GradientDescentOptimizer(learning_rate=0.0),
-        warm_start_from=warm_starting_util.WarmStartSettings(
+        warm_start_from=estimator.WarmStartSettings(
             ckpt_to_initialize_from=linear_classifier.model_dir,
             var_name_to_vocab_info={
                 OCCUPATION_WEIGHT_NAME: occupation_vocab_info
@@ -2082,7 +2081,7 @@ class BaseLinearWarmStartingTest(object):
         optimizer=gradient_descent.GradientDescentOptimizer(learning_rate=0.0),
         # The 'age' variable correspond to the 'age_in_years' variable in the
         # previous model.
-        warm_start_from=warm_starting_util.WarmStartSettings(
+        warm_start_from=estimator.WarmStartSettings(
             ckpt_to_initialize_from=linear_classifier.model_dir,
             var_name_to_prev_var_name={
                 AGE_WEIGHT_NAME: AGE_WEIGHT_NAME.replace('age', 'age_in_years')
index 6c402d8..41a1358 100644 (file)
@@ -19,6 +19,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import collections
 import copy
 import os
 import tempfile
@@ -35,7 +36,6 @@ from tensorflow.python.eager import context
 from tensorflow.python.estimator import model_fn as model_fn_lib
 from tensorflow.python.estimator import run_config
 from tensorflow.python.estimator import util
-from tensorflow.python.estimator import warm_starting_util
 from tensorflow.python.estimator.export.export import build_all_signature_defs
 from tensorflow.python.estimator.export.export import get_temp_export_dir
 from tensorflow.python.estimator.export.export import get_timestamped_export_dir
@@ -55,6 +55,7 @@ from tensorflow.python.training import monitored_session
 from tensorflow.python.training import saver
 from tensorflow.python.training import training
 from tensorflow.python.training import training_util
+from tensorflow.python.training import warm_starting_util
 from tensorflow.python.util import compat
 from tensorflow.python.util import compat_internal
 from tensorflow.python.util import nest
@@ -217,8 +218,8 @@ class Estimator(object):
     self._params = copy.deepcopy(params or {})
 
     # pylint: disable=protected-access
-    self._warm_start_settings = (
-        warm_starting_util._get_default_warm_start_settings(warm_start_from))
+    self._warm_start_settings = _get_default_warm_start_settings(
+        warm_start_from)
     # pylint: enable=protected-access
 
   @property
@@ -830,7 +831,7 @@ class Estimator(object):
         logging.info('Warm-starting with WarmStartSettings: %s' %
                      (self._warm_start_settings,))
         # pylint: disable=protected-access
-        warm_starting_util._warm_start(self._warm_start_settings)
+        warm_starting_util.warm_start(*self._warm_start_settings)
         # pylint: enable=protected-access
       # Check if the user created a loss summary, and add one if they didn't.
       # We assume here that the summary is called 'loss'. If it is not, we will
@@ -1152,3 +1153,187 @@ class _DatasetInitializerHook(training.SessionRunHook):
   def after_create_session(self, session, coord):
     del coord
     session.run(self._initializer)
+
+VocabInfo = warm_starting_util.VocabInfo  # pylint: disable=invalid-name
+
+
+@tf_export('estimator.WarmStartSettings')
+class WarmStartSettings(
+    collections.namedtuple('WarmStartSettings', [
+        'ckpt_to_initialize_from',
+        'vars_to_warm_start',
+        'var_name_to_vocab_info',
+        'var_name_to_prev_var_name',
+    ])):
+  """Settings for warm-starting in Estimators.
+
+  Example Use with canned `DNNEstimator`:
+
+  ```
+  emb_vocab_file = tf.feature_column.embedding_column(
+      tf.feature_column.categorical_column_with_vocabulary_file(
+          "sc_vocab_file", "new_vocab.txt", vocab_size=100),
+      dimension=8)
+  emb_vocab_list = tf.feature_column.embedding_column(
+      tf.feature_column.categorical_column_with_vocabulary_list(
+          "sc_vocab_list", vocabulary_list=["a", "b"]),
+      dimension=8)
+  estimator = tf.estimator.DNNClassifier(
+    hidden_units=[128, 64], feature_columns=[emb_vocab_file, emb_vocab_list],
+    warm_start_from=ws)
+  ```
+
+  where `ws` could be defined as:
+
+  Warm-start all weights in the model (input layer and hidden weights).
+  Either the directory or a specific checkpoint can be provided (in the case
+  of the former, the latest checkpoint will be used):
+
+  ```
+  ws = WarmStartSettings(ckpt_to_initialize_from="/tmp")
+  ws = WarmStartSettings(ckpt_to_initialize_from="/tmp/model-1000")
+  ```
+
+  Warm-start only the embeddings (input layer):
+
+  ```
+  ws = WarmStartSettings(ckpt_to_initialize_from="/tmp",
+                         vars_to_warm_start=".*input_layer.*")
+  ```
+
+  Warm-start all weights but the embedding parameters corresponding to
+  `sc_vocab_file` have a different vocab from the one used in the current
+  model:
+
+  ```
+  vocab_info = tf.estimator.VocabInfo(
+      new_vocab=sc_vocab_file.vocabulary_file,
+      new_vocab_size=sc_vocab_file.vocabulary_size,
+      num_oov_buckets=sc_vocab_file.num_oov_buckets,
+      old_vocab="old_vocab.txt"
+  )
+  ws = WarmStartSettings(
+      ckpt_to_initialize_from="/tmp",
+      var_name_to_vocab_info={
+          "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
+      })
+  ```
+
+  Warm-start only `sc_vocab_file` embeddings (and no other variables), which
+  have a different vocab from the one used in the current model:
+
+  ```
+  vocab_info = tf.estimator.VocabInfo(
+      new_vocab=sc_vocab_file.vocabulary_file,
+      new_vocab_size=sc_vocab_file.vocabulary_size,
+      num_oov_buckets=sc_vocab_file.num_oov_buckets,
+      old_vocab="old_vocab.txt"
+  )
+  ws = WarmStartSettings(
+      ckpt_to_initialize_from="/tmp",
+      vars_to_warm_start=None,
+      var_name_to_vocab_info={
+          "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
+      })
+  ```
+
+  Warm-start all weights but the parameters corresponding to `sc_vocab_file`
+  have a different vocab from the one used in current checkpoint, and only
+  100 of those entries were used:
+
+  ```
+  vocab_info = tf.estimator.VocabInfo(
+      new_vocab=sc_vocab_file.vocabulary_file,
+      new_vocab_size=sc_vocab_file.vocabulary_size,
+      num_oov_buckets=sc_vocab_file.num_oov_buckets,
+      old_vocab="old_vocab.txt",
+      old_vocab_size=100
+  )
+  ws = WarmStartSettings(
+      ckpt_to_initialize_from="/tmp",
+      var_name_to_vocab_info={
+          "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
+      })
+  ```
+
+  Warm-start all weights but the parameters corresponding to `sc_vocab_file`
+  have a different vocab from the one used in current checkpoint and the
+  parameters corresponding to `sc_vocab_list` have a different name from the
+  current checkpoint:
+
+  ```
+  vocab_info = tf.estimator.VocabInfo(
+      new_vocab=sc_vocab_file.vocabulary_file,
+      new_vocab_size=sc_vocab_file.vocabulary_size,
+      num_oov_buckets=sc_vocab_file.num_oov_buckets,
+      old_vocab="old_vocab.txt",
+      old_vocab_size=100
+  )
+  ws = WarmStartSettings(
+      ckpt_to_initialize_from="/tmp",
+      var_name_to_vocab_info={
+          "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
+      },
+      var_name_to_prev_var_name={
+          "input_layer/sc_vocab_list_embedding/embedding_weights":
+              "old_tensor_name"
+      })
+  ```
+
+  Attributes:
+    ckpt_to_initialize_from: [Required] A string specifying the directory with
+      checkpoint file(s) or path to checkpoint from which to warm-start the
+      model parameters.
+    vars_to_warm_start: [Optional] A regular expression that captures which
+      variables to warm-start (see tf.get_collection).  Defaults to `'.*'`,
+      which warm-starts all variables.  If `None` is explicitly given, only
+      variables specified in `var_name_to_vocab_info` will be warm-started.
+    var_name_to_vocab_info: [Optional] Dict of variable names (strings) to
+      VocabInfo. The variable names should be "full" variables, not the names
+      of the partitions.  If not explicitly provided, the variable is assumed to
+      have no vocabulary.
+    var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to
+      name of the previously-trained variable in `ckpt_to_initialize_from`. If
+      not explicitly provided, the name of the variable is assumed to be same
+      between previous checkpoint and current model.
+  """
+
+  def __new__(cls,
+              ckpt_to_initialize_from,
+              vars_to_warm_start='.*',
+              var_name_to_vocab_info=None,
+              var_name_to_prev_var_name=None):
+    if not ckpt_to_initialize_from:
+      raise ValueError(
+          '`ckpt_to_initialize_from` MUST be set in WarmStartSettings')
+    return super(WarmStartSettings, cls).__new__(
+        cls,
+        ckpt_to_initialize_from,
+        vars_to_warm_start,
+        var_name_to_vocab_info or {},
+        var_name_to_prev_var_name or {},
+    )
+
+
+def _get_default_warm_start_settings(warm_start_from):
+  """Returns default WarmStartSettings.
+
+  Args:
+    warm_start_from: Either a string representing the filepath of a checkpoint
+      to initialize from, or an instance of WarmStartSettings.
+
+  Returns:
+    Either None or an instance of WarmStartSettings.
+
+  Raises:
+    ValueError: If warm_start_from is not None but is neither a string nor an
+      instance of WarmStartSettings.
+  """
+  if warm_start_from is None:
+    return None
+  if isinstance(warm_start_from, six.string_types):
+    return WarmStartSettings(ckpt_to_initialize_from=warm_start_from)
+  elif isinstance(warm_start_from, WarmStartSettings):
+    return warm_start_from
+  else:
+    raise ValueError('warm_start_from must be a string or a WarmStartSettings')
index 01699e7..be8930b 100644 (file)
@@ -30,6 +30,8 @@ from tensorflow.python.estimator.canned.linear import LinearRegressor
 from tensorflow.python.estimator.canned.parsing_utils import classifier_parse_example_spec
 from tensorflow.python.estimator.canned.parsing_utils import regressor_parse_example_spec
 from tensorflow.python.estimator.estimator import Estimator
+from tensorflow.python.estimator.estimator import VocabInfo
+from tensorflow.python.estimator.estimator import WarmStartSettings
 from tensorflow.python.estimator.export import export_lib as export
 from tensorflow.python.estimator.exporter import Exporter
 from tensorflow.python.estimator.exporter import FinalExporter
@@ -41,8 +43,6 @@ from tensorflow.python.estimator.run_config import RunConfig
 from tensorflow.python.estimator.training import EvalSpec
 from tensorflow.python.estimator.training import train_and_evaluate
 from tensorflow.python.estimator.training import TrainSpec
-from tensorflow.python.estimator.warm_starting_util import VocabInfo
-from tensorflow.python.estimator.warm_starting_util import WarmStartSettings
 
 
 from tensorflow.python.util.all_util import remove_undocumented
index e623e27..6880cfc 100644 (file)
@@ -95,6 +95,8 @@ See the @{$python/train} guide.
 @@load_variable
 @@list_variables
 @@init_from_checkpoint
+@@warm_start
+@@VocabInfo
 """
 
 # Optimizers.
@@ -188,6 +190,8 @@ from tensorflow.python.training.training_util import get_global_step
 from tensorflow.python.training.training_util import assert_global_step
 from tensorflow.python.training.training_util import create_global_step
 from tensorflow.python.training.training_util import get_or_create_global_step
+from tensorflow.python.training.warm_starting_util import VocabInfo
+from tensorflow.python.training.warm_starting_util import warm_start
 from tensorflow.python.pywrap_tensorflow import do_quantize_training_on_graphdef
 from tensorflow.python.pywrap_tensorflow import NewCheckpointReader
 from tensorflow.python.util.tf_export import tf_export
@@ -33,7 +33,7 @@ from tensorflow.python.training import saver
 from tensorflow.python.util.tf_export import tf_export
 
 
-@tf_export("estimator.VocabInfo")
+@tf_export("train.VocabInfo", "estimator.VocabInfo")
 class VocabInfo(
     collections.namedtuple("VocabInfo", [
         "new_vocab",
@@ -43,7 +43,7 @@ class VocabInfo(
         "old_vocab_size",
         "backup_initializer",
     ])):
-  """Vocabulary information for WarmStartSettings.
+  """Vocabulary information for warm-starting.
 
   See @{tf.estimator.WarmStartSettings$WarmStartSettings} for examples of using
   VocabInfo to warm-start.
@@ -83,164 +83,6 @@ class VocabInfo(
     )
 
 
-@tf_export("estimator.WarmStartSettings")
-class WarmStartSettings(
-    collections.namedtuple("WarmStartSettings", [
-        "ckpt_to_initialize_from",
-        "vars_to_warm_start",
-        "var_name_to_vocab_info",
-        "var_name_to_prev_var_name",
-    ])):
-  """Settings for warm-starting in Estimators.
-
-  Example Use with canned `DNNEstimator`:
-
-  ```
-  emb_vocab_file = tf.feature_column.embedding_column(
-      tf.feature_column.categorical_column_with_vocabulary_file(
-          "sc_vocab_file", "new_vocab.txt", vocab_size=100),
-      dimension=8)
-  emb_vocab_list = tf.feature_column.embedding_column(
-      tf.feature_column.categorical_column_with_vocabulary_list(
-          "sc_vocab_list", vocabulary_list=["a", "b"]),
-      dimension=8)
-  estimator = tf.estimator.DNNClassifier(
-    hidden_units=[128, 64], feature_columns=[emb_vocab_file, emb_vocab_list],
-    warm_start_from=ws)
-  ```
-
-  where `ws` could be defined as:
-
-  Warm-start all weights in the model (input layer and hidden weights).
-  Either the directory or a specific checkpoint can be provided (in the case
-  of the former, the latest checkpoint will be used):
-
-  ```
-  ws = WarmStartSettings(ckpt_to_initialize_from="/tmp")
-  ws = WarmStartSettings(ckpt_to_initialize_from="/tmp/model-1000")
-  ```
-
-  Warm-start only the embeddings (input layer):
-
-  ```
-  ws = WarmStartSettings(ckpt_to_initialize_from="/tmp",
-                         vars_to_warm_start=".*input_layer.*")
-  ```
-
-  Warm-start all weights but the embedding parameters corresponding to
-  `sc_vocab_file` have a different vocab from the one used in the current
-  model:
-
-  ```
-  vocab_info = ws_util.VocabInfo(
-      new_vocab=sc_vocab_file.vocabulary_file,
-      new_vocab_size=sc_vocab_file.vocabulary_size,
-      num_oov_buckets=sc_vocab_file.num_oov_buckets,
-      old_vocab="old_vocab.txt"
-  )
-  ws = WarmStartSettings(
-      ckpt_to_initialize_from="/tmp",
-      var_name_to_vocab_info={
-          "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
-      })
-  ```
-
-  Warm-start only `sc_vocab_file` embeddings (and no other variables), which
-  have a different vocab from the one used in the current model:
-
-  ```
-  vocab_info = ws_util.VocabInfo(
-      new_vocab=sc_vocab_file.vocabulary_file,
-      new_vocab_size=sc_vocab_file.vocabulary_size,
-      num_oov_buckets=sc_vocab_file.num_oov_buckets,
-      old_vocab="old_vocab.txt"
-  )
-  ws = WarmStartSettings(
-      ckpt_to_initialize_from="/tmp",
-      vars_to_warm_start=None,
-      var_name_to_vocab_info={
-          "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
-      })
-  ```
-
-  Warm-start all weights but the parameters corresponding to `sc_vocab_file`
-  have a different vocab from the one used in current checkpoint, and only
-  100 of those entries were used:
-
-  ```
-  vocab_info = ws_util.VocabInfo(
-      new_vocab=sc_vocab_file.vocabulary_file,
-      new_vocab_size=sc_vocab_file.vocabulary_size,
-      num_oov_buckets=sc_vocab_file.num_oov_buckets,
-      old_vocab="old_vocab.txt",
-      old_vocab_size=100
-  )
-  ws = WarmStartSettings(
-      ckpt_to_initialize_from="/tmp",
-      var_name_to_vocab_info={
-          "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
-      })
-  ```
-
-  Warm-start all weights but the parameters corresponding to `sc_vocab_file`
-  have a different vocab from the one used in current checkpoint and the
-  parameters corresponding to `sc_vocab_list` have a different name from the
-  current checkpoint:
-
-  ```
-  vocab_info = ws_util.VocabInfo(
-      new_vocab=sc_vocab_file.vocabulary_file,
-      new_vocab_size=sc_vocab_file.vocabulary_size,
-      num_oov_buckets=sc_vocab_file.num_oov_buckets,
-      old_vocab="old_vocab.txt",
-      old_vocab_size=100
-  )
-  ws = WarmStartSettings(
-      ckpt_to_initialize_from="/tmp",
-      var_name_to_vocab_info={
-          "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
-      },
-      var_name_to_prev_var_name={
-          "input_layer/sc_vocab_list_embedding/embedding_weights":
-              "old_tensor_name"
-      })
-  ```
-
-  Attributes:
-    ckpt_to_initialize_from: [Required] A string specifying the directory with
-      checkpoint file(s) or path to checkpoint from which to warm-start the
-      model parameters.
-    vars_to_warm_start: [Optional] A regular expression that captures which
-      variables to warm-start (see tf.get_collection).  Defaults to `'.*'`,
-      which warm-starts all variables.  If `None` is explicitly given, only
-      variables specified in `var_name_to_vocab_info` will be warm-started.
-    var_name_to_vocab_info: [Optional] Dict of variable names (strings) to
-      VocabInfo. The variable names should be "full" variables, not the names
-      of the partitions.  If not explicitly provided, the variable is assumed to
-      have no vocabulary.
-    var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to
-      name of the previously-trained variable in `ckpt_to_initialize_from`. If
-      not explicitly provided, the name of the variable is assumed to be same
-      between previous checkpoint and current model.
-  """
-
-  def __new__(cls,
-              ckpt_to_initialize_from,
-              vars_to_warm_start=".*",
-              var_name_to_vocab_info=None,
-              var_name_to_prev_var_name=None):
-    if not ckpt_to_initialize_from:
-      raise ValueError(
-          "`ckpt_to_initialize_from` MUST be set in WarmStartSettings")
-    return super(WarmStartSettings, cls).__new__(
-        cls,
-        ckpt_to_initialize_from,
-        vars_to_warm_start,
-        var_name_to_vocab_info or {},
-        var_name_to_prev_var_name or {},
-    )
-
-
 def _is_variable(x):
   return (isinstance(x, variables_lib.Variable) or
           isinstance(x, resource_variable_ops.ResourceVariable))
@@ -375,8 +217,7 @@ def _warm_start_var_with_vocab(var,
           full_shape=slice_info.full_shape,
           var_offset=slice_info.var_offset)
 
-    # TODO(eddz): Support WarmStartSettings where class vocabularies need
-    # remapping too.
+    # TODO(eddz): Support cases where class vocabularies need remapping too.
     init = checkpoint_ops._load_and_remap_matrix_initializer(
         ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt),
         old_tensor_name=prev_tensor_name,
@@ -396,32 +237,53 @@ def _warm_start_var_with_vocab(var,
 # pylint: enable=protected-access
 
 
-def _warm_start(warm_start_settings):
+@tf_export("train.warm_start")
+def warm_start(ckpt_to_initialize_from,
+               vars_to_warm_start=".*",
+               var_name_to_vocab_info=None,
+               var_name_to_prev_var_name=None):
   """Warm-starts a model using the given settings.
 
   If you are using a tf.estimator.Estimator, this will automatically be called
   during training.
 
   Args:
-    warm_start_settings: An object of `WarmStartSettings`.
+    ckpt_to_initialize_from: [Required] A string specifying the directory with
+      checkpoint file(s) or path to checkpoint from which to warm-start the
+      model parameters.
+    vars_to_warm_start: [Optional] A regular expression that captures which
+      variables to warm-start (see tf.get_collection).  Defaults to `'.*'`,
+      which warm-starts all variables.  If `None` is explicitly given, only
+      variables specified in `var_name_to_vocab_info` will be warm-started.
+    var_name_to_vocab_info: [Optional] Dict of variable names (strings) to
+      VocabInfo. The variable names should be "full" variables, not the names
+      of the partitions.  If not explicitly provided, the variable is assumed to
+      have no vocabulary.
+    var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to
+      name of the previously-trained variable in `ckpt_to_initialize_from`. If
+      not explicitly provided, the name of the variable is assumed to be same
+      between previous checkpoint and current model.
   Raises:
     ValueError: If the WarmStartSettings contains prev_var_name or VocabInfo
       configuration for variable names that are not used.  This is to ensure
       a stronger check for variable configuration than relying on users to
       examine the logs.
   """
-  logging.info("Warm-starting from: %s",
-               (warm_start_settings.ckpt_to_initialize_from,))
+  if var_name_to_vocab_info is None:
+    var_name_to_vocab_info = {}
+  if var_name_to_prev_var_name is None:
+    var_name_to_prev_var_name = {}
+  logging.info("Warm-starting from: %s", (ckpt_to_initialize_from,))
   # We have to deal with partitioned variables, since get_collection flattens
   # out the list.
   grouped_variables = {}
-  # Both warm_start_settings.vars_to_warm_start = '.*' and
-  # warm_start_settings.vars_to_warm_start = None will match everything here.
+  # Both vars_to_warm_start = '.*' and
+  # vars_to_warm_start = None will match everything here.
   for v in ops.get_collection(
       # TODO(eddz): Allow for different collections here (to support
       # warm-starting accumulators).
       ops.GraphKeys.TRAINABLE_VARIABLES,
-      scope=warm_start_settings.vars_to_warm_start):
+      scope=vars_to_warm_start):
     if not isinstance(v, list):
       var_name = _infer_var_name([v])
     else:
@@ -437,10 +299,10 @@ def _warm_start(warm_start_settings):
   vocab_info_used = set()
 
   for var_name, variable in six.iteritems(grouped_variables):
-    prev_var_name = warm_start_settings.var_name_to_prev_var_name.get(var_name)
+    prev_var_name = var_name_to_prev_var_name.get(var_name)
     if prev_var_name:
       prev_var_name_used.add(var_name)
-    vocab_info = warm_start_settings.var_name_to_vocab_info.get(var_name)
+    vocab_info = var_name_to_vocab_info.get(var_name)
     if vocab_info:
       vocab_info_used.add(var_name)
       logging.info(
@@ -460,16 +322,16 @@ def _warm_start(warm_start_settings):
           variable,
           current_vocab_path=vocab_info.new_vocab,
           current_vocab_size=vocab_info.new_vocab_size,
-          prev_ckpt=warm_start_settings.ckpt_to_initialize_from,
+          prev_ckpt=ckpt_to_initialize_from,
           prev_vocab_path=vocab_info.old_vocab,
           previous_vocab_size=vocab_info.old_vocab_size,
           current_oov_buckets=vocab_info.num_oov_buckets,
           prev_tensor_name=prev_var_name,
           initializer=vocab_info.backup_initializer)
     else:
-      # For the special value of warm_start_settings.vars_to_warm_start = None,
+      # For the special value of vars_to_warm_start = None,
       # we only warm-start variables with explicitly specified vocabularies.
-      if warm_start_settings.vars_to_warm_start:
+      if vars_to_warm_start:
         logging.info("Warm-starting variable: {}; prev_var_name: {}".format(
             var_name, prev_var_name or "Unchanged"))
         # Because we use a default empty list in grouped_variables, single
@@ -477,48 +339,22 @@ def _warm_start(warm_start_settings):
         # for init_from_checkpoint logic to work correctly.
         if len(variable) == 1:
           variable = variable[0]
-        _warm_start_var(variable, warm_start_settings.ckpt_to_initialize_from,
-                        prev_var_name)
+        _warm_start_var(variable, ckpt_to_initialize_from, prev_var_name)
 
   prev_var_name_not_used = set(
-      warm_start_settings.var_name_to_prev_var_name.keys()) - prev_var_name_used
-  vocab_info_not_used = set(
-      warm_start_settings.var_name_to_vocab_info.keys()) - vocab_info_used
+      var_name_to_prev_var_name.keys()) - prev_var_name_used
+  vocab_info_not_used = set(var_name_to_vocab_info.keys()) - vocab_info_used
 
   if prev_var_name_not_used:
     raise ValueError(
         "You provided the following variables in "
-        "warm_start_settings.var_name_to_prev_var_name that were not used: "
+        "var_name_to_prev_var_name that were not used: "
         "{0}.  Perhaps you misspelled them?  Here is the list of viable "
         "variable names: {1}".format(prev_var_name_not_used,
                                      grouped_variables.keys()))
   if vocab_info_not_used:
     raise ValueError(
         "You provided the following variables in "
-        "warm_start_settings.var_name_to_vocab_info that were not used: {0}. "
+        "var_name_to_vocab_info that were not used: {0}. "
         " Perhaps you misspelled them?  Here is the list of viable variable "
         "names: {1}".format(vocab_info_not_used, grouped_variables.keys()))
-
-
-def _get_default_warm_start_settings(warm_start_from):
-  """Returns default WarmStartSettings.
-
-  Args:
-    warm_start_from: Either a string representing the filepath of a checkpoint
-      to initialize from, or an instance of WarmStartSettings.
-
-  Returns:
-    Either None or an instance of WarmStartSettings.
-
-  Raises:
-    ValueError: If warm_start_from is not None but is neither a string nor an
-      instance of WarmStartSettings.
-  """
-  if warm_start_from is None:
-    return None
-  if isinstance(warm_start_from, six.string_types):
-    return WarmStartSettings(ckpt_to_initialize_from=warm_start_from)
-  elif isinstance(warm_start_from, WarmStartSettings):
-    return warm_start_from
-  else:
-    raise ValueError("warm_start_from must be a string or a WarmStartSettings")
@@ -22,7 +22,6 @@ import os
 import numpy as np
 import six
 
-from tensorflow.python.estimator import warm_starting_util as ws_util
 from tensorflow.python.feature_column import feature_column as fc
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
@@ -32,6 +31,7 @@ from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
 from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.training import warm_starting_util as ws_util
 
 ones = init_ops.ones_initializer
 norms = init_ops.truncated_normal_initializer
@@ -330,9 +330,7 @@ class WarmStartingUtilTest(test.TestCase):
     with ops.Graph().as_default() as g:
       with self.test_session(graph=g) as sess:
         cols_to_vars = self._create_linear_model([sc_int], partitioner)
-        ws_util._warm_start(
-            ws_util.WarmStartSettings(
-                self.get_temp_dir(), vars_to_warm_start=".*sc_int.*"))
+        ws_util.warm_start(self.get_temp_dir(), vars_to_warm_start=".*sc_int.*")
         sess.run(variables.global_variables_initializer())
         # Verify weights were correctly warm-started.
         self._assert_cols_to_vars(cols_to_vars, {sc_int: [prev_int_val]}, sess)
@@ -361,9 +359,8 @@ class WarmStartingUtilTest(test.TestCase):
     with ops.Graph().as_default() as g:
       with self.test_session(graph=g) as sess:
         cols_to_vars = self._create_linear_model([sc_hash], partitioner)
-        ws_util._warm_start(
-            ws_util.WarmStartSettings(
-                self.get_temp_dir(), vars_to_warm_start=".*sc_hash.*"))
+        ws_util.warm_start(
+            self.get_temp_dir(), vars_to_warm_start=".*sc_hash.*")
         sess.run(variables.global_variables_initializer())
         # Verify weights were correctly warm-started.
         self._assert_cols_to_vars(cols_to_vars, {sc_hash: [prev_hash_val]},
@@ -398,9 +395,8 @@ class WarmStartingUtilTest(test.TestCase):
         cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
         # Since old vocab is not explicitly set in WarmStartSettings, the old
         # vocab is assumed to be same as new vocab.
-        ws_util._warm_start(
-            ws_util.WarmStartSettings(
-                self.get_temp_dir(), vars_to_warm_start=".*sc_vocab.*"))
+        ws_util.warm_start(
+            self.get_temp_dir(), vars_to_warm_start=".*sc_vocab.*")
         sess.run(variables.global_variables_initializer())
         # Verify weights were correctly warm-started.
         self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]},
@@ -435,11 +431,10 @@ class WarmStartingUtilTest(test.TestCase):
         cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
         # Since old vocab is not explicitly set in WarmStartSettings, the old
         # vocab is assumed to be same as new vocab.
-        ws_util._warm_start(
-            ws_util.WarmStartSettings(
-                # Explicitly provide the file prefix instead of just the dir.
-                os.path.join(self.get_temp_dir(), "model-0"),
-                vars_to_warm_start=".*sc_vocab.*"))
+        ws_util.warm_start(
+            # Explicitly provide the file prefix instead of just the dir.
+            os.path.join(self.get_temp_dir(), "model-0"),
+            vars_to_warm_start=".*sc_vocab.*")
         sess.run(variables.global_variables_initializer())
         # Verify weights were correctly warm-started.
         self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]},
@@ -485,13 +480,12 @@ class WarmStartingUtilTest(test.TestCase):
             num_oov_buckets=sc_vocab.num_oov_buckets,
             old_vocab=old_vocab_path,
             old_vocab_size=old_vocab_size)
-        warm_start_settings = ws_util.WarmStartSettings(
+        ws_util.warm_start(
             ckpt_to_initialize_from=self.get_temp_dir(),
             vars_to_warm_start=".*sc_vocab.*",
             var_name_to_vocab_info={
                 "linear_model/sc_vocab/weights": vocab_info
             })
-        ws_util._warm_start(warm_start_settings)
         sess.run(variables.global_variables_initializer())
         # Verify weights were correctly warm-started.  'banana' isn't in the
         # first two entries of the old vocabulary, so it's newly initialized.
@@ -523,9 +517,8 @@ class WarmStartingUtilTest(test.TestCase):
     with ops.Graph().as_default() as g:
       with self.test_session(graph=g) as sess:
         cols_to_vars = self._create_linear_model([real_bucket], partitioner)
-        ws_util._warm_start(
-            ws_util.WarmStartSettings(
-                self.get_temp_dir(), vars_to_warm_start=".*real_bucketized.*"))
+        ws_util.warm_start(
+            self.get_temp_dir(), vars_to_warm_start=".*real_bucketized.*")
         sess.run(variables.global_variables_initializer())
         # Verify weights were correctly warm-started.
         self._assert_cols_to_vars(cols_to_vars,
@@ -606,12 +599,11 @@ class WarmStartingUtilTest(test.TestCase):
             new_vocab_size=sc_vocab.vocabulary_size,
             num_oov_buckets=sc_vocab.num_oov_buckets,
             old_vocab=vocab_path)
-        ws_util._warm_start(
-            ws_util.WarmStartSettings(
-                self.get_temp_dir(),
-                var_name_to_vocab_info={
-                    "linear_model/sc_vocab/weights": vocab_info
-                }))
+        ws_util.warm_start(
+            self.get_temp_dir(),
+            var_name_to_vocab_info={
+                "linear_model/sc_vocab/weights": vocab_info
+            })
         sess.run(variables.global_variables_initializer())
         # Verify weights were correctly warm-started.
         self._assert_cols_to_vars(cols_to_vars, {
@@ -668,7 +660,7 @@ class WarmStartingUtilTest(test.TestCase):
             new_vocab_size=sc_vocab.vocabulary_size,
             num_oov_buckets=sc_vocab.num_oov_buckets,
             old_vocab=prev_vocab_path)
-        ws_settings = ws_util.WarmStartSettings(
+        ws_util.warm_start(
             self.get_temp_dir(),
             vars_to_warm_start=".*(sc_keys|sc_vocab).*",
             var_name_to_vocab_info={
@@ -678,7 +670,6 @@ class WarmStartingUtilTest(test.TestCase):
                 ws_util._infer_var_name(cols_to_vars[sc_keys]):
                     "some_other_name"
             })
-        ws_util._warm_start(ws_settings)
         sess.run(variables.global_variables_initializer())
         # Verify weights were correctly warm-started.  Var corresponding to
         # sc_hash should not be warm-started.  Var corresponding to sc_vocab
@@ -732,7 +723,7 @@ class WarmStartingUtilTest(test.TestCase):
             new_vocab_size=sc_vocab.vocabulary_size,
             num_oov_buckets=sc_vocab.num_oov_buckets,
             old_vocab=prev_vocab_path)
-        ws_settings = ws_util.WarmStartSettings(
+        ws_util.warm_start(
             self.get_temp_dir(),
             vars_to_warm_start=".*(sc_keys|sc_vocab).*",
             var_name_to_vocab_info={
@@ -742,7 +733,6 @@ class WarmStartingUtilTest(test.TestCase):
                 ws_util._infer_var_name(cols_to_vars[sc_keys]):
                     "some_other_name"
             })
-        ws_util._warm_start(ws_settings)
         sess.run(variables.global_variables_initializer())
         # Verify weights were correctly warm-started.  Var corresponding to
         # sc_hash should not be warm-started.  Var corresponding to sc_vocab
@@ -796,7 +786,7 @@ class WarmStartingUtilTest(test.TestCase):
             new_vocab_size=sc_vocab.vocabulary_size,
             num_oov_buckets=sc_vocab.num_oov_buckets,
             old_vocab=prev_vocab_path)
-        ws_settings = ws_util.WarmStartSettings(
+        ws_util.warm_start(
             self.get_temp_dir(),
             # The special value of None here will ensure that only the variable
             # specified in var_name_to_vocab_info (sc_vocab embedding) is
@@ -812,7 +802,6 @@ class WarmStartingUtilTest(test.TestCase):
                 ws_util._infer_var_name(cols_to_vars[sc_keys]):
                     "some_other_name"
             })
-        ws_util._warm_start(ws_settings)
         sess.run(variables.global_variables_initializer())
         # Verify weights were correctly warm-started.  Var corresponding to
         # sc_vocab should be correctly warm-started after vocab remapping,
@@ -874,13 +863,12 @@ class WarmStartingUtilTest(test.TestCase):
             # use a truncated normal initializer.
             backup_initializer=init_ops.random_uniform_initializer(
                 minval=0.42, maxval=0.42))
-        ws_settings = ws_util.WarmStartSettings(
+        ws_util.warm_start(
             self.get_temp_dir(),
             var_name_to_vocab_info={
                 ws_util._infer_var_name(cols_to_vars[emb_vocab_column]):
                     vocab_info
             })
-        ws_util._warm_start(ws_settings)
         sess.run(variables.global_variables_initializer())
         # Verify weights were correctly warm-started. Var corresponding to
         # emb_vocab_column should be correctly warm-started after vocab
@@ -947,13 +935,12 @@ class WarmStartingUtilTest(test.TestCase):
             # use a truncated normal initializer.
             backup_initializer=init_ops.random_uniform_initializer(
                 minval=0.42, maxval=0.42))
-        ws_settings = ws_util.WarmStartSettings(
+        ws_util.warm_start(
             self.get_temp_dir(),
             vars_to_warm_start=".*sc_vocab.*",
             var_name_to_vocab_info={
                 "linear_model/sc_vocab_embedding/embedding_weights": vocab_info
             })
-        ws_util._warm_start(ws_settings)
         sess.run(variables.global_variables_initializer())
         # Verify weights were correctly warm-started. Var corresponding to
         # emb_vocab should be correctly warm-started after vocab remapping.
@@ -973,7 +960,6 @@ class WarmStartingUtilTest(test.TestCase):
             }, sess)
 
   def testErrorConditions(self):
-    self.assertRaises(ValueError, ws_util.WarmStartSettings, None)
     x = variable_scope.get_variable(
         "x",
         shape=[4, 1],
@@ -983,9 +969,6 @@ class WarmStartingUtilTest(test.TestCase):
     # List of PartitionedVariable is invalid type when warm-starting with vocab.
     self.assertRaises(TypeError, ws_util._warm_start_var_with_vocab, [x],
                       "/tmp", 5, "/tmp", "/tmp")
-    # Keys of type other than FeatureColumn.
-    self.assertRaises(TypeError, ws_util._warm_start, {"StringType": x},
-                      ws_util.WarmStartSettings("/tmp"))
 
     # Unused variable names raises ValueError.
     with ops.Graph().as_default():
@@ -997,18 +980,16 @@ class WarmStartingUtilTest(test.TestCase):
             partitioner=lambda shape, dtype: [2, 1])
         self._write_checkpoint(sess)
 
-    self.assertRaises(ValueError, ws_util._warm_start,
-                      ws_util.WarmStartSettings(
-                          self.get_temp_dir(),
-                          var_name_to_vocab_info={
-                              "y": ws_util.VocabInfo("", 1, 0, "")
-                          }))
-    self.assertRaises(ValueError, ws_util._warm_start,
-                      ws_util.WarmStartSettings(
-                          self.get_temp_dir(),
-                          var_name_to_prev_var_name={
-                              "y": "y2"
-                          }))
+    self.assertRaises(
+        ValueError,
+        ws_util.warm_start,
+        self.get_temp_dir(),
+        var_name_to_vocab_info={"y": ws_util.VocabInfo("", 1, 0, "")})
+    self.assertRaises(
+        ValueError,
+        ws_util.warm_start,
+        self.get_temp_dir(),
+        var_name_to_prev_var_name={"y": "y2"})
 
 
 if __name__ == "__main__":
index a16e3ae..5301b94 100644 (file)
@@ -1,7 +1,7 @@
 path: "tensorflow.estimator.VocabInfo"
 tf_class {
-  is_instance: "<class \'tensorflow.python.estimator.warm_starting_util.VocabInfo\'>"
-  is_instance: "<class \'tensorflow.python.estimator.warm_starting_util.VocabInfo\'>"
+  is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
+  is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
   is_instance: "<type \'tuple\'>"
   member {
     name: "backup_initializer"
index afdd6bb..43f5343 100644 (file)
@@ -1,7 +1,7 @@
 path: "tensorflow.estimator.WarmStartSettings"
 tf_class {
-  is_instance: "<class \'tensorflow.python.estimator.warm_starting_util.WarmStartSettings\'>"
-  is_instance: "<class \'tensorflow.python.estimator.warm_starting_util.WarmStartSettings\'>"
+  is_instance: "<class \'tensorflow.python.estimator.estimator.WarmStartSettings\'>"
+  is_instance: "<class \'tensorflow.python.estimator.estimator.WarmStartSettings\'>"
   is_instance: "<type \'tuple\'>"
   member {
     name: "ckpt_to_initialize_from"
diff --git a/tensorflow/tools/api/golden/tensorflow.train.-vocab-info.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-vocab-info.pbtxt
new file mode 100644 (file)
index 0000000..4ce7cb1
--- /dev/null
@@ -0,0 +1,39 @@
+path: "tensorflow.train.VocabInfo"
+tf_class {
+  is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
+  is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
+  is_instance: "<type \'tuple\'>"
+  member {
+    name: "backup_initializer"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "new_vocab"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "new_vocab_size"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "num_oov_buckets"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "old_vocab"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "old_vocab_size"
+    mtype: "<type \'property\'>"
+  }
+  member_method {
+    name: "__init__"
+  }
+  member_method {
+    name: "count"
+  }
+  member_method {
+    name: "index"
+  }
+}
index 3b06aaf..c75ee47 100644 (file)
@@ -225,6 +225,10 @@ tf_module {
     mtype: "<type \'type\'>"
   }
   member {
+    name: "VocabInfo"
+    mtype: "<type \'type\'>"
+  }
+  member {
     name: "WorkerSessionCreator"
     mtype: "<type \'type\'>"
   }
@@ -437,6 +441,10 @@ tf_module {
     argspec: "args=[\'save_dir\', \'model_checkpoint_path\', \'all_model_checkpoint_paths\', \'latest_filename\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
   }
   member_method {
+    name: "warm_start"
+    argspec: "args=[\'ckpt_to_initialize_from\', \'vars_to_warm_start\', \'var_name_to_vocab_info\', \'var_name_to_prev_var_name\'], varargs=None, keywords=None, defaults=[\'.*\', \'None\', \'None\'], "
+  }
+  member_method {
     name: "write_graph"
     argspec: "args=[\'graph_or_graph_def\', \'logdir\', \'name\', \'as_text\'], varargs=None, keywords=None, defaults=[\'True\'], "
   }