Support multiple loss and multiple optimizers in replicate_model_fn.
authorIgor Saprykin <isaprykin@google.com>
Fri, 22 Dec 2017 07:43:57 +0000 (23:43 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 22 Dec 2017 07:47:34 +0000 (23:47 -0800)
Instead of supplying `optimizer_fn`, the user is now expected to wrap their optimizer in GatheringOptimizer.  The latter will gather gradients, reduce and apply them.

There can be multiple instances of GatheringOptimizer inside the model.

PiperOrigin-RevId: 179899422

tensorflow/contrib/estimator/__init__.py
tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py

index 28c1f8b1809d27db697365b7bb50441f7820d2b4..7533943d4834fd1211c0907b322854338e6d08f3 100644 (file)
@@ -47,6 +47,7 @@ _allowed_symbols = [
     'dnn_logit_fn_builder',
     'linear_logit_fn_builder',
     'replicate_model_fn',
+    'GatheringOptimizer',
 ]
 
 remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
index 598bd549c5cef7edde6bf94605aa8839b611e185..f4caa60460dbd4a9149afb45ab4d4e0148f76bca 100644 (file)
@@ -23,6 +23,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from contextlib import contextmanager
 import copy
 
 import six
@@ -44,11 +45,10 @@ from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops.losses import losses
 from tensorflow.python.platform import tf_logging
 from tensorflow.python.training import device_setter as device_setter_lib
-from tensorflow.python.training import training_util
+from tensorflow.python.training import optimizer as optimizer_lib
 
 
 def replicate_model_fn(model_fn,
-                       optimizer_fn,
                        loss_reduction=losses.Reduction.SUM,
                        devices=None):
   """Replicate `Estimator.model_fn` over GPUs within a single host.
@@ -74,30 +74,29 @@ def replicate_model_fn(model_fn,
 
   Here is an example of how one might use their `model_fn` to run over GPUs:
     ```python
-       def optimizer_fn():
-         return tf.train.GradientDescentOptimizer(learning_rate=0.001)
        ...
        def model_fn(...):  # See `model_fn` in `Estimator`.
          loss = ...
+         optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
+         optimizer = tf.contrib.estimator.GatheringOptimizer(optimizer)
          if mode == tf.estimator.ModeKeys.TRAIN:
            #  See the section below on `EstimatorSpec.train_op`.
-           return EstimatorSpec(mode=mode, loss=loss, train_op=tf.noop())
+           return EstimatorSpec(mode=mode, loss=loss,
+                                train_op=optimizer.minimize(loss))
 
          #  No change for `ModeKeys.EVAL` or `ModeKeys.PREDICT`.
          return EstimatorSpec(...)
        ...
        classifier = tf.estimator.Estimator(
-         model_fn=replicate_model_fn.replicate_model_fn(model_fn, optimizer_fn))
+         model_fn=tf.contrib.estimator.replicate_model_fn(model_fn))
     ```
 
   On `EstimatorSpec.train_op`:
   `model_fn` returns `EstimatorSpec.train_op` for
   `tf.estimator.GraphKeys.TRAIN`. It is typically derived using an optimizer.
-  `replicate_model_fn` ignores the returned `EstimatorSpec.train_op`, so there
-  is no need to use an optimizer inside the user's `model_fn`.  The
-  `EstimatorSpec.loss` subgraph is going to be executed, while
-  `EstimatorSpec.train_op` isn't going to be executed. One could pass
-  `train_op=tf.noop()` to `EstimatorSpec`.
+  Towers are expected to populate it in the same way.  Gradients from all towers
+  are reduced and applied in the last tower.  To achieve that,
+  `GatheringOptimizer` needs to be used. See `GatheringOptimizer`.
 
   On sharding input features and labels:
   Input features and labels are split for consumption by each tower. They are
@@ -125,9 +124,6 @@ def replicate_model_fn(model_fn,
   Args:
     model_fn: `model_fn` as defined in `Estimator`.  See the section above about
       the train_op argument of `EstimatorSpec`.
-    optimizer_fn: a function that returns an optimizer instance.  The function
-      may accept one `params` argument.  This is the `params` argument as
-      defined by `Estimator`.  See  the `Estimator` documentation for details.
     loss_reduction: controls whether losses are summed or averaged.
     devices: Optional list of devices to replicate the model across.  This
       argument can be used to replice only on the subset of available GPUs.
@@ -141,7 +137,6 @@ def replicate_model_fn(model_fn,
   """
   return _replicate_model_fn_with_mode(
       model_fn,
-      optimizer_fn,
       loss_reduction,
       devices,
       # TODO(isaprykin): Query the system configuration to choose modes other
@@ -177,7 +172,6 @@ class _VariableDistributionMode(object):
 
 def _replicate_model_fn_with_mode(
     model_fn,
-    optimizer_fn,
     loss_reduction=losses.Reduction.SUM,
     devices=None,
     mode=_VariableDistributionMode.SHARED_LOCAL_PARAMETER_SERVER):
@@ -216,8 +210,7 @@ def _replicate_model_fn_with_mode(
         local_ps_devices=ps_devices)
 
     if mode == model_fn_lib.ModeKeys.TRAIN:
-      train_op = _minimize_towers(tower_specs,
-                                  _call_optimizer_fn(optimizer_fn, params))
+      train_op = _minimize_towers(tower_specs)
       return _train_spec(
           tower_specs, train_op, aggregation_device=consolidation_device)
     elif mode == model_fn_lib.ModeKeys.EVAL:
@@ -228,6 +221,188 @@ def _replicate_model_fn_with_mode(
   return replicated_model_fn
 
 
+class GatheringOptimizer(optimizer_lib.Optimizer):
+  """Gathers gradients from all towers and reduces them in the last one."""
+
+  COLLECTION_FOR_GRAPH_STATES = 'replicate_model_fn_graph_states'
+
+  def __init__(self, optimizer_or_optimizer_fn):
+    """Wrap an existing optimizer for gathering gradients across towers.
+
+    Each invocation of model_fn has to call optimizers in the same order.
+
+    Multiple optimizers that use the same or different losses are supported.
+    Optimizers, however, need to be of different type as per `__class__`
+    in order to increment the global_step correctly.
+
+    Args:
+      optimizer_or_optimizer_fn: an instance of optimizer to wrap.  That
+        instance is going to be used for optimizer-specific logic.  This can
+        also be a no-argument function that returns such an optimizer instance.
+    """
+    self._optimizer_or_optimizer_fn = optimizer_or_optimizer_fn
+
+  @staticmethod
+  def has_been_used():
+    return GatheringOptimizer._graph_state().has_gathering_optimizer_been_used
+
+  def get_slot(self, *args, **kwargs):
+    return self._get_optimizer().get_slot(*args, **kwargs)
+
+  def get_slot_names(self, *args, **kwargs):
+    return self._get_optimizer().get_slot_names(*args, **kwargs)
+
+  def get_name(self, *args, **kwargs):
+    return self._get_optimizer().get_name(*args, **kwargs)
+
+  def variables(self, *args, **kwargs):
+    return self._get_optimizer().variables(*args, **kwargs)
+
+  def compute_gradients(self, loss, *args, **kwargs):
+    """Compute gradients, but first, if needed, scale the loss."""
+    loss = _scale_loss(loss,
+                       self._graph_state().loss_reduction,
+                       self._graph_state().number_of_towers)
+    return self._get_optimizer().compute_gradients(loss, *args, **kwargs)
+
+  def apply_gradients(self, grads_and_vars, global_step=None, **kwargs):
+    """Collect gradients updates to apply them with the last tower."""
+    self._graph_state().collect_gradients(grads_and_vars,
+                                          self._get_optimizer())
+
+    if not self._graph_state().is_the_last_tower:
+      return self._construct_no_op_train_op()
+    else:
+      # Gradients need to be gathered and applied in the scope of the first
+      # tower, so that the tensors are accessible via names without prefixes.
+      var_scope, name_scope = self._graph_state().scopes_of_the_first_tower
+      with variable_scope.variable_scope(var_scope):
+        with ops_lib.name_scope(name_scope):
+          return self._apply_gathered_gradients(global_step, **kwargs)
+
+  def _apply_gathered_gradients(self, global_step, **kwargs):
+    graph_state = self._graph_state()
+    optimizer = self._get_optimizer()
+    train_ops = []
+
+    grad_lists = {}
+    # Only aggregate gradients for `optimizer.__class__` type of Optimizer.
+    for grad, var in graph_state.get_grad_and_vars_for_optimizer(optimizer):
+      if grad is not None:
+        grad_lists.setdefault(var, []).append(grad)
+
+    aggregated_grads = []
+    with ops_lib.name_scope('gradient_aggregating'):
+      for var, grads in six.iteritems(grad_lists):
+        grad = _compute_sum_on_device(grads, var.device)
+        aggregated_grads.append((grad, var))
+    train_ops.append(optimizer.apply_gradients(aggregated_grads))
+
+    # A model might use multiple optimizers.  We only want to increment global
+    # step after apply_gradients of the last optimizer inside the tower.
+    if global_step and graph_state.is_the_last_optimizer_within_a_tower(
+        optimizer):
+      with ops_lib.control_dependencies(train_ops):
+        with ops_lib.colocate_with(global_step):
+          return state_ops.assign_add(global_step, 1)
+    else:
+      return control_flow_ops.group(train_ops)
+
+  def _get_optimizer(self):
+    if not isinstance(self._optimizer_or_optimizer_fn, optimizer_lib.Optimizer):
+      # If optimizer is given as a function then we need to wait till we are
+      # under the right graph context before constructing it.
+      self._optimizer_or_optimizer_fn = self._optimizer_or_optimizer_fn()
+    self._graph_state().has_gathering_optimizer_been_used = True
+    return self._optimizer_or_optimizer_fn
+
+  def _construct_no_op_train_op(self):
+    return control_flow_ops.no_op(name='train_op_placeholder')
+
+  @staticmethod
+  def _graph_state():
+    graph_states = ops_lib.get_default_graph().get_collection_ref(
+        GatheringOptimizer.COLLECTION_FOR_GRAPH_STATES)
+    if not graph_states:
+      graph_states.append(GatheringOptimizer._PerGraphState())
+    return graph_states[-1]
+
+  @staticmethod
+  def _clear_graph_state():
+    # Clearing a collection in Graph will prevent _PerGraphState from being
+    # serialized.
+    ops_lib.get_default_graph().clear_collection(
+        GatheringOptimizer.COLLECTION_FOR_GRAPH_STATES)
+
+  class _PerGraphState(object):
+    """Gradient reduction related state of a Tensorflow graph."""
+
+    def __init__(self):
+      # For every type of optimizer, collect all gradients and variables.
+      self._optimizer_grads_and_vars = {}
+      # In what order were optimizers invoked within each tower?
+      self._ordered_optimizer_types = []
+      self._number_of_towers = 0
+      self._is_the_last_tower = False
+      self._loss_reduction = None
+      # Scopes of the first tower that don't have a prefix:
+      self._variable_scope = None
+      self._name_scope = None
+      # If needed, alert that GatheringOptimizer needs to be used with model_fn.
+      self._has_gathering_optimizer_been_used = False
+
+    def collect_gradients(self, grads_and_vars, optimizer):
+      if optimizer.__class__ not in self._ordered_optimizer_types:
+        self._ordered_optimizer_types.append(optimizer.__class__)
+
+      self._optimizer_grads_and_vars.setdefault(optimizer.__class__,
+                                                []).extend(grads_and_vars)
+
+    def get_grad_and_vars_for_optimizer(self, optimizer):
+      return self._optimizer_grads_and_vars[optimizer.__class__]
+
+    def set_reduction_across_towers(self, loss_reduction, number_of_towers):
+      self._loss_reduction = loss_reduction
+      self._number_of_towers = number_of_towers
+
+    @contextmanager
+    def tower(self, tower_id, var_scope, name_scope):
+      if tower_id == 0:
+        self._variable_scope = var_scope
+        self._name_scope = name_scope
+      if tower_id == (self._number_of_towers - 1):
+        self._is_the_last_tower = True
+      yield
+      self._is_the_last_tower = False
+
+    @property
+    def scopes_of_the_first_tower(self):
+      return self._variable_scope, self._name_scope
+
+    @property
+    def is_the_last_tower(self):
+      return self._is_the_last_tower
+
+    def is_the_last_optimizer_within_a_tower(self, optimizer):
+      return optimizer.__class__ == self._ordered_optimizer_types[-1]
+
+    @property
+    def number_of_towers(self):
+      return self._number_of_towers
+
+    @property
+    def loss_reduction(self):
+      return self._loss_reduction
+
+    @property
+    def has_gathering_optimizer_been_used(self):
+      return self._has_gathering_optimizer_been_used
+
+    @has_gathering_optimizer_been_used.setter
+    def has_gathering_optimizer_been_used(self, value):
+      self._has_gathering_optimizer_been_used = value
+
+
 def _get_local_devices(device_type):
   local_device_protos = device_lib.list_local_devices()
   return [
@@ -296,7 +471,8 @@ def _get_loss_towers(model_fn,
   # pylint: disable=protected-access
   round_robin_strategy = device_setter_lib._RoundRobinStrategy(
       num_tasks=len(local_ps_devices))
-  # pylint: enable=protected-access
+  GatheringOptimizer._graph_state().set_reduction_across_towers(
+      loss_reduction, len(devices))
 
   for i, device in enumerate(devices):
     is_the_first_tower = (i == 0)
@@ -313,22 +489,35 @@ def _get_loss_towers(model_fn,
     if is_the_first_tower:
       name_scope = ''
 
-    with variable_scope.variable_scope('', reuse=not is_the_first_tower):
-      with ops_lib.name_scope(name_scope.format(i)):
-        with ops_lib.device(device_setter):
-          labels_shard = None
-          if labels:
-            labels_shard = labels[i]
-
-          tower_spec = model_fn(
-              mode=mode,
-              features=features[i],
-              labels=labels_shard,
-              **optional_params)
-          if loss_reduction != losses.Reduction.SUM:
+    with variable_scope.variable_scope(
+        '', reuse=not is_the_first_tower) as var_scope:
+      with ops_lib.name_scope(name_scope.format(i)) as name_scope:
+        with GatheringOptimizer._graph_state().tower(
+            tower_id=i, var_scope=var_scope, name_scope=name_scope):
+          with ops_lib.device(device_setter):
+            labels_shard = None
+            if labels:
+              labels_shard = labels[i]
+
+            tower_spec = model_fn(
+                mode=mode,
+                features=features[i],
+                labels=labels_shard,
+                **optional_params)
+
+            if (tower_spec.train_op is not None and
+                not GatheringOptimizer.has_been_used()):
+              raise ValueError('Please wrap optimizers with GatheringOptimizer'
+                               ' in order to use replicate_model_fn.')
+
+            # Scaling the loss here doesn't actually affect gradients.  Another
+            # instance of scaling happens inside the GatheringOptimizer.
             tower_spec = _scale_tower_loss(
-                tower_spec, number_of_towers=len(devices))
-          tower_specs.append(tower_spec)
+                tower_spec, loss_reduction, number_of_towers=len(devices))
+            tower_specs.append(tower_spec)
+
+  GatheringOptimizer._clear_graph_state()
+  # pylint: enable=protected-access
   return tower_specs
 
 
@@ -355,44 +544,31 @@ def _local_device_setter(worker_device, ps_devices, ps_strategy):
   return local_device_chooser
 
 
-def _scale_tower_loss(tower_spec, number_of_towers):
-  """Scale down the loss for arriving at the average loss by summing."""
+def _scale_tower_loss(tower_spec, loss_reduction, number_of_towers):
+  """Produce an EstimatorSpec with approproriately scaled loss."""
   if tower_spec.loss is None:
     return tower_spec
 
   estimator_spec = _asdict(tower_spec)
-  estimator_spec['loss'] = math_ops.div(
-      tower_spec.loss, 1.0 * number_of_towers, name='averaged_loss')
+  estimator_spec['loss'] = _scale_loss(tower_spec.loss, loss_reduction,
+                                       number_of_towers)
   return model_fn_lib.EstimatorSpec(**estimator_spec)
 
 
-def _minimize_towers(tower_specs, optimizer):
-  """Aggregate and apply gradients for computed losses."""
-  grad_lists = {}
-  for tower_spec in tower_specs:
-    with ops_lib.device(tower_spec.loss.device):
-      for grad, var in optimizer.compute_gradients(tower_spec.loss):
-        if grad is not None:
-          grad_lists.setdefault(var, []).append(grad)
-
-  aggregated_grads = []
-  with ops_lib.name_scope('gradient_aggregating'):
-    for var, grads in six.iteritems(grad_lists):
-      grad = _compute_sum_on_device(grads, var.device)
-      aggregated_grads.append((grad, var))
-
-  train_op = optimizer.apply_gradients(
-      aggregated_grads, global_step=training_util.get_global_step())
+def _scale_loss(loss, loss_reduction, number_of_towers):
+  """If needed, scale down the loss for averaging loss by summing."""
+  if loss is None:
+    return None
 
-  return train_op
+  if loss_reduction != losses.Reduction.SUM:
+    return math_ops.div(loss, 1.0 * number_of_towers, name='averaged_loss')
+  else:
+    return loss
 
 
-def _call_optimizer_fn(optimizer_fn, params):
-  arguments = {}
-  optimizer_fn_arguments = util.fn_args(optimizer_fn)
-  if 'params' in optimizer_fn_arguments:
-    arguments['params'] = params
-  return optimizer_fn(**arguments)
+def _minimize_towers(tower_specs):
+  """`train_op` of the last tower applies aggregated gradients."""
+  return tower_specs[-1].train_op
 
 
 def _compute_sum_on_device(values, device, name=None):
@@ -450,7 +626,7 @@ def _eval_spec(tower_specs, aggregation_device, aggregated_loss_name='loss'):
 def _reduce_metric_variables(number_of_towers):
   """Aggregate local variables used in metrics into the first tower."""
   if number_of_towers == 1:
-    return control_flow_ops.no_op()
+    return control_flow_ops.no_op(name='no_eval_metric_reduction')
 
   metric_variables = ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)
   variables_per_tower = len(metric_variables) // number_of_towers
index b452e5c7359a973bea670f5760b229cf72d032f5..0b6c0e957b75d641ca93b946267a5797083d16ab 100644 (file)
@@ -50,6 +50,7 @@ from tensorflow.python.platform import gfile
 from tensorflow.python.platform import test
 from tensorflow.python.saved_model import signature_constants
 from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import adam
 from tensorflow.python.training import device_setter
 from tensorflow.python.training import gradient_descent
 
@@ -112,26 +113,24 @@ class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase):
                     0., len(x_data), len(x_data), dtype=np.int64)), 1)
     ]
 
+    def optimizer_fn():
+      return optimizers.get_optimizer_instance('Adagrad', learning_rate=0.05)
+
     estimator = dnn.DNNClassifier(
         hidden_units=(2, 2),
+        # Adagrad is configured with `get_optimizer_instance`, so the function
+        # form of `GatheringOptimizer.__init__` is used.
+        optimizer=replicate_model_fn.GatheringOptimizer(optimizer_fn),
         feature_columns=feature_columns,
         n_classes=n_classes,
         model_dir=self._model_dir)
 
-    def optimizer_fn():
-      return optimizers.get_optimizer_instance('Adagrad', learning_rate=0.05)
-
     if not mode:  # Use the public `replicate_model_fn`.
       model_fn = replicate_model_fn.replicate_model_fn(
-          estimator.model_fn,
-          optimizer_fn,
-          devices=['/gpu:0', '/gpu:1', '/gpu:2'])
+          estimator.model_fn, devices=['/gpu:0', '/gpu:1', '/gpu:2'])
     else:
       model_fn = replicate_model_fn._replicate_model_fn_with_mode(
-          estimator.model_fn,
-          optimizer_fn,
-          devices=['/gpu:0', '/gpu:1', '/gpu:2'],
-          mode=mode)
+          estimator.model_fn, devices=['/gpu:0', '/gpu:1', '/gpu:2'], mode=mode)
 
     estimator = estimator_lib.Estimator(
         model_fn=model_fn,
@@ -159,6 +158,10 @@ class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase):
                                              serving_input_receiver_fn)
     self.assertTrue(gfile.Exists(export_dir))
 
+    # Nothing should be left in the graph so that it doesn't get serialized.
+    self.assertFalse(ops_lib.get_default_graph().get_collection_ref(
+        replicate_model_fn.GatheringOptimizer.COLLECTION_FOR_GRAPH_STATES))
+
   def _as_label(self, data_in_float):
     return np.rint(data_in_float).astype(np.int64)
 
@@ -178,28 +181,24 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
 
     predictions = math_ops.multiply(features, c)
 
-    loss = None
-    if mode is not model_fn_lib.ModeKeys.PREDICT:
-      loss = losses.absolute_difference(
-          labels=labels,
-          predictions=predictions,
-          reduction=losses.Reduction.SUM)
-      loss = math_ops.reduce_sum(loss)
+    loss = losses.absolute_difference(
+        labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
+    loss = math_ops.reduce_sum(loss)
 
     metrics = {
         'accuracy': metrics_lib.accuracy(labels, predictions),
         'auc': metrics_lib.auc(labels, predictions)
     }
 
+    optimizer = replicate_model_fn.GatheringOptimizer(
+        gradient_descent.GradientDescentOptimizer(params['learning_rate']))
+
     return model_fn_lib.EstimatorSpec(
         mode=mode,
         loss=loss,
         eval_metric_ops=metrics,
         predictions={'probabilities': predictions},
-        train_op=control_flow_ops.no_op())  # This train_op isn't actually used.
-
-  def optimizer_fn(self, params):
-    return gradient_descent.GradientDescentOptimizer(params['learning_rate'])
+        train_op=optimizer.minimize(loss))
 
   @property
   def params(self):
@@ -213,7 +212,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
 
     with self.test_session() as session:
       replicated_model_fn = replicate_model_fn.replicate_model_fn(
-          self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1'])
+          self.model_fn, devices=['/gpu:0', '/gpu:1'])
       estimator_spec = replicated_model_fn(
           features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
       session.run(variables.global_variables_initializer())
@@ -235,10 +234,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
 
     with self.test_session() as session:
       replicated_model_fn = replicate_model_fn.replicate_model_fn(
-          self.model_fn,
-          self.optimizer_fn,
-          losses.Reduction.MEAN,
-          devices=['/gpu:0', '/gpu:1'])
+          self.model_fn, losses.Reduction.MEAN, devices=['/gpu:0', '/gpu:1'])
       estimator_spec = replicated_model_fn(
           features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
       session.run(variables.global_variables_initializer())
@@ -256,24 +252,38 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
         c = variable_scope.get_variable('c', dtype=dtypes.float64)
         self.assertEqual(8.5, session.run(c))
 
-  def test_train_spec_with_optimizer_without_params(self):
-
-    def optimizer_fn_without_params():
-      return gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
+  def test_train_two_steps_collected_gradients_are_reset_between_steps(self):
+    with ops_lib.Graph().as_default():
+      features = array_ops.placeholder(dtypes.float64)
+      labels = array_ops.placeholder(dtypes.float64)
 
-    features = np.array([[1.0], [2.0]])
-    labels = np.array([[1.0], [2.0]])
+      feature_inputs = np.array([[1.0], [2.0]]), np.array([[1.5], [2.5]])
+      label_inputs = np.array([[1.0], [2.0]]), np.array([[1.5], [2.5]])
 
-    with self.test_session() as session:  # pylint: disable=unused-variable
-      replicated_model_fn = replicate_model_fn.replicate_model_fn(
-          self.model_fn,
-          optimizer_fn_without_params,
-          devices=['/gpu:0', '/gpu:1'])
-      # This call is going to fail if `replicated_model_fn` is still passing
-      # `params` inside `optimizer_fn`, even though the latter doesn't take any:
-      estimator_spec = replicated_model_fn(
-          features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
-      del estimator_spec
+      # loss = feature * c - label
+      expected_losses = ((1.0 * 10 - 1.0) + (2.0 * 10 - 2.0),
+                         (1.5 * 7.0 - 1.5) + (2.5 * 7.0 - 2.5))
+      # Derivative of the loss is 1.0 + 2.0 for the first step and 1.5 + 2.5
+      # for the second.
+      expected_c = 10.0 - 3.0, 7.0 - 4.0
+
+      with self.test_session() as session, variable_scope.variable_scope(
+          '', reuse=variable_scope.AUTO_REUSE):
+        replicated_model_fn = replicate_model_fn.replicate_model_fn(
+            self.model_fn, devices=['/gpu:0', '/gpu:1'])
+        estimator_spec = replicated_model_fn(
+            features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
+        session.run(variables.global_variables_initializer())
+
+        for feature_input, label_input, loss, weight in zip(
+            feature_inputs, label_inputs, expected_losses, expected_c):
+          feeds = {features: feature_input, labels: label_input}
+
+          self.assertEqual(loss, session.run(estimator_spec.loss, feeds))
+
+          session.run(estimator_spec.train_op, feeds)
+          c = variable_scope.get_variable('c', dtype=dtypes.float64)
+          self.assertEqual(weight, session.run(c, feeds))
 
   def test_eval(self):
     features = np.array([[0.01], [0.002]])
@@ -281,7 +291,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
 
     with self.test_session() as session:
       replicated_model_fn = replicate_model_fn.replicate_model_fn(
-          self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1'])
+          self.model_fn, devices=['/gpu:0', '/gpu:1'])
       estimator_spec = replicated_model_fn(
           features, labels, model_fn_lib.ModeKeys.EVAL, self.params)
       session.run(variables.local_variables_initializer())
@@ -310,10 +320,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
 
     with self.test_session() as session:
       replicated_model_fn = replicate_model_fn.replicate_model_fn(
-          self.model_fn,
-          self.optimizer_fn,
-          losses.Reduction.MEAN,
-          devices=['/gpu:0', '/gpu:1'])
+          self.model_fn, losses.Reduction.MEAN, devices=['/gpu:0', '/gpu:1'])
       estimator_spec = replicated_model_fn(
           features, labels, model_fn_lib.ModeKeys.EVAL, self.params)
       session.run(variables.local_variables_initializer())
@@ -342,7 +349,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
 
     with self.test_session() as session:
       replicated_model_fn = replicate_model_fn.replicate_model_fn(
-          self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1'])
+          self.model_fn, devices=['/gpu:0', '/gpu:1'])
       estimator_spec = replicated_model_fn(
           features, labels, model_fn_lib.ModeKeys.PREDICT, self.params)
       session.run(variables.global_variables_initializer())
@@ -357,7 +364,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
 
     with self.test_session() as session:
       replicated_model_fn = replicate_model_fn.replicate_model_fn(
-          self.model_fn, self.optimizer_fn, devices=['/gpu:0'])
+          self.model_fn, devices=['/gpu:0'])
       estimator_spec = replicated_model_fn(
           features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
       session.run(variables.global_variables_initializer())
@@ -379,7 +386,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
 
     with self.test_session() as session:
       replicated_model_fn = replicate_model_fn.replicate_model_fn(
-          self.model_fn, self.optimizer_fn, devices=['/gpu:0'])
+          self.model_fn, devices=['/gpu:0'])
       estimator_spec = replicated_model_fn(
           features, labels, model_fn_lib.ModeKeys.EVAL, self.params)
       session.run(variables.local_variables_initializer())
@@ -407,7 +414,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
 
     with self.test_session() as session:
       replicated_model_fn = replicate_model_fn.replicate_model_fn(
-          self.model_fn, self.optimizer_fn, devices=['/gpu:0'])
+          self.model_fn, devices=['/gpu:0'])
       estimator_spec = replicated_model_fn(
           features, labels, model_fn_lib.ModeKeys.PREDICT, self.params)
       session.run(variables.global_variables_initializer())
@@ -417,9 +424,191 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
       }, session.run(estimator_spec.predictions))
 
   def test_unsupported_loss_reduction(self):
-    with self.assertRaisesRegexp(ValueError, ''):
-      _ = replicate_model_fn.replicate_model_fn(
-          self.model_fn, self.optimizer_fn, losses.Reduction.NONE)
+    with self.assertRaisesRegexp(ValueError,
+                                 '.+none.+reduction.+is.+specified.+'):
+      _ = replicate_model_fn.replicate_model_fn(self.model_fn,
+                                                losses.Reduction.NONE)
+
+
+class ReplicateWithTwoOptimizersTest(test_util.TensorFlowTestCase):
+
+  def model_fn(self, mode, features, labels, params):
+    c = variable_scope.get_variable(
+        'c',
+        initializer=constant_op.constant(10, dtype=dtypes.float64),
+        dtype=dtypes.float64)
+
+    predictions = math_ops.multiply(features, c)
+
+    loss = losses.absolute_difference(
+        labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
+    loss = math_ops.reduce_sum(loss)
+
+    metrics = {
+        'accuracy': metrics_lib.accuracy(labels, predictions),
+        'auc': metrics_lib.auc(labels, predictions)
+    }
+
+    first_optimizer = replicate_model_fn.GatheringOptimizer(
+        gradient_descent.GradientDescentOptimizer(1.0))
+    second_optimizer = replicate_model_fn.GatheringOptimizer(
+        adam.AdamOptimizer(1.0))
+
+    train_op = control_flow_ops.group(
+        [first_optimizer.minimize(loss),
+         second_optimizer.minimize(loss)])
+
+    return model_fn_lib.EstimatorSpec(
+        mode=mode,
+        loss=loss,
+        eval_metric_ops=metrics,
+        predictions={'probabilities': predictions},
+        train_op=train_op)
+
+  def test_train(self):
+    features = np.array([[1.0], [2.0]])
+    labels = np.array([[1.0], [2.0]])
+
+    with self.test_session() as session:
+      replicated_model_fn = replicate_model_fn.replicate_model_fn(
+          self.model_fn, devices=['/gpu:0', '/gpu:1'])
+      estimator_spec = replicated_model_fn(features, labels,
+                                           model_fn_lib.ModeKeys.TRAIN, {})
+      session.run(variables.global_variables_initializer())
+
+      # loss = feature * c - label
+      total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0)
+      self.assertEqual(total_loss, session.run(estimator_spec.loss))
+
+      # loss' of c is 3.
+      # new value of c = 10 - learning rate * 3 = 7.0.
+      # Adam subtracts another ~1.
+      session.run(estimator_spec.train_op)
+      with variable_scope.variable_scope('', reuse=True):
+        c = variable_scope.get_variable('c', dtype=dtypes.float64)
+        self.assertNear(6.0, session.run(c), 0.000001)
+
+
+class ReplicateWithTwoLossesAndOneOptimizer(test_util.TensorFlowTestCase):
+
+  class AnotherOptimizer(gradient_descent.GradientDescentOptimizer):
+    pass
+
+  def model_fn(self, mode, features, labels, params):
+    c = variable_scope.get_variable(
+        'c',
+        initializer=constant_op.constant(10, dtype=dtypes.float64),
+        dtype=dtypes.float64)
+    d = variable_scope.get_variable(
+        'd',
+        initializer=constant_op.constant(2, dtype=dtypes.float64),
+        dtype=dtypes.float64)
+
+    predictions = math_ops.multiply(features, c)
+
+    loss = losses.absolute_difference(
+        labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
+    loss = math_ops.reduce_sum(loss)
+
+    another_predictions = math_ops.multiply(features, d)
+    another_loss = losses.absolute_difference(
+        labels=labels,
+        predictions=another_predictions,
+        reduction=losses.Reduction.SUM)
+    another_loss = math_ops.reduce_sum(another_loss)
+
+    total_loss = math_ops.add(loss, another_loss)
+
+    metrics = {
+        'accuracy': metrics_lib.accuracy(labels, predictions),
+        'auc': metrics_lib.auc(labels, predictions)
+    }
+
+    optimizer = replicate_model_fn.GatheringOptimizer(
+        gradient_descent.GradientDescentOptimizer(1.0))
+    another_optimizer = replicate_model_fn.GatheringOptimizer(
+        self.AnotherOptimizer(1.0))
+
+    train_op = control_flow_ops.group([
+        optimizer.minimize(loss, var_list=[c]),
+        another_optimizer.minimize(another_loss, var_list=[d])
+    ])
+
+    return model_fn_lib.EstimatorSpec(
+        mode=mode,
+        loss=total_loss,
+        eval_metric_ops=metrics,
+        predictions={'probabilities': predictions},
+        train_op=train_op)
+
+  def test_train(self):
+    features = np.array([[1.0], [2.0]])
+    labels = np.array([[1.0], [2.0]])
+
+    with self.test_session() as session:
+      replicated_model_fn = replicate_model_fn.replicate_model_fn(
+          self.model_fn, devices=['/gpu:0', '/gpu:1'])
+      estimator_spec = replicated_model_fn(features, labels,
+                                           model_fn_lib.ModeKeys.TRAIN, {})
+      session.run(variables.global_variables_initializer())
+
+      # For each tower, loss = (feature * c - label) + (feature * d - label).
+      total_loss = (1.0 * 10 - 1.0 + 1.0 * 2.0 - 1.0) + (
+          2.0 * 10 - 2.0 + 2.0 * 2.0 - 2.0)
+      self.assertEqual(total_loss, session.run(estimator_spec.loss))
+
+      session.run(estimator_spec.train_op)
+
+      # loss' of c or loss' of d is 3.
+      # new value of c = 10 - learning rate * 3 = 7.0.
+      # new value of d = 2  - learning rate * 3 = -1.0.
+      with variable_scope.variable_scope('', reuse=True):
+        c = variable_scope.get_variable('c', dtype=dtypes.float64)
+        self.assertNear(7.0, session.run(c), 0.000001)
+        d = variable_scope.get_variable('d', dtype=dtypes.float64)
+        self.assertNear(-1.0, session.run(d), 0.000001)
+
+
+class FailToWrapOptimizerInTheModelFn(test_util.TensorFlowTestCase):
+
+  def model_fn(self, mode, features, labels, params):
+    c = variable_scope.get_variable(
+        'c',
+        initializer=constant_op.constant(10, dtype=dtypes.float64),
+        dtype=dtypes.float64)
+
+    predictions = math_ops.multiply(features, c)
+
+    loss = losses.absolute_difference(
+        labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
+    loss = math_ops.reduce_sum(loss)
+
+    metrics = {
+        'accuracy': metrics_lib.accuracy(labels, predictions),
+        'auc': metrics_lib.auc(labels, predictions)
+    }
+
+    optimizer = gradient_descent.GradientDescentOptimizer(1.0)
+    train_op = optimizer.minimize(loss)
+
+    return model_fn_lib.EstimatorSpec(
+        mode=mode,
+        loss=loss,
+        eval_metric_ops=metrics,
+        predictions={'probabilities': predictions},
+        train_op=train_op)
+
+  def test_train(self):
+    features = np.array([[1.0], [2.0]])
+    labels = np.array([[1.0], [2.0]])
+
+    with self.test_session():
+      with self.assertRaisesRegexp(ValueError,
+                                   'Please.+wrap.+with.+GatheringOptimizer'):
+        replicated_model_fn = replicate_model_fn.replicate_model_fn(
+            self.model_fn, devices=['/gpu:0', '/gpu:1'])
+        _ = replicated_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN,
+                                {})
 
 
 class GetLossTowersTest(test_util.TensorFlowTestCase):
@@ -889,16 +1078,14 @@ class ReduceMetricVariablesTest(test_util.TensorFlowTestCase):
           variables.variables_initializer(
               ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))
 
-      with self.assertRaisesRegexp(ValueError, ''):
+      with self.assertRaisesRegexp(
+          ValueError, '.+Expected.+local.+variables.+but.+got.+instead.+'):
         session.run(
             replicate_model_fn._reduce_metric_variables(number_of_towers=3))
 
 
 class MergeExportOutputsTest(test_util.TensorFlowTestCase):
 
-  def optimizer_fn(self):
-    return gradient_descent.GradientDescentOptimizer(1.0)
-
   def model_fn(self, mode, features, labels, params):
     c = variable_scope.get_variable(
         'c',
@@ -940,7 +1127,6 @@ class MergeExportOutputsTest(test_util.TensorFlowTestCase):
         loss=math_ops.reduce_sum(loss),
         eval_metric_ops=metrics,
         predictions=predictions,
-        train_op=loss,  # This train_op isn't actually used.
         export_outputs=export_outputs)
 
   def replicate_estimator_spec(self, session):
@@ -948,13 +1134,13 @@ class MergeExportOutputsTest(test_util.TensorFlowTestCase):
     labels = np.array([0.01, 0.02])
 
     replicated_model_fn = replicate_model_fn.replicate_model_fn(
-        self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1'])
+        self.model_fn, devices=['/gpu:0', '/gpu:1'])
     estimator_spec = replicated_model_fn(features, labels,
                                          model_fn_lib.ModeKeys.PREDICT, {})
     session.run(variables.global_variables_initializer())
     return estimator_spec
 
-  def test_merde_predict_output(self):
+  def test_merge_predict_output(self):
     with self.test_session() as session:
       estimator_spec = self.replicate_estimator_spec(session)
       self.assertAllClose(
@@ -1151,7 +1337,7 @@ class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase):
         dense_shape=constant_op.constant([2]))
     b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1])
 
-    with self.assertRaisesRegexp(ValueError, ''):
+    with self.assertRaisesRegexp(ValueError, '.+name.+not.+expected.+'):
       _ = replicate_model_fn._compute_sum_on_device(
           [a, b], device='/device:GPU:0', name='cant_name_indexslices')