)
py_test(
+ name = "warm_starting_util_test",
+ size = "small",
+ srcs = ["training/warm_starting_util_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":array_ops",
+ ":client_testlib",
+ ":dtypes",
+ ":framework_ops",
+ ":init_ops",
+ ":training",
+ ":variable_scope",
+ ":variables",
+ "//tensorflow/python/feature_column",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
name = "monitored_session_test",
size = "medium",
srcs = ["training/monitored_session_test.py"],
":parsing_utils",
":run_config",
":training",
- ":warm_starting_util",
"//tensorflow/python:util",
],
)
srcs = ["canned/dnn_testing_utils.py"],
srcs_version = "PY2AND3",
deps = [
+ ":estimator",
":head",
":metric_keys",
":model_fn",
":numpy_io",
":prediction_keys",
- ":warm_starting_util",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
":model_fn",
":run_config",
":util",
- ":warm_starting_util",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:training",
],
)
-
-py_library(
- name = "warm_starting_util",
- srcs = ["warm_starting_util.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/python:array_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- "//tensorflow/python/feature_column",
- ],
-)
-
-py_test(
- name = "warm_starting_util_test",
- size = "small",
- srcs = ["warm_starting_util_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":warm_starting_util",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- "//tensorflow/python/feature_column",
- "//third_party/py/numpy",
- ],
-)
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
-from tensorflow.python.estimator import warm_starting_util
+from tensorflow.python.estimator import estimator
from tensorflow.python.estimator.canned import dnn_linear_combined
from tensorflow.python.estimator.canned import dnn_testing_utils
from tensorflow.python.estimator.canned import linear_testing_utils
learning_rate=0.0),
# The provided regular expression will only warm-start the deep
# portion of the model.
- warm_start_from=warm_starting_util.WarmStartSettings(
+ warm_start_from=estimator.WarmStartSettings(
ckpt_to_initialize_from=dnn_lc_classifier.model_dir,
vars_to_warm_start='.*(dnn).*')))
from tensorflow.core.framework import summary_pb2
from tensorflow.python.client import session as tf_session
+from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn
-from tensorflow.python.estimator import warm_starting_util
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.estimator.canned import prediction_keys
optimizer=gradient_descent.GradientDescentOptimizer(learning_rate=0.0),
# The provided regular expression will only warm-start the city
# embedding, not the kernels and biases of the hidden weights.
- warm_start_from=warm_starting_util.WarmStartSettings(
+ warm_start_from=estimator.WarmStartSettings(
ckpt_to_initialize_from=dnn_classifier.model_dir,
vars_to_warm_start='.*(city).*'))
dimension=2)
# We can create our VocabInfo object from the new and old occupation
# FeatureColumn's.
- occupation_vocab_info = warm_starting_util.VocabInfo(
+ occupation_vocab_info = estimator.VocabInfo(
new_vocab=new_occupation.categorical_column.vocabulary_file,
new_vocab_size=new_occupation.categorical_column.vocabulary_size,
num_oov_buckets=new_occupation.categorical_column.num_oov_buckets,
feature_columns=[occupation],
n_classes=4,
optimizer=gradient_descent.GradientDescentOptimizer(learning_rate=0.0),
- warm_start_from=warm_starting_util.WarmStartSettings(
+ warm_start_from=estimator.WarmStartSettings(
ckpt_to_initialize_from=dnn_classifier.model_dir,
var_name_to_vocab_info={
OCCUPATION_EMBEDDING_NAME: occupation_vocab_info
optimizer=gradient_descent.GradientDescentOptimizer(learning_rate=0.0),
# The 'city' variable correspond to the 'locality' variable in the
# previous model.
- warm_start_from=warm_starting_util.WarmStartSettings(
+ warm_start_from=estimator.WarmStartSettings(
ckpt_to_initialize_from=dnn_classifier.model_dir,
var_name_to_prev_var_name={
CITY_EMBEDDING_NAME:
from tensorflow.python.client import session as tf_session
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import run_config
-from tensorflow.python.estimator import warm_starting_util
from tensorflow.python.estimator.canned import linear
from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.estimator.export import export
optimizer=gradient_descent.GradientDescentOptimizer(learning_rate=0.0),
# The provided regular expression will only warm-start the age variable
# and not the bias.
- warm_start_from=warm_starting_util.WarmStartSettings(
+ warm_start_from=estimator.WarmStartSettings(
ckpt_to_initialize_from=linear_classifier.model_dir,
vars_to_warm_start='.*(age).*'))
vocabulary_size=len(new_vocab_list))
# We can create our VocabInfo object from the new and old occupation
# FeatureColumn's.
- occupation_vocab_info = warm_starting_util.VocabInfo(
+ occupation_vocab_info = estimator.VocabInfo(
new_vocab=new_occupation.vocabulary_file,
new_vocab_size=new_occupation.vocabulary_size,
num_oov_buckets=new_occupation.num_oov_buckets,
feature_columns=[occupation],
n_classes=4,
optimizer=gradient_descent.GradientDescentOptimizer(learning_rate=0.0),
- warm_start_from=warm_starting_util.WarmStartSettings(
+ warm_start_from=estimator.WarmStartSettings(
ckpt_to_initialize_from=linear_classifier.model_dir,
var_name_to_vocab_info={
OCCUPATION_WEIGHT_NAME: occupation_vocab_info
optimizer=gradient_descent.GradientDescentOptimizer(learning_rate=0.0),
# The 'age' variable correspond to the 'age_in_years' variable in the
# previous model.
- warm_start_from=warm_starting_util.WarmStartSettings(
+ warm_start_from=estimator.WarmStartSettings(
ckpt_to_initialize_from=linear_classifier.model_dir,
var_name_to_prev_var_name={
AGE_WEIGHT_NAME: AGE_WEIGHT_NAME.replace('age', 'age_in_years')
from __future__ import division
from __future__ import print_function
+import collections
import copy
import os
import tempfile
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config
from tensorflow.python.estimator import util
-from tensorflow.python.estimator import warm_starting_util
from tensorflow.python.estimator.export.export import build_all_signature_defs
from tensorflow.python.estimator.export.export import get_temp_export_dir
from tensorflow.python.estimator.export.export import get_timestamped_export_dir
from tensorflow.python.training import saver
from tensorflow.python.training import training
from tensorflow.python.training import training_util
+from tensorflow.python.training import warm_starting_util
from tensorflow.python.util import compat
from tensorflow.python.util import compat_internal
from tensorflow.python.util import nest
self._params = copy.deepcopy(params or {})
# pylint: disable=protected-access
- self._warm_start_settings = (
- warm_starting_util._get_default_warm_start_settings(warm_start_from))
+ self._warm_start_settings = _get_default_warm_start_settings(
+ warm_start_from)
# pylint: enable=protected-access
@property
logging.info('Warm-starting with WarmStartSettings: %s' %
(self._warm_start_settings,))
# pylint: disable=protected-access
- warm_starting_util._warm_start(self._warm_start_settings)
+ warm_starting_util.warm_start(*self._warm_start_settings)
# pylint: enable=protected-access
# Check if the user created a loss summary, and add one if they didn't.
# We assume here that the summary is called 'loss'. If it is not, we will
def after_create_session(self, session, coord):
del coord
session.run(self._initializer)
+
+VocabInfo = warm_starting_util.VocabInfo # pylint: disable=invalid-name
+
+
+@tf_export('estimator.WarmStartSettings')
+class WarmStartSettings(
+ collections.namedtuple('WarmStartSettings', [
+ 'ckpt_to_initialize_from',
+ 'vars_to_warm_start',
+ 'var_name_to_vocab_info',
+ 'var_name_to_prev_var_name',
+ ])):
+ """Settings for warm-starting in Estimators.
+
+ Example Use with canned `DNNEstimator`:
+
+ ```
+ emb_vocab_file = tf.feature_column.embedding_column(
+ tf.feature_column.categorical_column_with_vocabulary_file(
+ "sc_vocab_file", "new_vocab.txt", vocab_size=100),
+ dimension=8)
+ emb_vocab_list = tf.feature_column.embedding_column(
+ tf.feature_column.categorical_column_with_vocabulary_list(
+ "sc_vocab_list", vocabulary_list=["a", "b"]),
+ dimension=8)
+ estimator = tf.estimator.DNNClassifier(
+ hidden_units=[128, 64], feature_columns=[emb_vocab_file, emb_vocab_list],
+ warm_start_from=ws)
+ ```
+
+ where `ws` could be defined as:
+
+ Warm-start all weights in the model (input layer and hidden weights).
+ Either the directory or a specific checkpoint can be provided (in the case
+ of the former, the latest checkpoint will be used):
+
+ ```
+ ws = WarmStartSettings(ckpt_to_initialize_from="/tmp")
+ ws = WarmStartSettings(ckpt_to_initialize_from="/tmp/model-1000")
+ ```
+
+ Warm-start only the embeddings (input layer):
+
+ ```
+ ws = WarmStartSettings(ckpt_to_initialize_from="/tmp",
+ vars_to_warm_start=".*input_layer.*")
+ ```
+
+ Warm-start all weights but the embedding parameters corresponding to
+ `sc_vocab_file` have a different vocab from the one used in the current
+ model:
+
+ ```
+ vocab_info = tf.estimator.VocabInfo(
+ new_vocab=sc_vocab_file.vocabulary_file,
+ new_vocab_size=sc_vocab_file.vocabulary_size,
+ num_oov_buckets=sc_vocab_file.num_oov_buckets,
+ old_vocab="old_vocab.txt"
+ )
+ ws = WarmStartSettings(
+ ckpt_to_initialize_from="/tmp",
+ var_name_to_vocab_info={
+ "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
+ })
+ ```
+
+ Warm-start only `sc_vocab_file` embeddings (and no other variables), which
+ have a different vocab from the one used in the current model:
+
+ ```
+ vocab_info = tf.estimator.VocabInfo(
+ new_vocab=sc_vocab_file.vocabulary_file,
+ new_vocab_size=sc_vocab_file.vocabulary_size,
+ num_oov_buckets=sc_vocab_file.num_oov_buckets,
+ old_vocab="old_vocab.txt"
+ )
+ ws = WarmStartSettings(
+ ckpt_to_initialize_from="/tmp",
+ vars_to_warm_start=None,
+ var_name_to_vocab_info={
+ "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
+ })
+ ```
+
+ Warm-start all weights but the parameters corresponding to `sc_vocab_file`
+ have a different vocab from the one used in current checkpoint, and only
+ 100 of those entries were used:
+
+ ```
+ vocab_info = tf.estimator.VocabInfo(
+ new_vocab=sc_vocab_file.vocabulary_file,
+ new_vocab_size=sc_vocab_file.vocabulary_size,
+ num_oov_buckets=sc_vocab_file.num_oov_buckets,
+ old_vocab="old_vocab.txt",
+ old_vocab_size=100
+ )
+ ws = WarmStartSettings(
+ ckpt_to_initialize_from="/tmp",
+ var_name_to_vocab_info={
+ "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
+ })
+ ```
+
+ Warm-start all weights but the parameters corresponding to `sc_vocab_file`
+ have a different vocab from the one used in current checkpoint and the
+ parameters corresponding to `sc_vocab_list` have a different name from the
+ current checkpoint:
+
+ ```
+ vocab_info = tf.estimator.VocabInfo(
+ new_vocab=sc_vocab_file.vocabulary_file,
+ new_vocab_size=sc_vocab_file.vocabulary_size,
+ num_oov_buckets=sc_vocab_file.num_oov_buckets,
+ old_vocab="old_vocab.txt",
+ old_vocab_size=100
+ )
+ ws = WarmStartSettings(
+ ckpt_to_initialize_from="/tmp",
+ var_name_to_vocab_info={
+ "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
+ },
+ var_name_to_prev_var_name={
+ "input_layer/sc_vocab_list_embedding/embedding_weights":
+ "old_tensor_name"
+ })
+ ```
+
+ Attributes:
+ ckpt_to_initialize_from: [Required] A string specifying the directory with
+ checkpoint file(s) or path to checkpoint from which to warm-start the
+ model parameters.
+ vars_to_warm_start: [Optional] A regular expression that captures which
+ variables to warm-start (see tf.get_collection). Defaults to `'.*'`,
+ which warm-starts all variables. If `None` is explicitly given, only
+ variables specified in `var_name_to_vocab_info` will be warm-started.
+ var_name_to_vocab_info: [Optional] Dict of variable names (strings) to
+ VocabInfo. The variable names should be "full" variables, not the names
+ of the partitions. If not explicitly provided, the variable is assumed to
+ have no vocabulary.
+ var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to
+ name of the previously-trained variable in `ckpt_to_initialize_from`. If
+ not explicitly provided, the name of the variable is assumed to be same
+ between previous checkpoint and current model.
+ """
+
+ def __new__(cls,
+ ckpt_to_initialize_from,
+ vars_to_warm_start='.*',
+ var_name_to_vocab_info=None,
+ var_name_to_prev_var_name=None):
+ if not ckpt_to_initialize_from:
+ raise ValueError(
+ '`ckpt_to_initialize_from` MUST be set in WarmStartSettings')
+ return super(WarmStartSettings, cls).__new__(
+ cls,
+ ckpt_to_initialize_from,
+ vars_to_warm_start,
+ var_name_to_vocab_info or {},
+ var_name_to_prev_var_name or {},
+ )
+
+
+def _get_default_warm_start_settings(warm_start_from):
+ """Returns default WarmStartSettings.
+
+ Args:
+ warm_start_from: Either a string representing the filepath of a checkpoint
+ to initialize from, or an instance of WarmStartSettings.
+
+ Returns:
+ Either None or an instance of WarmStartSettings.
+
+ Raises:
+ ValueError: If warm_start_from is not None but is neither a string nor an
+ instance of WarmStartSettings.
+ """
+ if warm_start_from is None:
+ return None
+ if isinstance(warm_start_from, six.string_types):
+ return WarmStartSettings(ckpt_to_initialize_from=warm_start_from)
+ elif isinstance(warm_start_from, WarmStartSettings):
+ return warm_start_from
+ else:
+ raise ValueError('warm_start_from must be a string or a WarmStartSettings')
from tensorflow.python.estimator.canned.parsing_utils import classifier_parse_example_spec
from tensorflow.python.estimator.canned.parsing_utils import regressor_parse_example_spec
from tensorflow.python.estimator.estimator import Estimator
+from tensorflow.python.estimator.estimator import VocabInfo
+from tensorflow.python.estimator.estimator import WarmStartSettings
from tensorflow.python.estimator.export import export_lib as export
from tensorflow.python.estimator.exporter import Exporter
from tensorflow.python.estimator.exporter import FinalExporter
from tensorflow.python.estimator.training import EvalSpec
from tensorflow.python.estimator.training import train_and_evaluate
from tensorflow.python.estimator.training import TrainSpec
-from tensorflow.python.estimator.warm_starting_util import VocabInfo
-from tensorflow.python.estimator.warm_starting_util import WarmStartSettings
from tensorflow.python.util.all_util import remove_undocumented
@@load_variable
@@list_variables
@@init_from_checkpoint
+@@warm_start
+@@VocabInfo
"""
# Optimizers.
from tensorflow.python.training.training_util import assert_global_step
from tensorflow.python.training.training_util import create_global_step
from tensorflow.python.training.training_util import get_or_create_global_step
+from tensorflow.python.training.warm_starting_util import VocabInfo
+from tensorflow.python.training.warm_starting_util import warm_start
from tensorflow.python.pywrap_tensorflow import do_quantize_training_on_graphdef
from tensorflow.python.pywrap_tensorflow import NewCheckpointReader
from tensorflow.python.util.tf_export import tf_export
from tensorflow.python.util.tf_export import tf_export
-@tf_export("estimator.VocabInfo")
+@tf_export("train.VocabInfo", "estimator.VocabInfo")
class VocabInfo(
collections.namedtuple("VocabInfo", [
"new_vocab",
"old_vocab_size",
"backup_initializer",
])):
- """Vocabulary information for WarmStartSettings.
+ """Vocabulary information for warm-starting.
See @{tf.estimator.WarmStartSettings$WarmStartSettings} for examples of using
VocabInfo to warm-start.
)
-@tf_export("estimator.WarmStartSettings")
-class WarmStartSettings(
- collections.namedtuple("WarmStartSettings", [
- "ckpt_to_initialize_from",
- "vars_to_warm_start",
- "var_name_to_vocab_info",
- "var_name_to_prev_var_name",
- ])):
- """Settings for warm-starting in Estimators.
-
- Example Use with canned `DNNEstimator`:
-
- ```
- emb_vocab_file = tf.feature_column.embedding_column(
- tf.feature_column.categorical_column_with_vocabulary_file(
- "sc_vocab_file", "new_vocab.txt", vocab_size=100),
- dimension=8)
- emb_vocab_list = tf.feature_column.embedding_column(
- tf.feature_column.categorical_column_with_vocabulary_list(
- "sc_vocab_list", vocabulary_list=["a", "b"]),
- dimension=8)
- estimator = tf.estimator.DNNClassifier(
- hidden_units=[128, 64], feature_columns=[emb_vocab_file, emb_vocab_list],
- warm_start_from=ws)
- ```
-
- where `ws` could be defined as:
-
- Warm-start all weights in the model (input layer and hidden weights).
- Either the directory or a specific checkpoint can be provided (in the case
- of the former, the latest checkpoint will be used):
-
- ```
- ws = WarmStartSettings(ckpt_to_initialize_from="/tmp")
- ws = WarmStartSettings(ckpt_to_initialize_from="/tmp/model-1000")
- ```
-
- Warm-start only the embeddings (input layer):
-
- ```
- ws = WarmStartSettings(ckpt_to_initialize_from="/tmp",
- vars_to_warm_start=".*input_layer.*")
- ```
-
- Warm-start all weights but the embedding parameters corresponding to
- `sc_vocab_file` have a different vocab from the one used in the current
- model:
-
- ```
- vocab_info = ws_util.VocabInfo(
- new_vocab=sc_vocab_file.vocabulary_file,
- new_vocab_size=sc_vocab_file.vocabulary_size,
- num_oov_buckets=sc_vocab_file.num_oov_buckets,
- old_vocab="old_vocab.txt"
- )
- ws = WarmStartSettings(
- ckpt_to_initialize_from="/tmp",
- var_name_to_vocab_info={
- "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
- })
- ```
-
- Warm-start only `sc_vocab_file` embeddings (and no other variables), which
- have a different vocab from the one used in the current model:
-
- ```
- vocab_info = ws_util.VocabInfo(
- new_vocab=sc_vocab_file.vocabulary_file,
- new_vocab_size=sc_vocab_file.vocabulary_size,
- num_oov_buckets=sc_vocab_file.num_oov_buckets,
- old_vocab="old_vocab.txt"
- )
- ws = WarmStartSettings(
- ckpt_to_initialize_from="/tmp",
- vars_to_warm_start=None,
- var_name_to_vocab_info={
- "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
- })
- ```
-
- Warm-start all weights but the parameters corresponding to `sc_vocab_file`
- have a different vocab from the one used in current checkpoint, and only
- 100 of those entries were used:
-
- ```
- vocab_info = ws_util.VocabInfo(
- new_vocab=sc_vocab_file.vocabulary_file,
- new_vocab_size=sc_vocab_file.vocabulary_size,
- num_oov_buckets=sc_vocab_file.num_oov_buckets,
- old_vocab="old_vocab.txt",
- old_vocab_size=100
- )
- ws = WarmStartSettings(
- ckpt_to_initialize_from="/tmp",
- var_name_to_vocab_info={
- "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
- })
- ```
-
- Warm-start all weights but the parameters corresponding to `sc_vocab_file`
- have a different vocab from the one used in current checkpoint and the
- parameters corresponding to `sc_vocab_list` have a different name from the
- current checkpoint:
-
- ```
- vocab_info = ws_util.VocabInfo(
- new_vocab=sc_vocab_file.vocabulary_file,
- new_vocab_size=sc_vocab_file.vocabulary_size,
- num_oov_buckets=sc_vocab_file.num_oov_buckets,
- old_vocab="old_vocab.txt",
- old_vocab_size=100
- )
- ws = WarmStartSettings(
- ckpt_to_initialize_from="/tmp",
- var_name_to_vocab_info={
- "input_layer/sc_vocab_file_embedding/embedding_weights": vocab_info
- },
- var_name_to_prev_var_name={
- "input_layer/sc_vocab_list_embedding/embedding_weights":
- "old_tensor_name"
- })
- ```
-
- Attributes:
- ckpt_to_initialize_from: [Required] A string specifying the directory with
- checkpoint file(s) or path to checkpoint from which to warm-start the
- model parameters.
- vars_to_warm_start: [Optional] A regular expression that captures which
- variables to warm-start (see tf.get_collection). Defaults to `'.*'`,
- which warm-starts all variables. If `None` is explicitly given, only
- variables specified in `var_name_to_vocab_info` will be warm-started.
- var_name_to_vocab_info: [Optional] Dict of variable names (strings) to
- VocabInfo. The variable names should be "full" variables, not the names
- of the partitions. If not explicitly provided, the variable is assumed to
- have no vocabulary.
- var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to
- name of the previously-trained variable in `ckpt_to_initialize_from`. If
- not explicitly provided, the name of the variable is assumed to be same
- between previous checkpoint and current model.
- """
-
- def __new__(cls,
- ckpt_to_initialize_from,
- vars_to_warm_start=".*",
- var_name_to_vocab_info=None,
- var_name_to_prev_var_name=None):
- if not ckpt_to_initialize_from:
- raise ValueError(
- "`ckpt_to_initialize_from` MUST be set in WarmStartSettings")
- return super(WarmStartSettings, cls).__new__(
- cls,
- ckpt_to_initialize_from,
- vars_to_warm_start,
- var_name_to_vocab_info or {},
- var_name_to_prev_var_name or {},
- )
-
-
def _is_variable(x):
return (isinstance(x, variables_lib.Variable) or
isinstance(x, resource_variable_ops.ResourceVariable))
full_shape=slice_info.full_shape,
var_offset=slice_info.var_offset)
- # TODO(eddz): Support WarmStartSettings where class vocabularies need
- # remapping too.
+ # TODO(eddz): Support cases where class vocabularies need remapping too.
init = checkpoint_ops._load_and_remap_matrix_initializer(
ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt),
old_tensor_name=prev_tensor_name,
# pylint: enable=protected-access
-def _warm_start(warm_start_settings):
+@tf_export("train.warm_start")
+def warm_start(ckpt_to_initialize_from,
+ vars_to_warm_start=".*",
+ var_name_to_vocab_info=None,
+ var_name_to_prev_var_name=None):
"""Warm-starts a model using the given settings.
If you are using a tf.estimator.Estimator, this will automatically be called
during training.
Args:
- warm_start_settings: An object of `WarmStartSettings`.
+ ckpt_to_initialize_from: [Required] A string specifying the directory with
+ checkpoint file(s) or path to checkpoint from which to warm-start the
+ model parameters.
+ vars_to_warm_start: [Optional] A regular expression that captures which
+ variables to warm-start (see tf.get_collection). Defaults to `'.*'`,
+ which warm-starts all variables. If `None` is explicitly given, only
+ variables specified in `var_name_to_vocab_info` will be warm-started.
+ var_name_to_vocab_info: [Optional] Dict of variable names (strings) to
+ VocabInfo. The variable names should be "full" variables, not the names
+ of the partitions. If not explicitly provided, the variable is assumed to
+ have no vocabulary.
+ var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to
+ name of the previously-trained variable in `ckpt_to_initialize_from`. If
+ not explicitly provided, the name of the variable is assumed to be same
+ between previous checkpoint and current model.
Raises:
ValueError: If the WarmStartSettings contains prev_var_name or VocabInfo
configuration for variable names that are not used. This is to ensure
a stronger check for variable configuration than relying on users to
examine the logs.
"""
- logging.info("Warm-starting from: %s",
- (warm_start_settings.ckpt_to_initialize_from,))
+ if var_name_to_vocab_info is None:
+ var_name_to_vocab_info = {}
+ if var_name_to_prev_var_name is None:
+ var_name_to_prev_var_name = {}
+ logging.info("Warm-starting from: %s", (ckpt_to_initialize_from,))
# We have to deal with partitioned variables, since get_collection flattens
# out the list.
grouped_variables = {}
- # Both warm_start_settings.vars_to_warm_start = '.*' and
- # warm_start_settings.vars_to_warm_start = None will match everything here.
+ # Both vars_to_warm_start = '.*' and
+ # vars_to_warm_start = None will match everything here.
for v in ops.get_collection(
# TODO(eddz): Allow for different collections here (to support
# warm-starting accumulators).
ops.GraphKeys.TRAINABLE_VARIABLES,
- scope=warm_start_settings.vars_to_warm_start):
+ scope=vars_to_warm_start):
if not isinstance(v, list):
var_name = _infer_var_name([v])
else:
vocab_info_used = set()
for var_name, variable in six.iteritems(grouped_variables):
- prev_var_name = warm_start_settings.var_name_to_prev_var_name.get(var_name)
+ prev_var_name = var_name_to_prev_var_name.get(var_name)
if prev_var_name:
prev_var_name_used.add(var_name)
- vocab_info = warm_start_settings.var_name_to_vocab_info.get(var_name)
+ vocab_info = var_name_to_vocab_info.get(var_name)
if vocab_info:
vocab_info_used.add(var_name)
logging.info(
variable,
current_vocab_path=vocab_info.new_vocab,
current_vocab_size=vocab_info.new_vocab_size,
- prev_ckpt=warm_start_settings.ckpt_to_initialize_from,
+ prev_ckpt=ckpt_to_initialize_from,
prev_vocab_path=vocab_info.old_vocab,
previous_vocab_size=vocab_info.old_vocab_size,
current_oov_buckets=vocab_info.num_oov_buckets,
prev_tensor_name=prev_var_name,
initializer=vocab_info.backup_initializer)
else:
- # For the special value of warm_start_settings.vars_to_warm_start = None,
+ # For the special value of vars_to_warm_start = None,
# we only warm-start variables with explicitly specified vocabularies.
- if warm_start_settings.vars_to_warm_start:
+ if vars_to_warm_start:
logging.info("Warm-starting variable: {}; prev_var_name: {}".format(
var_name, prev_var_name or "Unchanged"))
# Because we use a default empty list in grouped_variables, single
# for init_from_checkpoint logic to work correctly.
if len(variable) == 1:
variable = variable[0]
- _warm_start_var(variable, warm_start_settings.ckpt_to_initialize_from,
- prev_var_name)
+ _warm_start_var(variable, ckpt_to_initialize_from, prev_var_name)
prev_var_name_not_used = set(
- warm_start_settings.var_name_to_prev_var_name.keys()) - prev_var_name_used
- vocab_info_not_used = set(
- warm_start_settings.var_name_to_vocab_info.keys()) - vocab_info_used
+ var_name_to_prev_var_name.keys()) - prev_var_name_used
+ vocab_info_not_used = set(var_name_to_vocab_info.keys()) - vocab_info_used
if prev_var_name_not_used:
raise ValueError(
"You provided the following variables in "
- "warm_start_settings.var_name_to_prev_var_name that were not used: "
+ "var_name_to_prev_var_name that were not used: "
"{0}. Perhaps you misspelled them? Here is the list of viable "
"variable names: {1}".format(prev_var_name_not_used,
grouped_variables.keys()))
if vocab_info_not_used:
raise ValueError(
"You provided the following variables in "
- "warm_start_settings.var_name_to_vocab_info that were not used: {0}. "
+ "var_name_to_vocab_info that were not used: {0}. "
" Perhaps you misspelled them? Here is the list of viable variable "
"names: {1}".format(vocab_info_not_used, grouped_variables.keys()))
-
-
-def _get_default_warm_start_settings(warm_start_from):
- """Returns default WarmStartSettings.
-
- Args:
- warm_start_from: Either a string representing the filepath of a checkpoint
- to initialize from, or an instance of WarmStartSettings.
-
- Returns:
- Either None or an instance of WarmStartSettings.
-
- Raises:
- ValueError: If warm_start_from is not None but is neither a string nor an
- instance of WarmStartSettings.
- """
- if warm_start_from is None:
- return None
- if isinstance(warm_start_from, six.string_types):
- return WarmStartSettings(ckpt_to_initialize_from=warm_start_from)
- elif isinstance(warm_start_from, WarmStartSettings):
- return warm_start_from
- else:
- raise ValueError("warm_start_from must be a string or a WarmStartSettings")
import numpy as np
import six
-from tensorflow.python.estimator import warm_starting_util as ws_util
from tensorflow.python.feature_column import feature_column as fc
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.training import warm_starting_util as ws_util
ones = init_ops.ones_initializer
norms = init_ops.truncated_normal_initializer
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
cols_to_vars = self._create_linear_model([sc_int], partitioner)
- ws_util._warm_start(
- ws_util.WarmStartSettings(
- self.get_temp_dir(), vars_to_warm_start=".*sc_int.*"))
+ ws_util.warm_start(self.get_temp_dir(), vars_to_warm_start=".*sc_int.*")
sess.run(variables.global_variables_initializer())
# Verify weights were correctly warm-started.
self._assert_cols_to_vars(cols_to_vars, {sc_int: [prev_int_val]}, sess)
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
cols_to_vars = self._create_linear_model([sc_hash], partitioner)
- ws_util._warm_start(
- ws_util.WarmStartSettings(
- self.get_temp_dir(), vars_to_warm_start=".*sc_hash.*"))
+ ws_util.warm_start(
+ self.get_temp_dir(), vars_to_warm_start=".*sc_hash.*")
sess.run(variables.global_variables_initializer())
# Verify weights were correctly warm-started.
self._assert_cols_to_vars(cols_to_vars, {sc_hash: [prev_hash_val]},
cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
# Since old vocab is not explicitly set in WarmStartSettings, the old
# vocab is assumed to be same as new vocab.
- ws_util._warm_start(
- ws_util.WarmStartSettings(
- self.get_temp_dir(), vars_to_warm_start=".*sc_vocab.*"))
+ ws_util.warm_start(
+ self.get_temp_dir(), vars_to_warm_start=".*sc_vocab.*")
sess.run(variables.global_variables_initializer())
# Verify weights were correctly warm-started.
self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]},
cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
# Since old vocab is not explicitly set in WarmStartSettings, the old
# vocab is assumed to be same as new vocab.
- ws_util._warm_start(
- ws_util.WarmStartSettings(
- # Explicitly provide the file prefix instead of just the dir.
- os.path.join(self.get_temp_dir(), "model-0"),
- vars_to_warm_start=".*sc_vocab.*"))
+ ws_util.warm_start(
+ # Explicitly provide the file prefix instead of just the dir.
+ os.path.join(self.get_temp_dir(), "model-0"),
+ vars_to_warm_start=".*sc_vocab.*")
sess.run(variables.global_variables_initializer())
# Verify weights were correctly warm-started.
self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]},
num_oov_buckets=sc_vocab.num_oov_buckets,
old_vocab=old_vocab_path,
old_vocab_size=old_vocab_size)
- warm_start_settings = ws_util.WarmStartSettings(
+ ws_util.warm_start(
ckpt_to_initialize_from=self.get_temp_dir(),
vars_to_warm_start=".*sc_vocab.*",
var_name_to_vocab_info={
"linear_model/sc_vocab/weights": vocab_info
})
- ws_util._warm_start(warm_start_settings)
sess.run(variables.global_variables_initializer())
# Verify weights were correctly warm-started. 'banana' isn't in the
# first two entries of the old vocabulary, so it's newly initialized.
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
cols_to_vars = self._create_linear_model([real_bucket], partitioner)
- ws_util._warm_start(
- ws_util.WarmStartSettings(
- self.get_temp_dir(), vars_to_warm_start=".*real_bucketized.*"))
+ ws_util.warm_start(
+ self.get_temp_dir(), vars_to_warm_start=".*real_bucketized.*")
sess.run(variables.global_variables_initializer())
# Verify weights were correctly warm-started.
self._assert_cols_to_vars(cols_to_vars,
new_vocab_size=sc_vocab.vocabulary_size,
num_oov_buckets=sc_vocab.num_oov_buckets,
old_vocab=vocab_path)
- ws_util._warm_start(
- ws_util.WarmStartSettings(
- self.get_temp_dir(),
- var_name_to_vocab_info={
- "linear_model/sc_vocab/weights": vocab_info
- }))
+ ws_util.warm_start(
+ self.get_temp_dir(),
+ var_name_to_vocab_info={
+ "linear_model/sc_vocab/weights": vocab_info
+ })
sess.run(variables.global_variables_initializer())
# Verify weights were correctly warm-started.
self._assert_cols_to_vars(cols_to_vars, {
new_vocab_size=sc_vocab.vocabulary_size,
num_oov_buckets=sc_vocab.num_oov_buckets,
old_vocab=prev_vocab_path)
- ws_settings = ws_util.WarmStartSettings(
+ ws_util.warm_start(
self.get_temp_dir(),
vars_to_warm_start=".*(sc_keys|sc_vocab).*",
var_name_to_vocab_info={
ws_util._infer_var_name(cols_to_vars[sc_keys]):
"some_other_name"
})
- ws_util._warm_start(ws_settings)
sess.run(variables.global_variables_initializer())
# Verify weights were correctly warm-started. Var corresponding to
# sc_hash should not be warm-started. Var corresponding to sc_vocab
new_vocab_size=sc_vocab.vocabulary_size,
num_oov_buckets=sc_vocab.num_oov_buckets,
old_vocab=prev_vocab_path)
- ws_settings = ws_util.WarmStartSettings(
+ ws_util.warm_start(
self.get_temp_dir(),
vars_to_warm_start=".*(sc_keys|sc_vocab).*",
var_name_to_vocab_info={
ws_util._infer_var_name(cols_to_vars[sc_keys]):
"some_other_name"
})
- ws_util._warm_start(ws_settings)
sess.run(variables.global_variables_initializer())
# Verify weights were correctly warm-started. Var corresponding to
# sc_hash should not be warm-started. Var corresponding to sc_vocab
new_vocab_size=sc_vocab.vocabulary_size,
num_oov_buckets=sc_vocab.num_oov_buckets,
old_vocab=prev_vocab_path)
- ws_settings = ws_util.WarmStartSettings(
+ ws_util.warm_start(
self.get_temp_dir(),
# The special value of None here will ensure that only the variable
# specified in var_name_to_vocab_info (sc_vocab embedding) is
ws_util._infer_var_name(cols_to_vars[sc_keys]):
"some_other_name"
})
- ws_util._warm_start(ws_settings)
sess.run(variables.global_variables_initializer())
# Verify weights were correctly warm-started. Var corresponding to
# sc_vocab should be correctly warm-started after vocab remapping,
# use a truncated normal initializer.
backup_initializer=init_ops.random_uniform_initializer(
minval=0.42, maxval=0.42))
- ws_settings = ws_util.WarmStartSettings(
+ ws_util.warm_start(
self.get_temp_dir(),
var_name_to_vocab_info={
ws_util._infer_var_name(cols_to_vars[emb_vocab_column]):
vocab_info
})
- ws_util._warm_start(ws_settings)
sess.run(variables.global_variables_initializer())
# Verify weights were correctly warm-started. Var corresponding to
# emb_vocab_column should be correctly warm-started after vocab
# use a truncated normal initializer.
backup_initializer=init_ops.random_uniform_initializer(
minval=0.42, maxval=0.42))
- ws_settings = ws_util.WarmStartSettings(
+ ws_util.warm_start(
self.get_temp_dir(),
vars_to_warm_start=".*sc_vocab.*",
var_name_to_vocab_info={
"linear_model/sc_vocab_embedding/embedding_weights": vocab_info
})
- ws_util._warm_start(ws_settings)
sess.run(variables.global_variables_initializer())
# Verify weights were correctly warm-started. Var corresponding to
# emb_vocab should be correctly warm-started after vocab remapping.
}, sess)
def testErrorConditions(self):
- self.assertRaises(ValueError, ws_util.WarmStartSettings, None)
x = variable_scope.get_variable(
"x",
shape=[4, 1],
# List of PartitionedVariable is invalid type when warm-starting with vocab.
self.assertRaises(TypeError, ws_util._warm_start_var_with_vocab, [x],
"/tmp", 5, "/tmp", "/tmp")
- # Keys of type other than FeatureColumn.
- self.assertRaises(TypeError, ws_util._warm_start, {"StringType": x},
- ws_util.WarmStartSettings("/tmp"))
# Unused variable names raises ValueError.
with ops.Graph().as_default():
partitioner=lambda shape, dtype: [2, 1])
self._write_checkpoint(sess)
- self.assertRaises(ValueError, ws_util._warm_start,
- ws_util.WarmStartSettings(
- self.get_temp_dir(),
- var_name_to_vocab_info={
- "y": ws_util.VocabInfo("", 1, 0, "")
- }))
- self.assertRaises(ValueError, ws_util._warm_start,
- ws_util.WarmStartSettings(
- self.get_temp_dir(),
- var_name_to_prev_var_name={
- "y": "y2"
- }))
+ self.assertRaises(
+ ValueError,
+ ws_util.warm_start,
+ self.get_temp_dir(),
+ var_name_to_vocab_info={"y": ws_util.VocabInfo("", 1, 0, "")})
+ self.assertRaises(
+ ValueError,
+ ws_util.warm_start,
+ self.get_temp_dir(),
+ var_name_to_prev_var_name={"y": "y2"})
if __name__ == "__main__":
path: "tensorflow.estimator.VocabInfo"
tf_class {
- is_instance: "<class \'tensorflow.python.estimator.warm_starting_util.VocabInfo\'>"
- is_instance: "<class \'tensorflow.python.estimator.warm_starting_util.VocabInfo\'>"
+ is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
+ is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
is_instance: "<type \'tuple\'>"
member {
name: "backup_initializer"
path: "tensorflow.estimator.WarmStartSettings"
tf_class {
- is_instance: "<class \'tensorflow.python.estimator.warm_starting_util.WarmStartSettings\'>"
- is_instance: "<class \'tensorflow.python.estimator.warm_starting_util.WarmStartSettings\'>"
+ is_instance: "<class \'tensorflow.python.estimator.estimator.WarmStartSettings\'>"
+ is_instance: "<class \'tensorflow.python.estimator.estimator.WarmStartSettings\'>"
is_instance: "<type \'tuple\'>"
member {
name: "ckpt_to_initialize_from"
--- /dev/null
+path: "tensorflow.train.VocabInfo"
+tf_class {
+ is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
+ is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
+ is_instance: "<type \'tuple\'>"
+ member {
+ name: "backup_initializer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "new_vocab"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "new_vocab_size"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "num_oov_buckets"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "old_vocab"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "old_vocab_size"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "count"
+ }
+ member_method {
+ name: "index"
+ }
+}
mtype: "<type \'type\'>"
}
member {
+ name: "VocabInfo"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "WorkerSessionCreator"
mtype: "<type \'type\'>"
}
argspec: "args=[\'save_dir\', \'model_checkpoint_path\', \'all_model_checkpoint_paths\', \'latest_filename\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
+ name: "warm_start"
+ argspec: "args=[\'ckpt_to_initialize_from\', \'vars_to_warm_start\', \'var_name_to_vocab_info\', \'var_name_to_prev_var_name\'], varargs=None, keywords=None, defaults=[\'.*\', \'None\', \'None\'], "
+ }
+ member_method {
name: "write_graph"
argspec: "args=[\'graph_or_graph_def\', \'logdir\', \'name\', \'as_text\'], varargs=None, keywords=None, defaults=[\'True\'], "
}