Supports optimizer arg in head.create_estimator_spec.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 22 Mar 2018 21:01:10 +0000 (14:01 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 22 Mar 2018 21:03:24 +0000 (14:03 -0700)
PiperOrigin-RevId: 190121386

tensorflow/contrib/estimator/BUILD
tensorflow/contrib/estimator/python/estimator/head.py
tensorflow/contrib/estimator/python/estimator/head_test.py
tensorflow/contrib/estimator/python/estimator/multi_head.py
tensorflow/contrib/estimator/python/estimator/multi_head_test.py
tensorflow/python/estimator/BUILD
tensorflow/python/estimator/canned/dnn.py
tensorflow/python/estimator/canned/dnn_testing_utils.py
tensorflow/python/estimator/canned/head.py
tensorflow/python/estimator/canned/head_test.py
tensorflow/python/estimator/canned/linear.py

index 676d602..2437426 100644 (file)
@@ -175,6 +175,7 @@ py_library(
         "//tensorflow/python:sparse_ops",
         "//tensorflow/python:sparse_tensor",
         "//tensorflow/python:summary",
+        "//tensorflow/python:training",
         "//tensorflow/python/estimator:export_output",
         "//tensorflow/python/estimator:head",
         "//tensorflow/python/estimator:metric_keys",
@@ -292,6 +293,7 @@ py_library(
         "//tensorflow/python:math_ops",
         "//tensorflow/python:metrics",
         "//tensorflow/python:summary",
+        "//tensorflow/python:training",
         "//tensorflow/python/estimator:export_output",
         "//tensorflow/python/estimator:head",
         "//tensorflow/python/estimator:metric_keys",
index f95fcc8..42e1b7b 100644 (file)
@@ -36,10 +36,12 @@ from tensorflow.python.ops import sparse_ops
 from tensorflow.python.ops.losses import losses
 from tensorflow.python.saved_model import signature_constants
 from tensorflow.python.summary import summary
+from tensorflow.python.training import training_util
 
 _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
 
 
+# TODO(b/65403806): Switch loss_reduction default to SUM_OVER_BATCH_SIZE.
 def multi_class_head(n_classes,
                      weight_column=None,
                      label_vocabulary=None,
@@ -489,8 +491,8 @@ class _MultiLabelHead(head_lib._Head):  # pylint:disable=protected-access
         processed_labels=processed_labels)
 
   def create_estimator_spec(
-      self, features, mode, logits, labels=None, train_op_fn=None,
-      regularization_losses=None):
+      self, features, mode, logits, labels=None, optimizer=None,
+      train_op_fn=None, regularization_losses=None):
     """Returns an `EstimatorSpec`.
 
     Args:
@@ -502,8 +504,11 @@ class _MultiLabelHead(head_lib._Head):  # pylint:disable=protected-access
         with shape `[D0, D1, ... DN, n_classes]` or `SparseTensor` with
         `dense_shape` `[D0, D1, ... DN, ?]`. `labels` is required argument when
         `mode` equals `TRAIN` or `EVAL`.
+      optimizer: `Optimizer` instance to optimize the loss in TRAIN mode.
+        Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which
+        updates variables and increments `global_step`.
       train_op_fn: Function that takes a scalar loss `Tensor` and returns
-        `train_op`. Required in TRAIN mode.
+        `train_op`. Used if `optimizer` is `None`.
       regularization_losses: A list of additional scalar losses to be added to
         the training loss, such as regularization losses. These losses are
         usually expressed as a batch average, so for best results users need to
@@ -513,7 +518,8 @@ class _MultiLabelHead(head_lib._Head):  # pylint:disable=protected-access
     Returns:
       `EstimatorSpec`.
     Raises:
-      ValueError: If `train_op_fn` is `None` in TRAIN mode.
+      ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN
+        mode, or if both are set.
     """
     with ops.name_scope(self._name, 'head'):
       logits = head_lib._check_logits_final_dim(logits, self.logits_dimension)  # pylint:disable=protected-access
@@ -565,8 +571,16 @@ class _MultiLabelHead(head_lib._Head):  # pylint:disable=protected-access
                 regularization_loss=regularization_loss))
 
       # Train.
-      if train_op_fn is None:
-        raise ValueError('train_op_fn can not be None.')
+      if optimizer is not None:
+        if train_op_fn is not None:
+          raise ValueError('train_op_fn and optimizer cannot both be set.')
+        train_op = optimizer.minimize(
+            regularized_training_loss,
+            global_step=training_util.get_global_step())
+      elif train_op_fn is not None:
+        train_op = train_op_fn(regularized_training_loss)
+      else:
+        raise ValueError('train_op_fn and optimizer cannot both be None.')
       # Only summarize mean_loss for SUM reduction to preserve backwards
       # compatibility. Otherwise skip it to avoid unnecessary computation.
       if self._loss_reduction == losses.Reduction.SUM:
@@ -592,7 +606,7 @@ class _MultiLabelHead(head_lib._Head):  # pylint:disable=protected-access
         mode=model_fn.ModeKeys.TRAIN,
         predictions=predictions,
         loss=regularized_training_loss,
-        train_op=train_op_fn(regularized_training_loss))
+        train_op=train_op)
 
   def _eval_metric_ops(
       self, labels, probabilities, weights, unreduced_loss,
index dc30dde..776f0ee 100644 (file)
@@ -863,6 +863,42 @@ class MultiLabelHead(test.TestCase):
     self._test_train(
         head=head, logits=logits, labels=labels, expected_loss=expected_loss)
 
+  def test_train_with_optimizer(self):
+    head = head_lib.multi_label_head(n_classes=2)
+    logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
+    labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
+    # For large logits, sigmoid cross entropy loss is approximated as:
+    # loss = labels * (logits < 0) * (-logits) +
+    #        (1 - labels) * (logits > 0) * logits =>
+    # expected_unweighted_loss = [[10., 10.], [15., 0.]]
+    # Average over classes, sum over weights.
+    expected_loss = 17.5
+    expected_train_result = 'my_train_op'
+
+    class _Optimizer(object):
+
+      def minimize(self, loss, global_step):
+        del global_step
+        return string_ops.string_join(
+            [constant_op.constant(expected_train_result),
+             string_ops.as_string(loss, precision=3)])
+
+    spec = head.create_estimator_spec(
+        features={'x': np.array(((42,),), dtype=np.int32)},
+        mode=model_fn.ModeKeys.TRAIN,
+        logits=logits,
+        labels=labels,
+        optimizer=_Optimizer())
+
+    tol = 1e-3
+    with self.test_session() as sess:
+      _initialize_variables(self, spec.scaffold)
+      loss, train_result = sess.run((spec.loss, spec.train_op))
+      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
+      self.assertEqual(
+          six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
+          train_result)
+
   def test_train_with_regularization_losses(self):
     head = head_lib.multi_label_head(
         n_classes=2, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
index 23d3714..bbbc19c 100644 (file)
@@ -31,6 +31,7 @@ from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import metrics as metrics_lib
 from tensorflow.python.saved_model import signature_constants
 from tensorflow.python.summary import summary
+from tensorflow.python.training import training_util
 
 
 _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
@@ -227,8 +228,10 @@ class _MultiHead(head_lib._Head):  # pylint:disable=protected-access
         weights=example_weights_by_head,
         processed_labels=labels_by_head)
 
+  # TODO(b/65403806): Support regularization_losses arg.
   def create_estimator_spec(
-      self, features, mode, logits, labels=None, train_op_fn=None):
+      self, features, mode, logits, labels=None, optimizer=None,
+      train_op_fn=None):
     """See `_Head`."""
     if isinstance(logits, dict):
       logits_dict = logits
@@ -249,9 +252,10 @@ class _MultiHead(head_lib._Head):  # pylint:disable=protected-access
               train_op_fn=_no_op_train_fn))
 
     if mode == model_fn.ModeKeys.TRAIN:
-      if train_op_fn is None:
-        raise ValueError('train_op_fn can not be None in TRAIN mode.')
-      spec = self._merge_train(all_estimator_spec, train_op_fn)
+      spec = self._merge_train(
+          all_estimator_spec=all_estimator_spec,
+          optimizer=optimizer,
+          train_op_fn=train_op_fn)
       with ops.name_scope(''):
         summary.scalar(metric_keys.MetricKeys.LOSS, spec.loss)
       return spec
@@ -280,16 +284,21 @@ class _MultiHead(head_lib._Head):  # pylint:disable=protected-access
         begin_idx += head.logits_dimension
     return logits_dict
 
-  def _merge_train(self, all_estimator_spec, train_op_fn):
+  def _merge_train(self, all_estimator_spec, optimizer, train_op_fn):
     """Merges list of `EstimatorSpec` for training.
 
     Args:
       all_estimator_spec: list of `EstimatorSpec` for the individual heads.
-      train_op_fn: Function to create train op. See `create_estimator_spec`
-        documentation for more details.
+      optimizer: `Optimizer` instance to create train op. See
+        `create_estimator_spec` documentation for more details.
+      train_op_fn: Function to create train op. Used if `optimizer` is `None`.
 
     Returns:
       `EstimatorSpec` that merges all heads for TRAIN.
+
+    Raises:
+      ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN
+        mode.
     """
     losses = []
     metrics = {}
@@ -298,11 +307,20 @@ class _MultiHead(head_lib._Head):  # pylint:disable=protected-access
       # Metric keys already contain head.name.
       metrics.update(spec.eval_metric_ops or {})
     loss = _merge_losses(losses, self._head_weights)
+    if optimizer is not None:
+      if train_op_fn is not None:
+        raise ValueError('train_op_fn and optimizer cannot both be set.')
+      train_op = optimizer.minimize(
+          loss, global_step=training_util.get_global_step())
+    elif train_op_fn is not None:
+      train_op = train_op_fn(loss)
+    else:
+      raise ValueError('train_op_fn and optimizer cannot both be None.')
 
     return model_fn.EstimatorSpec(
         mode=model_fn.ModeKeys.TRAIN,
         loss=loss,
-        train_op=train_op_fn(loss),
+        train_op=train_op,
         eval_metric_ops=metrics)
 
   def _merge_predict(self, all_estimator_spec):
index 8e788a9..43cc157 100644 (file)
@@ -550,6 +550,44 @@ class MultiHeadTest(test.TestCase):
           metric_keys.MetricKeys.LOSS_MEAN + '/head1': expected_loss / 2,
       }, summary_str, tol)
 
+  def test_train_one_head_with_optimizer(self):
+    head1 = head_lib.multi_label_head(n_classes=2, name='head1')
+    multi_head = multi_head_lib.multi_head([head1])
+
+    logits = {'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)}
+    labels = {'head1': np.array([[1, 0], [1, 1]], dtype=np.int64)}
+    # For large logits, sigmoid cross entropy loss is approximated as:
+    # loss = labels * (logits < 0) * (-logits) +
+    #        (1 - labels) * (logits > 0) * logits =>
+    # expected_unweighted_loss = [[10., 10.], [15., 0.]]
+    # Average over classes, sum over weights.
+    expected_loss = 17.5
+    expected_train_result = 'my_train_op'
+
+    class _Optimizer(object):
+
+      def minimize(self, loss, global_step):
+        del global_step
+        return string_ops.string_join(
+            [constant_op.constant(expected_train_result),
+             string_ops.as_string(loss, precision=3)])
+
+    spec = multi_head.create_estimator_spec(
+        features={'x': np.array(((42,),), dtype=np.int32)},
+        mode=model_fn.ModeKeys.TRAIN,
+        logits=logits,
+        labels=labels,
+        optimizer=_Optimizer())
+
+    tol = 1e-3
+    with self.test_session() as sess:
+      _initialize_variables(self, spec.scaffold)
+      loss, train_result = sess.run((spec.loss, spec.train_op))
+      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
+      self.assertEqual(
+          six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
+          train_result)
+
   def test_train_two_heads_with_weights(self):
     head1 = head_lib.multi_label_head(n_classes=2, name='head1')
     head2 = head_lib.multi_label_head(n_classes=3, name='head2')
index b25f9d2..5afb5a7 100644 (file)
@@ -265,7 +265,6 @@ py_library(
         "//tensorflow/python:nn",
         "//tensorflow/python:partitioned_variables",
         "//tensorflow/python:summary",
-        "//tensorflow/python:training",
         "//tensorflow/python:variable_scope",
         "//tensorflow/python/feature_column",
         "//tensorflow/python/ops/losses",
@@ -617,6 +616,7 @@ py_library(
         "//tensorflow/python:sparse_tensor",
         "//tensorflow/python:string_ops",
         "//tensorflow/python:summary",
+        "//tensorflow/python:training",
         "//tensorflow/python:weights_broadcast_ops",
         "//tensorflow/python/feature_column",
         "//tensorflow/python/ops/losses",
index 7043da8..6382622 100644 (file)
@@ -32,7 +32,6 @@ from tensorflow.python.ops import partitioned_variables
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops.losses import losses
 from tensorflow.python.summary import summary
-from tensorflow.python.training import training_util
 from tensorflow.python.util.tf_export import tf_export
 
 # The default learning rate of 0.05 is a historical artifact of the initial
@@ -183,17 +182,11 @@ def _dnn_model_fn(features,
         input_layer_partitioner=input_layer_partitioner)
     logits = logit_fn(features=features, mode=mode)
 
-    def _train_op_fn(loss):
-      """Returns the op to optimize the loss."""
-      return optimizer.minimize(
-          loss,
-          global_step=training_util.get_global_step())
-
     return head.create_estimator_spec(
         features=features,
         mode=mode,
         labels=labels,
-        train_op_fn=_train_op_fn,
+        optimizer=optimizer,
         logits=logits)
 
 
index 85b058c..44545c0 100644 (file)
@@ -53,7 +53,7 @@ from tensorflow.python.summary.writer import writer_cache
 from tensorflow.python.training import checkpoint_utils
 from tensorflow.python.training import gradient_descent
 from tensorflow.python.training import monitored_session
-from tensorflow.python.training import optimizer
+from tensorflow.python.training import optimizer as optimizer_lib
 from tensorflow.python.training import saver
 from tensorflow.python.training import session_run_hook
 from tensorflow.python.training import training_util
@@ -134,7 +134,8 @@ def mock_head(testcase, hidden_units, logits_dimension, expected_logits):
       hidden_weights_names + hidden_biases_names +
       [LOGITS_WEIGHTS_NAME + '/part_0:0', LOGITS_BIASES_NAME + '/part_0:0'])
 
-  def _create_estimator_spec(features, mode, logits, labels, train_op_fn):
+  def _create_estimator_spec(
+      features, mode, logits, labels, train_op_fn=None, optimizer=None):
     del features, labels  # Not used.
     trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
     testcase.assertItemsEqual(expected_var_names,
@@ -144,8 +145,12 @@ def mock_head(testcase, hidden_units, logits_dimension, expected_logits):
         expected_logits, logits, message='Failed for mode={}. '.format(mode))
     with ops.control_dependencies([assert_logits]):
       if mode == model_fn.ModeKeys.TRAIN:
+        if train_op_fn is not None:
+          train_op = train_op_fn(loss)
+        elif optimizer is not None:
+          train_op = optimizer.minimize(loss, global_step=None)
         return model_fn.EstimatorSpec(
-            mode=mode, loss=loss, train_op=train_op_fn(loss))
+            mode=mode, loss=loss, train_op=train_op)
       elif mode == model_fn.ModeKeys.EVAL:
         return model_fn.EstimatorSpec(mode=mode, loss=array_ops.identity(loss))
       elif mode == model_fn.ModeKeys.PREDICT:
@@ -203,8 +208,8 @@ def mock_optimizer(testcase, hidden_units, expected_loss=None):
       return control_flow_ops.no_op()
 
   optimizer_mock = test.mock.NonCallableMagicMock(
-      spec=optimizer.Optimizer,
-      wraps=optimizer.Optimizer(use_locking=False, name='my_optimizer'))
+      spec=optimizer_lib.Optimizer,
+      wraps=optimizer_lib.Optimizer(use_locking=False, name='my_optimizer'))
   optimizer_mock.minimize = test.mock.MagicMock(wraps=_minimize)
 
   return optimizer_mock
index f68204a..c9635a9 100644 (file)
@@ -44,6 +44,7 @@ from tensorflow.python.ops import weights_broadcast_ops
 from tensorflow.python.ops.losses import losses
 from tensorflow.python.saved_model import signature_constants
 from tensorflow.python.summary import summary
+from tensorflow.python.training import training_util
 
 _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
 
@@ -85,40 +86,39 @@ class _Head(object):
     ```python
     def _my_dnn_model_fn(features, labels, mode, params, config=None):
       # Optionally your callers can pass head to model_fn as a param.
-      head = tf.contrib.learn.regression_head(...)
-      input = tf.contrib.layers.input_from_feature_columns(features, ...)
-      last_hidden_layer_out = tf.contrib.layers.stack(
-          input, tf.contrib.layers.fully_connected, [1000, 500])
-      logits = tf.contrib.layers.fully_connected(
-          last_hidden_layer_out, head.logits_dimension, activation_fn=None)
-
-      def _train_op_fn(loss):
-        return optimizer.minimize(loss)
+      head = tf.contrib.estimator.regression_head(...)
+      inputs = tf.feature_column.input_layer(features, ...)
+      hidden_layer0 = tf.layers.dense(
+          inputs, units=1000, activation=tf.nn.relu)
+      hidden_layer1 = tf.layers.dense(
+          hidden_layer0, units=500, activation=tf.nn.relu)
+      logits = tf.layers.dense(
+          hidden_layer1, units=head.logits_dimension, activation=None)
 
       return head.create_estimator_spec(
           features=features,
           labels=labels,
           mode=mode,
           logits=logits,
-          train_op_fn=_train_op_fn)
+          optimizer=optimizer)
     ```
 
   There are cases where computing and applying gradients can not be meaningfully
-  captured with train_op_fn we support (for example, with sync optimizer). In
-  such case, you can take the responsibility on your own. Here is a common
-  use case,
+  captured with optimizer or train_op_fn we support (for example, with sync
+  optimizer). In such case, you can take the responsibility on your own. Here is
+  a common use case,
     ```python
     estimator_spec = head.create_estimator_spec(
         features=features,
         labels=labels,
         mode=mode,
         logits=logits,
-        train_op_fn=tf.contrib.learn.no_op_train_fn)
+        train_op_fn=lambda _: tf.no_op())
     if mode == model_fn.ModeKeys.TRAIN:
       optimizer = ...
       sync = tf.train.SyncReplicasOptimizer(opt=optimizer, ...)
-      update_op = tf.contrib.layers.optimize_loss(optimizer=sync,
-                                                  loss=estimator_spec.loss, ...)
+      update_op = sync.minimize(
+          estimator_spec.loss, global_step=tf.get_global_step())
       hooks = [sync.make_session_run_hook(is_chief)]
       ... update train_op and hooks in EstimatorSpec and return
     ```
@@ -172,10 +172,12 @@ class _Head(object):
     """
     raise NotImplementedError('Calling an abstract method.')
 
+  # TODO(b/65403806): By default, collect regularization_losses from
+  # GraphKeys.REGULARIZATION_LOSSES collection.
   @abc.abstractmethod
   def create_estimator_spec(
-      self, features, mode, logits, labels=None, train_op_fn=None,
-      regularization_losses=None):
+      self, features, mode, logits, labels=None, optimizer=None,
+      train_op_fn=None, regularization_losses=None):
     """Returns `EstimatorSpec` that a model_fn can return.
 
     Please note that,
@@ -186,10 +188,14 @@ class _Head(object):
       mode: Estimator's `ModeKeys`.
       logits: logits `Tensor` to be used by the head.
       labels: Labels `Tensor`, or `dict` of same.
+      optimizer: `Optimizer` instance to optimize the loss in TRAIN mode.
+        Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which
+        updates variables and increments `global_step`.
       train_op_fn: Function that takes a scalar loss `Tensor` and returns an op
-        to optimize the model with the loss. This is used in TRAIN mode and
-        must not be None. None is allowed in other modes. If you want to
-        optimize loss yourself you can pass `no_op_train_fn` and then use
+        to optimize the model with the loss in TRAIN mode. Used if `optimizer`
+        is `None`. Exactly one of `train_op_fn` and `optimizer` must be set in
+        TRAIN mode. None is allowed in other modes. If you want to optimize loss
+        yourself you can pass `lambda _: tf.no_op()` and then use
         EstimatorSpec.loss to compute and apply gradients.
       regularization_losses: A list of additional scalar losses to be added to
         the training loss, such as regularization losses.
@@ -694,8 +700,8 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
         processed_labels=label_ids)
 
   def create_estimator_spec(
-      self, features, mode, logits, labels=None, train_op_fn=None,
-      regularization_losses=None):
+      self, features, mode, logits, labels=None, optimizer=None,
+      train_op_fn=None, regularization_losses=None):
     """Returns an `EstimatorSpec`.
 
     Args:
@@ -706,8 +712,11 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
       labels: Labels integer or string `Tensor` with shape matching `logits`,
         namely `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN]`. `labels` is
         required argument when `mode` equals `TRAIN` or `EVAL`.
+      optimizer: `Optimizer` instance to optimize the loss in TRAIN mode.
+        Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which
+        updates variables and increments `global_step`.
       train_op_fn: Function that takes a scalar loss `Tensor` and returns
-        `train_op`. Required in TRAIN mode.
+        `train_op`. Used if `optimizer` is `None`.
       regularization_losses: A list of additional scalar losses to be added to
         the training loss, such as regularization losses. These losses are
         usually expressed as a batch average, so for best results users need to
@@ -717,7 +726,8 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
     Returns:
       `EstimatorSpec`.
     Raises:
-      ValueError: If `train_op_fn` is `None` in TRAIN mode.
+      ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN
+        mode, or if both are set.
     """
     with ops.name_scope(self._name, 'head'):
       logits = _check_logits_final_dim(logits, self.logits_dimension)
@@ -780,8 +790,16 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
                 regularization_loss=regularization_loss))
 
       # Train.
-      if train_op_fn is None:
-        raise ValueError('train_op_fn cannot be None.')
+      if optimizer is not None:
+        if train_op_fn is not None:
+          raise ValueError('train_op_fn and optimizer cannot both be set.')
+        train_op = optimizer.minimize(
+            regularized_training_loss,
+            global_step=training_util.get_global_step())
+      elif train_op_fn is not None:
+        train_op = train_op_fn(regularized_training_loss)
+      else:
+        raise ValueError('train_op_fn and optimizer cannot both be None.')
       # Only summarize mean_loss for SUM reduction to preserve backwards
       # compatibility. Otherwise skip it to avoid unnecessary computation.
       if self._loss_reduction == losses.Reduction.SUM:
@@ -807,7 +825,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
         mode=model_fn.ModeKeys.TRAIN,
         predictions=predictions,
         loss=regularized_training_loss,
-        train_op=train_op_fn(regularized_training_loss))
+        train_op=train_op)
 
 
 def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
@@ -1039,8 +1057,8 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
         processed_labels=labels)
 
   def create_estimator_spec(
-      self, features, mode, logits, labels=None, train_op_fn=None,
-      regularization_losses=None):
+      self, features, mode, logits, labels=None, optimizer=None,
+      train_op_fn=None, regularization_losses=None):
     """Returns an `EstimatorSpec`.
 
     Args:
@@ -1051,8 +1069,11 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
       labels: Labels integer or string `Tensor` with shape matching `logits`,
         namely `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN]`. `labels` is required
         argument when `mode` equals `TRAIN` or `EVAL`.
+      optimizer: `Optimizer` instance to optimize the loss in TRAIN mode.
+        Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which
+        updates variables and increments `global_step`.
       train_op_fn: Function that takes a scalar loss `Tensor` and returns
-        `train_op`. Required in TRAIN mode.
+        `train_op`. Used if `optimizer` is `None`.
       regularization_losses: A list of additional scalar losses to be added to
         the training loss, such as regularization losses. These losses are
         usually expressed as a batch average, so for best results users need to
@@ -1062,7 +1083,8 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
     Returns:
       `EstimatorSpec`.
     Raises:
-      ValueError: If `train_op_fn` is `None` in TRAIN mode.
+      ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN
+        mode, or if both are set.
     """
     # Predict.
     with ops.name_scope(self._name, 'head'):
@@ -1134,8 +1156,16 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
                 regularization_loss=regularization_loss))
 
       # Train.
-      if train_op_fn is None:
-        raise ValueError('train_op_fn can not be None.')
+      if optimizer is not None:
+        if train_op_fn is not None:
+          raise ValueError('train_op_fn and optimizer cannot both be set.')
+        train_op = optimizer.minimize(
+            regularized_training_loss,
+            global_step=training_util.get_global_step())
+      elif train_op_fn is not None:
+        train_op = train_op_fn(regularized_training_loss)
+      else:
+        raise ValueError('train_op_fn and optimizer cannot both be None.')
       # Only summarize mean_loss for SUM reduction to preserve backwards
       # compatibility. Otherwise skip it to avoid unnecessary computation.
       if self._loss_reduction == losses.Reduction.SUM:
@@ -1160,7 +1190,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
         mode=model_fn.ModeKeys.TRAIN,
         predictions=predictions,
         loss=regularized_training_loss,
-        train_op=train_op_fn(regularized_training_loss))
+        train_op=train_op)
 
 
 def _regression_head_with_mean_squared_error_loss(
@@ -1289,8 +1319,8 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
         processed_labels=labels)
 
   def create_estimator_spec(
-      self, features, mode, logits, labels=None, train_op_fn=None,
-      regularization_losses=None):
+      self, features, mode, logits, labels=None, optimizer=None,
+      train_op_fn=None, regularization_losses=None):
     """Returns an `EstimatorSpec`.
 
     Args:
@@ -1302,8 +1332,11 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
         `[D0, D1, ... DN, logits_dimension]`. When `logits_dimension=1`, shape
         `[D0, D1, ... DN]` is also supported. `labels` is required argument when
         `mode` equals `TRAIN` or `EVAL`.
+      optimizer: `Optimizer` instance to optimize the loss in TRAIN mode.
+        Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which
+        updates variables and increments `global_step`.
       train_op_fn: Function that takes a scalar loss `Tensor` and returns
-        `train_op`. Required in TRAIN mode.
+        `train_op`. Used if `optimizer` is `None`.
       regularization_losses: A list of additional scalar losses to be added to
         the training loss, such as regularization losses. These losses are
         usually expressed as a batch average, so for best results users need to
@@ -1313,7 +1346,8 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
     Returns:
       `EstimatorSpec`.
     Raises:
-      ValueError: If `train_op_fn` is `None` in TRAIN mode.
+      ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN
+        mode, or if both are set.
     """
     # Predict.
     with ops.name_scope(self._name, 'head'):
@@ -1373,8 +1407,16 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
             eval_metric_ops=eval_metric_ops)
 
       # Train.
-      if train_op_fn is None:
-        raise ValueError('train_op_fn can not be None.')
+      if optimizer is not None:
+        if train_op_fn is not None:
+          raise ValueError('train_op_fn and optimizer cannot both be set.')
+        train_op = optimizer.minimize(
+            regularized_training_loss,
+            global_step=training_util.get_global_step())
+      elif train_op_fn is not None:
+        train_op = train_op_fn(regularized_training_loss)
+      else:
+        raise ValueError('train_op_fn and optimizer cannot both be None.')
       # Only summarize mean_loss for SUM reduction to preserve backwards
       # compatibility. Otherwise skip it to avoid unnecessary computation.
       if self._loss_reduction == losses.Reduction.SUM:
@@ -1399,7 +1441,7 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
         mode=model_fn.ModeKeys.TRAIN,
         predictions=predictions,
         loss=regularized_training_loss,
-        train_op=train_op_fn(regularized_training_loss))
+        train_op=train_op)
 
 
 def _assert_range(labels, n_classes, message=None):
index b5d35c9..fe6ee07 100644 (file)
@@ -842,6 +842,41 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
           metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2,
       }, summary_str, tol)
 
+  def test_train_with_optimizer(self):
+    n_classes = 3
+    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(n_classes)
+
+    logits = np.array(((10, 0, 0), (0, 10, 0),), dtype=np.float32)
+    labels = np.array(((1,), (1,)), dtype=np.int64)
+    features = {'x': np.array(((42,),), dtype=np.int32)}
+    expected_train_result = 'my_train_op'
+
+    class _Optimizer(object):
+
+      def minimize(self, loss, global_step):
+        del global_step
+        return string_ops.string_join(
+            [constant_op.constant(expected_train_result),
+             string_ops.as_string(loss, precision=2)])
+
+    # loss = sum(cross_entropy(labels, logits)) = sum(10, 0) = 10.
+    expected_loss = 10.
+    spec = head.create_estimator_spec(
+        features=features,
+        mode=model_fn.ModeKeys.TRAIN,
+        logits=logits,
+        labels=labels,
+        optimizer=_Optimizer())
+
+    tol = 1e-2
+    with self.test_session() as sess:
+      _initialize_variables(self, spec.scaffold)
+      loss, train_result = sess.run((spec.loss, spec.train_op))
+      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
+      self.assertEqual(
+          six.b('{0:s}{1:.2f}'.format(expected_train_result, expected_loss)),
+          train_result)
+
   def test_train_summaries_with_head_name(self):
     n_classes = 3
     head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
@@ -1942,6 +1977,39 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
           metric_keys.MetricKeys.LOSS_MEAN: 20.5,
       }, summary_str)
 
+  def test_train_with_optimizer(self):
+    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()
+
+    logits = np.array(((45,), (-41,),), dtype=np.float32)
+    labels = np.array(((1,), (1,),), dtype=np.float64)
+    expected_train_result = b'my_train_op'
+    features = {'x': np.array(((42,),), dtype=np.float32)}
+    # loss = sum(cross_entropy(labels, logits)) = sum(0, 41) = 41
+    expected_loss = 41.
+
+    class _Optimizer(object):
+
+      def minimize(self, loss, global_step):
+        del global_step
+        with ops.control_dependencies((check_ops.assert_equal(
+            math_ops.to_float(expected_loss), math_ops.to_float(loss),
+            name='assert_loss'),)):
+          return constant_op.constant(expected_train_result)
+
+    # Create estimator spec.
+    spec = head.create_estimator_spec(
+        features=features,
+        mode=model_fn.ModeKeys.TRAIN,
+        logits=logits,
+        labels=labels,
+        optimizer=_Optimizer())
+
+    with self.test_session() as sess:
+      _initialize_variables(self, spec.scaffold)
+      loss, train_result = sess.run((spec.loss, spec.train_op))
+      self.assertAllClose(expected_loss, loss)
+      self.assertEqual(expected_train_result, train_result)
+
   def test_train_summaries_with_head_name(self):
     head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
         name='some_binary_head')
@@ -3076,6 +3144,40 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
           metric_keys.MetricKeys.LOSS_MEAN: 6.5,
       }, summary_str)
 
+  def test_train_with_optimizer(self):
+    head = head_lib._regression_head_with_mean_squared_error_loss()
+    self.assertEqual(1, head.logits_dimension)
+
+    # Create estimator spec.
+    logits = np.array(((45,), (41,),), dtype=np.float32)
+    labels = np.array(((43.,), (44.,),), dtype=np.float64)
+    expected_train_result = b'my_train_op'
+    features = {'x': np.array(((42.,),), dtype=np.float32)}
+    # loss = (43-45)^2 + (44-41)^2 = 4 + 9 = 13
+    expected_loss = 13
+
+    class _Optimizer(object):
+
+      def minimize(self, loss, global_step):
+        del global_step
+        with ops.control_dependencies((check_ops.assert_equal(
+            math_ops.to_float(expected_loss), math_ops.to_float(loss),
+            name='assert_loss'),)):
+          return constant_op.constant(expected_train_result)
+
+    spec = head.create_estimator_spec(
+        features=features,
+        mode=model_fn.ModeKeys.TRAIN,
+        logits=logits,
+        labels=labels,
+        optimizer=_Optimizer())
+
+    with self.test_session() as sess:
+      _initialize_variables(self, spec.scaffold)
+      loss, train_result = sess.run((spec.loss, spec.train_op))
+      self.assertAllClose(expected_loss, loss)
+      self.assertEqual(expected_train_result, train_result)
+
   def test_train_summaries_with_head_name(self):
     head = head_lib._regression_head_with_mean_squared_error_loss(
         name='some_regression_head')
index a2f24ef..e7ec417 100644 (file)
@@ -33,7 +33,6 @@ from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops.losses import losses
 from tensorflow.python.summary import summary
 from tensorflow.python.training import ftrl
-from tensorflow.python.training import training_util
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -157,17 +156,11 @@ def _linear_model_fn(features, labels, mode, head, feature_columns, optimizer,
         units=head.logits_dimension, feature_columns=feature_columns)
     logits = logit_fn(features=features)
 
-    def _train_op_fn(loss):
-      """Returns the op to optimize the loss."""
-      return optimizer.minimize(
-          loss,
-          global_step=training_util.get_global_step())
-
     return head.create_estimator_spec(
         features=features,
         mode=mode,
         labels=labels,
-        train_op_fn=_train_op_fn,
+        optimizer=optimizer,
         logits=logits)