From 804f98e5bc0a75284f5f92569e5c82fe88b455ad Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 22 Mar 2018 14:01:10 -0700 Subject: [PATCH] Supports optimizer arg in head.create_estimator_spec. PiperOrigin-RevId: 190121386 --- tensorflow/contrib/estimator/BUILD | 2 + .../contrib/estimator/python/estimator/head.py | 28 +++-- .../estimator/python/estimator/head_test.py | 36 ++++++ .../estimator/python/estimator/multi_head.py | 34 ++++-- .../estimator/python/estimator/multi_head_test.py | 38 +++++++ tensorflow/python/estimator/BUILD | 2 +- tensorflow/python/estimator/canned/dnn.py | 9 +- .../python/estimator/canned/dnn_testing_utils.py | 15 ++- tensorflow/python/estimator/canned/head.py | 126 ++++++++++++++------- tensorflow/python/estimator/canned/head_test.py | 102 +++++++++++++++++ tensorflow/python/estimator/canned/linear.py | 9 +- 11 files changed, 322 insertions(+), 79 deletions(-) diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 676d602..2437426 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -175,6 +175,7 @@ py_library( "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:summary", + "//tensorflow/python:training", "//tensorflow/python/estimator:export_output", "//tensorflow/python/estimator:head", "//tensorflow/python/estimator:metric_keys", @@ -292,6 +293,7 @@ py_library( "//tensorflow/python:math_ops", "//tensorflow/python:metrics", "//tensorflow/python:summary", + "//tensorflow/python:training", "//tensorflow/python/estimator:export_output", "//tensorflow/python/estimator:head", "//tensorflow/python/estimator:metric_keys", diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index f95fcc8..42e1b7b 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -36,10 +36,12 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.ops.losses import losses from tensorflow.python.saved_model import signature_constants from tensorflow.python.summary import summary +from tensorflow.python.training import training_util _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY +# TODO(b/65403806): Switch loss_reduction default to SUM_OVER_BATCH_SIZE. def multi_class_head(n_classes, weight_column=None, label_vocabulary=None, @@ -489,8 +491,8 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access processed_labels=processed_labels) def create_estimator_spec( - self, features, mode, logits, labels=None, train_op_fn=None, - regularization_losses=None): + self, features, mode, logits, labels=None, optimizer=None, + train_op_fn=None, regularization_losses=None): """Returns an `EstimatorSpec`. Args: @@ -502,8 +504,11 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access with shape `[D0, D1, ... DN, n_classes]` or `SparseTensor` with `dense_shape` `[D0, D1, ... DN, ?]`. `labels` is required argument when `mode` equals `TRAIN` or `EVAL`. + optimizer: `Optimizer` instance to optimize the loss in TRAIN mode. + Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which + updates variables and increments `global_step`. train_op_fn: Function that takes a scalar loss `Tensor` and returns - `train_op`. Required in TRAIN mode. + `train_op`. Used if `optimizer` is `None`. regularization_losses: A list of additional scalar losses to be added to the training loss, such as regularization losses. These losses are usually expressed as a batch average, so for best results users need to @@ -513,7 +518,8 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access Returns: `EstimatorSpec`. Raises: - ValueError: If `train_op_fn` is `None` in TRAIN mode. + ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN + mode, or if both are set. """ with ops.name_scope(self._name, 'head'): logits = head_lib._check_logits_final_dim(logits, self.logits_dimension) # pylint:disable=protected-access @@ -565,8 +571,16 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access regularization_loss=regularization_loss)) # Train. - if train_op_fn is None: - raise ValueError('train_op_fn can not be None.') + if optimizer is not None: + if train_op_fn is not None: + raise ValueError('train_op_fn and optimizer cannot both be set.') + train_op = optimizer.minimize( + regularized_training_loss, + global_step=training_util.get_global_step()) + elif train_op_fn is not None: + train_op = train_op_fn(regularized_training_loss) + else: + raise ValueError('train_op_fn and optimizer cannot both be None.') # Only summarize mean_loss for SUM reduction to preserve backwards # compatibility. Otherwise skip it to avoid unnecessary computation. if self._loss_reduction == losses.Reduction.SUM: @@ -592,7 +606,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access mode=model_fn.ModeKeys.TRAIN, predictions=predictions, loss=regularized_training_loss, - train_op=train_op_fn(regularized_training_loss)) + train_op=train_op) def _eval_metric_ops( self, labels, probabilities, weights, unreduced_loss, diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index dc30dde..776f0ee 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -863,6 +863,42 @@ class MultiLabelHead(test.TestCase): self._test_train( head=head, logits=logits, labels=labels, expected_loss=expected_loss) + def test_train_with_optimizer(self): + head = head_lib.multi_label_head(n_classes=2) + logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) + labels = np.array([[1, 0], [1, 1]], dtype=np.int64) + # For large logits, sigmoid cross entropy loss is approximated as: + # loss = labels * (logits < 0) * (-logits) + + # (1 - labels) * (logits > 0) * logits => + # expected_unweighted_loss = [[10., 10.], [15., 0.]] + # Average over classes, sum over weights. + expected_loss = 17.5 + expected_train_result = 'my_train_op' + + class _Optimizer(object): + + def minimize(self, loss, global_step): + del global_step + return string_ops.string_join( + [constant_op.constant(expected_train_result), + string_ops.as_string(loss, precision=3)]) + + spec = head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + optimizer=_Optimizer()) + + tol = 1e-3 + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + loss, train_result = sess.run((spec.loss, spec.train_op)) + self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol) + self.assertEqual( + six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), + train_result) + def test_train_with_regularization_losses(self): head = head_lib.multi_label_head( n_classes=2, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head.py b/tensorflow/contrib/estimator/python/estimator/multi_head.py index 23d3714..bbbc19c 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.saved_model import signature_constants from tensorflow.python.summary import summary +from tensorflow.python.training import training_util _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY @@ -227,8 +228,10 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access weights=example_weights_by_head, processed_labels=labels_by_head) + # TODO(b/65403806): Support regularization_losses arg. def create_estimator_spec( - self, features, mode, logits, labels=None, train_op_fn=None): + self, features, mode, logits, labels=None, optimizer=None, + train_op_fn=None): """See `_Head`.""" if isinstance(logits, dict): logits_dict = logits @@ -249,9 +252,10 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access train_op_fn=_no_op_train_fn)) if mode == model_fn.ModeKeys.TRAIN: - if train_op_fn is None: - raise ValueError('train_op_fn can not be None in TRAIN mode.') - spec = self._merge_train(all_estimator_spec, train_op_fn) + spec = self._merge_train( + all_estimator_spec=all_estimator_spec, + optimizer=optimizer, + train_op_fn=train_op_fn) with ops.name_scope(''): summary.scalar(metric_keys.MetricKeys.LOSS, spec.loss) return spec @@ -280,16 +284,21 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access begin_idx += head.logits_dimension return logits_dict - def _merge_train(self, all_estimator_spec, train_op_fn): + def _merge_train(self, all_estimator_spec, optimizer, train_op_fn): """Merges list of `EstimatorSpec` for training. Args: all_estimator_spec: list of `EstimatorSpec` for the individual heads. - train_op_fn: Function to create train op. See `create_estimator_spec` - documentation for more details. + optimizer: `Optimizer` instance to create train op. See + `create_estimator_spec` documentation for more details. + train_op_fn: Function to create train op. Used if `optimizer` is `None`. Returns: `EstimatorSpec` that merges all heads for TRAIN. + + Raises: + ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN + mode. """ losses = [] metrics = {} @@ -298,11 +307,20 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access # Metric keys already contain head.name. metrics.update(spec.eval_metric_ops or {}) loss = _merge_losses(losses, self._head_weights) + if optimizer is not None: + if train_op_fn is not None: + raise ValueError('train_op_fn and optimizer cannot both be set.') + train_op = optimizer.minimize( + loss, global_step=training_util.get_global_step()) + elif train_op_fn is not None: + train_op = train_op_fn(loss) + else: + raise ValueError('train_op_fn and optimizer cannot both be None.') return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.TRAIN, loss=loss, - train_op=train_op_fn(loss), + train_op=train_op, eval_metric_ops=metrics) def _merge_predict(self, all_estimator_spec): diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py index 8e788a9..43cc157 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py @@ -550,6 +550,44 @@ class MultiHeadTest(test.TestCase): metric_keys.MetricKeys.LOSS_MEAN + '/head1': expected_loss / 2, }, summary_str, tol) + def test_train_one_head_with_optimizer(self): + head1 = head_lib.multi_label_head(n_classes=2, name='head1') + multi_head = multi_head_lib.multi_head([head1]) + + logits = {'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)} + labels = {'head1': np.array([[1, 0], [1, 1]], dtype=np.int64)} + # For large logits, sigmoid cross entropy loss is approximated as: + # loss = labels * (logits < 0) * (-logits) + + # (1 - labels) * (logits > 0) * logits => + # expected_unweighted_loss = [[10., 10.], [15., 0.]] + # Average over classes, sum over weights. + expected_loss = 17.5 + expected_train_result = 'my_train_op' + + class _Optimizer(object): + + def minimize(self, loss, global_step): + del global_step + return string_ops.string_join( + [constant_op.constant(expected_train_result), + string_ops.as_string(loss, precision=3)]) + + spec = multi_head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + optimizer=_Optimizer()) + + tol = 1e-3 + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + loss, train_result = sess.run((spec.loss, spec.train_op)) + self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol) + self.assertEqual( + six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), + train_result) + def test_train_two_heads_with_weights(self): head1 = head_lib.multi_label_head(n_classes=2, name='head1') head2 = head_lib.multi_label_head(n_classes=3, name='head2') diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index b25f9d2..5afb5a7 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -265,7 +265,6 @@ py_library( "//tensorflow/python:nn", "//tensorflow/python:partitioned_variables", "//tensorflow/python:summary", - "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python/feature_column", "//tensorflow/python/ops/losses", @@ -617,6 +616,7 @@ py_library( "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:summary", + "//tensorflow/python:training", "//tensorflow/python:weights_broadcast_ops", "//tensorflow/python/feature_column", "//tensorflow/python/ops/losses", diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py index 7043da8..6382622 100644 --- a/tensorflow/python/estimator/canned/dnn.py +++ b/tensorflow/python/estimator/canned/dnn.py @@ -32,7 +32,6 @@ from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import variable_scope from tensorflow.python.ops.losses import losses from tensorflow.python.summary import summary -from tensorflow.python.training import training_util from tensorflow.python.util.tf_export import tf_export # The default learning rate of 0.05 is a historical artifact of the initial @@ -183,17 +182,11 @@ def _dnn_model_fn(features, input_layer_partitioner=input_layer_partitioner) logits = logit_fn(features=features, mode=mode) - def _train_op_fn(loss): - """Returns the op to optimize the loss.""" - return optimizer.minimize( - loss, - global_step=training_util.get_global_step()) - return head.create_estimator_spec( features=features, mode=mode, labels=labels, - train_op_fn=_train_op_fn, + optimizer=optimizer, logits=logits) diff --git a/tensorflow/python/estimator/canned/dnn_testing_utils.py b/tensorflow/python/estimator/canned/dnn_testing_utils.py index 85b058c..44545c0 100644 --- a/tensorflow/python/estimator/canned/dnn_testing_utils.py +++ b/tensorflow/python/estimator/canned/dnn_testing_utils.py @@ -53,7 +53,7 @@ from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import checkpoint_utils from tensorflow.python.training import gradient_descent from tensorflow.python.training import monitored_session -from tensorflow.python.training import optimizer +from tensorflow.python.training import optimizer as optimizer_lib from tensorflow.python.training import saver from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util @@ -134,7 +134,8 @@ def mock_head(testcase, hidden_units, logits_dimension, expected_logits): hidden_weights_names + hidden_biases_names + [LOGITS_WEIGHTS_NAME + '/part_0:0', LOGITS_BIASES_NAME + '/part_0:0']) - def _create_estimator_spec(features, mode, logits, labels, train_op_fn): + def _create_estimator_spec( + features, mode, logits, labels, train_op_fn=None, optimizer=None): del features, labels # Not used. trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) testcase.assertItemsEqual(expected_var_names, @@ -144,8 +145,12 @@ def mock_head(testcase, hidden_units, logits_dimension, expected_logits): expected_logits, logits, message='Failed for mode={}. '.format(mode)) with ops.control_dependencies([assert_logits]): if mode == model_fn.ModeKeys.TRAIN: + if train_op_fn is not None: + train_op = train_op_fn(loss) + elif optimizer is not None: + train_op = optimizer.minimize(loss, global_step=None) return model_fn.EstimatorSpec( - mode=mode, loss=loss, train_op=train_op_fn(loss)) + mode=mode, loss=loss, train_op=train_op) elif mode == model_fn.ModeKeys.EVAL: return model_fn.EstimatorSpec(mode=mode, loss=array_ops.identity(loss)) elif mode == model_fn.ModeKeys.PREDICT: @@ -203,8 +208,8 @@ def mock_optimizer(testcase, hidden_units, expected_loss=None): return control_flow_ops.no_op() optimizer_mock = test.mock.NonCallableMagicMock( - spec=optimizer.Optimizer, - wraps=optimizer.Optimizer(use_locking=False, name='my_optimizer')) + spec=optimizer_lib.Optimizer, + wraps=optimizer_lib.Optimizer(use_locking=False, name='my_optimizer')) optimizer_mock.minimize = test.mock.MagicMock(wraps=_minimize) return optimizer_mock diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py index f68204a..c9635a9 100644 --- a/tensorflow/python/estimator/canned/head.py +++ b/tensorflow/python/estimator/canned/head.py @@ -44,6 +44,7 @@ from tensorflow.python.ops import weights_broadcast_ops from tensorflow.python.ops.losses import losses from tensorflow.python.saved_model import signature_constants from tensorflow.python.summary import summary +from tensorflow.python.training import training_util _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY @@ -85,40 +86,39 @@ class _Head(object): ```python def _my_dnn_model_fn(features, labels, mode, params, config=None): # Optionally your callers can pass head to model_fn as a param. - head = tf.contrib.learn.regression_head(...) - input = tf.contrib.layers.input_from_feature_columns(features, ...) - last_hidden_layer_out = tf.contrib.layers.stack( - input, tf.contrib.layers.fully_connected, [1000, 500]) - logits = tf.contrib.layers.fully_connected( - last_hidden_layer_out, head.logits_dimension, activation_fn=None) - - def _train_op_fn(loss): - return optimizer.minimize(loss) + head = tf.contrib.estimator.regression_head(...) + inputs = tf.feature_column.input_layer(features, ...) + hidden_layer0 = tf.layers.dense( + inputs, units=1000, activation=tf.nn.relu) + hidden_layer1 = tf.layers.dense( + hidden_layer0, units=500, activation=tf.nn.relu) + logits = tf.layers.dense( + hidden_layer1, units=head.logits_dimension, activation=None) return head.create_estimator_spec( features=features, labels=labels, mode=mode, logits=logits, - train_op_fn=_train_op_fn) + optimizer=optimizer) ``` There are cases where computing and applying gradients can not be meaningfully - captured with train_op_fn we support (for example, with sync optimizer). In - such case, you can take the responsibility on your own. Here is a common - use case, + captured with optimizer or train_op_fn we support (for example, with sync + optimizer). In such case, you can take the responsibility on your own. Here is + a common use case, ```python estimator_spec = head.create_estimator_spec( features=features, labels=labels, mode=mode, logits=logits, - train_op_fn=tf.contrib.learn.no_op_train_fn) + train_op_fn=lambda _: tf.no_op()) if mode == model_fn.ModeKeys.TRAIN: optimizer = ... sync = tf.train.SyncReplicasOptimizer(opt=optimizer, ...) - update_op = tf.contrib.layers.optimize_loss(optimizer=sync, - loss=estimator_spec.loss, ...) + update_op = sync.minimize( + estimator_spec.loss, global_step=tf.get_global_step()) hooks = [sync.make_session_run_hook(is_chief)] ... update train_op and hooks in EstimatorSpec and return ``` @@ -172,10 +172,12 @@ class _Head(object): """ raise NotImplementedError('Calling an abstract method.') + # TODO(b/65403806): By default, collect regularization_losses from + # GraphKeys.REGULARIZATION_LOSSES collection. @abc.abstractmethod def create_estimator_spec( - self, features, mode, logits, labels=None, train_op_fn=None, - regularization_losses=None): + self, features, mode, logits, labels=None, optimizer=None, + train_op_fn=None, regularization_losses=None): """Returns `EstimatorSpec` that a model_fn can return. Please note that, @@ -186,10 +188,14 @@ class _Head(object): mode: Estimator's `ModeKeys`. logits: logits `Tensor` to be used by the head. labels: Labels `Tensor`, or `dict` of same. + optimizer: `Optimizer` instance to optimize the loss in TRAIN mode. + Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which + updates variables and increments `global_step`. train_op_fn: Function that takes a scalar loss `Tensor` and returns an op - to optimize the model with the loss. This is used in TRAIN mode and - must not be None. None is allowed in other modes. If you want to - optimize loss yourself you can pass `no_op_train_fn` and then use + to optimize the model with the loss in TRAIN mode. Used if `optimizer` + is `None`. Exactly one of `train_op_fn` and `optimizer` must be set in + TRAIN mode. None is allowed in other modes. If you want to optimize loss + yourself you can pass `lambda _: tf.no_op()` and then use EstimatorSpec.loss to compute and apply gradients. regularization_losses: A list of additional scalar losses to be added to the training loss, such as regularization losses. @@ -694,8 +700,8 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): processed_labels=label_ids) def create_estimator_spec( - self, features, mode, logits, labels=None, train_op_fn=None, - regularization_losses=None): + self, features, mode, logits, labels=None, optimizer=None, + train_op_fn=None, regularization_losses=None): """Returns an `EstimatorSpec`. Args: @@ -706,8 +712,11 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): labels: Labels integer or string `Tensor` with shape matching `logits`, namely `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN]`. `labels` is required argument when `mode` equals `TRAIN` or `EVAL`. + optimizer: `Optimizer` instance to optimize the loss in TRAIN mode. + Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which + updates variables and increments `global_step`. train_op_fn: Function that takes a scalar loss `Tensor` and returns - `train_op`. Required in TRAIN mode. + `train_op`. Used if `optimizer` is `None`. regularization_losses: A list of additional scalar losses to be added to the training loss, such as regularization losses. These losses are usually expressed as a batch average, so for best results users need to @@ -717,7 +726,8 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): Returns: `EstimatorSpec`. Raises: - ValueError: If `train_op_fn` is `None` in TRAIN mode. + ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN + mode, or if both are set. """ with ops.name_scope(self._name, 'head'): logits = _check_logits_final_dim(logits, self.logits_dimension) @@ -780,8 +790,16 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): regularization_loss=regularization_loss)) # Train. - if train_op_fn is None: - raise ValueError('train_op_fn cannot be None.') + if optimizer is not None: + if train_op_fn is not None: + raise ValueError('train_op_fn and optimizer cannot both be set.') + train_op = optimizer.minimize( + regularized_training_loss, + global_step=training_util.get_global_step()) + elif train_op_fn is not None: + train_op = train_op_fn(regularized_training_loss) + else: + raise ValueError('train_op_fn and optimizer cannot both be None.') # Only summarize mean_loss for SUM reduction to preserve backwards # compatibility. Otherwise skip it to avoid unnecessary computation. if self._loss_reduction == losses.Reduction.SUM: @@ -807,7 +825,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): mode=model_fn.ModeKeys.TRAIN, predictions=predictions, loss=regularized_training_loss, - train_op=train_op_fn(regularized_training_loss)) + train_op=train_op) def _binary_logistic_head_with_sigmoid_cross_entropy_loss( @@ -1039,8 +1057,8 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): processed_labels=labels) def create_estimator_spec( - self, features, mode, logits, labels=None, train_op_fn=None, - regularization_losses=None): + self, features, mode, logits, labels=None, optimizer=None, + train_op_fn=None, regularization_losses=None): """Returns an `EstimatorSpec`. Args: @@ -1051,8 +1069,11 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): labels: Labels integer or string `Tensor` with shape matching `logits`, namely `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN]`. `labels` is required argument when `mode` equals `TRAIN` or `EVAL`. + optimizer: `Optimizer` instance to optimize the loss in TRAIN mode. + Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which + updates variables and increments `global_step`. train_op_fn: Function that takes a scalar loss `Tensor` and returns - `train_op`. Required in TRAIN mode. + `train_op`. Used if `optimizer` is `None`. regularization_losses: A list of additional scalar losses to be added to the training loss, such as regularization losses. These losses are usually expressed as a batch average, so for best results users need to @@ -1062,7 +1083,8 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): Returns: `EstimatorSpec`. Raises: - ValueError: If `train_op_fn` is `None` in TRAIN mode. + ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN + mode, or if both are set. """ # Predict. with ops.name_scope(self._name, 'head'): @@ -1134,8 +1156,16 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): regularization_loss=regularization_loss)) # Train. - if train_op_fn is None: - raise ValueError('train_op_fn can not be None.') + if optimizer is not None: + if train_op_fn is not None: + raise ValueError('train_op_fn and optimizer cannot both be set.') + train_op = optimizer.minimize( + regularized_training_loss, + global_step=training_util.get_global_step()) + elif train_op_fn is not None: + train_op = train_op_fn(regularized_training_loss) + else: + raise ValueError('train_op_fn and optimizer cannot both be None.') # Only summarize mean_loss for SUM reduction to preserve backwards # compatibility. Otherwise skip it to avoid unnecessary computation. if self._loss_reduction == losses.Reduction.SUM: @@ -1160,7 +1190,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): mode=model_fn.ModeKeys.TRAIN, predictions=predictions, loss=regularized_training_loss, - train_op=train_op_fn(regularized_training_loss)) + train_op=train_op) def _regression_head_with_mean_squared_error_loss( @@ -1289,8 +1319,8 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): processed_labels=labels) def create_estimator_spec( - self, features, mode, logits, labels=None, train_op_fn=None, - regularization_losses=None): + self, features, mode, logits, labels=None, optimizer=None, + train_op_fn=None, regularization_losses=None): """Returns an `EstimatorSpec`. Args: @@ -1302,8 +1332,11 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): `[D0, D1, ... DN, logits_dimension]`. When `logits_dimension=1`, shape `[D0, D1, ... DN]` is also supported. `labels` is required argument when `mode` equals `TRAIN` or `EVAL`. + optimizer: `Optimizer` instance to optimize the loss in TRAIN mode. + Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which + updates variables and increments `global_step`. train_op_fn: Function that takes a scalar loss `Tensor` and returns - `train_op`. Required in TRAIN mode. + `train_op`. Used if `optimizer` is `None`. regularization_losses: A list of additional scalar losses to be added to the training loss, such as regularization losses. These losses are usually expressed as a batch average, so for best results users need to @@ -1313,7 +1346,8 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): Returns: `EstimatorSpec`. Raises: - ValueError: If `train_op_fn` is `None` in TRAIN mode. + ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN + mode, or if both are set. """ # Predict. with ops.name_scope(self._name, 'head'): @@ -1373,8 +1407,16 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): eval_metric_ops=eval_metric_ops) # Train. - if train_op_fn is None: - raise ValueError('train_op_fn can not be None.') + if optimizer is not None: + if train_op_fn is not None: + raise ValueError('train_op_fn and optimizer cannot both be set.') + train_op = optimizer.minimize( + regularized_training_loss, + global_step=training_util.get_global_step()) + elif train_op_fn is not None: + train_op = train_op_fn(regularized_training_loss) + else: + raise ValueError('train_op_fn and optimizer cannot both be None.') # Only summarize mean_loss for SUM reduction to preserve backwards # compatibility. Otherwise skip it to avoid unnecessary computation. if self._loss_reduction == losses.Reduction.SUM: @@ -1399,7 +1441,7 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): mode=model_fn.ModeKeys.TRAIN, predictions=predictions, loss=regularized_training_loss, - train_op=train_op_fn(regularized_training_loss)) + train_op=train_op) def _assert_range(labels, n_classes, message=None): diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py index b5d35c9..fe6ee07 100644 --- a/tensorflow/python/estimator/canned/head_test.py +++ b/tensorflow/python/estimator/canned/head_test.py @@ -842,6 +842,41 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2, }, summary_str, tol) + def test_train_with_optimizer(self): + n_classes = 3 + head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(n_classes) + + logits = np.array(((10, 0, 0), (0, 10, 0),), dtype=np.float32) + labels = np.array(((1,), (1,)), dtype=np.int64) + features = {'x': np.array(((42,),), dtype=np.int32)} + expected_train_result = 'my_train_op' + + class _Optimizer(object): + + def minimize(self, loss, global_step): + del global_step + return string_ops.string_join( + [constant_op.constant(expected_train_result), + string_ops.as_string(loss, precision=2)]) + + # loss = sum(cross_entropy(labels, logits)) = sum(10, 0) = 10. + expected_loss = 10. + spec = head.create_estimator_spec( + features=features, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + optimizer=_Optimizer()) + + tol = 1e-2 + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + loss, train_result = sess.run((spec.loss, spec.train_op)) + self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol) + self.assertEqual( + six.b('{0:s}{1:.2f}'.format(expected_train_result, expected_loss)), + train_result) + def test_train_summaries_with_head_name(self): n_classes = 3 head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( @@ -1942,6 +1977,39 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): metric_keys.MetricKeys.LOSS_MEAN: 20.5, }, summary_str) + def test_train_with_optimizer(self): + head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss() + + logits = np.array(((45,), (-41,),), dtype=np.float32) + labels = np.array(((1,), (1,),), dtype=np.float64) + expected_train_result = b'my_train_op' + features = {'x': np.array(((42,),), dtype=np.float32)} + # loss = sum(cross_entropy(labels, logits)) = sum(0, 41) = 41 + expected_loss = 41. + + class _Optimizer(object): + + def minimize(self, loss, global_step): + del global_step + with ops.control_dependencies((check_ops.assert_equal( + math_ops.to_float(expected_loss), math_ops.to_float(loss), + name='assert_loss'),)): + return constant_op.constant(expected_train_result) + + # Create estimator spec. + spec = head.create_estimator_spec( + features=features, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + optimizer=_Optimizer()) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + loss, train_result = sess.run((spec.loss, spec.train_op)) + self.assertAllClose(expected_loss, loss) + self.assertEqual(expected_train_result, train_result) + def test_train_summaries_with_head_name(self): head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( name='some_binary_head') @@ -3076,6 +3144,40 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): metric_keys.MetricKeys.LOSS_MEAN: 6.5, }, summary_str) + def test_train_with_optimizer(self): + head = head_lib._regression_head_with_mean_squared_error_loss() + self.assertEqual(1, head.logits_dimension) + + # Create estimator spec. + logits = np.array(((45,), (41,),), dtype=np.float32) + labels = np.array(((43.,), (44.,),), dtype=np.float64) + expected_train_result = b'my_train_op' + features = {'x': np.array(((42.,),), dtype=np.float32)} + # loss = (43-45)^2 + (44-41)^2 = 4 + 9 = 13 + expected_loss = 13 + + class _Optimizer(object): + + def minimize(self, loss, global_step): + del global_step + with ops.control_dependencies((check_ops.assert_equal( + math_ops.to_float(expected_loss), math_ops.to_float(loss), + name='assert_loss'),)): + return constant_op.constant(expected_train_result) + + spec = head.create_estimator_spec( + features=features, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + optimizer=_Optimizer()) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + loss, train_result = sess.run((spec.loss, spec.train_op)) + self.assertAllClose(expected_loss, loss) + self.assertEqual(expected_train_result, train_result) + def test_train_summaries_with_head_name(self): head = head_lib._regression_head_with_mean_squared_error_loss( name='some_regression_head') diff --git a/tensorflow/python/estimator/canned/linear.py b/tensorflow/python/estimator/canned/linear.py index a2f24ef..e7ec417 100644 --- a/tensorflow/python/estimator/canned/linear.py +++ b/tensorflow/python/estimator/canned/linear.py @@ -33,7 +33,6 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops.losses import losses from tensorflow.python.summary import summary from tensorflow.python.training import ftrl -from tensorflow.python.training import training_util from tensorflow.python.util.tf_export import tf_export @@ -157,17 +156,11 @@ def _linear_model_fn(features, labels, mode, head, feature_columns, optimizer, units=head.logits_dimension, feature_columns=feature_columns) logits = logit_fn(features=features) - def _train_op_fn(loss): - """Returns the op to optimize the loss.""" - return optimizer.minimize( - loss, - global_step=training_util.get_global_step()) - return head.create_estimator_spec( features=features, mode=mode, labels=labels, - train_op_fn=_train_op_fn, + optimizer=optimizer, logits=logits) -- 2.7.4