Move `loss_reduction` argument from `replicate_model_fn` to `TowerOptimizer.
authorIgor Saprykin <isaprykin@google.com>
Mon, 12 Mar 2018 20:07:12 +0000 (13:07 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 12 Mar 2018 20:15:10 +0000 (13:15 -0700)
PiperOrigin-RevId: 188766477

tensorflow/python/estimator/replicate_model_fn.py
tensorflow/python/estimator/replicate_model_fn_test.py

index 7418852..144d89a 100644 (file)
@@ -50,7 +50,6 @@ from tensorflow.python.training import optimizer as optimizer_lib
 
 
 def _replicate_model_fn(model_fn,
-                        loss_reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
                         devices=None):
   """Replicate `Estimator.model_fn` over GPUs.
 
@@ -109,8 +108,9 @@ def _replicate_model_fn(model_fn,
   On reduction algorithms:
   Certain algorithms were chosen for aggregating results of computations on
   multiple towers:
-    - Losses from all towers are reduced according to `loss_reduction`.
-    - Gradients from all towers are reduced according to `loss_reduction`
+    - Losses from all towers are reduced according to `loss_reduction` argument
+      to TowerOptimizer..
+    - Gradients from all towers are reduced according to the `loss_reduction`
       for each trainable variable.
     - `eval_metrics_ops` are reduced per metric using `reduce_mean`.
     - `EstimatorSpec.predictions` and `EstimatorSpec.export_outputs` are
@@ -134,16 +134,11 @@ def _replicate_model_fn(model_fn,
   Args:
     model_fn: `model_fn` as defined in `Estimator`.  See the section above about
       the train_op argument of `EstimatorSpec`.
-    loss_reduction: controls whether losses are summed or averaged.
     devices: Optional list of devices to replicate the model across.  This
       argument can be used to replice only on the subset of available GPUs.
       If `None`, then all available GPUs are going to be used for replication.
       If no GPUs are available, then the model is going to be placed on the CPU.
 
-  Raises:
-    ValueError: if there is no `loss_reduction` or if _TowerOptimizer is
-      mis-used.
-
   Returns:
     A replicated version of the supplied `model_fn`. Returned function that
       conforms to the requirements of `Estimator`'s `model_fn` and can be used
@@ -151,7 +146,6 @@ def _replicate_model_fn(model_fn,
   """
   return _replicate_model_fn_with_mode(
       model_fn,
-      loss_reduction,
       devices,
       # TODO(isaprykin): Query the system configuration to choose modes other
       # than `SHARED_LOCAL_PARAMETER_SERVER`, even though it is often
@@ -186,13 +180,9 @@ class _VariableDistributionMode(object):
 
 def _replicate_model_fn_with_mode(
     model_fn,
-    loss_reduction,
     devices=None,
     mode=_VariableDistributionMode.SHARED_LOCAL_PARAMETER_SERVER):
   """A version of `replicate_model_fn` that allows to specify a `mode`."""
-  if loss_reduction == losses.Reduction.NONE:
-    raise ValueError('Tower losses need to be reduced in some way, yet {} '
-                     'reduction is specified.'.format(loss_reduction))
   if not devices:
     devices = _get_local_devices('GPU') or _get_local_devices('CPU')
 
@@ -215,7 +205,6 @@ def _replicate_model_fn_with_mode(
         features=[features],
         labels=[labels],
         params=params,
-        loss_reduction=loss_reduction,
         config=config,
         devices=devices,
         local_ps_devices=ps_devices)[0]  # One device, so one spec is out.
@@ -230,7 +219,6 @@ def _replicate_model_fn_with_mode(
         features=feature_shards,
         labels=label_shards,
         params=params,
-        loss_reduction=loss_reduction,
         config=config,
         devices=devices,
         local_ps_devices=ps_devices)
@@ -255,7 +243,8 @@ class _TowerOptimizer(optimizer_lib.Optimizer):
 
   COLLECTION_FOR_GRAPH_STATES = 'replicate_model_fn_graph_states'
 
-  def __init__(self, optimizer_or_optimizer_fn):
+  def __init__(self, optimizer_or_optimizer_fn,
+               loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE):
     """Wrap an existing optimizer for gathering gradients across towers.
 
     Each invocation of model_fn has to call the same optimizers in the same
@@ -275,8 +264,10 @@ class _TowerOptimizer(optimizer_lib.Optimizer):
       optimizer_or_optimizer_fn: an instance of optimizer to wrap.  That
         instance is going to be used for optimizer-specific logic.  This can
         also be a no-argument function that returns such an optimizer instance.
+      loss_reduction: controls whether losses are summed or averaged.
     """
     self._optimizer_or_optimizer_fn = optimizer_or_optimizer_fn
+    self._loss_reduction = loss_reduction
 
   @staticmethod
   def has_been_used():
@@ -296,8 +287,9 @@ class _TowerOptimizer(optimizer_lib.Optimizer):
 
   def compute_gradients(self, loss, *args, **kwargs):
     """Compute gradients, but first, if needed, scale the loss."""
+    _TowerOptimizer._graph_state().set_loss_reduction(self._loss_reduction)
     loss = _scale_loss(loss,
-                       self._graph_state().loss_reduction,
+                       self._loss_reduction,
                        self._graph_state().number_of_towers)
     return self._get_optimizer().compute_gradients(loss, *args, **kwargs)
 
@@ -402,10 +394,12 @@ class _TowerOptimizer(optimizer_lib.Optimizer):
             self._collected_grads_and_vars[tower_id][index_of_last_gradients])
       return grads_and_vars
 
-    def set_reduction_across_towers(self, loss_reduction, number_of_towers):
-      self._loss_reduction = loss_reduction
+    def set_number_of_towers(self, number_of_towers):
       self._number_of_towers = number_of_towers
 
+    def set_loss_reduction(self, loss_reduction):
+      self._loss_reduction = loss_reduction
+
     @contextmanager
     def tower(self, tower_id, var_scope, name_scope):
       if tower_id == 0:
@@ -509,7 +503,6 @@ def _get_loss_towers(model_fn,
                      config,
                      devices,
                      local_ps_devices,
-                     loss_reduction,
                      name_scope_pattern=_DEFAULT_NAME_SCOPE_PATTERN):
   """Replicate the loss computation across devices."""
   tower_specs = []
@@ -524,8 +517,7 @@ def _get_loss_towers(model_fn,
   # pylint: disable=protected-access
   round_robin_strategy = device_setter_lib._RoundRobinStrategy(
       num_tasks=len(local_ps_devices))
-  _TowerOptimizer._graph_state().set_reduction_across_towers(
-      loss_reduction, len(devices))
+  _TowerOptimizer._graph_state().set_number_of_towers(len(devices))
 
   for i, device in enumerate(devices):
     is_the_first_tower = (i == 0)
@@ -567,7 +559,9 @@ def _get_loss_towers(model_fn,
             # Scaling the loss here doesn't actually affect gradients.  Another
             # instance of scaling happens inside the _TowerOptimizer.
             tower_spec = _scale_tower_loss(
-                tower_spec, loss_reduction, number_of_towers=len(devices))
+                tower_spec,
+                _TowerOptimizer._graph_state().loss_reduction,
+                number_of_towers=len(devices))
             tower_specs.append(tower_spec)
 
   if not _TowerOptimizer._did_towers_have_same_optimizer_calls():
@@ -607,20 +601,27 @@ def _scale_tower_loss(tower_spec, loss_reduction, number_of_towers):
     return tower_spec
 
   estimator_spec = _asdict(tower_spec)
-  estimator_spec['loss'] = _scale_loss(tower_spec.loss, loss_reduction,
-                                       number_of_towers)
+  estimator_spec['loss'] = _scale_loss(
+      tower_spec.loss,
+      loss_reduction,
+      number_of_towers,
+      reduced_loss_name='averaged_loss')
   return model_fn_lib.EstimatorSpec(**estimator_spec)
 
 
-def _scale_loss(loss, loss_reduction, number_of_towers):
+def _scale_loss(loss, loss_reduction, number_of_towers, reduced_loss_name=None):
   """If needed, scale down the loss for averaging loss by summing."""
   if loss is None:
     return None
   if number_of_towers == 1:
     return loss
 
+  if loss_reduction == losses.Reduction.NONE:
+    raise ValueError('Tower losses need to be reduced in some way, yet {} '
+                     'reduction is specified.'.format(loss_reduction))
+
   if loss_reduction != losses.Reduction.SUM:
-    return math_ops.div(loss, 1.0 * number_of_towers, name='averaged_loss')
+    return math_ops.div(loss, 1.0 * number_of_towers, name=reduced_loss_name)
   else:
     return loss
 
index b6dd4e9..ad1f9c0 100644 (file)
@@ -121,8 +121,9 @@ class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase):
     estimator = dnn.DNNClassifier(
         hidden_units=(2, 2),
         # Adagrad is configured with `get_optimizer_instance`, so the function
-        # form of `_TowerOptimizer.__init__` is used.
-        optimizer=replicate_model_fn._TowerOptimizer(optimizer_fn),
+        # form of `TowerOptimizer.__init__` is used.
+        optimizer=replicate_model_fn._TowerOptimizer(
+            optimizer_fn, loss_reduction=losses.Reduction.SUM),
         feature_columns=feature_columns,
         n_classes=n_classes,
         model_dir=self._model_dir)
@@ -134,7 +135,6 @@ class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase):
       model_fn = replicate_model_fn._replicate_model_fn_with_mode(
           estimator.model_fn,
           devices=['/gpu:0', '/gpu:1', '/gpu:2'],
-          loss_reduction=losses.Reduction.SUM,
           mode=mode)
 
     estimator = estimator_lib.Estimator(
@@ -178,32 +178,39 @@ class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase):
 
 class ReplicateModelTest(test_util.TensorFlowTestCase):
 
-  def model_fn(self, mode, features, labels, params):
-    c = variable_scope.get_variable(
-        'c',
-        initializer=constant_op.constant(10, dtype=dtypes.float64),
-        dtype=dtypes.float64)
+  def create_model_fn_with_loss_reduction(self, loss_reduction):
 
-    predictions = math_ops.multiply(features, c)
+    def model_fn(mode, features, labels, params):
+      c = variable_scope.get_variable(
+          'c',
+          initializer=constant_op.constant(10, dtype=dtypes.float64),
+          dtype=dtypes.float64)
 
-    loss = losses.absolute_difference(
-        labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
-    loss = math_ops.reduce_sum(loss)
+      predictions = math_ops.multiply(features, c)
 
-    metrics = {
-        'accuracy': metrics_lib.accuracy(labels, predictions),
-        'auc': metrics_lib.auc(labels, predictions)
-    }
+      loss = losses.absolute_difference(
+          labels=labels,
+          predictions=predictions,
+          reduction=losses.Reduction.SUM)
+      loss = math_ops.reduce_sum(loss)
 
-    optimizer = replicate_model_fn._TowerOptimizer(
-        gradient_descent.GradientDescentOptimizer(params['learning_rate']))
+      metrics = {
+          'accuracy': metrics_lib.accuracy(labels, predictions),
+          'auc': metrics_lib.auc(labels, predictions)
+      }
 
-    return model_fn_lib.EstimatorSpec(
-        mode=mode,
-        loss=loss,
-        eval_metric_ops=metrics,
-        predictions={'probabilities': predictions},
-        train_op=optimizer.minimize(loss))
+      optimizer = replicate_model_fn._TowerOptimizer(
+          gradient_descent.GradientDescentOptimizer(params['learning_rate']),
+          loss_reduction=loss_reduction)
+
+      return model_fn_lib.EstimatorSpec(
+          mode=mode,
+          loss=loss,
+          eval_metric_ops=metrics,
+          predictions={'probabilities': predictions},
+          train_op=optimizer.minimize(loss))
+
+    return model_fn
 
   @property
   def params(self):
@@ -217,8 +224,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
 
     with self.test_session() as session:
       replicated_model_fn = replicate_model_fn._replicate_model_fn(
-          self.model_fn,
-          loss_reduction=losses.Reduction.SUM,
+          self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
           devices=['/gpu:0', '/gpu:1'])
       estimator_spec = replicated_model_fn(
           features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
@@ -248,7 +254,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
           dtype=dtypes.float64)
 
       replicated_model_fn = replicate_model_fn._replicate_model_fn(
-          self.model_fn, losses.Reduction.MEAN, devices=['/gpu:0', '/gpu:1'])
+          self.create_model_fn_with_loss_reduction(losses.Reduction.MEAN),
+          devices=['/gpu:0', '/gpu:1'])
       estimator_spec = replicated_model_fn(
           features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
       session.run(variables.global_variables_initializer())
@@ -284,8 +291,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
       with self.test_session() as session, variable_scope.variable_scope(
           '', reuse=variable_scope.AUTO_REUSE):
         replicated_model_fn = replicate_model_fn._replicate_model_fn(
-            self.model_fn,
-            loss_reduction=losses.Reduction.SUM,
+            self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
             devices=['/gpu:0', '/gpu:1'])
         estimator_spec = replicated_model_fn(
             features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
@@ -307,8 +313,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
 
     with self.test_session() as session:
       replicated_model_fn = replicate_model_fn._replicate_model_fn(
-          self.model_fn,
-          loss_reduction=losses.Reduction.SUM,
+          self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
           devices=['/gpu:0', '/gpu:1'])
       estimator_spec = replicated_model_fn(
           features, labels, model_fn_lib.ModeKeys.EVAL, self.params)
@@ -338,7 +343,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
 
     with self.test_session() as session:
       replicated_model_fn = replicate_model_fn._replicate_model_fn(
-          self.model_fn, losses.Reduction.MEAN, devices=['/gpu:0', '/gpu:1'])
+          self.create_model_fn_with_loss_reduction(losses.Reduction.MEAN),
+          devices=['/gpu:0', '/gpu:1'])
       estimator_spec = replicated_model_fn(
           features, labels, model_fn_lib.ModeKeys.EVAL, self.params)
       session.run(variables.local_variables_initializer())
@@ -367,7 +373,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
 
     with self.test_session() as session:
       replicated_model_fn = replicate_model_fn._replicate_model_fn(
-          self.model_fn, devices=['/gpu:0', '/gpu:1'])
+          self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
+          devices=['/gpu:0', '/gpu:1'])
       estimator_spec = replicated_model_fn(
           features, labels, model_fn_lib.ModeKeys.PREDICT, self.params)
       session.run(variables.global_variables_initializer())
@@ -382,7 +389,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
 
     with self.test_session() as session:
       replicated_model_fn = replicate_model_fn._replicate_model_fn(
-          self.model_fn, devices=['/gpu:0'])
+          self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
+          devices=['/gpu:0'])
       estimator_spec = replicated_model_fn(
           features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
       session.run(variables.global_variables_initializer())
@@ -404,7 +412,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
 
     with self.test_session() as session:
       replicated_model_fn = replicate_model_fn._replicate_model_fn(
-          self.model_fn, devices=['/gpu:0'])
+          self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
+          devices=['/gpu:0'])
       estimator_spec = replicated_model_fn(
           features, labels, model_fn_lib.ModeKeys.EVAL, self.params)
       session.run(variables.local_variables_initializer())
@@ -432,7 +441,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
 
     with self.test_session() as session:
       replicated_model_fn = replicate_model_fn._replicate_model_fn(
-          self.model_fn, devices=['/gpu:0'])
+          self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
+          devices=['/gpu:0'])
       estimator_spec = replicated_model_fn(
           features, labels, model_fn_lib.ModeKeys.PREDICT, self.params)
       session.run(variables.global_variables_initializer())
@@ -448,15 +458,22 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
     with self.assertRaisesRegexp(
         ValueError, '.*Batch.+size.+needs.+to.+be.+divisible.+by.+GPUs.+'):
       replicated_model_fn = replicate_model_fn._replicate_model_fn(
-          self.model_fn, devices=['/gpu:0', '/gpu:1'])
+          self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
+          devices=['/gpu:0', '/gpu:1'])
       _ = replicated_model_fn(
           features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
 
   def test_unsupported_loss_reduction(self):
+    features = np.array([[1.0], [2.0], [3.0]])
+    labels = np.array([[1.0], [2.0], [3.0]])
+
     with self.assertRaisesRegexp(ValueError,
                                  '.+none.+reduction.+is.+specified.+'):
-      _ = replicate_model_fn._replicate_model_fn(self.model_fn,
-                                                 losses.Reduction.NONE)
+      replicated_model_fn = replicate_model_fn._replicate_model_fn(
+          self.create_model_fn_with_loss_reduction(losses.Reduction.NONE),
+          devices=['/gpu:0', '/gpu:1', '/gpu:2'])
+      _ = replicated_model_fn(
+          features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
 
   def test_places_on_gpu_with_upper_case_spelling(self):
     features = np.array([[0.01], [0.002]])
@@ -464,7 +481,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
 
     with self.test_session():
       replicated_model_fn = replicate_model_fn._replicate_model_fn(
-          self.model_fn, devices=['/GPU:0'])
+          self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
+          devices=['/GPU:0'])
       _ = replicated_model_fn(
           features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
 
@@ -478,7 +496,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
 
     with self.test_session():
       replicated_model_fn = replicate_model_fn._replicate_model_fn(
-          self.model_fn, devices=['/gpu:0'])
+          self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
+          devices=['/gpu:0'])
       _ = replicated_model_fn(
           features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
 
@@ -624,7 +643,8 @@ class MakeSureSyncReplicasOptimizerWorks(test_util.TensorFlowTestCase):
     optimizer = training.SyncReplicasOptimizer(
         optimizer, replicas_to_aggregate=1)
     sync_hook = optimizer.make_session_run_hook(True)
-    optimizer = replicate_model_fn._TowerOptimizer(optimizer)
+    optimizer = replicate_model_fn._TowerOptimizer(
+        optimizer, loss_reduction=losses.Reduction.SUM)
 
     return model_fn_lib.EstimatorSpec(
         mode=mode,
@@ -650,7 +670,6 @@ class MakeSureSyncReplicasOptimizerWorks(test_util.TensorFlowTestCase):
 
     model_fn = replicate_model_fn._replicate_model_fn(
         self.model_fn,
-        loss_reduction=losses.Reduction.SUM,
         devices=['/gpu:0', '/gpu:1'])
 
     estimator = estimator_lib.Estimator(
@@ -687,9 +706,10 @@ class ReplicateWithTwoOptimizersTest(test_util.TensorFlowTestCase):
     }
 
     first_optimizer = replicate_model_fn._TowerOptimizer(
-        gradient_descent.GradientDescentOptimizer(1.0))
+        gradient_descent.GradientDescentOptimizer(1.0),
+        loss_reduction=losses.Reduction.SUM)
     second_optimizer = replicate_model_fn._TowerOptimizer(
-        adam.AdamOptimizer(1.0))
+        adam.AdamOptimizer(1.0), loss_reduction=losses.Reduction.SUM)
 
     with ops_lib.control_dependencies([side_effects.assign_add(1.0)]):
       first_grads_and_vars = first_optimizer.compute_gradients(loss)
@@ -712,7 +732,6 @@ class ReplicateWithTwoOptimizersTest(test_util.TensorFlowTestCase):
     with self.test_session() as session:
       replicated_model_fn = replicate_model_fn._replicate_model_fn(
           self.model_fn,
-          loss_reduction=losses.Reduction.SUM,
           devices=['/gpu:0', '/gpu:1'])
       estimator_spec = replicated_model_fn(features, labels,
                                            model_fn_lib.ModeKeys.TRAIN, {})
@@ -787,11 +806,13 @@ class ReplicateWithTwoLossesAndOneOptimizer(test_util.TensorFlowTestCase):
     train_ops = []
 
     optimizer = replicate_model_fn._TowerOptimizer(
-        gradient_descent.GradientDescentOptimizer(1.0))
+        gradient_descent.GradientDescentOptimizer(1.0),
+        loss_reduction=losses.Reduction.SUM)
     train_ops.append(optimizer.minimize(loss, var_list=[c]))
     if not self.should_skip_optimizer():
       another_optimizer = replicate_model_fn._TowerOptimizer(
-          gradient_descent.GradientDescentOptimizer(1.0))
+          gradient_descent.GradientDescentOptimizer(1.0),
+          loss_reduction=losses.Reduction.SUM)
       train_ops.append(another_optimizer.minimize(another_loss, var_list=[d]))
 
     train_op = control_flow_ops.group(train_ops)
@@ -806,10 +827,9 @@ class ReplicateWithTwoLossesAndOneOptimizer(test_util.TensorFlowTestCase):
     features = np.array([[1.0], [2.0]])
     labels = np.array([[1.0], [2.0]])
 
-    with self.test_session() as session:
+    with ops_lib.Graph().as_default(), self.test_session() as session:
       replicated_model_fn = replicate_model_fn._replicate_model_fn(
           self.model_fn,
-          loss_reduction=losses.Reduction.SUM,
           devices=['/gpu:0', '/gpu:1'])
       estimator_spec = replicated_model_fn(features, labels,
                                            model_fn_lib.ModeKeys.TRAIN, {})
@@ -881,7 +901,7 @@ class FailToWrapOptimizerInTheModelFn(test_util.TensorFlowTestCase):
 
     with self.test_session():
       with self.assertRaisesRegexp(ValueError,
-                                   'Please.+wrap.+with.+_TowerOptimizer'):
+                                   'Please.+wrap.+with.+TowerOptimizer'):
         replicated_model_fn = replicate_model_fn._replicate_model_fn(
             self.model_fn, devices=['/gpu:0', '/gpu:1'])
         _ = replicated_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN,
@@ -890,30 +910,43 @@ class FailToWrapOptimizerInTheModelFn(test_util.TensorFlowTestCase):
 
 class GetLossTowersTest(test_util.TensorFlowTestCase):
 
-  def model_fn(self, mode, features, labels, params):
-    c = variable_scope.get_variable(
-        'c',
-        initializer=constant_op.constant(0.25, dtype=dtypes.float64),
-        dtype=dtypes.float64)
+  def create_model_fn_with_loss_reduction(self, loss_reduction):
 
-    predictions = math_ops.add(np.array([0.1, 0.2, 0.3, features[0]]), c)
-    labels = np.array([0.1, 0.2, 0.3, labels[0]])
+    def model_fn(mode, features, labels, params):
+      del params
+      c = variable_scope.get_variable(
+          'c',
+          initializer=constant_op.constant(0.25, dtype=dtypes.float64),
+          dtype=dtypes.float64)
 
-    loss = losses.absolute_difference(
-        labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
+      predictions = math_ops.add(np.array([0.1, 0.2, 0.3, features[0]]), c)
+      labels = np.array([0.1, 0.2, 0.3, labels[0]])
 
-    return model_fn_lib.EstimatorSpec(mode=mode, loss=math_ops.reduce_sum(loss))
+      loss = losses.absolute_difference(
+          labels=labels,
+          predictions=predictions,
+          reduction=losses.Reduction.SUM)
+
+      optimizer = replicate_model_fn._TowerOptimizer(
+          gradient_descent.GradientDescentOptimizer(1.0),
+          loss_reduction)
+
+      return model_fn_lib.EstimatorSpec(
+          mode=mode,
+          loss=math_ops.reduce_sum(loss),
+          train_op=optimizer.minimize(loss))
+
+    return model_fn
 
   def test_gradients_are_computed(self):
     with self.test_session() as session:
       tower_specs = replicate_model_fn._get_loss_towers(
-          self.model_fn,
+          self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
           mode=None,
           features=[[0.6], [1.6]],
           labels=[[0.6], [0.6]],
           params=None,
           config=None,
-          loss_reduction=losses.Reduction.SUM,
           devices=['/gpu:0', '/gpu:1'],
           local_ps_devices=['/gpu:0'],
           name_scope_pattern='test_tower_{}')
@@ -941,12 +974,11 @@ class GetLossTowersTest(test_util.TensorFlowTestCase):
   def test_gradients_are_computed_with_mean_reduction(self):
     with self.test_session() as session:
       tower_specs = replicate_model_fn._get_loss_towers(
-          self.model_fn,
+          self.create_model_fn_with_loss_reduction(losses.Reduction.MEAN),
           mode=model_fn_lib.ModeKeys.EVAL,
           features=[[0.6], [1.6]],
           labels=[[0.6], [0.6]],
           params=None,
-          loss_reduction=losses.Reduction.MEAN,
           config=None,
           devices=['/gpu:0', '/gpu:1'],
           local_ps_devices=['/gpu:0'],
@@ -999,7 +1031,6 @@ class GetLossTowersTest(test_util.TensorFlowTestCase):
           features=[[0.6], [1.6], [2.6]],
           labels=[[0.6], [0.6], [2.6]],
           params=None,
-          loss_reduction=losses.Reduction.SUM,
           config=None,
           devices=['/gpu:0', '/gpu:1', '/gpu:3'],
           local_ps_devices=['/gpu:0', '/gpu:1', '/gpu:3'],
@@ -1296,7 +1327,6 @@ class PredictSpecTest(test_util.TensorFlowTestCase):
           self.model_fn,
           mode=None,
           features=[[0.1], [0.2]],
-          loss_reduction=losses.Reduction.SUM,
           labels=[[], []],
           params=None,
           config=None,