From: A. Unique TensorFlower Date: Wed, 23 May 2018 21:57:23 +0000 (-0700) Subject: Allow vars_to_warm_start to be a list of strings or Variables, which allows for non... X-Git-Tag: upstream/v1.9.0_rc1~38^2~4^2~159 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=aee7ebade2a975bdc3518bc47aef7d4f29614eb6;p=platform%2Fupstream%2Ftensorflow.git Allow vars_to_warm_start to be a list of strings or Variables, which allows for non-TRAINABLE_VARIABLES to be warm-started. PiperOrigin-RevId: 197794989 --- diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index a2e84c8..ecb5659 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -1746,10 +1746,19 @@ class WarmStartSettings( ckpt_to_initialize_from: [Required] A string specifying the directory with checkpoint file(s) or path to checkpoint from which to warm-start the model parameters. - vars_to_warm_start: [Optional] A regular expression that captures which - variables to warm-start (see tf.get_collection). Defaults to `'.*'`, - which warm-starts all variables. If `None` is explicitly given, only - variables specified in `var_name_to_vocab_info` will be warm-started. + vars_to_warm_start: [Optional] One of the following: + + - A regular expression (string) that captures which variables to + warm-start (see tf.get_collection). This expression will only consider + variables in the TRAINABLE_VARIABLES collection. + - A list of Variables to warm-start. + - A list of strings, each representing a full variable name to warm-start. + - `None`, in which case only variables specified in + `var_name_to_vocab_info` will be warm-started. + + Defaults to `'.*'`, which warm-starts all variables in the + TRAINABLE_VARIABLES collection. Note that this excludes variables such as + accumulators and moving statistics from batch norm. var_name_to_vocab_info: [Optional] Dict of variable names (strings) to VocabInfo. The variable names should be "full" variables, not the names of the partitions. If not explicitly provided, the variable is assumed to diff --git a/tensorflow/python/training/warm_starting_util.py b/tensorflow/python/training/warm_starting_util.py index b0f37f8..ec740ab 100644 --- a/tensorflow/python/training/warm_starting_util.py +++ b/tensorflow/python/training/warm_starting_util.py @@ -237,6 +237,62 @@ def _warm_start_var_with_vocab(var, # pylint: enable=protected-access +def _get_grouped_variables(vars_to_warm_start): + """Collects and groups (possibly partitioned) variables into a dictionary. + + The variables can be provided explicitly through vars_to_warm_start, or they + are retrieved from collections (see below). + + Args: + vars_to_warm_start: One of the following: + + - A regular expression (string) that captures which variables to + warm-start (see tf.get_collection). This expression will only consider + variables in the TRAINABLE_VARIABLES collection. + - A list of Variables to warm-start. + - A list of strings, each representing a full variable name to warm-start. + - `None`, in which case only variables specified in + `var_name_to_vocab_info` will be warm-started. + Returns: + A dictionary mapping variable names (strings) to lists of Variables. + Raises: + ValueError: If vars_to_warm_start is not a string, `None`, a list of + `Variables`, or a list of strings. + """ + if isinstance(vars_to_warm_start, str) or vars_to_warm_start is None: + # Both vars_to_warm_start = '.*' and vars_to_warm_start = None will match + # everything (in TRAINABLE_VARIABLES) here. + list_of_vars = ops.get_collection( + ops.GraphKeys.TRAINABLE_VARIABLES, + scope=vars_to_warm_start) + elif isinstance(vars_to_warm_start, list): + if all([isinstance(v, str) for v in vars_to_warm_start]): + list_of_vars = [] + for v in vars_to_warm_start: + list_of_vars += ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, + scope=v) + elif all([_is_variable(v) for v in vars_to_warm_start]): + list_of_vars = vars_to_warm_start + else: + raise ValueError("If `vars_to_warm_start` is a list, it must be all " + "`Variable` or all `str`. Given types are {}".format( + [type(v) for v in vars_to_warm_start])) + else: + raise ValueError("`vars_to_warm_start must be a `list` or `str`. Given " + "type is {}".format(type(vars_to_warm_start))) + # We have to deal with partitioned variables, since get_collection flattens + # out the list. + grouped_variables = {} + for v in list_of_vars: + if not isinstance(v, list): + var_name = _infer_var_name([v]) + else: + var_name = _infer_var_name(v) + grouped_variables.setdefault(var_name, []).append(v) + + return grouped_variables + + @tf_export("train.warm_start") def warm_start(ckpt_to_initialize_from, vars_to_warm_start=".*", @@ -251,10 +307,19 @@ def warm_start(ckpt_to_initialize_from, ckpt_to_initialize_from: [Required] A string specifying the directory with checkpoint file(s) or path to checkpoint from which to warm-start the model parameters. - vars_to_warm_start: [Optional] A regular expression that captures which - variables to warm-start (see tf.get_collection). Defaults to `'.*'`, - which warm-starts all variables. If `None` is explicitly given, only - variables specified in `var_name_to_vocab_info` will be warm-started. + vars_to_warm_start: [Optional] One of the following: + + - A regular expression (string) that captures which variables to + warm-start (see tf.get_collection). This expression will only consider + variables in the TRAINABLE_VARIABLES collection. + - A list of Variables to warm-start. + - A list of strings, each representing a full variable name to warm-start. + - `None`, in which case only variables specified in + `var_name_to_vocab_info` will be warm-started. + + Defaults to `'.*'`, which warm-starts all variables in the + TRAINABLE_VARIABLES collection. Note that this excludes variables such as + accumulators and moving statistics from batch norm. var_name_to_vocab_info: [Optional] Dict of variable names (strings) to VocabInfo. The variable names should be "full" variables, not the names of the partitions. If not explicitly provided, the variable is assumed to @@ -274,21 +339,7 @@ def warm_start(ckpt_to_initialize_from, if var_name_to_prev_var_name is None: var_name_to_prev_var_name = {} logging.info("Warm-starting from: %s", (ckpt_to_initialize_from,)) - # We have to deal with partitioned variables, since get_collection flattens - # out the list. - grouped_variables = {} - # Both vars_to_warm_start = '.*' and - # vars_to_warm_start = None will match everything here. - for v in ops.get_collection( - # TODO(eddz): Allow for different collections here (to support - # warm-starting accumulators). - ops.GraphKeys.TRAINABLE_VARIABLES, - scope=vars_to_warm_start): - if not isinstance(v, list): - var_name = _infer_var_name([v]) - else: - var_name = _infer_var_name(v) - grouped_variables.setdefault(var_name, []).append(v) + grouped_variables = _get_grouped_variables(vars_to_warm_start) # Keep track of which var_names in var_name_to_prev_var_name and # var_name_to_vocab_info have been used. Err on the safer side by throwing an diff --git a/tensorflow/python/training/warm_starting_util_test.py b/tensorflow/python/training/warm_starting_util_test.py index 7e8cbd6..6a4c207 100644 --- a/tensorflow/python/training/warm_starting_util_test.py +++ b/tensorflow/python/training/warm_starting_util_test.py @@ -36,6 +36,7 @@ from tensorflow.python.training import warm_starting_util as ws_util ones = init_ops.ones_initializer norms = init_ops.truncated_normal_initializer rand = init_ops.random_uniform_initializer +zeros = init_ops.zeros_initializer class WarmStartingUtilTest(test.TestCase): @@ -305,6 +306,46 @@ class WarmStartingUtilTest(test.TestCase): self.assertAllEqual([[0.5], [0.], [0.]], fruit_weights_vars[1].eval(sess)) + def testWarmStart_ListOfVariables(self): + # Save checkpoint from which to warm-start. + _, prev_int_val = self._create_prev_run_var("v1", shape=[10, 1], + initializer=ones()) + # Verify we initialized the values correctly. + self.assertAllEqual(np.ones([10, 1]), prev_int_val) + + # New graph, new session with warm-starting. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + # Initialize with zeros. + var = variable_scope.get_variable( + "v1", + shape=[10, 1], + initializer=zeros()) + ws_util.warm_start(self.get_temp_dir(), vars_to_warm_start=[var]) + sess.run(variables.global_variables_initializer()) + # Verify weights were correctly warm-started (init overridden to ones). + self.assertAllEqual(var.eval(), prev_int_val) + + def testWarmStart_ListOfStrings(self): + # Save checkpoint from which to warm-start. + _, prev_int_val = self._create_prev_run_var("v1", shape=[10, 1], + initializer=ones()) + # Verify we initialized the values correctly. + self.assertAllEqual(np.ones([10, 1]), prev_int_val) + + # New graph, new session with warm-starting. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + # Initialize with zeros. + var = variable_scope.get_variable( + "v1", + shape=[10, 1], + initializer=zeros()) + ws_util.warm_start(self.get_temp_dir(), vars_to_warm_start=["v1"]) + sess.run(variables.global_variables_initializer()) + # Verify weights were correctly warm-started (init overridden to ones). + self.assertAllEqual(var.eval(), prev_int_val) + def testWarmStart_SparseColumnIntegerized(self): # Create feature column. sc_int = fc.categorical_column_with_identity("sc_int", num_buckets=10)