Allow vars_to_warm_start to be a list of strings or Variables, which allows for non...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 23 May 2018 21:57:23 +0000 (14:57 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 23 May 2018 22:02:33 +0000 (15:02 -0700)
PiperOrigin-RevId: 197794989

tensorflow/python/estimator/estimator.py
tensorflow/python/training/warm_starting_util.py
tensorflow/python/training/warm_starting_util_test.py

index a2e84c8..ecb5659 100644 (file)
@@ -1746,10 +1746,19 @@ class WarmStartSettings(
     ckpt_to_initialize_from: [Required] A string specifying the directory with
       checkpoint file(s) or path to checkpoint from which to warm-start the
       model parameters.
-    vars_to_warm_start: [Optional] A regular expression that captures which
-      variables to warm-start (see tf.get_collection).  Defaults to `'.*'`,
-      which warm-starts all variables.  If `None` is explicitly given, only
-      variables specified in `var_name_to_vocab_info` will be warm-started.
+    vars_to_warm_start: [Optional] One of the following:
+
+      - A regular expression (string) that captures which variables to
+        warm-start (see tf.get_collection).  This expression will only consider
+        variables in the TRAINABLE_VARIABLES collection.
+      - A list of Variables to warm-start.
+      - A list of strings, each representing a full variable name to warm-start.
+      - `None`, in which case only variables specified in
+        `var_name_to_vocab_info` will be warm-started.
+
+      Defaults to `'.*'`, which warm-starts all variables in the
+      TRAINABLE_VARIABLES collection.  Note that this excludes variables such as
+      accumulators and moving statistics from batch norm.
     var_name_to_vocab_info: [Optional] Dict of variable names (strings) to
       VocabInfo. The variable names should be "full" variables, not the names
       of the partitions.  If not explicitly provided, the variable is assumed to
index b0f37f8..ec740ab 100644 (file)
@@ -237,6 +237,62 @@ def _warm_start_var_with_vocab(var,
 # pylint: enable=protected-access
 
 
+def _get_grouped_variables(vars_to_warm_start):
+  """Collects and groups (possibly partitioned) variables into a dictionary.
+
+  The variables can be provided explicitly through vars_to_warm_start, or they
+  are retrieved from collections (see below).
+
+  Args:
+    vars_to_warm_start: One of the following:
+
+      - A regular expression (string) that captures which variables to
+        warm-start (see tf.get_collection).  This expression will only consider
+        variables in the TRAINABLE_VARIABLES collection.
+      - A list of Variables to warm-start.
+      - A list of strings, each representing a full variable name to warm-start.
+      - `None`, in which case only variables specified in
+        `var_name_to_vocab_info` will be warm-started.
+  Returns:
+    A dictionary mapping variable names (strings) to lists of Variables.
+  Raises:
+    ValueError: If vars_to_warm_start is not a string, `None`, a list of
+      `Variables`, or a list of strings.
+  """
+  if isinstance(vars_to_warm_start, str) or vars_to_warm_start is None:
+    # Both vars_to_warm_start = '.*' and vars_to_warm_start = None will match
+    # everything (in TRAINABLE_VARIABLES) here.
+    list_of_vars = ops.get_collection(
+        ops.GraphKeys.TRAINABLE_VARIABLES,
+        scope=vars_to_warm_start)
+  elif isinstance(vars_to_warm_start, list):
+    if all([isinstance(v, str) for v in vars_to_warm_start]):
+      list_of_vars = []
+      for v in vars_to_warm_start:
+        list_of_vars += ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
+                                           scope=v)
+    elif all([_is_variable(v) for v in vars_to_warm_start]):
+      list_of_vars = vars_to_warm_start
+    else:
+      raise ValueError("If `vars_to_warm_start` is a list, it must be all "
+                       "`Variable` or all `str`.  Given types are {}".format(
+                           [type(v) for v in vars_to_warm_start]))
+  else:
+    raise ValueError("`vars_to_warm_start must be a `list` or `str`.  Given "
+                     "type is {}".format(type(vars_to_warm_start)))
+  # We have to deal with partitioned variables, since get_collection flattens
+  # out the list.
+  grouped_variables = {}
+  for v in list_of_vars:
+    if not isinstance(v, list):
+      var_name = _infer_var_name([v])
+    else:
+      var_name = _infer_var_name(v)
+    grouped_variables.setdefault(var_name, []).append(v)
+
+  return grouped_variables
+
+
 @tf_export("train.warm_start")
 def warm_start(ckpt_to_initialize_from,
                vars_to_warm_start=".*",
@@ -251,10 +307,19 @@ def warm_start(ckpt_to_initialize_from,
     ckpt_to_initialize_from: [Required] A string specifying the directory with
       checkpoint file(s) or path to checkpoint from which to warm-start the
       model parameters.
-    vars_to_warm_start: [Optional] A regular expression that captures which
-      variables to warm-start (see tf.get_collection).  Defaults to `'.*'`,
-      which warm-starts all variables.  If `None` is explicitly given, only
-      variables specified in `var_name_to_vocab_info` will be warm-started.
+    vars_to_warm_start: [Optional] One of the following:
+
+      - A regular expression (string) that captures which variables to
+        warm-start (see tf.get_collection).  This expression will only consider
+        variables in the TRAINABLE_VARIABLES collection.
+      - A list of Variables to warm-start.
+      - A list of strings, each representing a full variable name to warm-start.
+      - `None`, in which case only variables specified in
+        `var_name_to_vocab_info` will be warm-started.
+
+      Defaults to `'.*'`, which warm-starts all variables in the
+      TRAINABLE_VARIABLES collection.  Note that this excludes variables such as
+      accumulators and moving statistics from batch norm.
     var_name_to_vocab_info: [Optional] Dict of variable names (strings) to
       VocabInfo. The variable names should be "full" variables, not the names
       of the partitions.  If not explicitly provided, the variable is assumed to
@@ -274,21 +339,7 @@ def warm_start(ckpt_to_initialize_from,
   if var_name_to_prev_var_name is None:
     var_name_to_prev_var_name = {}
   logging.info("Warm-starting from: %s", (ckpt_to_initialize_from,))
-  # We have to deal with partitioned variables, since get_collection flattens
-  # out the list.
-  grouped_variables = {}
-  # Both vars_to_warm_start = '.*' and
-  # vars_to_warm_start = None will match everything here.
-  for v in ops.get_collection(
-      # TODO(eddz): Allow for different collections here (to support
-      # warm-starting accumulators).
-      ops.GraphKeys.TRAINABLE_VARIABLES,
-      scope=vars_to_warm_start):
-    if not isinstance(v, list):
-      var_name = _infer_var_name([v])
-    else:
-      var_name = _infer_var_name(v)
-    grouped_variables.setdefault(var_name, []).append(v)
+  grouped_variables = _get_grouped_variables(vars_to_warm_start)
 
   # Keep track of which var_names in var_name_to_prev_var_name and
   # var_name_to_vocab_info have been used.  Err on the safer side by throwing an
index 7e8cbd6..6a4c207 100644 (file)
@@ -36,6 +36,7 @@ from tensorflow.python.training import warm_starting_util as ws_util
 ones = init_ops.ones_initializer
 norms = init_ops.truncated_normal_initializer
 rand = init_ops.random_uniform_initializer
+zeros = init_ops.zeros_initializer
 
 
 class WarmStartingUtilTest(test.TestCase):
@@ -305,6 +306,46 @@ class WarmStartingUtilTest(test.TestCase):
         self.assertAllEqual([[0.5], [0.], [0.]],
                             fruit_weights_vars[1].eval(sess))
 
+  def testWarmStart_ListOfVariables(self):
+    # Save checkpoint from which to warm-start.
+    _, prev_int_val = self._create_prev_run_var("v1", shape=[10, 1],
+                                                initializer=ones())
+    # Verify we initialized the values correctly.
+    self.assertAllEqual(np.ones([10, 1]), prev_int_val)
+
+    # New graph, new session with warm-starting.
+    with ops.Graph().as_default() as g:
+      with self.test_session(graph=g) as sess:
+        # Initialize with zeros.
+        var = variable_scope.get_variable(
+            "v1",
+            shape=[10, 1],
+            initializer=zeros())
+        ws_util.warm_start(self.get_temp_dir(), vars_to_warm_start=[var])
+        sess.run(variables.global_variables_initializer())
+        # Verify weights were correctly warm-started (init overridden to ones).
+        self.assertAllEqual(var.eval(), prev_int_val)
+
+  def testWarmStart_ListOfStrings(self):
+    # Save checkpoint from which to warm-start.
+    _, prev_int_val = self._create_prev_run_var("v1", shape=[10, 1],
+                                                initializer=ones())
+    # Verify we initialized the values correctly.
+    self.assertAllEqual(np.ones([10, 1]), prev_int_val)
+
+    # New graph, new session with warm-starting.
+    with ops.Graph().as_default() as g:
+      with self.test_session(graph=g) as sess:
+        # Initialize with zeros.
+        var = variable_scope.get_variable(
+            "v1",
+            shape=[10, 1],
+            initializer=zeros())
+        ws_util.warm_start(self.get_temp_dir(), vars_to_warm_start=["v1"])
+        sess.run(variables.global_variables_initializer())
+        # Verify weights were correctly warm-started (init overridden to ones).
+        self.assertAllEqual(var.eval(), prev_int_val)
+
   def testWarmStart_SparseColumnIntegerized(self):
     # Create feature column.
     sc_int = fc.categorical_column_with_identity("sc_int", num_buckets=10)