Remove the hidden replicate_model_fn copy from core.
authorIgor Saprykin <isaprykin@google.com>
Mon, 16 Apr 2018 18:27:09 +0000 (11:27 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 16 Apr 2018 18:30:27 +0000 (11:30 -0700)
PiperOrigin-RevId: 193070799

tensorflow/python/estimator/BUILD
tensorflow/python/estimator/replicate_model_fn.py [deleted file]
tensorflow/python/estimator/replicate_model_fn_test.py [deleted file]

index a34405c..7bf4447 100644 (file)
@@ -7,7 +7,6 @@ package(
 licenses(["notice"])  # Apache 2.0
 
 load("//tensorflow:tensorflow.bzl", "py_test")
-load("//tensorflow:tensorflow.bzl", "cuda_py_test")
 
 py_library(
     name = "estimator_py",
@@ -25,7 +24,6 @@ py_library(
         ":linear",
         ":model_fn",
         ":parsing_utils",
-        ":replicate_model_fn",
         ":run_config",
         ":training",
         "//tensorflow/python:util",
@@ -909,68 +907,3 @@ py_test(
         "//tensorflow/python:training",
     ],
 )
-
-py_library(
-    name = "replicate_model_fn",
-    srcs = [
-        "replicate_model_fn.py",
-    ],
-    srcs_version = "PY2AND3",
-    deps = [
-        ":export_output",
-        ":model_fn",
-        ":util",
-        "//tensorflow/core:protos_all_py",
-        "//tensorflow/python:array_ops",
-        "//tensorflow/python:control_flow_ops",
-        "//tensorflow/python:device",
-        "//tensorflow/python:device_lib",
-        "//tensorflow/python:framework_ops",
-        "//tensorflow/python:math_ops",
-        "//tensorflow/python:platform",
-        "//tensorflow/python:sparse_ops",
-        "//tensorflow/python:sparse_tensor",
-        "//tensorflow/python:state_ops",
-        "//tensorflow/python:training",
-        "//tensorflow/python:variable_scope",
-        "//tensorflow/python/ops/losses",
-        "@six_archive//:six",
-    ],
-)
-
-cuda_py_test(
-    name = "replicate_model_fn_test",
-    size = "medium",
-    srcs = ["replicate_model_fn_test.py"],
-    additional_deps = [
-        "//tensorflow/python/estimator",
-        ":dnn",
-        ":export_export",
-        ":export_output",
-        ":model_fn",
-        ":numpy_io",
-        ":optimizers",
-        ":prediction_keys",
-        "//tensorflow/python/feature_column",
-        "//tensorflow/python/ops/losses",
-        "//tensorflow/python/saved_model:signature_constants",
-        "//tensorflow/python:array_ops",
-        "//tensorflow/python:client_testlib",
-        "//tensorflow/python:control_flow_ops",
-        "//tensorflow/python:framework_for_generated_wrappers",
-        "//tensorflow/python:framework_test_lib",
-        "//tensorflow/python:math_ops",
-        "//tensorflow/python:metrics",
-        "//tensorflow/python:platform",
-        "//tensorflow/python:summary",
-        "//tensorflow/python:training",
-        "//tensorflow/python:variable_scope",
-        "//tensorflow/python:variables",
-        ":replicate_model_fn",
-    ],
-    tags = [
-        "multi_gpu",
-        "noasan",  # flaky time outs
-        "notsan",  # flaky
-    ],
-)
diff --git a/tensorflow/python/estimator/replicate_model_fn.py b/tensorflow/python/estimator/replicate_model_fn.py
deleted file mode 100644 (file)
index 144d89a..0000000
+++ /dev/null
@@ -1,824 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Utilities to replicate model_fn's over local GPUs.
-
-This file contains util that allow to replicate `Estimator.model_fn` over
-GPUs.  Replicated version of a `model_fn` is returned that can subsequently
-be used with `Estimator`.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from collections import defaultdict
-from contextlib import contextmanager
-import copy
-
-import six
-
-from tensorflow.core.framework import node_def_pb2
-from tensorflow.python.client import device_lib
-from tensorflow.python.estimator import model_fn as model_fn_lib
-from tensorflow.python.estimator import util
-from tensorflow.python.estimator.export import export_output as export_output_lib
-from tensorflow.python.framework import device as framework_device
-from tensorflow.python.framework import ops as ops_lib
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import sparse_ops
-from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops.losses import losses
-from tensorflow.python.platform import tf_logging
-from tensorflow.python.training import device_setter as device_setter_lib
-from tensorflow.python.training import optimizer as optimizer_lib
-
-
-def _replicate_model_fn(model_fn,
-                        devices=None):
-  """Replicate `Estimator.model_fn` over GPUs.
-
-  The given `model_fn` specifies a single forward pass of a model.  To replicate
-  such a model over GPUs, each GPU gets its own instance of the forward pass
-  (a.k.a. a tower).  The input features and labels get sharded into the chunks
-  that correspond to the number of GPUs.  Each tower computes a loss based
-  on its input.  For each such loss, gradients are computed.  After that, the
-  available losses are aggregated to form aggregated loss.  Available
-  gradients are summed.  Then, they update weights using the specified
-  optimizer.
-
-  If `devices` are `None`, then all available GPUs are going to be used for
-  replication.  If no GPUs are available, then the model is going to be
-  placed on the CPU.
-
-  Two modes of local replication over available GPUs are supported:
-    1)  If exactly 1 GPU is detected, then variables and operations are placed
-        onto the GPU.
-    2)  If more than 1 GPU is detected, then variables are going to be placed on
-        the CPU.  Replicas of operations are placed on each individual GPU.
-
-  Here is an example of how one might use their `model_fn` to run over GPUs:
-    ```python
-       ...
-       def model_fn(...):  # See `model_fn` in `Estimator`.
-         loss = ...
-         optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
-         optimizer = tf.contrib.estimator._TowerOptimizer(optimizer)
-         if mode == tf.estimator.ModeKeys.TRAIN:
-           #  See the section below on `EstimatorSpec.train_op`.
-           return EstimatorSpec(mode=mode, loss=loss,
-                                train_op=optimizer.minimize(loss))
-
-         #  No change for `ModeKeys.EVAL` or `ModeKeys.PREDICT`.
-         return EstimatorSpec(...)
-       ...
-       classifier = tf.estimator.Estimator(
-         model_fn=tf.contrib.estimator.replicate_model_fn(model_fn))
-    ```
-
-  Please see `DNNClassifierIntegrationTest` for an example with a canned
-  Estimator.
-
-  On `EstimatorSpec.train_op`:
-  `model_fn` returns `EstimatorSpec.train_op` for
-  `tf.estimator.GraphKeys.TRAIN`. It is typically derived using an optimizer.
-  Towers are expected to populate it in the same way.  Gradients from all towers
-  are reduced and applied in the last tower.  To achieve that in the case of
-  multiple towers, `_TowerOptimizer` needs to be used.  See `_TowerOptimizer`.
-
-  On sharding input features and labels:
-  Input features and labels are split for consumption by each tower. They are
-  split across the dimension 0.  Features and labels need to be batch major.
-
-  On reduction algorithms:
-  Certain algorithms were chosen for aggregating results of computations on
-  multiple towers:
-    - Losses from all towers are reduced according to `loss_reduction` argument
-      to TowerOptimizer..
-    - Gradients from all towers are reduced according to the `loss_reduction`
-      for each trainable variable.
-    - `eval_metrics_ops` are reduced per metric using `reduce_mean`.
-    - `EstimatorSpec.predictions` and `EstimatorSpec.export_outputs` are
-      reduced using concatenation.
-    - For all other fields of `EstimatorSpec` the values of the first tower
-      are taken.
-
-  On distribution of variables:
-  Variables are not duplicated between towers.  Instead, they are placed on a
-  single device as defined above and shared across towers.
-
-  On overhead:
-  If only one device is specified, then aggregation of loss and gradients
-  doesn't happen. Replication consists of placing `model_fn` onto the
-  specified device.
-
-  On current limitations:
-    - `predictions` are not supported for `ModeKeys.EVAL`.  They are required
-       for `tf.contrib.estimator.add_metrics`.
-
-  Args:
-    model_fn: `model_fn` as defined in `Estimator`.  See the section above about
-      the train_op argument of `EstimatorSpec`.
-    devices: Optional list of devices to replicate the model across.  This
-      argument can be used to replice only on the subset of available GPUs.
-      If `None`, then all available GPUs are going to be used for replication.
-      If no GPUs are available, then the model is going to be placed on the CPU.
-
-  Returns:
-    A replicated version of the supplied `model_fn`. Returned function that
-      conforms to the requirements of `Estimator`'s `model_fn` and can be used
-      instead of the supplied `model_fn`.
-  """
-  return _replicate_model_fn_with_mode(
-      model_fn,
-      devices,
-      # TODO(isaprykin): Query the system configuration to choose modes other
-      # than `SHARED_LOCAL_PARAMETER_SERVER`, even though it is often
-      # appropriate.
-      mode=_VariableDistributionMode.SHARED_LOCAL_PARAMETER_SERVER)
-
-
-class _VariableDistributionMode(object):
-  """Modes for variable distribution used for forcing a particular one.
-
-  Forcing a mode is meant for performance experimentation purposes rather than
-  for general use cases.
-  """
-
-  SHARED_LOCAL_PARAMETER_SERVER = 1
-  """Variables are placed on a single device and shared across all devices.
-
-  Two ways to achieve this distribution over available GPUs are supported:
-    1)  If exactly 1 GPU is detected, then variables and operations are placed
-        onto GPU.
-    2)  If more than 1 GPU is detected, then variables are going to be placed on
-        the CPU.  Replicas of operations are placed on each individual GPU.
-  """
-
-  SHARED_ROUND_ROBIN = 2
-  """Variables are placed on all devices in a round-robin fashion.
-
-  Every subsequent variable is placed on the next device.  There is only one
-  copy of each variable that is shared across all devices.
-  """
-
-
-def _replicate_model_fn_with_mode(
-    model_fn,
-    devices=None,
-    mode=_VariableDistributionMode.SHARED_LOCAL_PARAMETER_SERVER):
-  """A version of `replicate_model_fn` that allows to specify a `mode`."""
-  if not devices:
-    devices = _get_local_devices('GPU') or _get_local_devices('CPU')
-
-  is_a_single_gpu_case = len(devices) == 1 and 'GPU' in devices[0].upper()
-  consolidation_device = devices[0] if is_a_single_gpu_case else '/CPU:0'
-
-  ps_devices = [consolidation_device]
-  if mode == _VariableDistributionMode.SHARED_ROUND_ROBIN:
-    ps_devices = devices
-
-  tf_logging.info('Replicating the `model_fn` across {}.  Variables are going '
-                  'to be placed on {}.  Consolidation device is going to be {}.'
-                  .format(devices, ps_devices, consolidation_device))
-
-  def single_device_model_fn(features, labels, mode, params=None, config=None):
-    """`model_fn` on a single device without reduction overhead."""
-    return _get_loss_towers(
-        model_fn=model_fn,
-        mode=mode,
-        features=[features],
-        labels=[labels],
-        params=params,
-        config=config,
-        devices=devices,
-        local_ps_devices=ps_devices)[0]  # One device, so one spec is out.
-
-  def replicated_model_fn(features, labels, mode, params=None, config=None):
-    """Replicated version of `model_fn` to be used instead."""
-    feature_shards, label_shards = _split_batch(
-        features, labels, len(devices), device=consolidation_device)
-    tower_specs = _get_loss_towers(
-        model_fn=model_fn,
-        mode=mode,
-        features=feature_shards,
-        labels=label_shards,
-        params=params,
-        config=config,
-        devices=devices,
-        local_ps_devices=ps_devices)
-
-    if mode == model_fn_lib.ModeKeys.TRAIN:
-      train_op = _minimize_towers(tower_specs)
-      return _train_spec(
-          tower_specs, train_op, aggregation_device=consolidation_device)
-    elif mode == model_fn_lib.ModeKeys.EVAL:
-      return _eval_spec(tower_specs, aggregation_device=consolidation_device)
-    elif mode == model_fn_lib.ModeKeys.PREDICT:
-      return _predict_spec(tower_specs, aggregation_device=consolidation_device)
-
-  if len(devices) == 1:
-    return single_device_model_fn
-  else:
-    return replicated_model_fn
-
-
-class _TowerOptimizer(optimizer_lib.Optimizer):
-  """Gathers gradients from all towers and reduces them in the last one."""
-
-  COLLECTION_FOR_GRAPH_STATES = 'replicate_model_fn_graph_states'
-
-  def __init__(self, optimizer_or_optimizer_fn,
-               loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE):
-    """Wrap an existing optimizer for gathering gradients across towers.
-
-    Each invocation of model_fn has to call the same optimizers in the same
-    order.
-
-    Multiple optimizers that use the same or different losses are supported.
-
-    If _TowerOptimizer is used but `replicate_model_fn` isn't, then no
-    aggregation will happen.  All calls will simply be forwarded to the
-    underlying optimizer. The behavior is similar if there is only one tower.
-
-    If _TowerOptimizer is used together with SyncReplicasOptimizer that wraps
-    the user's optimizer, then it's the SyncReplicasOptimizer that needs to be
-    wrapped with _TowerOptimizer.
-
-    Args:
-      optimizer_or_optimizer_fn: an instance of optimizer to wrap.  That
-        instance is going to be used for optimizer-specific logic.  This can
-        also be a no-argument function that returns such an optimizer instance.
-      loss_reduction: controls whether losses are summed or averaged.
-    """
-    self._optimizer_or_optimizer_fn = optimizer_or_optimizer_fn
-    self._loss_reduction = loss_reduction
-
-  @staticmethod
-  def has_been_used():
-    return _TowerOptimizer._graph_state().has_tower_optimizer_been_used
-
-  def get_slot(self, *args, **kwargs):
-    return self._get_optimizer().get_slot(*args, **kwargs)
-
-  def get_slot_names(self, *args, **kwargs):
-    return self._get_optimizer().get_slot_names(*args, **kwargs)
-
-  def get_name(self, *args, **kwargs):
-    return self._get_optimizer().get_name(*args, **kwargs)
-
-  def variables(self, *args, **kwargs):
-    return self._get_optimizer().variables(*args, **kwargs)
-
-  def compute_gradients(self, loss, *args, **kwargs):
-    """Compute gradients, but first, if needed, scale the loss."""
-    _TowerOptimizer._graph_state().set_loss_reduction(self._loss_reduction)
-    loss = _scale_loss(loss,
-                       self._loss_reduction,
-                       self._graph_state().number_of_towers)
-    return self._get_optimizer().compute_gradients(loss, *args, **kwargs)
-
-  def apply_gradients(self, grads_and_vars, global_step=None, **kwargs):
-    """Collect gradients updates to apply them with the last tower."""
-    if self._graph_state().number_of_towers == 1:
-      # Avoid the overhead of reduction if there's only one tower.
-      #
-      # There assumed to be only one tower if aggregation-related methods were
-      # not called by `_get_loss_towers`, for example if the model_fn uses
-      # TowerEstimator, but `replicate_model_fn` isn't used.
-      return self._get_optimizer().apply_gradients(grads_and_vars, global_step,
-                                                   **kwargs)
-
-    self._graph_state().collect_gradients(grads_and_vars)
-
-    if not self._graph_state().is_the_last_tower:
-      with ops_lib.control_dependencies(_extract_tensors(grads_and_vars)):
-        return self._construct_no_op_train_op()
-    else:
-      # Gradients need to be gathered and applied in the scope of the first
-      # tower, so that the tensors are accessible via names without prefixes.
-      var_scope, name_scope = self._graph_state().scopes_of_the_first_tower
-      with variable_scope.variable_scope(var_scope):
-        with ops_lib.name_scope(name_scope):
-          return self._apply_gathered_gradients(global_step, **kwargs)
-
-  def _apply_gathered_gradients(self, global_step, **kwargs):
-    graph_state = self._graph_state()
-    optimizer = self._get_optimizer()
-
-    grad_lists = {}
-    for grad, var in graph_state.get_latest_gradients_from_all_towers():
-      if grad is not None:
-        grad_lists.setdefault(var, []).append(grad)
-
-    aggregated_grads = []
-    with ops_lib.name_scope('gradient_aggregating'):
-      for var, grads in six.iteritems(grad_lists):
-        grad = _compute_sum_on_device(grads, var.device)
-        aggregated_grads.append((grad, var))
-    return optimizer.apply_gradients(
-        aggregated_grads, global_step=global_step, **kwargs)
-
-  def _get_optimizer(self):
-    if callable(self._optimizer_or_optimizer_fn):
-      # If optimizer is given as a function then we need to wait till we are
-      # under the right graph context before constructing it.  That's why the
-      # optimizer is constructed in _get_optimizer() rather than __init__().
-      self._optimizer_or_optimizer_fn = self._optimizer_or_optimizer_fn()
-    self._graph_state().has_tower_optimizer_been_used = True
-    return self._optimizer_or_optimizer_fn
-
-  def _construct_no_op_train_op(self):
-    return control_flow_ops.no_op(name='train_op_placeholder')
-
-  @staticmethod
-  def _graph_state():
-    graph_states = ops_lib.get_default_graph().get_collection_ref(
-        _TowerOptimizer.COLLECTION_FOR_GRAPH_STATES)
-    if not graph_states:
-      graph_states.append(_TowerOptimizer._PerGraphState())
-    return graph_states[-1]
-
-  @staticmethod
-  def _did_towers_have_same_optimizer_calls():
-    graph_state = _TowerOptimizer._graph_state()
-    return graph_state.did_towers_have_same_optimizer_calls()
-
-  @staticmethod
-  def _clear_graph_state():
-    # Clearing the Graph collection will prevent _PerGraphState from being
-    # serialized.
-    ops_lib.get_default_graph().clear_collection(
-        _TowerOptimizer.COLLECTION_FOR_GRAPH_STATES)
-
-  class _PerGraphState(object):
-    """Gradient reduction related state of a Tensorflow graph."""
-
-    def __init__(self):
-      self._collected_grads_and_vars = defaultdict(list)
-      self._current_tower_index = 0
-      self._number_of_towers = 1
-      self._loss_reduction = None
-      # Scopes of the first tower that don't have a prefix:
-      self._variable_scope = None
-      self._name_scope = None
-      # If needed, alert that _TowerOptimizer needs to be used with model_fn.
-      self._has_tower_optimizer_been_used = False
-
-    def collect_gradients(self, grads_and_vars):
-      self._collected_grads_and_vars[self._current_tower_index].append(
-          grads_and_vars)
-
-    def get_latest_gradients_from_all_towers(self):
-      """Get gradients across towers for the last called optimizer."""
-      grads_and_vars = []
-      index_of_last_gradients = len(
-          self._collected_grads_and_vars[self._current_tower_index]) - 1
-      for tower_id in range(self._current_tower_index + 1):
-        grads_and_vars.extend(
-            self._collected_grads_and_vars[tower_id][index_of_last_gradients])
-      return grads_and_vars
-
-    def set_number_of_towers(self, number_of_towers):
-      self._number_of_towers = number_of_towers
-
-    def set_loss_reduction(self, loss_reduction):
-      self._loss_reduction = loss_reduction
-
-    @contextmanager
-    def tower(self, tower_id, var_scope, name_scope):
-      if tower_id == 0:
-        self._variable_scope = var_scope
-        self._name_scope = name_scope
-      self._current_tower_index = tower_id
-      yield
-
-    @property
-    def scopes_of_the_first_tower(self):
-      return self._variable_scope, self._name_scope
-
-    @property
-    def is_the_last_tower(self):
-      return self._current_tower_index == (self._number_of_towers - 1)
-
-    @property
-    def number_of_towers(self):
-      return self._number_of_towers
-
-    @property
-    def loss_reduction(self):
-      return self._loss_reduction
-
-    @property
-    def has_tower_optimizer_been_used(self):
-      return self._has_tower_optimizer_been_used
-
-    @has_tower_optimizer_been_used.setter
-    def has_tower_optimizer_been_used(self, value):
-      self._has_tower_optimizer_been_used = value
-
-    def did_towers_have_same_optimizer_calls(self):
-      total_number_of_grads = sum([
-          len(grads)
-          for _, grads in six.iteritems(self._collected_grads_and_vars)
-      ])
-      return total_number_of_grads % self._number_of_towers == 0
-
-
-def _get_local_devices(device_type):
-  local_device_protos = device_lib.list_local_devices()
-  return [
-      device.name
-      for device in local_device_protos
-      if device.device_type == device_type
-  ]
-
-
-def _split_batch(features, labels, number_of_shards, device):
-  """Split input features and labes into batches."""
-
-  def ensure_divisible_by_shards(sequence):
-    batch_size = ops_lib.convert_to_tensor(sequence).get_shape()[0]
-    if batch_size % number_of_shards != 0:
-      raise ValueError(
-          'Batch size {} needs to be divisible by the number of GPUs, which '
-          'is {}.'.format(batch_size, number_of_shards))
-
-  def split_dictionary(dictionary):
-    """Split a dictionary into shards."""
-    shards = [{} for _ in range(number_of_shards)]
-    for name, tensor in six.iteritems(dictionary):
-      if isinstance(tensor, sparse_tensor.SparseTensor):
-        for i, shard in enumerate(
-            sparse_ops.sparse_split(
-                sp_input=tensor, num_split=number_of_shards, axis=0)):
-          shards[i][name] = shard
-      else:
-        ensure_divisible_by_shards(tensor)
-        for i, shard in enumerate(array_ops.split(tensor, number_of_shards)):
-          shards[i][name] = shard
-    return shards
-
-  with ops_lib.name_scope('split_inputs'):
-    with ops_lib.device(device):
-      if isinstance(features, dict):
-        feature_shards = split_dictionary(features)
-      else:
-        ensure_divisible_by_shards(features)
-        feature_shards = array_ops.split(features, number_of_shards)
-
-      if labels is None:
-        label_shards = None
-      elif isinstance(labels, dict):
-        label_shards = split_dictionary(labels)
-      else:
-        ensure_divisible_by_shards(labels)
-        label_shards = array_ops.split(labels, number_of_shards)
-  return feature_shards, label_shards
-
-
-_DEFAULT_NAME_SCOPE_PATTERN = 'tower_{}'
-
-
-def _get_loss_towers(model_fn,
-                     mode,
-                     features,
-                     labels,
-                     params,
-                     config,
-                     devices,
-                     local_ps_devices,
-                     name_scope_pattern=_DEFAULT_NAME_SCOPE_PATTERN):
-  """Replicate the loss computation across devices."""
-  tower_specs = []
-
-  model_fn_args = util.fn_args(model_fn)
-  optional_params = {}
-  if 'params' in model_fn_args:
-    optional_params['params'] = copy.deepcopy(params)
-  if 'config' in model_fn_args:
-    optional_params['config'] = copy.deepcopy(config)
-
-  # pylint: disable=protected-access
-  round_robin_strategy = device_setter_lib._RoundRobinStrategy(
-      num_tasks=len(local_ps_devices))
-  _TowerOptimizer._graph_state().set_number_of_towers(len(devices))
-
-  for i, device in enumerate(devices):
-    is_the_first_tower = (i == 0)
-
-    device_setter = _local_device_setter(
-        worker_device=device,
-        ps_devices=local_ps_devices,
-        ps_strategy=round_robin_strategy)
-
-    # We would like to preserve the names of the variables and ops that the user
-    # might be relying on. Names without a prefix are going to resolve to
-    # variables and ops of the first tower.
-    name_scope = name_scope_pattern
-    if is_the_first_tower:
-      name_scope = ''
-
-    with variable_scope.variable_scope(
-        '', reuse=not is_the_first_tower) as var_scope:
-      with ops_lib.name_scope(name_scope.format(i)) as name_scope:
-        with _TowerOptimizer._graph_state().tower(
-            tower_id=i, var_scope=var_scope, name_scope=name_scope):
-          with ops_lib.device(device_setter):
-            labels_shard = None
-            if labels:
-              labels_shard = labels[i]
-
-            tower_spec = model_fn(
-                mode=mode,
-                features=features[i],
-                labels=labels_shard,
-                **optional_params)
-
-            if (tower_spec.train_op is not None and len(devices) > 1 and
-                not _TowerOptimizer.has_been_used()):
-              raise ValueError('Please wrap optimizers with _TowerOptimizer'
-                               ' in order to use replicate_model_fn with'
-                               ' multiple `devices`.')
-
-            # Scaling the loss here doesn't actually affect gradients.  Another
-            # instance of scaling happens inside the _TowerOptimizer.
-            tower_spec = _scale_tower_loss(
-                tower_spec,
-                _TowerOptimizer._graph_state().loss_reduction,
-                number_of_towers=len(devices))
-            tower_specs.append(tower_spec)
-
-  if not _TowerOptimizer._did_towers_have_same_optimizer_calls():
-    raise ValueError('Each invocation of model_fn was supposed to make the same'
-                     ' optimizer calls.')
-  _TowerOptimizer._clear_graph_state()
-  # pylint: enable=protected-access
-  return tower_specs
-
-
-def _local_device_setter(worker_device, ps_devices, ps_strategy):
-  """A device setter that puts distributes Var/Ops to PS/workers."""
-  ps_ops = ['Variable', 'VariableV2', 'VarHandleOp']
-
-  def local_device_chooser(op):
-    current_device = framework_device.DeviceSpec.from_string(op.device or '')
-
-    node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
-    if node_def.op in ps_ops:
-      ps_device_spec = framework_device.DeviceSpec.from_string(
-          '{}'.format(ps_devices[ps_strategy(op)]))
-
-      ps_device_spec.merge_from(current_device)
-      return ps_device_spec.to_string()
-    else:
-      worker_device_spec = framework_device.DeviceSpec.from_string(
-          worker_device or '')
-      worker_device_spec.merge_from(current_device)
-      return worker_device_spec.to_string()
-
-  return local_device_chooser
-
-
-def _scale_tower_loss(tower_spec, loss_reduction, number_of_towers):
-  """Produce an EstimatorSpec with approproriately scaled loss."""
-  if tower_spec.loss is None:
-    return tower_spec
-
-  estimator_spec = _asdict(tower_spec)
-  estimator_spec['loss'] = _scale_loss(
-      tower_spec.loss,
-      loss_reduction,
-      number_of_towers,
-      reduced_loss_name='averaged_loss')
-  return model_fn_lib.EstimatorSpec(**estimator_spec)
-
-
-def _scale_loss(loss, loss_reduction, number_of_towers, reduced_loss_name=None):
-  """If needed, scale down the loss for averaging loss by summing."""
-  if loss is None:
-    return None
-  if number_of_towers == 1:
-    return loss
-
-  if loss_reduction == losses.Reduction.NONE:
-    raise ValueError('Tower losses need to be reduced in some way, yet {} '
-                     'reduction is specified.'.format(loss_reduction))
-
-  if loss_reduction != losses.Reduction.SUM:
-    return math_ops.div(loss, 1.0 * number_of_towers, name=reduced_loss_name)
-  else:
-    return loss
-
-
-def _minimize_towers(tower_specs):
-  """`train_op` of the last tower applies aggregated gradients."""
-  return tower_specs[-1].train_op
-
-
-def _compute_sum_on_device(values, device, name=None):
-  with ops_lib.device(device):
-    if isinstance(values[0], ops_lib.IndexedSlices):
-      if name:
-        raise ValueError('The name {} is not expected to be given to '
-                         'IndexedSlices {}'.format(name, values))
-
-      values_concat = array_ops.concat([v.values for v in values], axis=0)
-      indices_concat = array_ops.concat([v.indices for v in values], axis=0)
-      return ops_lib.IndexedSlices(values_concat, indices_concat,
-                                   values[0].dense_shape)
-    else:
-      return math_ops.add_n(values, name=name)
-
-
-def _train_spec(tower_specs,
-                train_op,
-                aggregation_device,
-                aggregated_loss_name='loss'):
-  """Populate replicated EstimatorSpec for `GraphKeys.TRAIN`."""
-  # Spec of the last tower is used as the template for the final spec, because
-  # some `EstimatorSpec.training_hooks` rely on calls made in model_fn.  For
-  # example, `SyncReplicasOptimizerHook` validates the
-  # `SyncReplicasOptimizer.apply_gradients` call. `TowerEstimator` makes that
-  # call only in the last tower.
-  estimator_spec = _asdict(tower_specs[-1])
-  estimator_spec['mode'] = model_fn_lib.ModeKeys.TRAIN
-  estimator_spec['train_op'] = train_op
-  estimator_spec['loss'] = _compute_sum_on_device(
-      [spec.loss for spec in tower_specs], aggregation_device,
-      aggregated_loss_name)
-  return model_fn_lib.EstimatorSpec(**estimator_spec)
-
-
-def _eval_spec(tower_specs, aggregation_device, aggregated_loss_name='loss'):
-  """Populate replicated EstimatorSpec for `GraphKeys.EVAL`."""
-  estimator_spec = _asdict(tower_specs[0])
-  estimator_spec['mode'] = model_fn_lib.ModeKeys.EVAL
-  estimator_spec['loss'] = _compute_sum_on_device(
-      [spec.loss for spec in tower_specs], aggregation_device,
-      aggregated_loss_name)
-
-  update_ops = []
-  for tower_spec in tower_specs:
-    for name, (_, update_op) in six.iteritems(tower_spec.eval_metric_ops):
-      update_ops.append(update_op)
-
-  with ops_lib.control_dependencies(update_ops):
-    reduced_update_op = _reduce_metric_variables(len(tower_specs))
-
-  eval_metric_ops = {}
-  for name, (metric_tensor, _) in six.iteritems(tower_specs[0].eval_metric_ops):
-    eval_metric_ops[name] = (metric_tensor, reduced_update_op)
-  estimator_spec['eval_metric_ops'] = eval_metric_ops
-  return model_fn_lib.EstimatorSpec(**estimator_spec)
-
-
-def _reduce_metric_variables(number_of_towers):
-  """Aggregate local variables used in metrics into the first tower."""
-  if number_of_towers == 1:
-    return control_flow_ops.no_op(name='no_eval_metric_reduction')
-
-  metric_variables = ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)
-  variables_per_tower = len(metric_variables) // number_of_towers
-
-  if len(metric_variables) % number_of_towers != 0:
-    raise ValueError(
-        'Different `EstimatorSpec.eval_metric_ops` across `model_fn()` calls.'
-        ' Expected {} local variables, but got {} instead.'.format(
-            variables_per_tower * number_of_towers, len(metric_variables)))
-
-  # `metric_variables` has the size of `variables_per_tower` x
-  #  number_of_towers.  Each tower is produced by calling the same model_fn.
-  #  First `variables_per_tower` correspond to the first tower.  Each such
-  #  variable has an replica at the `(variables_per_tower * i)` position, where
-  #  `i` is `[1.. number_of_towers]`.  We are going to add values from replicas
-  #  to each variable of the first tower.  We then zero out replica values, so
-  #  that `_reduce_metric_variables` operation is idempotent.  If a metric
-  #  is then computed based on local variables from the first tower, then the
-  #  resulting metric is an estimate for all `number_of_towers` towers.
-  ops = []
-  for i in range(0, variables_per_tower):
-    next_replica_id = i + variables_per_tower
-    replicas = [
-        metric_variables[replica_id]
-        for replica_id in range(next_replica_id, len(metric_variables),
-                                variables_per_tower)
-    ]  #  `replicas` doesn't contain the first-tower variable.
-
-    reduce_op = state_ops.assign_add(metric_variables[i],
-                                     math_ops.add_n(replicas))
-
-    with ops_lib.control_dependencies([reduce_op]):
-      for replica in replicas:
-        zeros_for_replica = array_ops.zeros(
-            array_ops.shape(replica), dtype=replica.dtype)
-        zero_out_replica_op = state_ops.assign(replica, zeros_for_replica)
-        ops.append(zero_out_replica_op)
-
-  return control_flow_ops.group(*ops)
-
-
-def _predict_spec(tower_specs, aggregation_device):
-  """Populate replicated EstimatorSpec for `GraphKeys.PREDICT`."""
-  estimator_spec = _asdict(tower_specs[0])
-  estimator_spec['mode'] = model_fn_lib.ModeKeys.PREDICT
-
-  with ops_lib.device(aggregation_device):
-    estimator_spec['predictions'] = _concat_tensor_dicts(
-        *[tower_spec.predictions for tower_spec in tower_specs])
-
-    export_outputs_dict = _dict_concat(
-        *[tower_spec.export_outputs for tower_spec in tower_specs])
-
-    export_outputs = {}
-    for name, export_output_list in six.iteritems(export_outputs_dict):
-      if isinstance(export_output_list[0], export_output_lib.PredictOutput):
-        export_outputs[name] = export_output_lib.PredictOutput(
-            outputs=_concat_tensor_dicts(*[
-                export_output.outputs for export_output in export_output_list
-            ]))
-      elif isinstance(export_output_list[0],
-                      export_output_lib.RegressionOutput):
-        export_outputs[name] = export_output_lib.RegressionOutput(
-            value=array_ops.concat(
-                [export_output.value for export_output in export_output_list],
-                axis=0))
-      elif isinstance(export_output_list[0],
-                      export_output_lib.ClassificationOutput):
-        scores = None
-        if export_output_list[0].scores is not None:
-          scores = array_ops.concat(
-              [export_output.scores for export_output in export_output_list],
-              axis=0)
-
-        classes = None
-        if export_output_list[0].classes is not None:
-          classes = array_ops.stack(
-              [export_output.classes for export_output in export_output_list],
-              axis=0)
-
-        export_outputs[name] = export_output_lib.ClassificationOutput(
-            scores=scores, classes=classes)
-
-  estimator_spec['export_outputs'] = export_outputs
-  return model_fn_lib.EstimatorSpec(**estimator_spec)
-
-
-def _concat_tensor_dicts(*tensor_dicts):
-  return {
-      name: array_ops.concat(tensors, axis=0, name=name)
-      for name, tensors in six.iteritems(_dict_concat(*tensor_dicts))
-  }
-
-
-def _extract_tensors(tensors_and_vars):
-  tensors = []
-  for tensor_and_var in tensors_and_vars:
-    tensor, _ = tensor_and_var
-    if isinstance(tensor, ops_lib.IndexedSlices):
-      tensors.append(tensor.values)
-    elif tensor is not None:
-      tensors.append(tensor)
-  return tensors
-
-
-def _dict_concat(*dicts):
-  list_dict = {}
-  for d in dicts:
-    if d is None:
-      continue
-
-    for k, v in six.iteritems(d):
-      list_dict.setdefault(k, []).append(v)
-  return list_dict
-
-
-def _asdict(namedtuple):
-  """Returns a namedtuple as a dictionary.
-
-  This is required because `_asdict()` in Python 3.x.x is broken in classes
-  that inherit from `collections.namedtuple`. See
-  https://bugs.python.org/issue24931 for more details.
-
-  Args:
-    namedtuple: An object that inherits from `collections.namedtuple`.
-
-  Returns:
-    A dictionary version of the tuple.
-  """
-  return {k: getattr(namedtuple, k) for k in namedtuple._fields}
diff --git a/tensorflow/python/estimator/replicate_model_fn_test.py b/tensorflow/python/estimator/replicate_model_fn_test.py
deleted file mode 100644 (file)
index ad1f9c0..0000000
+++ /dev/null
@@ -1,1739 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for utilities that replicate `Estimator.model_fn` over GPUs."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import re
-import shutil
-import tempfile
-import numpy as np
-import six
-
-from tensorflow.python.estimator import estimator as estimator_lib
-from tensorflow.python.estimator import model_fn as model_fn_lib
-from tensorflow.python.estimator import replicate_model_fn
-from tensorflow.python.estimator.canned import dnn
-from tensorflow.python.estimator.canned import optimizers
-from tensorflow.python.estimator.canned import prediction_keys
-from tensorflow.python.estimator.export import export
-from tensorflow.python.estimator.export import export_output
-from tensorflow.python.estimator.inputs import numpy_io
-from tensorflow.python.feature_column import feature_column
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops as ops_lib
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.framework import test_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import losses
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import metrics as metrics_lib
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import variables
-from tensorflow.python.ops.losses import losses
-from tensorflow.python.platform import gfile
-from tensorflow.python.platform import test
-from tensorflow.python.saved_model import signature_constants
-from tensorflow.python.summary.writer import writer_cache
-from tensorflow.python.training import adam
-from tensorflow.python.training import device_setter
-from tensorflow.python.training import gradient_descent
-from tensorflow.python.training import training
-
-
-# TODO(isaprykin):  Parametrize all the tests on
-#   replicate_model_fn._VariableDistributionMode when it's supported.
-class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase):
-
-  def setUp(self):
-    self._model_dir = tempfile.mkdtemp()
-
-  def test_complete_flow_with_public_version(self):
-    return self._complete_flow_with_mode(mode=None)
-
-  def test_complete_flow_with_mode_local_ps_server(self):
-    return self._complete_flow_with_mode(
-        replicate_model_fn._VariableDistributionMode.
-        SHARED_LOCAL_PARAMETER_SERVER)
-
-  def test_complete_flow_with_mode_round_robin(self):
-    return self._complete_flow_with_mode(
-        replicate_model_fn._VariableDistributionMode.SHARED_ROUND_ROBIN)
-
-  def _complete_flow_with_mode(self, mode):
-    n_classes = 3
-    input_dimension = 2
-    batch_size = 12
-
-    data = np.linspace(
-        0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
-    x_data = data.reshape(batch_size, input_dimension)
-    categorical_data = np.random.random_integers(
-        0, len(x_data), size=len(x_data))
-    y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1))
-    train_input_fn = numpy_io.numpy_input_fn(
-        x={'x': x_data,
-           'categories': categorical_data},
-        y=y_data,
-        batch_size=batch_size,
-        num_epochs=None,
-        shuffle=True)
-    eval_input_fn = numpy_io.numpy_input_fn(
-        x={'x': x_data,
-           'categories': categorical_data},
-        y=y_data,
-        batch_size=batch_size,
-        shuffle=False)
-    predict_input_fn = numpy_io.numpy_input_fn(
-        x={'x': x_data,
-           'categories': categorical_data},
-        batch_size=batch_size,
-        shuffle=False)
-
-    feature_columns = [
-        feature_column.numeric_column('x', shape=(input_dimension,)),
-        feature_column.embedding_column(
-            feature_column.categorical_column_with_vocabulary_list(
-                'categories',
-                vocabulary_list=np.linspace(
-                    0., len(x_data), len(x_data), dtype=np.int64)), 1)
-    ]
-
-    def optimizer_fn():
-      return optimizers.get_optimizer_instance('Adagrad', learning_rate=0.05)
-
-    estimator = dnn.DNNClassifier(
-        hidden_units=(2, 2),
-        # Adagrad is configured with `get_optimizer_instance`, so the function
-        # form of `TowerOptimizer.__init__` is used.
-        optimizer=replicate_model_fn._TowerOptimizer(
-            optimizer_fn, loss_reduction=losses.Reduction.SUM),
-        feature_columns=feature_columns,
-        n_classes=n_classes,
-        model_dir=self._model_dir)
-
-    if not mode:  # Use the public `replicate_model_fn`.
-      model_fn = replicate_model_fn._replicate_model_fn(
-          estimator.model_fn, devices=['/gpu:0', '/gpu:1', '/gpu:2'])
-    else:
-      model_fn = replicate_model_fn._replicate_model_fn_with_mode(
-          estimator.model_fn,
-          devices=['/gpu:0', '/gpu:1', '/gpu:2'],
-          mode=mode)
-
-    estimator = estimator_lib.Estimator(
-        model_fn=model_fn,
-        model_dir=estimator.model_dir,
-        config=estimator.config,
-        params=estimator.params)
-
-    num_steps = 10
-    estimator.train(train_input_fn, steps=num_steps)
-
-    scores = estimator.evaluate(eval_input_fn)
-    self.assertEqual(num_steps, scores[ops_lib.GraphKeys.GLOBAL_STEP])
-    self.assertIn('loss', six.iterkeys(scores))
-
-    predicted_proba = np.array([
-        x[prediction_keys.PredictionKeys.PROBABILITIES]
-        for x in estimator.predict(predict_input_fn)
-    ])
-    self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)
-
-    feature_spec = feature_column.make_parse_example_spec(feature_columns)
-    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
-        feature_spec)
-    export_dir = estimator.export_savedmodel(tempfile.mkdtemp(),
-                                             serving_input_receiver_fn)
-    self.assertTrue(gfile.Exists(export_dir))
-
-    # Nothing should be left in the graph so that it doesn't get serialized.
-    self.assertFalse(ops_lib.get_default_graph().get_collection_ref(
-        replicate_model_fn._TowerOptimizer.COLLECTION_FOR_GRAPH_STATES))
-
-  def _as_label(self, data_in_float):
-    return np.rint(data_in_float).astype(np.int64)
-
-  def tearDown(self):
-    if self._model_dir:
-      writer_cache.FileWriterCache.clear()
-      shutil.rmtree(self._model_dir)
-
-
-class ReplicateModelTest(test_util.TensorFlowTestCase):
-
-  def create_model_fn_with_loss_reduction(self, loss_reduction):
-
-    def model_fn(mode, features, labels, params):
-      c = variable_scope.get_variable(
-          'c',
-          initializer=constant_op.constant(10, dtype=dtypes.float64),
-          dtype=dtypes.float64)
-
-      predictions = math_ops.multiply(features, c)
-
-      loss = losses.absolute_difference(
-          labels=labels,
-          predictions=predictions,
-          reduction=losses.Reduction.SUM)
-      loss = math_ops.reduce_sum(loss)
-
-      metrics = {
-          'accuracy': metrics_lib.accuracy(labels, predictions),
-          'auc': metrics_lib.auc(labels, predictions)
-      }
-
-      optimizer = replicate_model_fn._TowerOptimizer(
-          gradient_descent.GradientDescentOptimizer(params['learning_rate']),
-          loss_reduction=loss_reduction)
-
-      return model_fn_lib.EstimatorSpec(
-          mode=mode,
-          loss=loss,
-          eval_metric_ops=metrics,
-          predictions={'probabilities': predictions},
-          train_op=optimizer.minimize(loss))
-
-    return model_fn
-
-  @property
-  def params(self):
-    params = {}
-    params['learning_rate'] = 1.0
-    return params
-
-  def test_train(self):
-    features = np.array([[1.0], [2.0]])
-    labels = np.array([[1.0], [2.0]])
-
-    with self.test_session() as session:
-      replicated_model_fn = replicate_model_fn._replicate_model_fn(
-          self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
-          devices=['/gpu:0', '/gpu:1'])
-      estimator_spec = replicated_model_fn(
-          features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
-      session.run(variables.global_variables_initializer())
-
-      # loss = feature * c - label
-      total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0)
-      self.assertEqual(total_loss, session.run(estimator_spec.loss))
-
-      # derivative of loss = (1*c - 1) + (2*c - 2) is 3.
-      # new value of c = 10 - learning rate * 3 = 7.0.
-      session.run(estimator_spec.train_op)
-      with variable_scope.variable_scope('', reuse=True):
-        c = variable_scope.get_variable('c', dtype=dtypes.float64)
-        self.assertEqual(7.0, session.run(c))
-
-  def test_train_with_mean_reduction(self):
-    features = np.array([[1.0], [2.0]])
-    labels = np.array([[1.0], [2.0]])
-
-    with self.test_session() as session:
-      # Add another trainable variable that doesn't produce a gradient to
-      # verify that None gradients are supported.
-      _ = variable_scope.get_variable(
-          'another_variable',
-          initializer=constant_op.constant(1, dtype=dtypes.float64),
-          dtype=dtypes.float64)
-
-      replicated_model_fn = replicate_model_fn._replicate_model_fn(
-          self.create_model_fn_with_loss_reduction(losses.Reduction.MEAN),
-          devices=['/gpu:0', '/gpu:1'])
-      estimator_spec = replicated_model_fn(
-          features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
-      session.run(variables.global_variables_initializer())
-
-      # loss = feature * c - label
-      total_loss = ((1.0 * 10 - 1.0) + (2.0 * 10 - 2.0)) / 2.0
-      self.assertEqual(total_loss, session.run(estimator_spec.loss))
-
-      # derivative of loss = (1*c - 1)/2 + (2*c - 2)/2 is 1.5.
-      # It's the same computation as without mean reduction, but the
-      # loss from every tower is scaled by 1/<number of towers>.
-      # new value of c = 10 - learning rate * 1.5 = 8.5
-      session.run(estimator_spec.train_op)
-      with variable_scope.variable_scope('', reuse=True):
-        c = variable_scope.get_variable('c', dtype=dtypes.float64)
-        self.assertEqual(8.5, session.run(c))
-
-  def test_train_two_steps_collected_gradients_are_reset_between_steps(self):
-    with ops_lib.Graph().as_default():
-      features = array_ops.placeholder(dtypes.float64)
-      labels = array_ops.placeholder(dtypes.float64)
-
-      feature_inputs = np.array([[1.0], [2.0]]), np.array([[1.5], [2.5]])
-      label_inputs = np.array([[1.0], [2.0]]), np.array([[1.5], [2.5]])
-
-      # loss = feature * c - label
-      expected_losses = ((1.0 * 10 - 1.0) + (2.0 * 10 - 2.0),
-                         (1.5 * 7.0 - 1.5) + (2.5 * 7.0 - 2.5))
-      # Derivative of the loss is 1.0 + 2.0 for the first step and 1.5 + 2.5
-      # for the second.
-      expected_c = 10.0 - 3.0, 7.0 - 4.0
-
-      with self.test_session() as session, variable_scope.variable_scope(
-          '', reuse=variable_scope.AUTO_REUSE):
-        replicated_model_fn = replicate_model_fn._replicate_model_fn(
-            self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
-            devices=['/gpu:0', '/gpu:1'])
-        estimator_spec = replicated_model_fn(
-            features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
-        session.run(variables.global_variables_initializer())
-
-        for feature_input, label_input, loss, weight in zip(
-            feature_inputs, label_inputs, expected_losses, expected_c):
-          feeds = {features: feature_input, labels: label_input}
-
-          self.assertEqual(loss, session.run(estimator_spec.loss, feeds))
-
-          session.run(estimator_spec.train_op, feeds)
-          c = variable_scope.get_variable('c', dtype=dtypes.float64)
-          self.assertEqual(weight, session.run(c, feeds))
-
-  def test_eval(self):
-    features = np.array([[0.01], [0.002]])
-    labels = np.array([[0.01], [0.02]])
-
-    with self.test_session() as session:
-      replicated_model_fn = replicate_model_fn._replicate_model_fn(
-          self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
-          devices=['/gpu:0', '/gpu:1'])
-      estimator_spec = replicated_model_fn(
-          features, labels, model_fn_lib.ModeKeys.EVAL, self.params)
-      session.run(variables.local_variables_initializer())
-      session.run(variables.global_variables_initializer())
-
-      accuracy, a = estimator_spec.eval_metric_ops['accuracy']
-      auc, b = estimator_spec.eval_metric_ops['auc']
-
-      session.run([a, b])
-      accuracy = session.run(accuracy)
-      auc = session.run(auc)
-
-      # loss[i] = features[i] * 10 - labels[i].
-      # Accuracy is 0.0 (no match) in the first tower.
-      # Accuracy is 1.0 (match) in the second tower, since the feature
-      # times weight "c" happened to be equal to the label.
-      total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02))
-
-      self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01)
-      self.assertEqual(0, auc)
-      self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01)
-
-  def test_eval_with_mean_reduction(self):
-    features = np.array([[0.01], [0.002]])
-    labels = np.array([[0.01], [0.02]])
-
-    with self.test_session() as session:
-      replicated_model_fn = replicate_model_fn._replicate_model_fn(
-          self.create_model_fn_with_loss_reduction(losses.Reduction.MEAN),
-          devices=['/gpu:0', '/gpu:1'])
-      estimator_spec = replicated_model_fn(
-          features, labels, model_fn_lib.ModeKeys.EVAL, self.params)
-      session.run(variables.local_variables_initializer())
-      session.run(variables.global_variables_initializer())
-
-      accuracy, a = estimator_spec.eval_metric_ops['accuracy']
-      auc, b = estimator_spec.eval_metric_ops['auc']
-
-      session.run([a, b])
-      accuracy = session.run(accuracy)
-      auc = session.run(auc)
-
-      # loss[i] = features[i] * 10 - labels[i].
-      # Accuracy is 0.0 (no match) in the first tower.
-      # Accuracy is 1.0 (match) in the second tower, since the feature
-      # times weight "c" happened to be equal to the label.
-      total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02)) / 2.0
-
-      self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01)
-      self.assertEqual(0, auc)
-      self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01)
-
-  def test_predict(self):
-    features = np.array([[0.01], [0.002]])
-    labels = np.array([[0.01], [0.02]])
-
-    with self.test_session() as session:
-      replicated_model_fn = replicate_model_fn._replicate_model_fn(
-          self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
-          devices=['/gpu:0', '/gpu:1'])
-      estimator_spec = replicated_model_fn(
-          features, labels, model_fn_lib.ModeKeys.PREDICT, self.params)
-      session.run(variables.global_variables_initializer())
-
-      self.assertAllClose({
-          'probabilities': np.array([[0.1], [0.02]])
-      }, session.run(estimator_spec.predictions))
-
-  def test_train_single_tower(self):
-    features = np.array([[1.0], [2.0]])
-    labels = np.array([[1.0], [2.0]])
-
-    with self.test_session() as session:
-      replicated_model_fn = replicate_model_fn._replicate_model_fn(
-          self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
-          devices=['/gpu:0'])
-      estimator_spec = replicated_model_fn(
-          features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
-      session.run(variables.global_variables_initializer())
-
-      # loss = feature * c - label
-      total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0)
-      self.assertEqual(total_loss, session.run(estimator_spec.loss))
-
-      # loss' of c is 3.
-      # new value of c = 10 - learning rate * 3 = 7.0.
-      session.run(estimator_spec.train_op)
-      with variable_scope.variable_scope('', reuse=True):
-        c = variable_scope.get_variable('c', dtype=dtypes.float64)
-        self.assertEqual(7.0, session.run(c))
-
-  def test_eval_single_tower(self):
-    features = np.array([[0.01], [0.002]])
-    labels = np.array([[0.01], [0.02]])
-
-    with self.test_session() as session:
-      replicated_model_fn = replicate_model_fn._replicate_model_fn(
-          self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
-          devices=['/gpu:0'])
-      estimator_spec = replicated_model_fn(
-          features, labels, model_fn_lib.ModeKeys.EVAL, self.params)
-      session.run(variables.local_variables_initializer())
-      session.run(variables.global_variables_initializer())
-
-      accuracy, a = estimator_spec.eval_metric_ops['accuracy']
-      auc, b = estimator_spec.eval_metric_ops['auc']
-
-      session.run([a, b])
-      accuracy = session.run(accuracy)
-      auc = session.run(auc)
-
-      # Accuracy is 0.0 (no match) in the first tower.
-      # Accuracy is 1.0 (match) in the second tower, since the feature
-      # times weight "c" happened to be equal to the label.
-      total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02))
-
-      self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01)
-      self.assertEqual(0, auc)
-      self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01)
-
-  def test_predict_single_tower(self):
-    features = np.array([[0.01], [0.002]])
-    labels = np.array([[0.01], [0.02]])
-
-    with self.test_session() as session:
-      replicated_model_fn = replicate_model_fn._replicate_model_fn(
-          self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
-          devices=['/gpu:0'])
-      estimator_spec = replicated_model_fn(
-          features, labels, model_fn_lib.ModeKeys.PREDICT, self.params)
-      session.run(variables.global_variables_initializer())
-
-      self.assertAllClose({
-          'probabilities': np.array([[0.1], [0.02]])
-      }, session.run(estimator_spec.predictions))
-
-  def test_batch_size_that_is_not_divisible_by_the_number_of_gpus(self):
-    features = np.array([[1.0], [2.0], [3.0]])
-    labels = np.array([[1.0], [2.0], [3.0]])
-
-    with self.assertRaisesRegexp(
-        ValueError, '.*Batch.+size.+needs.+to.+be.+divisible.+by.+GPUs.+'):
-      replicated_model_fn = replicate_model_fn._replicate_model_fn(
-          self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
-          devices=['/gpu:0', '/gpu:1'])
-      _ = replicated_model_fn(
-          features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
-
-  def test_unsupported_loss_reduction(self):
-    features = np.array([[1.0], [2.0], [3.0]])
-    labels = np.array([[1.0], [2.0], [3.0]])
-
-    with self.assertRaisesRegexp(ValueError,
-                                 '.+none.+reduction.+is.+specified.+'):
-      replicated_model_fn = replicate_model_fn._replicate_model_fn(
-          self.create_model_fn_with_loss_reduction(losses.Reduction.NONE),
-          devices=['/gpu:0', '/gpu:1', '/gpu:2'])
-      _ = replicated_model_fn(
-          features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
-
-  def test_places_on_gpu_with_upper_case_spelling(self):
-    features = np.array([[0.01], [0.002]])
-    labels = np.array([[0.01], [0.02]])
-
-    with self.test_session():
-      replicated_model_fn = replicate_model_fn._replicate_model_fn(
-          self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
-          devices=['/GPU:0'])
-      _ = replicated_model_fn(
-          features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
-
-      with variable_scope.variable_scope('', reuse=True):
-        c = variable_scope.get_variable('c', dtype=dtypes.float64)
-        self.assertEqual('/device:GPU:0', c.device)
-
-  def test_places_on_gpu_with_lower_case_spelling(self):
-    features = np.array([[0.01], [0.002]])
-    labels = np.array([[0.01], [0.02]])
-
-    with self.test_session():
-      replicated_model_fn = replicate_model_fn._replicate_model_fn(
-          self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
-          devices=['/gpu:0'])
-      _ = replicated_model_fn(
-          features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
-
-      with variable_scope.variable_scope('', reuse=True):
-        c = variable_scope.get_variable('c', dtype=dtypes.float64)
-        self.assertEqual('/device:GPU:0', c.device)
-
-
-class ReplicateAcrossASingleDeviceWithoutTowerOptimizer(
-    test_util.TensorFlowTestCase):
-
-  def model_fn(self, mode, features, labels, params):
-    c = variable_scope.get_variable(
-        'c',
-        initializer=constant_op.constant(10, dtype=dtypes.float64),
-        dtype=dtypes.float64)
-
-    predictions = math_ops.multiply(features, c)
-
-    loss = losses.absolute_difference(
-        labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
-    loss = math_ops.reduce_sum(loss)
-
-    metrics = {
-        'accuracy': metrics_lib.accuracy(labels, predictions),
-        'auc': metrics_lib.auc(labels, predictions)
-    }
-
-    optimizer = gradient_descent.GradientDescentOptimizer(
-        params['learning_rate'])
-
-    return model_fn_lib.EstimatorSpec(
-        mode=mode,
-        loss=loss,
-        eval_metric_ops=metrics,
-        predictions={'probabilities': predictions},
-        train_op=optimizer.minimize(loss))
-
-  @property
-  def params(self):
-    params = {}
-    params['learning_rate'] = 1.0
-    return params
-
-  def test_train_single_tower(self):
-    features = np.array([[1.0], [2.0]])
-    labels = np.array([[1.0], [2.0]])
-
-    with self.test_session() as session:
-      replicated_model_fn = replicate_model_fn._replicate_model_fn(
-          self.model_fn, devices=['/gpu:0'])
-      estimator_spec = replicated_model_fn(
-          features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
-      session.run(variables.global_variables_initializer())
-
-      # loss = feature * c - label
-      total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0)
-      self.assertEqual(total_loss, session.run(estimator_spec.loss))
-
-      # loss' of c is 3.
-      # new value of c = 10 - learning rate * 3 = 7.0.
-      session.run(estimator_spec.train_op)
-      with variable_scope.variable_scope('', reuse=True):
-        c = variable_scope.get_variable('c', dtype=dtypes.float64)
-        self.assertEqual(7.0, session.run(c))
-
-
-class UseTowerEstimatorWithoutReplication(test_util.TensorFlowTestCase):
-
-  def model_fn(self, mode, features, labels, params):
-    c = variable_scope.get_variable(
-        'c',
-        initializer=constant_op.constant(10, dtype=dtypes.float64),
-        dtype=dtypes.float64)
-
-    features = features['features']
-    predictions = math_ops.multiply(features, c)
-
-    loss = losses.absolute_difference(
-        labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
-    loss = math_ops.reduce_sum(loss)
-
-    metrics = {
-        'accuracy': metrics_lib.accuracy(labels, predictions),
-        'auc': metrics_lib.auc(labels, predictions)
-    }
-
-    optimizer = replicate_model_fn._TowerOptimizer(
-        gradient_descent.GradientDescentOptimizer(params['learning_rate']))
-
-    return model_fn_lib.EstimatorSpec(
-        mode=mode,
-        loss=loss,
-        eval_metric_ops=metrics,
-        predictions={'probabilities': predictions},
-        train_op=optimizer.minimize(loss))
-
-  @property
-  def params(self):
-    params = {}
-    params['learning_rate'] = 1.0
-    return params
-
-  def test_train_single_tower(self):
-    features = np.array([[1.0], [2.0]])
-    labels = np.array([[1.0], [2.0]])
-
-    train_input_fn = numpy_io.numpy_input_fn(
-        x={'features': features}, y=labels, batch_size=2, shuffle=False)
-
-    with self.test_session():
-      estimator = estimator_lib.Estimator(
-          model_fn=self.model_fn,
-          model_dir=tempfile.mkdtemp(),
-          params=self.params)
-      estimator.train(train_input_fn, steps=1)
-
-      self.assertEqual(7.0, estimator.get_variable_value('c'))
-
-
-class MakeSureSyncReplicasOptimizerWorks(test_util.TensorFlowTestCase):
-
-  def model_fn(self, mode, features, labels, params):
-    c = variable_scope.get_variable(
-        'c',
-        initializer=constant_op.constant(10, dtype=dtypes.float64),
-        dtype=dtypes.float64)
-
-    features = features['features']
-    predictions = math_ops.multiply(features, c)
-
-    loss = losses.absolute_difference(
-        labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
-    loss = math_ops.reduce_sum(loss)
-
-    metrics = {
-        'accuracy': metrics_lib.accuracy(labels, predictions),
-        'auc': metrics_lib.auc(labels, predictions)
-    }
-
-    optimizer = gradient_descent.GradientDescentOptimizer(
-        params['learning_rate'])
-    optimizer = training.SyncReplicasOptimizer(
-        optimizer, replicas_to_aggregate=1)
-    sync_hook = optimizer.make_session_run_hook(True)
-    optimizer = replicate_model_fn._TowerOptimizer(
-        optimizer, loss_reduction=losses.Reduction.SUM)
-
-    return model_fn_lib.EstimatorSpec(
-        mode=mode,
-        loss=loss,
-        eval_metric_ops=metrics,
-        training_hooks=[sync_hook],
-        predictions={'probabilities': predictions},
-        train_op=optimizer.minimize(
-            loss, global_step=training.get_global_step()))
-
-  @property
-  def params(self):
-    params = {}
-    params['learning_rate'] = 1.0
-    return params
-
-  def test_train_multiple_towers(self):
-    features = np.array([[1.0], [2.0]])
-    labels = np.array([[1.0], [2.0]])
-
-    train_input_fn = numpy_io.numpy_input_fn(
-        x={'features': features}, y=labels, batch_size=2, shuffle=False)
-
-    model_fn = replicate_model_fn._replicate_model_fn(
-        self.model_fn,
-        devices=['/gpu:0', '/gpu:1'])
-
-    estimator = estimator_lib.Estimator(
-        model_fn=model_fn, model_dir=tempfile.mkdtemp(), params=self.params)
-    estimator.train(train_input_fn, steps=1)
-
-    self.assertEqual(7.0, estimator.get_variable_value('c'))
-
-
-class ReplicateWithTwoOptimizersTest(test_util.TensorFlowTestCase):
-
-  def model_fn(self, mode, features, labels, params):
-    c = variable_scope.get_variable(
-        'c',
-        initializer=constant_op.constant(10, dtype=dtypes.float64),
-        dtype=dtypes.float64)
-
-    side_effects = variable_scope.get_variable(
-        'side_effects',
-        initializer=constant_op.constant(0, dtype=dtypes.float64),
-        dtype=dtypes.float64,
-        use_resource=True,
-        trainable=False)
-
-    predictions = math_ops.multiply(features, c)
-
-    loss = losses.absolute_difference(
-        labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
-    loss = math_ops.reduce_sum(loss)
-
-    metrics = {
-        'accuracy': metrics_lib.accuracy(labels, predictions),
-        'auc': metrics_lib.auc(labels, predictions)
-    }
-
-    first_optimizer = replicate_model_fn._TowerOptimizer(
-        gradient_descent.GradientDescentOptimizer(1.0),
-        loss_reduction=losses.Reduction.SUM)
-    second_optimizer = replicate_model_fn._TowerOptimizer(
-        adam.AdamOptimizer(1.0), loss_reduction=losses.Reduction.SUM)
-
-    with ops_lib.control_dependencies([side_effects.assign_add(1.0)]):
-      first_grads_and_vars = first_optimizer.compute_gradients(loss)
-
-    train_op = control_flow_ops.group(
-        [first_optimizer.apply_gradients(first_grads_and_vars),
-         second_optimizer.minimize(loss)])
-
-    return model_fn_lib.EstimatorSpec(
-        mode=mode,
-        loss=loss,
-        eval_metric_ops=metrics,
-        predictions={'probabilities': predictions},
-        train_op=train_op)
-
-  def test_train(self):
-    features = np.array([[1.0], [2.0]])
-    labels = np.array([[1.0], [2.0]])
-
-    with self.test_session() as session:
-      replicated_model_fn = replicate_model_fn._replicate_model_fn(
-          self.model_fn,
-          devices=['/gpu:0', '/gpu:1'])
-      estimator_spec = replicated_model_fn(features, labels,
-                                           model_fn_lib.ModeKeys.TRAIN, {})
-      session.run(variables.global_variables_initializer())
-
-      # loss = feature * c - label
-      total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0)
-      self.assertEqual(total_loss, session.run(estimator_spec.loss))
-
-      # loss' of c is 3.
-      # new value of c = 10 - learning rate * 3 = 7.0.
-      # Adam subtracts another ~1.
-      session.run(estimator_spec.train_op)
-      with variable_scope.variable_scope('', reuse=True):
-        c = variable_scope.get_variable('c', dtype=dtypes.float64)
-        self.assertNear(6.0, session.run(c), 0.000001)
-
-        side_effects = variable_scope.get_variable(
-            'side_effects', dtype=dtypes.float64)
-        self.assertNear(2.0, session.run(side_effects), 0.000001)
-
-
-class ReplicateWithTwoLossesAndOneOptimizer(test_util.TensorFlowTestCase):
-
-  def setUp(self):
-    self._should_skip_optimizer = False
-    self._towers_left_before_skipping_optimizer = -1
-
-  def incorrectly_skip_optimizer_for_tower(self, tower_number):
-    self._should_skip_optimizer = True
-    self._towers_left_before_skipping_optimizer = tower_number
-
-  def should_skip_optimizer(self):
-    if not self._should_skip_optimizer:
-      return False
-    if self._towers_left_before_skipping_optimizer == 0:
-      return True
-    else:
-      self._towers_left_before_skipping_optimizer -= 1
-      return False
-
-  def model_fn(self, mode, features, labels, params):
-    c = variable_scope.get_variable(
-        'c',
-        initializer=constant_op.constant(10, dtype=dtypes.float64),
-        dtype=dtypes.float64)
-    d = variable_scope.get_variable(
-        'd',
-        initializer=constant_op.constant(2, dtype=dtypes.float64),
-        dtype=dtypes.float64)
-
-    predictions = math_ops.multiply(features, c)
-
-    loss = losses.absolute_difference(
-        labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
-    loss = math_ops.reduce_sum(loss)
-
-    another_predictions = math_ops.multiply(features, d)
-    another_loss = losses.absolute_difference(
-        labels=labels,
-        predictions=another_predictions,
-        reduction=losses.Reduction.SUM)
-    another_loss = math_ops.reduce_sum(another_loss)
-
-    total_loss = math_ops.add(loss, another_loss)
-
-    metrics = {
-        'accuracy': metrics_lib.accuracy(labels, predictions),
-        'auc': metrics_lib.auc(labels, predictions)
-    }
-
-    train_ops = []
-
-    optimizer = replicate_model_fn._TowerOptimizer(
-        gradient_descent.GradientDescentOptimizer(1.0),
-        loss_reduction=losses.Reduction.SUM)
-    train_ops.append(optimizer.minimize(loss, var_list=[c]))
-    if not self.should_skip_optimizer():
-      another_optimizer = replicate_model_fn._TowerOptimizer(
-          gradient_descent.GradientDescentOptimizer(1.0),
-          loss_reduction=losses.Reduction.SUM)
-      train_ops.append(another_optimizer.minimize(another_loss, var_list=[d]))
-
-    train_op = control_flow_ops.group(train_ops)
-    return model_fn_lib.EstimatorSpec(
-        mode=mode,
-        loss=total_loss,
-        eval_metric_ops=metrics,
-        predictions={'probabilities': predictions},
-        train_op=train_op)
-
-  def test_train(self):
-    features = np.array([[1.0], [2.0]])
-    labels = np.array([[1.0], [2.0]])
-
-    with ops_lib.Graph().as_default(), self.test_session() as session:
-      replicated_model_fn = replicate_model_fn._replicate_model_fn(
-          self.model_fn,
-          devices=['/gpu:0', '/gpu:1'])
-      estimator_spec = replicated_model_fn(features, labels,
-                                           model_fn_lib.ModeKeys.TRAIN, {})
-      session.run(variables.global_variables_initializer())
-
-      # For each tower, loss = (feature * c - label) + (feature * d - label).
-      total_loss = (1.0 * 10 - 1.0 + 1.0 * 2.0 - 1.0) + (
-          2.0 * 10 - 2.0 + 2.0 * 2.0 - 2.0)
-      self.assertEqual(total_loss, session.run(estimator_spec.loss))
-
-      session.run(estimator_spec.train_op)
-
-      # loss' of c or loss' of d is 3.
-      # new value of c = 10 - learning rate * 3 = 7.0.
-      # new value of d = 2  - learning rate * 3 = -1.0.
-      with variable_scope.variable_scope('', reuse=True):
-        c = variable_scope.get_variable('c', dtype=dtypes.float64)
-        self.assertNear(7.0, session.run(c), 0.000001)
-        d = variable_scope.get_variable('d', dtype=dtypes.float64)
-        self.assertNear(-1.0, session.run(d), 0.000001)
-
-  def test_different_optimizer_calls_within_towers(self):
-    self.incorrectly_skip_optimizer_for_tower(1)
-
-    features = np.array([[1.0], [2.0]])
-    labels = np.array([[1.0], [2.0]])
-
-    with self.test_session(), ops_lib.Graph().as_default():
-      with self.assertRaisesRegexp(
-          ValueError, '.+was.+supposed.+to.+make.+same.+optimizer.+calls.+'):
-        replicated_model_fn = replicate_model_fn._replicate_model_fn(
-            self.model_fn, devices=['/gpu:0', '/gpu:1'])
-        _ = replicated_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN,
-                                {})
-
-
-class FailToWrapOptimizerInTheModelFn(test_util.TensorFlowTestCase):
-
-  def model_fn(self, mode, features, labels, params):
-    c = variable_scope.get_variable(
-        'c',
-        initializer=constant_op.constant(10, dtype=dtypes.float64),
-        dtype=dtypes.float64)
-
-    predictions = math_ops.multiply(features, c)
-
-    loss = losses.absolute_difference(
-        labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
-    loss = math_ops.reduce_sum(loss)
-
-    metrics = {
-        'accuracy': metrics_lib.accuracy(labels, predictions),
-        'auc': metrics_lib.auc(labels, predictions)
-    }
-
-    optimizer = gradient_descent.GradientDescentOptimizer(1.0)
-    train_op = optimizer.minimize(loss)
-
-    return model_fn_lib.EstimatorSpec(
-        mode=mode,
-        loss=loss,
-        eval_metric_ops=metrics,
-        predictions={'probabilities': predictions},
-        train_op=train_op)
-
-  def test_train(self):
-    features = np.array([[1.0], [2.0]])
-    labels = np.array([[1.0], [2.0]])
-
-    with self.test_session():
-      with self.assertRaisesRegexp(ValueError,
-                                   'Please.+wrap.+with.+TowerOptimizer'):
-        replicated_model_fn = replicate_model_fn._replicate_model_fn(
-            self.model_fn, devices=['/gpu:0', '/gpu:1'])
-        _ = replicated_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN,
-                                {})
-
-
-class GetLossTowersTest(test_util.TensorFlowTestCase):
-
-  def create_model_fn_with_loss_reduction(self, loss_reduction):
-
-    def model_fn(mode, features, labels, params):
-      del params
-      c = variable_scope.get_variable(
-          'c',
-          initializer=constant_op.constant(0.25, dtype=dtypes.float64),
-          dtype=dtypes.float64)
-
-      predictions = math_ops.add(np.array([0.1, 0.2, 0.3, features[0]]), c)
-      labels = np.array([0.1, 0.2, 0.3, labels[0]])
-
-      loss = losses.absolute_difference(
-          labels=labels,
-          predictions=predictions,
-          reduction=losses.Reduction.SUM)
-
-      optimizer = replicate_model_fn._TowerOptimizer(
-          gradient_descent.GradientDescentOptimizer(1.0),
-          loss_reduction)
-
-      return model_fn_lib.EstimatorSpec(
-          mode=mode,
-          loss=math_ops.reduce_sum(loss),
-          train_op=optimizer.minimize(loss))
-
-    return model_fn
-
-  def test_gradients_are_computed(self):
-    with self.test_session() as session:
-      tower_specs = replicate_model_fn._get_loss_towers(
-          self.create_model_fn_with_loss_reduction(losses.Reduction.SUM),
-          mode=None,
-          features=[[0.6], [1.6]],
-          labels=[[0.6], [0.6]],
-          params=None,
-          config=None,
-          devices=['/gpu:0', '/gpu:1'],
-          local_ps_devices=['/gpu:0'],
-          name_scope_pattern='test_tower_{}')
-      session.run(variables.global_variables_initializer())
-
-      self.assertEqual(len(tower_specs), 2)
-
-      self.assertEqual('/device:GPU:0', tower_specs[0].loss.device)
-      self.assertEqual('Sum:0', tower_specs[0].loss.name)
-      self.assertEqual(1.0, session.run(tower_specs[0].loss))
-
-      self.assertEqual('/device:GPU:1', tower_specs[1].loss.device)
-      self.assertEqual('test_tower_1/Sum:0', tower_specs[1].loss.name)
-      # The input batch for the second tower had a loss that is 1.0
-      # bigger: 0.6 vs 1.6.
-      self.assertEqual(2.0, session.run(tower_specs[1].loss))
-
-      self.assertEqual(1, len(variables.global_variables()))
-      self.assertEqual(1, len(variables.trainable_variables()))
-
-      with variable_scope.variable_scope('', reuse=True):
-        c = variable_scope.get_variable('c', dtype=dtypes.float64)
-        self.assertEqual(0.25, session.run(c))
-
-  def test_gradients_are_computed_with_mean_reduction(self):
-    with self.test_session() as session:
-      tower_specs = replicate_model_fn._get_loss_towers(
-          self.create_model_fn_with_loss_reduction(losses.Reduction.MEAN),
-          mode=model_fn_lib.ModeKeys.EVAL,
-          features=[[0.6], [1.6]],
-          labels=[[0.6], [0.6]],
-          params=None,
-          config=None,
-          devices=['/gpu:0', '/gpu:1'],
-          local_ps_devices=['/gpu:0'],
-          name_scope_pattern='test_tower_{}')
-      session.run(variables.global_variables_initializer())
-
-      self.assertEqual(len(tower_specs), 2)
-
-      self.assertEqual('/device:GPU:0', tower_specs[0].loss.device)
-      self.assertEqual('averaged_loss:0', tower_specs[0].loss.name)
-      self.assertEqual(0.5, session.run(tower_specs[0].loss))
-
-      self.assertEqual('/device:GPU:1', tower_specs[1].loss.device)
-      self.assertEqual('test_tower_1/averaged_loss:0', tower_specs[1].loss.name)
-      # The input batch for the second tower had a loss that is 1.0
-      # bigger: 0.6 vs 1.6.
-      self.assertEqual(1.0, session.run(tower_specs[1].loss))
-
-      self.assertEqual(1, len(variables.global_variables()))
-      self.assertEqual(1, len(variables.trainable_variables()))
-
-      with variable_scope.variable_scope('', reuse=True):
-        c = variable_scope.get_variable('c', dtype=dtypes.float64)
-        self.assertEqual(0.25, session.run(c))
-
-  def test_variables_are_round_robined_correctly(self):
-    """Test that creates multiple variables and tests round-robin placement."""
-
-    def model_fn(mode, features, labels, params):
-      del params
-      for variable_name in ['a', 'b', 'c', 'd']:
-        c = variable_scope.get_variable(
-            variable_name,
-            initializer=constant_op.constant(0.25, dtype=dtypes.float64),
-            dtype=dtypes.float64)
-
-      predictions = math_ops.add(np.array([0.1, 0.2, 0.3, features[0]]), c)
-      labels = np.array([0.1, 0.2, 0.3, labels[0]])
-      loss = losses.absolute_difference(
-          labels=labels,
-          predictions=predictions,
-          reduction=losses.Reduction.SUM)
-      return model_fn_lib.EstimatorSpec(
-          mode=mode, loss=math_ops.reduce_sum(loss))
-
-    with self.test_session() as session:
-      tower_specs = replicate_model_fn._get_loss_towers(
-          model_fn,
-          mode=None,
-          features=[[0.6], [1.6], [2.6]],
-          labels=[[0.6], [0.6], [2.6]],
-          params=None,
-          config=None,
-          devices=['/gpu:0', '/gpu:1', '/gpu:3'],
-          local_ps_devices=['/gpu:0', '/gpu:1', '/gpu:3'],
-          name_scope_pattern='test_tower_{}')
-      session.run(variables.global_variables_initializer())
-
-      self.assertEqual(len(tower_specs), 3)
-      self.assertEqual('/device:GPU:0', tower_specs[0].loss.device)
-      self.assertEqual('/device:GPU:1', tower_specs[1].loss.device)
-      self.assertEqual('/device:GPU:3', tower_specs[2].loss.device)
-
-      with variable_scope.variable_scope('', reuse=True):
-        a = variable_scope.get_variable('a', dtype=dtypes.float64)
-        self.assertEqual('/device:GPU:0', a.device)
-        b = variable_scope.get_variable('b', dtype=dtypes.float64)
-        self.assertEqual('/device:GPU:1', b.device)
-        c = variable_scope.get_variable('c', dtype=dtypes.float64)
-        self.assertEqual('/device:GPU:3', c.device)
-        d = variable_scope.get_variable('d', dtype=dtypes.float64)
-        self.assertEqual('/device:GPU:0', d.device)
-
-
-class SplitBatchTest(test_util.TensorFlowTestCase):
-
-  def evaluate_shards(self, first_list, second_list):
-    evaluate_items = lambda x: x.eval()
-    return list(map(evaluate_items, first_list)), list(
-        map(evaluate_items, second_list))
-
-  def assertSparseValuesEqual(self, a, b):
-    self.assertAllEqual(a.indices, b.indices)
-    self.assertAllEqual(a.values, b.values)
-    self.assertAllEqual(a.dense_shape, b.dense_shape)
-
-  def test_simple_half_split(self):
-    with self.test_session():
-      features = [0.0, 1.0, 2.0, 3.0]
-      labels = [10.0, 11.0, 12.0, 13.0]
-      feature_shards, label_shards = replicate_model_fn._split_batch(
-          features, labels, 2, device='/gpu:0')
-
-      feature_shards, label_shards = self.evaluate_shards(
-          feature_shards, label_shards)
-
-      self.assertAllEqual([[0.0, 1.0], [2.0, 3.0]], feature_shards)
-      self.assertAllEqual([[10.0, 11.0], [12.0, 13.0]], label_shards)
-
-  def test_to_each_their_own(self):
-    with self.test_session():
-      features = [0.0, 1.0, 2.0, 3.0]
-      labels = [10.0, 11.0, 12.0, 13.0]
-      feature_shards, label_shards = replicate_model_fn._split_batch(
-          features, labels, 4, device='/gpu:0')
-
-      feature_shards, label_shards = self.evaluate_shards(
-          feature_shards, label_shards)
-
-      self.assertAllEqual([[0.0], [1.0], [2.0], [3.0]], feature_shards)
-      self.assertAllEqual([[10.0], [11.0], [12.0], [13.0]], label_shards)
-
-  def test_one_batch(self):
-    with self.test_session():
-      features = [0.0, 1.0, 2.0, 3.0]
-      labels = [10.0, 11.0, 12.0, 13.0]
-      feature_shards, label_shards = replicate_model_fn._split_batch(
-          features, labels, 1, device='/gpu:0')
-
-      feature_shards, label_shards = self.evaluate_shards(
-          feature_shards, label_shards)
-
-      self.assertAllEqual([[0.0, 1.0, 2.0, 3.0]], feature_shards)
-      self.assertAllEqual([[10.0, 11.0, 12.0, 13.0]], label_shards)
-
-  def test_half_split_in_dictionary(self):
-    with self.test_session():
-      features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]}
-      labels = [10.0, 11.0, 12.0, 13.0]
-
-      feature_shards, label_shards = replicate_model_fn._split_batch(
-          features, labels, 2, device='/gpu:0')
-
-      self.assertAllEqual([0.0, 1.0], feature_shards[0]['first'].eval())
-      self.assertAllEqual([4.0, 5.0], feature_shards[0]['second'].eval())
-      self.assertAllEqual([2.0, 3.0], feature_shards[1]['first'].eval())
-      self.assertAllEqual([6.0, 7.0], feature_shards[1]['second'].eval())
-      self.assertAllEqual([10.0, 11.0], label_shards[0].eval())
-      self.assertAllEqual([12.0, 13.0], label_shards[1].eval())
-
-  def test_sparse_tensor_can_be_split_unevenly(self):
-    with self.test_session():
-      features = {
-          'x':
-              sparse_tensor.SparseTensor(
-                  indices=[[0, 0], [1, 2], [2, 2]],
-                  values=[1.0, 2.0, 3.0],
-                  dense_shape=[3, 4])
-      }
-      labels = np.array([[1.0], [2.0]])
-
-      feature_shards, label_shards = replicate_model_fn._split_batch(
-          features, labels, 2, device='/gpu:0')
-
-      self.assertSparseValuesEqual(
-          sparse_tensor.SparseTensorValue(
-              indices=[[0, 0], [1, 2]], values=[1., 2.], dense_shape=[2, 4]),
-          feature_shards[0]['x'].eval())
-      self.assertSparseValuesEqual(
-          sparse_tensor.SparseTensorValue(
-              indices=[[0, 2]], values=[3.], dense_shape=[1, 4]),
-          feature_shards[1]['x'].eval())
-      self.assertAllEqual([[1.0]], label_shards[0].eval())
-      self.assertAllEqual([[2.0]], label_shards[1].eval())
-
-  def test_sparse_tensor_can_be_split_unevenly_repeated_row(self):
-    with self.test_session():
-      features = {
-          'x':
-              sparse_tensor.SparseTensor(
-                  indices=[[0, 0], [1, 0], [1, 1]],
-                  values=[1.0, 2.0, 3.0],
-                  dense_shape=[3, 4])
-      }
-      labels = np.array([[1.0], [2.0]])
-
-      feature_shards, label_shards = replicate_model_fn._split_batch(
-          features, labels, 2, device='/gpu:0')
-
-      self.assertSparseValuesEqual(
-          sparse_tensor.SparseTensorValue(
-              indices=[[0, 0], [1, 0], [1, 1]],
-              values=[1., 2., 3.],
-              dense_shape=[2, 4]), feature_shards[0]['x'].eval())
-
-      second_batch = feature_shards[1]['x'].eval()
-      self.assertFalse(len(second_batch.indices))
-      self.assertFalse(len(second_batch.values))
-      self.assertAllEqual([1, 4], second_batch.dense_shape)
-      self.assertAllEqual([[1.0]], label_shards[0].eval())
-      self.assertAllEqual([[2.0]], label_shards[1].eval())
-
-  def test_one_batch_in_dictionary(self):
-    with self.test_session() as session:  # pylint: disable=unused-variable
-      features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]}
-      labels = [10.0, 11.0, 12.0, 13.0]
-
-      feature_shards, label_shards = replicate_model_fn._split_batch(
-          features, labels, 1, device='/gpu:0')
-
-      self.assertAllEqual([0.0, 1.0, 2.0, 3.0],
-                          feature_shards[0]['first'].eval())
-      self.assertAllEqual([4.0, 5.0, 6.0, 7.0],
-                          feature_shards[0]['second'].eval())
-      self.assertAllEqual([10.0, 11.0, 12.0, 13.0], label_shards[0].eval())
-
-  def test_feature_and_label_dictionaries(self):
-    with self.test_session() as session:  # pylint: disable=unused-variable
-      features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]}
-      labels = {'first': [10.0, 11.0], 'second': [12.0, 13.0]}
-
-      feature_shards, label_shards = replicate_model_fn._split_batch(
-          features, labels, 2, device='/gpu:0')
-
-      self.assertAllEqual([0.0, 1.0], feature_shards[0]['first'].eval())
-      self.assertAllEqual([4.0, 5.0], feature_shards[0]['second'].eval())
-      self.assertAllEqual([2.0, 3.0], feature_shards[1]['first'].eval())
-      self.assertAllEqual([6.0, 7.0], feature_shards[1]['second'].eval())
-      self.assertAllEqual([10.0], label_shards[0]['first'].eval())
-      self.assertAllEqual([12.0], label_shards[0]['second'].eval())
-      self.assertAllEqual([11], label_shards[1]['first'].eval())
-      self.assertAllEqual([13.0], label_shards[1]['second'].eval())
-
-
-class TrainSpecTest(test_util.TensorFlowTestCase):
-
-  expected_predictions = {}
-
-  def create_estimator_spec(self, loss):
-    return model_fn_lib.EstimatorSpec(
-        mode=model_fn_lib.ModeKeys.TRAIN,
-        loss=loss,
-        train_op=loss,  # Not used; currently required.
-        predictions=self.expected_predictions)
-
-  def create_constant_loss(self, loss_value):
-    return constant_op.constant(loss_value, dtype=dtypes.float64)
-
-  def test_example(self):
-    with self.test_session() as session:
-      tower_losses = list(map(self.create_constant_loss, [2, 4, 6]))
-      tower_specs = list(map(self.create_estimator_spec, tower_losses))
-
-      expected_train_op = tower_losses[1]
-
-      estimator_spec = replicate_model_fn._train_spec(
-          tower_specs, expected_train_op, aggregation_device='/gpu:0')
-
-      self.assertEqual(expected_train_op, estimator_spec.train_op)
-      self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss))
-      self.assertEqual(self.expected_predictions, estimator_spec.predictions)
-
-
-class EvalSpecTest(test_util.TensorFlowTestCase):
-
-  def create_estimator_spec(self, loss, metrics):
-    return model_fn_lib.EstimatorSpec(
-        mode=model_fn_lib.ModeKeys.EVAL, loss=loss, eval_metric_ops=metrics)
-
-  def create_constant_loss(self, loss_value):
-    return constant_op.constant(loss_value, dtype=dtypes.float64)
-
-  def create_eval_metrics(self, noise):
-    predictions = np.array([0.1, 0.2, 0.3, 0.6 + noise])
-    labels = np.array([0.1, 0.2, 0.3, 0.6])
-
-    metrics = {
-        'accuracy': metrics_lib.accuracy(labels, predictions),
-        'auc': metrics_lib.auc(labels, predictions)
-    }
-    return metrics
-
-  def test_example(self):
-    with self.test_session() as session:
-      tower_losses = map(self.create_constant_loss, [2, 4, 6])
-      tower_metrics = map(self.create_eval_metrics, [0, 0.2, 0.3])
-      tower_specs = [
-          self.create_estimator_spec(l, m)
-          for l, m in zip(tower_losses, tower_metrics)
-      ]
-      session.run(variables.local_variables_initializer())
-
-      estimator_spec = replicate_model_fn._eval_spec(
-          tower_specs, aggregation_device='/device:GPU:0')
-
-      accuracy, a = estimator_spec.eval_metric_ops['accuracy']
-      auc, b = estimator_spec.eval_metric_ops['auc']
-
-      self.assertEqual('/device:CPU:0', accuracy.device)
-      self.assertEqual('/device:CPU:0', auc.device)
-
-      session.run([a, b])
-      accuracy, auc = session.run([accuracy, auc])
-
-      self.assertNear((12 - 2) / 12, accuracy, 0.01)
-      self.assertEqual(0, auc)
-      self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss))
-
-  def test_handles_single_tower(self):
-    with self.test_session() as session:
-      tower_losses = map(self.create_constant_loss, [5])
-      tower_metrics = map(self.create_eval_metrics, [0.2])
-      tower_specs = [
-          self.create_estimator_spec(l, m)
-          for l, m in zip(tower_losses, tower_metrics)
-      ]
-      session.run(variables.local_variables_initializer())
-
-      estimator_spec = replicate_model_fn._eval_spec(
-          tower_specs, aggregation_device='/device:GPU:0')
-
-      accuracy, a = estimator_spec.eval_metric_ops['accuracy']
-      auc, b = estimator_spec.eval_metric_ops['auc']
-
-      self.assertEqual('/device:CPU:0', accuracy.device)
-      self.assertEqual('/device:CPU:0', auc.device)
-
-      session.run([a, b])
-      accuracy = session.run(accuracy)
-      auc = session.run(auc)
-
-      self.assertNear((4 - 1) / 4, accuracy, 0.01)
-      self.assertEqual(0, auc)
-      self.assertEqual(5, session.run(estimator_spec.loss))
-
-
-class PredictSpecTest(test_util.TensorFlowTestCase):
-
-  def model_fn(self, mode, features, labels, params):
-    c = variable_scope.get_variable(
-        'c',
-        initializer=constant_op.constant(0.25, dtype=dtypes.float64),
-        dtype=dtypes.float64)
-
-    predictions = math_ops.add(np.array([features[0], features[0]]), c)
-
-    return model_fn_lib.EstimatorSpec(
-        mode=model_fn_lib.ModeKeys.PREDICT,
-        predictions={
-            'probabilities': predictions
-        })
-
-  def test_example(self):
-    with self.test_session() as session:
-      tower_specs = replicate_model_fn._get_loss_towers(
-          self.model_fn,
-          mode=None,
-          features=[[0.1], [0.2]],
-          labels=[[], []],
-          params=None,
-          config=None,
-          devices=['/gpu:0', '/gpu:1'],
-          local_ps_devices=['/gpu:0'],
-      )
-      session.run(variables.global_variables_initializer())
-
-      estimator_spec = replicate_model_fn._predict_spec(
-          tower_specs, aggregation_device='/gpu:0')
-
-      self.assertEqual('/device:GPU:0',
-                       estimator_spec.predictions['probabilities'].device)
-      self.assertAllClose({
-          'probabilities': np.array([0.35, 0.35, 0.45, 0.45])
-      }, session.run(estimator_spec.predictions))
-
-
-class ReduceMetricVariablesTest(test_util.TensorFlowTestCase):
-
-  def create_metric_variable(self, initial_value, name):
-    return variable_scope.variable(
-        initial_value,
-        trainable=False,
-        collections=[ops_lib.GraphKeys.METRIC_VARIABLES],
-        validate_shape=True,
-        name=name)
-
-  def create_tower_metrics(self, tower_id):
-    with variable_scope.variable_scope('', reuse=(tower_id != 0)):
-      self.create_metric_variable(1.3 * (tower_id + 1), 'total')
-      self.create_metric_variable(2.3 * (tower_id + 1), 'count')
-      self.create_metric_variable(
-          np.array([3.3, 3.5, 3.7]) * (tower_id + 1), 'total')
-
-  def test_example(self):
-    with self.test_session() as session:
-      for tower_id in range(3):
-        self.create_tower_metrics(tower_id)
-
-      session.run(
-          variables.variables_initializer(
-              ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))
-
-      session.run(
-          replicate_model_fn._reduce_metric_variables(number_of_towers=3))
-
-      # 1st tower = 1.3, 2.3,  [3.3, 3.5, 3.7]
-      # 2nd tower = 2.6, 4.6,  [6.6, 7.0, 7.4]
-      # 3rd tower = 3.9, 6.9,  [9.9, 10.5, 11.1]
-      # Reduced =   7.8, 13.8, [19.8, 21.0, 22.2]
-      # Towers are accumulated in the first tower.
-      local_metrics = session.run(
-          ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))
-
-      self.assertNear(7.8, local_metrics[0], 0.01)
-      self.assertNear(13.8, local_metrics[1], 0.01)
-      self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01)
-      self.assertNear(0.0, local_metrics[3], 0.01)
-      self.assertNear(0.0, local_metrics[4], 0.01)
-      self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01)
-      self.assertNear(0.0, local_metrics[6], 0.01)
-      self.assertNear(0.0, local_metrics[7], 0.01)
-      self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01)
-
-  def test_reduce_is_idempotent(self):
-    with self.test_session() as session:
-      for tower_id in range(3):
-        self.create_tower_metrics(tower_id)
-
-      session.run(
-          variables.variables_initializer(
-              ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))
-
-      for _ in range(20):
-        session.run(
-            replicate_model_fn._reduce_metric_variables(number_of_towers=3))
-
-      local_metrics = session.run(
-          ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))
-
-      self.assertNear(7.8, local_metrics[0], 0.01)
-      self.assertNear(13.8, local_metrics[1], 0.01)
-      self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01)
-      self.assertNear(0.0, local_metrics[3], 0.01)
-      self.assertNear(0.0, local_metrics[4], 0.01)
-      self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01)
-      self.assertNear(0.0, local_metrics[6], 0.01)
-      self.assertNear(0.0, local_metrics[7], 0.01)
-      self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01)
-
-  def test_handles_single_tower(self):
-    with self.test_session() as session:
-      self.create_tower_metrics(0)
-      session.run(
-          variables.variables_initializer(
-              ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))
-
-      session.run(
-          replicate_model_fn._reduce_metric_variables(number_of_towers=1))
-
-      local_metrics = session.run(
-          ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))
-
-      self.assertNear(1.3, local_metrics[0], 0.01)
-      self.assertNear(2.3, local_metrics[1], 0.01)
-      self.assertAllClose([3.3, 3.5, 3.7], local_metrics[2], 0.01)
-
-  def test_doesnt_accept_uneven_number_of_variables(self):
-    with self.test_session() as session:
-      for tower_id in range(3):
-        self.create_tower_metrics(tower_id)
-      self.create_metric_variable(-1.0, 'oddball')
-
-      session.run(
-          variables.variables_initializer(
-              ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))
-
-      with self.assertRaisesRegexp(
-          ValueError, '.+Expected.+local.+variables.+but.+got.+instead.+'):
-        session.run(
-            replicate_model_fn._reduce_metric_variables(number_of_towers=3))
-
-
-class MergeExportOutputsTest(test_util.TensorFlowTestCase):
-
-  def model_fn(self, mode, features, labels, params):
-    c = variable_scope.get_variable(
-        'c',
-        initializer=constant_op.constant(10, dtype=dtypes.float64),
-        dtype=dtypes.float64)
-
-    predictions = {'probabilities': math_ops.multiply(features, c)}
-    loss = losses.absolute_difference(
-        labels=labels,
-        predictions=predictions['probabilities'],
-        reduction=losses.Reduction.SUM)
-
-    metrics = {
-        'accuracy': metrics_lib.accuracy(labels, predictions['probabilities']),
-        'auc': metrics_lib.auc(labels, predictions['probabilities'])
-    }
-    tensor_string_repr = str(features)
-    classes = constant_op.constant(
-        re.search('(split_inputs/split:[0-9])', tensor_string_repr).group(1),
-        dtype=dtypes.string)
-
-    export_outputs = {
-        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
-            export_output.PredictOutput(predictions),
-        'classification_output':
-            export_output.ClassificationOutput(predictions['probabilities'],
-                                               classes),
-        'classification_scores':
-            export_output.ClassificationOutput(
-                scores=predictions['probabilities']),
-        'classification_classes':
-            export_output.ClassificationOutput(classes=classes),
-        'regression_output':
-            export_output.RegressionOutput(predictions['probabilities']),
-    }
-
-    return model_fn_lib.EstimatorSpec(
-        mode=mode,
-        loss=math_ops.reduce_sum(loss),
-        eval_metric_ops=metrics,
-        predictions=predictions,
-        export_outputs=export_outputs)
-
-  def replicate_estimator_spec(self, session):
-    features = np.array([0.01, 0.002])
-    labels = np.array([0.01, 0.02])
-
-    replicated_model_fn = replicate_model_fn._replicate_model_fn(
-        self.model_fn, devices=['/gpu:0', '/gpu:1'])
-    estimator_spec = replicated_model_fn(features, labels,
-                                         model_fn_lib.ModeKeys.PREDICT, {})
-    session.run(variables.global_variables_initializer())
-    return estimator_spec
-
-  def test_merge_predict_output(self):
-    with self.test_session() as session:
-      estimator_spec = self.replicate_estimator_spec(session)
-      self.assertAllClose(
-          {
-              'probabilities': np.array([0.1, 0.02])
-          },
-          session.run(estimator_spec.export_outputs[
-              signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs))
-
-  def test_merge_classification_output_scores_classes(self):
-    with self.test_session() as session:
-      estimator_spec = self.replicate_estimator_spec(session)
-      self.assertAllClose(
-          [0.1, 0.02],
-          session.run(
-              estimator_spec.export_outputs['classification_output'].scores))
-      self.assertAllEqual(
-          [b'split_inputs/split:0', b'split_inputs/split:1'],
-          session.run(
-              estimator_spec.export_outputs['classification_output'].classes))
-
-  def test_merge_classification_output_scores(self):
-    with self.test_session() as session:
-      estimator_spec = self.replicate_estimator_spec(session)
-      self.assertAllClose(
-          [0.1, 0.02],
-          session.run(
-              estimator_spec.export_outputs['classification_scores'].scores))
-      self.assertEqual(
-          None, estimator_spec.export_outputs['classification_scores'].classes)
-
-  def test_merge_classification_output_classes(self):
-    with self.test_session() as session:
-      estimator_spec = self.replicate_estimator_spec(session)
-      self.assertAllEqual(
-          [b'split_inputs/split:0', b'split_inputs/split:1'],
-          session.run(
-              estimator_spec.export_outputs['classification_classes'].classes))
-      self.assertEqual(
-          None, estimator_spec.export_outputs['classification_classes'].scores)
-
-  def test_merge_regression_output(self):
-    with self.test_session() as session:
-      estimator_spec = self.replicate_estimator_spec(session)
-      self.assertAllClose(
-          [0.1, 0.02],
-          session.run(estimator_spec.export_outputs['regression_output'].value))
-
-
-class GetLocalDevicesTest(test_util.TensorFlowTestCase):
-
-  def test_there_is_at_least_a_cpu(self):
-    self.assertTrue(replicate_model_fn._get_local_devices('CPU'))
-
-  def test_there_is_no_xpu(self):
-    self.assertFalse(
-        replicate_model_fn._get_local_devices('XPU'))  # XPU doesn't exist.
-
-  def test_whether_there_is_a_gpu(self):
-    if test.is_gpu_available():
-      self.assertTrue(len(replicate_model_fn._get_local_devices('GPU')))
-
-
-class LocalDeviceSetterTest(test_util.TensorFlowTestCase):
-
-  def test_vars_are_on_ps_but_ops_are_on_workers(self):
-    ps_devices = ['/device:GPU:3']
-    round_robin = device_setter._RoundRobinStrategy(num_tasks=len(ps_devices))
-
-    local_device_setter = replicate_model_fn._local_device_setter(
-        ps_devices=ps_devices,
-        ps_strategy=round_robin,
-        worker_device='/device:GPU:2')
-
-    with ops_lib.device(local_device_setter):
-      a = variables.Variable(0.01)
-      self.assertEqual('/device:GPU:3', a.device)
-
-      b = variables.Variable(0.02)
-      self.assertEqual('/device:GPU:3', b.device)
-
-      c = variables.Variable(0.03)
-      self.assertEqual('/device:GPU:3', c.device)
-
-      a_op = array_ops.concat(a, axis=0)
-      self.assertEqual('/device:GPU:2', a_op.device)
-
-      b_op = array_ops.concat(b, axis=0)
-      self.assertEqual('/device:GPU:2', b_op.device)
-
-  def test_round_robin_placement(self):
-    ps_devices = [
-        '/device:GPU:0', '/device:GPU:1', '/device:GPU:3', '/device:GPU:4'
-    ]
-    round_robin = device_setter._RoundRobinStrategy(num_tasks=len(ps_devices))
-
-    local_device_setter = replicate_model_fn._local_device_setter(
-        ps_devices=ps_devices,
-        ps_strategy=round_robin,
-        worker_device='/device:GPU:2')
-
-    with ops_lib.device(local_device_setter):
-      a = variables.Variable(0.01)
-      self.assertEqual('/device:GPU:0', a.device)
-
-      b = variables.Variable(0.02)
-      self.assertEqual('/device:GPU:1', b.device)
-
-      c = variables.Variable(0.03)
-      self.assertEqual('/device:GPU:3', c.device)
-
-      a_op = array_ops.concat(a, axis=0)
-      self.assertEqual('/device:GPU:2', a_op.device)
-
-      b_op = array_ops.concat(b, axis=0)
-      self.assertEqual('/device:GPU:2', b_op.device)
-
-      c = variables.Variable(0.03)
-      self.assertEqual('/device:GPU:4', c.device)
-
-      d = variables.Variable(0.03)
-      self.assertEqual('/device:GPU:0', d.device)
-
-      c_op = array_ops.concat(c, axis=0)
-      self.assertEqual('/device:GPU:2', c_op.device)
-
-
-class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase):
-
-  def test_vectors(self):
-    with self.test_session() as session:
-      total = replicate_model_fn._compute_sum_on_device(
-          [1.0, 2.0, 3.0, 4.0], device='/device:GPU:0', name='test_sum')
-
-      self.assertEqual('/device:GPU:0', total.device)
-      self.assertEqual('test_sum', total.op.name)
-      self.assertEqual(10.0, session.run(total))
-
-  def test_tensors(self):
-    with self.test_session() as session:
-      total = replicate_model_fn._compute_sum_on_device(
-          [[1.0, 2.0], [3.0, 4.0]], device='/device:GPU:0', name='test_sum')
-
-      self.assertEqual('/device:GPU:0', total.device)
-      self.assertEqual('test_sum', total.op.name)
-      self.assertAllEqual([4.0, 6.0], session.run(total))
-
-  def test_indexedslices(self):
-    with self.test_session() as session:
-      a = ops_lib.IndexedSlices(
-          constant_op.constant([1.0, 2.0]), [0, 1],
-          dense_shape=constant_op.constant([2]))
-      b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1])
-
-      total = replicate_model_fn._compute_sum_on_device(
-          [a, b], device='/device:GPU:0')
-
-      self.assertEqual('/device:GPU:0', total.device)
-      self.assertAllEqual([4.0, 6.0],
-                          session.run(ops_lib.convert_to_tensor(total)))
-
-  def test_indexedslices_higher_dimensions(self):
-    with self.test_session() as session:
-      a = ops_lib.IndexedSlices(
-          constant_op.constant([[1.0, 5.0], [2.0, 6.0]]), [0, 1],
-          dense_shape=constant_op.constant([2, 4]))
-      b = ops_lib.IndexedSlices(
-          constant_op.constant([[3.0, 7.0], [4.0, 8.0]]), [0, 1])
-
-      total = replicate_model_fn._compute_sum_on_device(
-          [a, b], device='/device:GPU:0')
-
-      self.assertEqual('/device:GPU:0', total.device)
-      self.assertAllEqual([[4.0, 12.0], [6.0, 14.0]],
-                          session.run(ops_lib.convert_to_tensor(total)))
-
-  def test_indexedslices_some_dont_overlap(self):
-    with self.test_session() as session:
-      a = ops_lib.IndexedSlices(
-          constant_op.constant([1.0, 2.0]), [0, 3],
-          dense_shape=constant_op.constant([4]))
-      b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1])
-
-      total = replicate_model_fn._compute_sum_on_device(
-          [a, b], device='/device:GPU:0')
-
-      self.assertEqual('/device:GPU:0', total.device)
-      self.assertAllEqual([4.0, 4.0, 0.0, 2.0],
-                          session.run(ops_lib.convert_to_tensor(total)))
-
-  def test_no_name_for_indexslices(self):
-    a = ops_lib.IndexedSlices(
-        constant_op.constant([1.0, 2.0]), [0, 1],
-        dense_shape=constant_op.constant([2]))
-    b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1])
-
-    with self.assertRaisesRegexp(ValueError, '.+name.+not.+expected.+'):
-      _ = replicate_model_fn._compute_sum_on_device(
-          [a, b], device='/device:GPU:0', name='cant_name_indexslices')
-
-
-class ConcatTensorDictsTest(test_util.TensorFlowTestCase):
-
-  def test_example(self):
-    tensor_dicts = [
-        {
-            'a': np.array([1.0, 2.0]),
-            'b': np.array([11.0]),
-            'c': np.array([21.0]),
-        },
-        {
-            'a': np.array([3.0]),
-            'b': np.array([12.0, 13.0]),
-        },
-        {
-            'b': np.array([14.0]),
-        },
-    ]
-
-    with self.test_session() as session:
-      self.assertAllClose({
-          'a': np.array([1.0, 2.0, 3.0]),
-          'b': np.array([11.0, 12.0, 13.0, 14.0]),
-          'c': np.array([21.0]),
-      }, session.run(replicate_model_fn._concat_tensor_dicts(*tensor_dicts)))
-
-
-if __name__ == '__main__':
-  test.main()