From: Yuefeng Zhou Date: Fri, 4 May 2018 05:01:39 +0000 (-0700) Subject: Add the MultiWorkerMirroredStrategy X-Git-Tag: upstream/v1.9.0_rc1~162^2^2~14 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=8ec11ae8eb7b97caced73ed3971209236e2aef5c;p=platform%2Fupstream%2Ftensorflow.git Add the MultiWorkerMirroredStrategy PiperOrigin-RevId: 195368876 --- diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index aaafc18..8dfcaf6 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -87,6 +87,19 @@ py_library( ) py_library( + name = "multi_worker_strategy", + srcs = ["multi_worker_strategy.py"], + visibility = ["//tensorflow:internal"], + deps = [ + ":mirrored_strategy", + ":values", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:training", + "//tensorflow/python:util", + ], +) + +py_library( name = "one_device_strategy", srcs = ["one_device_strategy.py"], visibility = ["//tensorflow:internal"], diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 2e57b02..8237b23 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -80,6 +80,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): dict((d, i) for i, d in enumerate(devices))) self._cross_tower_ops = cross_tower_ops self._prefetch_on_device = prefetch_on_device + # TODO(yuefengz): consider setting the default device. def _create_variable(self, next_creator, *args, **kwargs): """Create a mirrored variable. See `DistributionStrategy.scope`.""" diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy.py b/tensorflow/contrib/distribute/python/multi_worker_strategy.py new file mode 100644 index 0000000..a552b37 --- /dev/null +++ b/tensorflow/contrib/distribute/python/multi_worker_strategy.py @@ -0,0 +1,141 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Classes implementing a mirrored DistributionStrategy for multiple workers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from functools import partial + +from tensorflow.contrib.distribute.python import values +from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy +from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.python.training import device_util +from tensorflow.python.training import server_lib +from tensorflow.python.util import nest + + +# TODO(yuefengz): support between-graph replication. +# TODO(yuefengz): merge this class into its base class. +# TODO(yuefengz): in some cases, we probably want to use configure method to +# configure this class. +# TODO(yuefengz): MirroredStrategy.worker_devices may be confusing after the +# class is introduced. +class MultiWorkerMirroredStrategy(MirroredStrategy): + """Mirrored strategy that works on multiple workers with in-graph replication. + + There are several important concepts for distributed TensorFlow, e.g. + `client`, `job`, 'task', `cluster`, `in-graph replication` and + 'synchronous training' and they have already been defined in the + [TensorFlow's documentation](https://www.tensorflow.org/deploy/distributed). + The distribution strategy inherits these concepts as well and in addition to + that we also clarify several more concepts: + * **In-graph replication**: the `client` creates a single `tf.Graph` that + specifies tasks for devices on all workers. The `client` then creates a + client session which will talk to the `master` service of a `worker`. Then + the `master` will parition the graph and distribute the work to all + participating workers. + * **Worker**: A `worker` is a TensorFlow `task` that usually maps to one + physical machine. We will have multiple `worker`s with different `task` + index. They all do similar things except for one worker checkpointing model + variables, writing summaries, etc. in addition to its ordinary work. + + This class maps one tower to one device on a worker. It mirrors all model + variables on all towers. For example, if you have two `worker`s and each + `worker` has 4 GPUs, it will create 8 copies of the model variables on these 8 + GPUs. Then like in MirroredStrategy, each tower performs their computation + with their own copy of variables unless in cross-tower model where variable or + tensor reduction happens. + """ + + def __init__(self, + num_gpus_per_worker=1, + worker_job_name=None, + num_workers=None, + cluster=None, + cross_tower_ops=None, + prefetch_on_device=None): + """Initialize the strategy object. + + Args: + num_gpus_per_worker: number of GPUs per work. If it is zero, the local + CPU will be used. + worker_job_name: the job name for `worker`, typically just 'worker'. + num_workers: the number of workers. If it is 0, it regenerates to + single-worker MirroredStrategy. + cluster: a `tf.train.ClusterSpec` object or a dict that can be used to + construct a `tf.train.ClusterSpec` object or a `tf.train.ClusterDef` + proto buffer. It is an alternative way to initialize this object. + cross_tower_ops: the cross tower ops to use. If None, a default one will + be used. If configure method is called, a best one for the configuration + will be chosen. + prefetch_on_device: a boolean to specify whether to prefetech input to + each worker's devices. + + Raises: + ValueError: if got an unexpected `cluster`. + """ + if cluster is None: + self._workers = [ + '/job:%s/task:%d' % (worker_job_name, task_index) + for task_index in range(num_workers) + ] + else: + if isinstance(cluster, (dict, cluster_pb2.ClusterDef)): + cluster_spec = server_lib.ClusterSpec(cluster) + elif isinstance(cluster, server_lib.ClusterSpec): + cluster_spec = cluster + else: + raise ValueError( + "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " + '`tf.train.ClusterDef` object') + + self._workers = [] + for job in sorted(cluster_spec.jobs): + for task in range(cluster_spec.num_tasks(job)): + self._workers.append('/job:%s/task:%d' % (job, task)) + + self._num_gpus_per_worker = num_gpus_per_worker + if num_gpus_per_worker > 0: + self._worker_device_map = { + worker: [ + device_util.canonicalize(worker + '/device:GPU:%d' % gpu) + for gpu in range(num_gpus_per_worker) + ] for worker in self._workers + } + else: + self._worker_device_map = { + worker: [device_util.canonicalize(worker, '/device:CPU:0')] + for worker in self._workers + } + self._devices = nest.flatten(self._worker_device_map.values()) + + super(MultiWorkerMirroredStrategy, self).__init__( + devices=self._devices, prefetch_on_device=prefetch_on_device) + + # Setting `_default_device` will add a device scope in the + # distribution.scope. We set the default device to the first worker. When + # users specify device under distribution.scope by + # with tf.device("/cpu:0"): + # ... + # their ops will end up on the cpu device of its first worker, e.g. + # "/job:worker/task:0/device:CPU:0". Note this is not used in tower mode. + self._default_device = self._workers[0] + + def distribute_dataset(self, dataset_fn): + return values.MultiWorkerDataset( + partial(self._call_dataset_fn, dataset_fn), self._worker_device_map, + self._prefetch_on_device) diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py b/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py new file mode 100644 index 0000000..ee75881 --- /dev/null +++ b/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py @@ -0,0 +1,64 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MultiWorkerMirroredStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distribute.python import multi_worker_strategy +from tensorflow.contrib.distribute.python import multi_worker_test_base +from tensorflow.contrib.distribute.python import strategy_test_lib +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.training import server_lib + + +@test_util.with_c_api +class MultiWorkerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, + strategy_test_lib.DistributionTestBase): + + def _get_distribution_strategy(self): + return multi_worker_strategy.MultiWorkerMirroredStrategy( + cluster=server_lib.ClusterSpec({ + 'worker': ['/job:worker/task:0', '/job:worker/task:1'] + }), + num_gpus_per_worker=context.num_gpus()) + + def testMinimizeLossGraph(self): + self._test_minimize_loss_graph(self._get_distribution_strategy()) + + +class DeviceScopeTest(test.TestCase): + """Test the device scope of MultiWorkerMirroredStrategy.""" + + def testDeviceScope(self): + with context.graph_mode(): + strategy = multi_worker_strategy.MultiWorkerMirroredStrategy( + cluster={'worker': ['/job:worker/task:0', '/job:worker/task:1']}, + num_gpus_per_worker=context.num_gpus()) + with strategy.scope(): + a = constant_op.constant(1.) + with ops.device('/cpu:0'): + b = constant_op.constant(1.) + self.assertEqual(a.device, '/job:worker/task:0') + self.assertEqual(b.device, '/job:worker/task:0/device:CPU:0') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 64aa369..09b6d4a 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -40,6 +40,7 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): super(OneDeviceStrategy, self).__init__() self._device = device self._prefetch_on_device = prefetch_on_device + self._default_device = device def _create_variable(self, next_creator, *args, **kwargs): # No need to distinguish tower-local variables when not mirroring, diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py index c16b051..21f81ee 100644 --- a/tensorflow/python/training/distribute.py +++ b/tensorflow/python/training/distribute.py @@ -290,19 +290,31 @@ def _require_distribution_strategy_scope(distribution_strategy): class _CurrentDistributionContext(object): """Context manager for setting the `DistributionStrategy` and var creator.""" - def __init__(self, distribution_strategy, var_creator_scope, var_scope=None): + def __init__(self, + distribution_strategy, + var_creator_scope, + var_scope=None, + default_device=None): self._context = _CrossTowerThreadMode(distribution_strategy) self._var_creator_scope = var_creator_scope self._var_scope = var_scope + if default_device: + self._device_scope = ops.device(default_device) + else: + self._device_scope = None def __enter__(self): _push_per_thread_mode(self._context) if self._var_scope: self._var_scope.__enter__() self._var_creator_scope.__enter__() + if self._device_scope: + self._device_scope.__enter__() return self._context.distribution_strategy def __exit__(self, exception_type, exception_value, traceback): + if self._device_scope: + self._device_scope.__exit__(exception_type, exception_value, traceback) self._var_creator_scope.__exit__(exception_type, exception_value, traceback) if self._var_scope: self._var_scope.__exit__(exception_type, exception_value, traceback) @@ -557,6 +569,9 @@ class DistributionStrategy(object): # TODO(josh11b): List of towers with their worker and parameter devices # (where the parameter devices may overlap in the ps case). + def __init__(self): + self._default_device = None + def scope(self): """Returns a context manager selecting this DistributionStrategy as current. @@ -587,7 +602,8 @@ class DistributionStrategy(object): self, variable_scope.variable_creator_scope(creator_with_resource_vars), variable_scope.variable_scope( variable_scope.get_variable_scope(), - custom_getter=disable_partitioned_variables)) + custom_getter=disable_partitioned_variables), + self._default_device) def _create_variable(self, next_creator, *args, **kwargs): # Note: should support "colocate_with" argument.