Add the MultiWorkerMirroredStrategy
authorYuefeng Zhou <yuefengz@google.com>
Fri, 4 May 2018 05:01:39 +0000 (22:01 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 4 May 2018 17:42:01 +0000 (10:42 -0700)
PiperOrigin-RevId: 195368876

tensorflow/contrib/distribute/python/BUILD
tensorflow/contrib/distribute/python/mirrored_strategy.py
tensorflow/contrib/distribute/python/multi_worker_strategy.py [new file with mode: 0644]
tensorflow/contrib/distribute/python/multi_worker_strategy_test.py [new file with mode: 0644]
tensorflow/contrib/distribute/python/one_device_strategy.py
tensorflow/python/training/distribute.py

index aaafc18..8dfcaf6 100644 (file)
@@ -87,6 +87,19 @@ py_library(
 )
 
 py_library(
+    name = "multi_worker_strategy",
+    srcs = ["multi_worker_strategy.py"],
+    visibility = ["//tensorflow:internal"],
+    deps = [
+        ":mirrored_strategy",
+        ":values",
+        "//tensorflow/core:protos_all_py",
+        "//tensorflow/python:training",
+        "//tensorflow/python:util",
+    ],
+)
+
+py_library(
     name = "one_device_strategy",
     srcs = ["one_device_strategy.py"],
     visibility = ["//tensorflow:internal"],
index 2e57b02..8237b23 100644 (file)
@@ -80,6 +80,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
         dict((d, i) for i, d in enumerate(devices)))
     self._cross_tower_ops = cross_tower_ops
     self._prefetch_on_device = prefetch_on_device
+    # TODO(yuefengz): consider setting the default device.
 
   def _create_variable(self, next_creator, *args, **kwargs):
     """Create a mirrored variable. See `DistributionStrategy.scope`."""
diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy.py b/tensorflow/contrib/distribute/python/multi_worker_strategy.py
new file mode 100644 (file)
index 0000000..a552b37
--- /dev/null
@@ -0,0 +1,141 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Classes implementing a mirrored DistributionStrategy for multiple workers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from functools import partial
+
+from tensorflow.contrib.distribute.python import values
+from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy
+from tensorflow.core.protobuf import cluster_pb2
+from tensorflow.python.training import device_util
+from tensorflow.python.training import server_lib
+from tensorflow.python.util import nest
+
+
+# TODO(yuefengz): support between-graph replication.
+# TODO(yuefengz): merge this class into its base class.
+# TODO(yuefengz): in some cases, we probably want to use configure method to
+# configure this class.
+# TODO(yuefengz): MirroredStrategy.worker_devices may be confusing after the
+# class is introduced.
+class MultiWorkerMirroredStrategy(MirroredStrategy):
+  """Mirrored strategy that works on multiple workers with in-graph replication.
+
+  There are several important concepts for distributed TensorFlow, e.g.
+  `client`, `job`, 'task', `cluster`, `in-graph replication` and
+  'synchronous training' and they have already been defined in the
+  [TensorFlow's documentation](https://www.tensorflow.org/deploy/distributed).
+  The distribution strategy inherits these concepts as well and in addition to
+  that we also clarify several more concepts:
+    * **In-graph replication**: the `client` creates a single `tf.Graph` that
+    specifies tasks for devices on all workers. The `client` then creates a
+    client session which will talk to the `master` service of a `worker`. Then
+    the `master` will parition the graph and distribute the work to all
+    participating workers.
+    * **Worker**: A `worker` is a TensorFlow `task` that usually maps to one
+    physical machine. We will have multiple `worker`s with different `task`
+    index. They all do similar things except for one worker checkpointing model
+    variables, writing summaries, etc. in addition to its ordinary work.
+
+  This class maps one tower to one device on a worker. It mirrors all model
+  variables on all towers. For example, if you have two `worker`s and each
+  `worker` has 4 GPUs, it will create 8 copies of the model variables on these 8
+  GPUs. Then like in MirroredStrategy, each tower performs their computation
+  with their own copy of variables unless in cross-tower model where variable or
+  tensor reduction happens.
+  """
+
+  def __init__(self,
+               num_gpus_per_worker=1,
+               worker_job_name=None,
+               num_workers=None,
+               cluster=None,
+               cross_tower_ops=None,
+               prefetch_on_device=None):
+    """Initialize the strategy object.
+
+    Args:
+      num_gpus_per_worker: number of GPUs per work. If it is zero, the local
+        CPU will be used.
+      worker_job_name: the job name for `worker`, typically just 'worker'.
+      num_workers: the number of workers. If it is 0, it regenerates to
+        single-worker MirroredStrategy.
+      cluster: a `tf.train.ClusterSpec` object or a dict that can be used to
+        construct a `tf.train.ClusterSpec` object or a `tf.train.ClusterDef`
+        proto buffer. It is an alternative way to initialize this object.
+      cross_tower_ops: the cross tower ops to use. If None, a default one will
+        be used. If configure method is called, a best one for the configuration
+        will be chosen.
+      prefetch_on_device: a boolean to specify whether to prefetech input to
+        each worker's devices.
+
+    Raises:
+      ValueError: if got an unexpected `cluster`.
+    """
+    if cluster is None:
+      self._workers = [
+          '/job:%s/task:%d' % (worker_job_name, task_index)
+          for task_index in range(num_workers)
+      ]
+    else:
+      if isinstance(cluster, (dict, cluster_pb2.ClusterDef)):
+        cluster_spec = server_lib.ClusterSpec(cluster)
+      elif isinstance(cluster, server_lib.ClusterSpec):
+        cluster_spec = cluster
+      else:
+        raise ValueError(
+            "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
+            '`tf.train.ClusterDef` object')
+
+      self._workers = []
+      for job in sorted(cluster_spec.jobs):
+        for task in range(cluster_spec.num_tasks(job)):
+          self._workers.append('/job:%s/task:%d' % (job, task))
+
+    self._num_gpus_per_worker = num_gpus_per_worker
+    if num_gpus_per_worker > 0:
+      self._worker_device_map = {
+          worker: [
+              device_util.canonicalize(worker + '/device:GPU:%d' % gpu)
+              for gpu in range(num_gpus_per_worker)
+          ] for worker in self._workers
+      }
+    else:
+      self._worker_device_map = {
+          worker: [device_util.canonicalize(worker, '/device:CPU:0')]
+          for worker in self._workers
+      }
+    self._devices = nest.flatten(self._worker_device_map.values())
+
+    super(MultiWorkerMirroredStrategy, self).__init__(
+        devices=self._devices, prefetch_on_device=prefetch_on_device)
+
+    # Setting `_default_device` will add a device scope in the
+    # distribution.scope. We set the default device to the first worker. When
+    # users specify device under distribution.scope by
+    #   with tf.device("/cpu:0"):
+    #     ...
+    # their ops will end up on the cpu device of its first worker, e.g.
+    # "/job:worker/task:0/device:CPU:0". Note this is not used in tower mode.
+    self._default_device = self._workers[0]
+
+  def distribute_dataset(self, dataset_fn):
+    return values.MultiWorkerDataset(
+        partial(self._call_dataset_fn, dataset_fn), self._worker_device_map,
+        self._prefetch_on_device)
diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py b/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py
new file mode 100644 (file)
index 0000000..ee75881
--- /dev/null
@@ -0,0 +1,64 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for MultiWorkerMirroredStrategy."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.distribute.python import multi_worker_strategy
+from tensorflow.contrib.distribute.python import multi_worker_test_base
+from tensorflow.contrib.distribute.python import strategy_test_lib
+from tensorflow.python.eager import context
+from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.training import server_lib
+
+
+@test_util.with_c_api
+class MultiWorkerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
+                              strategy_test_lib.DistributionTestBase):
+
+  def _get_distribution_strategy(self):
+    return multi_worker_strategy.MultiWorkerMirroredStrategy(
+        cluster=server_lib.ClusterSpec({
+            'worker': ['/job:worker/task:0', '/job:worker/task:1']
+        }),
+        num_gpus_per_worker=context.num_gpus())
+
+  def testMinimizeLossGraph(self):
+    self._test_minimize_loss_graph(self._get_distribution_strategy())
+
+
+class DeviceScopeTest(test.TestCase):
+  """Test the device scope of MultiWorkerMirroredStrategy."""
+
+  def testDeviceScope(self):
+    with context.graph_mode():
+      strategy = multi_worker_strategy.MultiWorkerMirroredStrategy(
+          cluster={'worker': ['/job:worker/task:0', '/job:worker/task:1']},
+          num_gpus_per_worker=context.num_gpus())
+      with strategy.scope():
+        a = constant_op.constant(1.)
+        with ops.device('/cpu:0'):
+          b = constant_op.constant(1.)
+        self.assertEqual(a.device, '/job:worker/task:0')
+        self.assertEqual(b.device, '/job:worker/task:0/device:CPU:0')
+
+
+if __name__ == '__main__':
+  test.main()
index 64aa369..09b6d4a 100644 (file)
@@ -40,6 +40,7 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
     super(OneDeviceStrategy, self).__init__()
     self._device = device
     self._prefetch_on_device = prefetch_on_device
+    self._default_device = device
 
   def _create_variable(self, next_creator, *args, **kwargs):
     # No need to distinguish tower-local variables when not mirroring,
index c16b051..21f81ee 100644 (file)
@@ -290,19 +290,31 @@ def _require_distribution_strategy_scope(distribution_strategy):
 class _CurrentDistributionContext(object):
   """Context manager for setting the `DistributionStrategy` and var creator."""
 
-  def __init__(self, distribution_strategy, var_creator_scope, var_scope=None):
+  def __init__(self,
+               distribution_strategy,
+               var_creator_scope,
+               var_scope=None,
+               default_device=None):
     self._context = _CrossTowerThreadMode(distribution_strategy)
     self._var_creator_scope = var_creator_scope
     self._var_scope = var_scope
+    if default_device:
+      self._device_scope = ops.device(default_device)
+    else:
+      self._device_scope = None
 
   def __enter__(self):
     _push_per_thread_mode(self._context)
     if self._var_scope:
       self._var_scope.__enter__()
     self._var_creator_scope.__enter__()
+    if self._device_scope:
+      self._device_scope.__enter__()
     return self._context.distribution_strategy
 
   def __exit__(self, exception_type, exception_value, traceback):
+    if self._device_scope:
+      self._device_scope.__exit__(exception_type, exception_value, traceback)
     self._var_creator_scope.__exit__(exception_type, exception_value, traceback)
     if self._var_scope:
       self._var_scope.__exit__(exception_type, exception_value, traceback)
@@ -557,6 +569,9 @@ class DistributionStrategy(object):
   # TODO(josh11b): List of towers with their worker and parameter devices
   #   (where the parameter devices may overlap in the ps case).
 
+  def __init__(self):
+    self._default_device = None
+
   def scope(self):
     """Returns a context manager selecting this DistributionStrategy as current.
 
@@ -587,7 +602,8 @@ class DistributionStrategy(object):
         self, variable_scope.variable_creator_scope(creator_with_resource_vars),
         variable_scope.variable_scope(
             variable_scope.get_variable_scope(),
-            custom_getter=disable_partitioned_variables))
+            custom_getter=disable_partitioned_variables),
+        self._default_device)
 
   def _create_variable(self, next_creator, *args, **kwargs):
     # Note: should support "colocate_with" argument.