From 30bb88dba14d7bfb0472bd79f949014b6f2902a7 Mon Sep 17 00:00:00 2001 From: Jianwei Xie Date: Sun, 11 Feb 2018 15:54:39 -0800 Subject: [PATCH] [TPUEstimator] Automatically detect the TPU system information, including topology for model parallelism. PiperOrigin-RevId: 185318852 --- tensorflow/contrib/tpu/BUILD | 2 + tensorflow/contrib/tpu/python/tpu/tpu_config.py | 39 +- .../contrib/tpu/python/tpu/tpu_config_test.py | 5 - tensorflow/contrib/tpu/python/tpu/tpu_context.py | 492 +++++++++++++++++++++ tensorflow/contrib/tpu/python/tpu/tpu_estimator.py | 393 ++-------------- .../contrib/tpu/python/tpu/tpu_system_metadata.py | 139 ++++++ 6 files changed, 676 insertions(+), 394 deletions(-) create mode 100644 tensorflow/contrib/tpu/python/tpu/tpu_context.py create mode 100644 tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index a7d54d8..c48e84d 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -36,7 +36,9 @@ py_library( name = "tpu_estimator", srcs = [ "python/tpu/tpu_config.py", + "python/tpu/tpu_context.py", "python/tpu/tpu_estimator.py", + "python/tpu/tpu_system_metadata.py", "python/tpu/util.py", ], srcs_version = "PY2AND3", diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py index a1076b7..6440702 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py @@ -55,17 +55,18 @@ class TPUConfig( system before returning to CPU host for each `Session.run`. This means global step is increased `iterations_per_loop` times in one `Session.run`. It is recommended to be set as number of global steps for next checkpoint. - num_shards: The number of model replicas in the system. For - non-model-parallelism case, this number equals the total number of TPU - cores. For model-parallelism, the total number of TPU cores equals + num_shards: (Deprecated, ignored by TPUEstimator). + The number of model replicas in the system. For non-model-parallelism + case, this number equals the total number of TPU cores. For + model-parallelism, the total number of TPU cores equals product(computation_shape) * num_shards. - computation_shape: A list of size 3 which describes the shape of a model - replica's block of cores. This is required by model-parallelism which - enables partitioning the model to multiple cores. For example, [2, 2, 1] - means the model - is partitioned across 4 cores which span two cores in both x and y - coordinates. Set it to `None` for non-model-parallelism. Please refer to - ${tf.contrib.tpu.TopologyProto} for the geometry of a TPU mesh. + computation_shape: Defaults to `None`, which disables model parallelism. A + list of size 3 which describes the shape of a model replica's block of + cores. This is required by model-parallelism which enables partitioning + the model to multiple cores. For example, [2, 2, 1] means the model is + partitioned across 4 cores which span two cores in both x and y + coordinates. Please refer to ${tf.contrib.tpu.TopologyProto} for the + geometry of a TPU mesh. per_host_input_for_training: If `True`, `input_fn` is invoked Per-Host rather than Per-Core. With Per-Host input pipeline deployment, `input_fn` is invoked once on each host. With Per-Core input pipeline deployment, it @@ -80,13 +81,14 @@ class TPUConfig( initial_infeed_sleep_secs: The number of seconds the infeed thread should wait before enqueueing the first batch. This helps avoid timeouts for models that require a long compilation time. + Raises: ValueError: If `computation_shape` or `computation_shape` are invalid. """ def __new__(cls, iterations_per_loop=2, - num_shards=2, + num_shards=None, computation_shape=None, per_host_input_for_training=True, tpu_job_name=None, @@ -97,7 +99,8 @@ class TPUConfig( 'TPUConfig iterations_per_loop') # Check num_shards. - util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards') + if num_shards is not None: + util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards') # Check computation_shape if computation_shape is not None and len(computation_shape) != 3: @@ -112,18 +115,6 @@ class TPUConfig( if any(computation_shape_array < 1) or any(computation_shape_array > 2): raise ValueError('computation_shape elements can only be 1 or 2; got ' 'computation_shape={}'.format(computation_shape)) - max_replicas_per_host = ( - _NUM_CORES_PER_HOST // np.prod(computation_shape_array)) - if num_shards > max_replicas_per_host and ( - num_shards % max_replicas_per_host != 0): - raise ValueError( - '{0} shards can not be evenly distributed across' - ' multiple hosts. Each shard needs {1} cores and each' - ' host has {2} cores. Thus {0} shards needs {3} hosts.' - ' Please adjust num shards so that num_shards is' - ' divisible by {4} or <= {4}.'.format( - num_shards, np.prod(computation_shape), _NUM_CORES_PER_HOST, - num_shards / max_replicas_per_host, max_replicas_per_host)) # Check initial_infeed_sleep_secs. if initial_infeed_sleep_secs: diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py index 2c9d7be..37ef3db 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py @@ -53,11 +53,6 @@ class TPURunConfigTest(test.TestCase): 'computation_shape elements can only be'): tpu_config_lib.TPUConfig(computation_shape=[1, 3, 1]) - def test_fail_with_invalid_shards(self): - with self.assertRaisesRegexp(ValueError, - 'shards can not be evenly distributed across'): - tpu_config_lib.TPUConfig(num_shards=6, computation_shape=[1, 1, 2]) - class TPURunConfigMasterTest(test.TestCase): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py new file mode 100644 index 0000000..8c65018 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py @@ -0,0 +1,492 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =================================================================== +"""TPU system metdata and associated tooling.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from contextlib import contextmanager +import copy + +import numpy as np + +from tensorflow.contrib.tpu.python.tpu import device_assignment as tpu_device_assignment +from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.platform import tf_logging as logging + + +_DEFAULT_JOB_NAME = 'tpu_worker' +_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator' +_LOCAL_MASTERS = ('', 'local') + + +class _TPUContext(object): + """A context holds immutable states of TPU computation. + + This immutable object holds TPUEstimator config, train/eval batch size, and + `TPUEstimator.use_tpu`, which is expected to be passed around. It also + provides utility functions, basded on the current state, to determine other + information commonly required by TPU computation, such as TPU device names, + TPU hosts, shard batch size, etc. + + N.B. As `mode` is not immutable state in Estimator, but essential to + distinguish between TPU training and evaluation, a common usage for + _TPUContext with `mode` is as follows: + ``` + with _ctx.with_mode(mode) as ctx: + if ctx.is_running_on_cpu(): + ... + ``` + """ + + def __init__(self, config, train_batch_size, eval_batch_size, + predict_batch_size, use_tpu): + self._config = config + self._train_batch_size = train_batch_size + self._eval_batch_size = eval_batch_size + self._predict_batch_size = predict_batch_size + self._use_tpu = use_tpu + self._model_parallelism_enabled = ( + use_tpu and config.tpu_config.computation_shape) + self._mode = None + + self._lazy_tpu_system_metadata_dict = {} # key by master address + self._lazy_device_assignment_dict = {} # key by master address + self._lazy_validation_dict = {} # key by ModeKeys + + def _assert_mode(self): + if self._mode is None: + raise RuntimeError( + '`mode` needs to be set via contextmanager `with_mode`.') + return self._mode + + @contextmanager + def with_mode(self, mode): + # NOTE(xiejw): Shallow copy is enough. It will share he lazy dictionaries, + # such as _lazy_tpu_system_metadata_dict between new copy and the original + # one. Note that all lazy states stored in properties _lazy_foo are sort of + # immutable as they should be same for the process lifetime. + new_ctx = copy.copy(self) + new_ctx._mode = mode # pylint: disable=protected-access + yield new_ctx + + @property + def mode(self): + return self._assert_mode() + + def _get_master_address(self): + mode = self._assert_mode() + config = self._config + master = ( + config.master + if mode != model_fn_lib.ModeKeys.EVAL else config.evaluation_master) + return master + + def _get_tpu_system_metadata(self): + """Gets the (maybe cached) TPU system metadata.""" + master = self._get_master_address() + tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master) + if tpu_system_metadata is not None: + return tpu_system_metadata + + # pylint: disable=protected-access + tpu_system_metadata = ( + tpu_system_metadata_lib._query_tpu_system_metadata( + master, query_topology=self.model_parallelism_enabled)) + + self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata + return tpu_system_metadata + + def _get_device_assignment(self): + """Gets the (maybe cached) TPU device assignment.""" + master = self._get_master_address() + device_assignment = self._lazy_device_assignment_dict.get(master) + if device_assignment is not None: + return device_assignment + + tpu_system_metadata = self._get_tpu_system_metadata() + + device_assignment = tpu_device_assignment.device_assignment( + tpu_system_metadata.topology, + computation_shape=self._config.tpu_config.computation_shape, + num_replicas=self.num_replicas) + + logging.info('computation_shape: %s', + str(self._config.tpu_config.computation_shape)) + logging.info('num_replicas: %d', self.num_replicas) + logging.info('device_assignment.topology.device_coordinates: %s', + str(device_assignment.topology.device_coordinates)) + logging.info('device_assignment.core_assignment: %s', + str(device_assignment.core_assignment)) + + self._lazy_device_assignment_dict[master] = device_assignment + return device_assignment + + @property + def model_parallelism_enabled(self): + return self._model_parallelism_enabled + + @property + def device_assignment(self): + return (self._get_device_assignment() + if self._model_parallelism_enabled else None) + + @property + def num_of_cores_per_host(self): + metadata = self._get_tpu_system_metadata() + return metadata.num_of_cores_per_host + + @property + def num_cores(self): + metadata = self._get_tpu_system_metadata() + return metadata.num_cores + + @property + def num_of_replicas_per_host(self): + if self.model_parallelism_enabled: + return self.num_replicas // self.num_hosts + else: + return self.num_of_cores_per_host + + @property + def num_replicas(self): + num_cores_in_system = self.num_cores + + if self.model_parallelism_enabled: + computation_shape_array = np.asarray( + self._config.tpu_config.computation_shape, dtype=np.int32) + num_cores_per_replica = np.prod(computation_shape_array) + if num_cores_per_replica > num_cores_in_system: + raise ValueError( + 'The num of cores required by the model parallelism, specified by ' + 'TPUConfig.computation_shape, is larger than the total num of ' + 'TPU cores in the system. computation_shape: {}, num cores ' + 'in the system: {}'.format( + self._config.tpu_config.computation_shape, + num_cores_in_system)) + + if num_cores_in_system % num_cores_per_replica != 0: + raise RuntimeError( + 'The num of cores in the system ({}) is not divisible by the num ' + 'of cores ({}) required by the model parallelism, specified by ' + 'TPUConfig.computation_shape. This should never happen!'.format( + num_cores_in_system, num_cores_per_replica)) + + return num_cores_in_system // num_cores_per_replica + else: + return num_cores_in_system + + @property + def num_hosts(self): + metadata = self._get_tpu_system_metadata() + return metadata.num_hosts + + @property + def config(self): + return self._config + + def is_input_sharded_per_core(self): + """Return true if input_fn is invoked per-core (other than per-host).""" + mode = self._assert_mode() + return (mode == model_fn_lib.ModeKeys.TRAIN and + not self._config.tpu_config.per_host_input_for_training) + + def is_running_on_cpu(self, is_export_mode=False): + """Determines whether the input_fn and model_fn should be invoked on CPU. + + This API also validates user provided configuration, such as batch size, + according the lazy initialized TPU system metadata. + + Args: + is_export_mode: Indicates whether the current mode is for exporting the + model, when mode == PREDICT. Only with this bool, we could + tell whether user is calling the Estimator.predict or + Estimator.export_savedmodel, which are running on TPU and CPU + respectively. Parent class Estimator does not distingush these two. + + Returns: + bool, whether current input_fn or model_fn should be running on CPU. + + Raises: + ValueError: any configuration is invalid. + """ + + is_running_on_cpu = self._is_running_on_cpu(is_export_mode) + if not is_running_on_cpu: + self._validate_tpu_configuration() + return is_running_on_cpu + + def _is_running_on_cpu(self, is_export_mode): + """Determines whether the input_fn and model_fn should be invoked on CPU.""" + mode = self._assert_mode() + + if not self._use_tpu: + return True + + if mode != model_fn_lib.ModeKeys.PREDICT: + return False + + # There are actually 2 use cases when running with mode.PREDICT: prediction + # and saving the model. We run actual predictions on the TPU, but + # model export is run on the CPU. + if is_export_mode: + return True + + return False + + @property + def global_batch_size(self): + mode = self._assert_mode() + if mode == model_fn_lib.ModeKeys.TRAIN: + return self._train_batch_size + elif mode == model_fn_lib.ModeKeys.EVAL: + return self._eval_batch_size + elif mode == model_fn_lib.ModeKeys.PREDICT: + return self._predict_batch_size + else: + return None + + @property + def batch_size_for_input_fn(self): + """Returns the shard batch size for `input_fn`.""" + global_batch_size = self.global_batch_size + + if self.is_running_on_cpu(): + return global_batch_size + + # On TPU + if self.is_input_sharded_per_core(): + # We prohibit per core input sharding for the model parallelism case, + # therefore it is safe to use num_cores here. + return global_batch_size // self.num_cores + else: + return global_batch_size // self.num_hosts + + @property + def batch_size_for_model_fn(self): + """Returns the shard batch size for `model_fn`.""" + global_batch_size = self.global_batch_size + + if self.is_running_on_cpu(): + return global_batch_size + + # On TPU. always sharded per shard. + return global_batch_size // self.num_replicas + + @property + def master_job(self): + """Returns the job name to use to place TPU computations on. + + Returns: + A string containing the job name, or None if no job should be specified. + + Raises: + ValueError: If the user needs to specify a tpu_job_name, because we are + unable to infer the job name automatically, or if the user-specified job + names are inappropriate. + """ + run_config = self._config + # If the user specifies the tpu_job_name, use that. + if run_config.tpu_config.tpu_job_name: + return run_config.tpu_config.tpu_job_name + + # The tpu job is determined by the run_config. Right now, this method is + # required as tpu_config is not part of the RunConfig. + mode = self._assert_mode() + master = ( + run_config.evaluation_master + if mode == model_fn_lib.ModeKeys.EVAL else run_config.master) + if master in _LOCAL_MASTERS: + return None + + if (not run_config.session_config or + not run_config.session_config.cluster_def.job): + return _DEFAULT_JOB_NAME + cluster_def = run_config.session_config.cluster_def + job_names = set([job.name for job in cluster_def.job]) + if _DEFAULT_JOB_NAME in job_names: + # b/37868888 tracks allowing ClusterSpec propagation to reuse job names. + raise ValueError('Currently, tpu_worker is not an allowed job name.') + if len(job_names) == 1: + return cluster_def.job[0].name + if len(job_names) == 2: + if _DEFAULT_COORDINATOR_JOB_NAME in job_names: + job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME) + return job_names.pop() + # TODO(b/67716447): Include more sophisticated heuristics. + raise ValueError( + 'Could not infer TPU job name. Please specify a tpu_job_name as part ' + 'of your TPUConfig.') + + @property + def tpu_host_placement_function(self): + """Returns the TPU host place function.""" + master = self.master_job + + def _placement_function(_sentinal=None, core_id=None, host_id=None): # pylint: disable=invalid-name + assert _sentinal is None + if core_id is not None and host_id is not None: + raise RuntimeError( + 'core_id and host_id can have only one non-None value.') + + if master is None: + return '/replica:0/task:0/device:CPU:0' + else: + if core_id is not None: + host_id = core_id / self.num_of_cores_per_host + return '/job:%s/task:%d/device:CPU:0' % (master, host_id) + + return _placement_function + + @property + def tpu_device_placement_function(self): + """Returns a TPU device placement Fn.""" + master = self.master_job + job_device = '' if master is None else ('/job:%s' % master) + + def _placement_function(i): + if self.model_parallelism_enabled: + return self.device_assignment.tpu_device(replica=i, job=master) + else: + num_of_cores_per_host = self.num_of_cores_per_host + host_id = i / num_of_cores_per_host + ordinal_id = i % num_of_cores_per_host + return '%s/task:%d/device:TPU:%d' % (job_device, host_id, ordinal_id) + + return _placement_function + + @property + def tpu_ordinal_function(self): + """Returns the TPU ordinal fn.""" + + def _tpu_ordinal_function(index): + """Return the TPU ordinal associated with a shard. + + Required because the enqueue ops are placed on CPU. + + Args: + index: the shard index + + Returns: + The ordinal of the TPU device the shard's infeed should be placed on. + """ + if self.model_parallelism_enabled: + return self.device_assignment.tpu_ordinal(replica=index) + else: + return index % self.num_of_cores_per_host + + return _tpu_ordinal_function + + def _validate_tpu_configuration(self): + """Validates the configuration based on the TPU system metadata.""" + mode = self._assert_mode() + if self._lazy_validation_dict.get(mode): + return + + # All following information is obtained from TPU system metadata. + num_cores = self.num_cores + num_replicas = self.num_replicas + num_hosts = self.num_hosts + + if not num_cores: + tpu_system_metadata = self._get_tpu_system_metadata() + raise RuntimeError( + 'Cannot find any TPU cores in the system. Please double check ' + 'Tensorflow master address and TPU worker(s). Available devices ' + 'are {}.'.format(tpu_system_metadata.devices)) + + if mode == model_fn_lib.ModeKeys.TRAIN: + if self._train_batch_size % num_replicas != 0: + raise ValueError( + 'train batch size {} must be divisible by number of replicas {}' + .format(self._train_batch_size, num_replicas)) + + elif mode == model_fn_lib.ModeKeys.EVAL: + if self._eval_batch_size is None: + raise ValueError( + 'eval_batch_size in TPUEstimator constructor cannot be `None`' + 'if .evaluate is running on TPU.') + if self._eval_batch_size % num_replicas != 0: + raise ValueError( + 'eval batch size {} must be divisible by number of replicas {}' + .format(self._eval_batch_size, num_replicas)) + if num_hosts > 1: + raise ValueError( + 'TPUEstimator.evaluate should be running on single TPU worker. ' + 'got {}.'.format(num_hosts)) + else: + assert mode == model_fn_lib.ModeKeys.PREDICT + if self._predict_batch_size is None: + raise ValueError( + 'predict_batch_size in TPUEstimator constructor should not be ' + '`None` if .predict is running on TPU.') + if self._predict_batch_size % num_replicas != 0: + raise ValueError( + 'predict batch size {} must be divisible by number of replicas {}' + .format(self._predict_batch_size, num_replicas)) + if num_hosts > 1: + raise ValueError( + 'TPUEstimator.predict should be running on single TPU worker. ' + 'got {}.'.format(num_hosts)) + + # Record the state "validated" into lazy dictionary. + self._lazy_validation_dict[mode] = True + + +class _OneCoreTPUContext(_TPUContext): + """Special _TPUContext for one core usage.""" + + def __init__(self, config, train_batch_size, eval_batch_size, + predict_batch_size, use_tpu): + + super(_OneCoreTPUContext, self).__init__( + config, train_batch_size, eval_batch_size, + predict_batch_size, use_tpu) + + def _get_tpu_system_metadata(self): + """Gets the (maybe cached) TPU system metadata.""" + master = self._get_master_address() + tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master) + if tpu_system_metadata is not None: + return tpu_system_metadata + + tpu_system_metadata = ( + tpu_system_metadata_lib._TPUSystemMetadata( # pylint: disable=protected-access + num_cores=1, + num_hosts=1, + num_of_cores_per_host=1, + topology=None, + devices=[])) + + self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata + return tpu_system_metadata + + +def _get_tpu_context(config, train_batch_size, eval_batch_size, + predict_batch_size, use_tpu): + """Returns an instance of `_TPUContext`.""" + + if (config.tpu_config.num_shards == 1 and + config.tpu_config.computation_shape is None): + logging.warning( + 'Setting TPUConfig.num_shards==1 is an unsupported behavior. ' + 'Please fix as soon as possible (leaving num_shards as None.') + return _OneCoreTPUContext(config, train_batch_size, eval_batch_size, + predict_batch_size, use_tpu) + + return _TPUContext(config, train_batch_size, eval_batch_size, + predict_batch_size, use_tpu) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 7d2f655..ff53fe4 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import collections -from contextlib import contextmanager import copy import signal import threading @@ -34,13 +33,12 @@ from tensorflow.contrib.summary import summary_ops as contrib_summary from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_config +from tensorflow.contrib.tpu.python.tpu import tpu_context from tensorflow.contrib.tpu.python.tpu import tpu_feed from tensorflow.contrib.tpu.python.tpu import training_loop from tensorflow.contrib.tpu.python.tpu import util as util_lib -from tensorflow.contrib.tpu.python.tpu.device_assignment import device_assignment as tpu_device_assignment from tensorflow.core.framework.summary_pb2 import Summary from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.client import session as tf_session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib @@ -71,7 +69,7 @@ _TPU_ESTIMATOR = 'tpu_estimator' _ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop' _BATCH_SIZE_KEY = 'batch_size' _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum' -_PINGING_MASTER_TIMEOUT_IN_MS = 300 * 1000 # 5 minutes + _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY] @@ -149,285 +147,6 @@ def _increase_eval_step_op(iterations_per_loop): use_locking=True) -_DEFAULT_JOB_NAME = 'tpu_worker' -_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator' -_LOCAL_MASTERS = ('', 'local') - - -class _TPUContext(object): - """A context holds immutable states of TPU computation. - - This immutable object holds TPUEstimator config, train/eval batch size, and - `TPUEstimator.use_tpu`, which is expected to be passed around. It also - provides utility functions, basded on the current state, to determine other - information commonly required by TPU computation, such as TPU device names, - TPU hosts, shard batch size, etc. - - N.B. As `mode` is not immutable state in Estimator, but essential to - distinguish between TPU training and evaluation, a common usage for - _TPUContext with `mode` is as follows: - ``` - with _ctx.with_mode(mode) as ctx: - if ctx.is_running_on_cpu(): - ... - ``` - """ - - def __init__(self, config, train_batch_size, eval_batch_size, - predict_batch_size, use_tpu, device_assignment): - self._config = config - self._train_batch_size = train_batch_size - self._eval_batch_size = eval_batch_size - self._predict_batch_size = predict_batch_size - self._use_tpu = use_tpu - self._mode = None - self._device_assignment = device_assignment - self._max_cores_per_host = 8 - - def _assert_mode(self): - if self._mode is None: - raise RuntimeError( - '`mode` needs to be set via contextmanager `with_mode`.') - return self._mode - - @property - def num_of_cores_per_host(self): - num_cores = self.num_cores - return min(num_cores, self._max_cores_per_host) - - @property - def num_of_shards_per_host(self): - if self._device_assignment: - maximum_shards_per_host = ( - self._max_cores_per_host // - self._device_assignment.num_cores_per_replica) - return min(self.num_shards, maximum_shards_per_host) - else: - num_cores = self.num_cores - return min(num_cores, self._max_cores_per_host) - - @contextmanager - def with_mode(self, mode): - new_ctx = copy.copy(self) # Shallow copy is enough. - new_ctx._mode = mode # pylint: disable=protected-access - yield new_ctx - - @property - def mode(self): - return self._assert_mode() - - @property - def num_cores(self): - # TODO(xiejw): Adds lazy num_shards initialization. - if self._device_assignment: - return self._device_assignment.num_cores_per_replica * self.num_shards - else: - return self.num_shards - - @property - def num_shards(self): - return self._config.tpu_config.num_shards - - @property - def num_hosts(self): - return self.num_cores // self.num_of_cores_per_host - - @property - def config(self): - return self._config - - def is_input_sharded_per_core(self): - """Return true if input_fn is invoked per-core (other than per-host).""" - self._assert_mode() - return (self._mode == model_fn_lib.ModeKeys.TRAIN and - not self._config.tpu_config.per_host_input_for_training) - - def is_running_on_cpu(self, is_export_mode=False): - """Determines whether the input_fn and model_fn should be invoked on CPU. - - Args: - is_export_mode: Indicates whether the current mode is for exporting the - model, when mode == PREDICT. Only with this bool, we could - tell whether user is calling the Estimator.predict or - Estimator.export_savedmodel, which are running on TPU and CPU - respectively. Parent class Estimator does not distingush these two. - - Returns: - bool, whether current input_fn or model_fn should be running on CPU. - - Raises: - ValueError: any configuration is invalid. - """ - mode = self._assert_mode() - - if not self._use_tpu: - return True - - if mode != model_fn_lib.ModeKeys.PREDICT: - return False - - # There are actually 2 use cases when running with mode.PREDICT: prediction - # and saving the model. We run actual predictions on the TPU, but - # model export is run on the CPU. - if is_export_mode: - return True - - if self._predict_batch_size is None: - raise ValueError( - 'predict_batch_size in TPUEstimator constructor should not be ' - '`None` if .predict is running on TPU.') - if self.num_hosts > 1: - raise ValueError( - 'TPUEstimator.predict should be running on single host.') - - return False - - @property - def global_batch_size(self): - mode = self._assert_mode() - if mode == model_fn_lib.ModeKeys.TRAIN: - return self._train_batch_size - elif mode == model_fn_lib.ModeKeys.EVAL: - return self._eval_batch_size - elif mode == model_fn_lib.ModeKeys.PREDICT: - return self._predict_batch_size - else: - return None - - @property - def batch_size_for_input_fn(self): - """Returns the shard batch size for `input_fn`.""" - global_batch_size = self.global_batch_size - - if self.is_running_on_cpu(): - return global_batch_size - - # On TPU - if self.is_input_sharded_per_core(): - # We prohibit per core input sharding for the model parallelism case, - # therefore it is safe to use num_cores here. - return global_batch_size // self.num_cores - else: - return global_batch_size // self.num_hosts - - @property - def batch_size_for_model_fn(self): - """Returns the shard batch size for `model_fn`.""" - global_batch_size = self.global_batch_size - - if self.is_running_on_cpu(): - return global_batch_size - - return global_batch_size // self.num_shards - - @property - def master_job(self): - """Returns the job name to use to place TPU computations on. - - Returns: - A string containing the job name, or None if no job should be specified. - - Raises: - ValueError: If the user needs to specify a tpu_job_name, because we are - unable to infer the job name automatically, or if the user-specified job - names are inappropriate. - """ - run_config = self._config - # If the user specifies the tpu_job_name, use that. - if run_config.tpu_config.tpu_job_name: - return run_config.tpu_config.tpu_job_name - - # The tpu job is determined by the run_config. Right now, this method is - # required as tpu_config is not part of the RunConfig. - mode = self._assert_mode() - master = ( - run_config.evaluation_master - if mode == model_fn_lib.ModeKeys.EVAL else run_config.master) - if master in _LOCAL_MASTERS: - return None - - if (not run_config.session_config or - not run_config.session_config.cluster_def.job): - return _DEFAULT_JOB_NAME - cluster_def = run_config.session_config.cluster_def - job_names = set([job.name for job in cluster_def.job]) - if _DEFAULT_JOB_NAME in job_names: - # b/37868888 tracks allowing ClusterSpec propagation to reuse job names. - raise ValueError('Currently, tpu_worker is not an allowed job name.') - if len(job_names) == 1: - return cluster_def.job[0].name - if len(job_names) == 2: - if _DEFAULT_COORDINATOR_JOB_NAME in job_names: - job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME) - return job_names.pop() - # TODO(b/67716447): Include more sophisticated heuristics. - raise ValueError( - 'Could not infer TPU job name. Please specify a tpu_job_name as part ' - 'of your TPUConfig.') - - @property - def tpu_host_placement_function(self): - """Returns the TPU host place function.""" - master = self.master_job - - def _placement_function(_sentinal=None, core_id=None, host_id=None): # pylint: disable=invalid-name - assert _sentinal is None - if core_id is not None and host_id is not None: - raise RuntimeError( - 'core_id and host_id can have only one non-None value.') - - if master is None: - return '/replica:0/task:0/device:CPU:0' - else: - # This assumes that if using more than 8 shards, - # the job configuration varies 'task'. - if core_id is not None: - host_id = core_id / 8 - return '/job:%s/task:%d/device:CPU:0' % (master, host_id) - - return _placement_function - - @property - def tpu_device_placement_function(self): - """Returns the TPU device place function.""" - master = self.master_job - job_device = '' if master is None else ('/job:%s' % master) - - def _placement_function(i): - if self._device_assignment: - return self._device_assignment.tpu_device(replica=i, job=master) - else: - return '%s/task:%d/device:TPU:%d' % (job_device, i / 8, i % 8) - - return _placement_function - - @property - def tpu_ordinal_function(self): - """Returns the TPU ordinal fn.""" - - def _tpu_ordinal_function(index): - """Return the TPU ordinal associated with a shard. - - Required because the enqueue ops are placed on CPU. - - Args: - index: the shard index - - Returns: - The ordinal of the TPU device the shard's infeed should be placed on. - """ - if self._device_assignment: - return self._device_assignment.tpu_ordinal(replica=index) - else: - return index % 8 - - return _tpu_ordinal_function - - @property - def device_assignment(self): - return self._device_assignment - - class _SIGNAL(object): """Signal used to control the thread of infeed/outfeed. @@ -963,20 +682,23 @@ def generate_per_host_enqueue_ops_fn_for_host( if is_dataset: hooks.append(inputs.dataset_initializer_hook()) + # TODO(ylc): Refactoring the code to merge the tpu ordinal logic here and the + # _TPUContext.tpu_ordinal_function. We should either introduce another + # abstraction or a different helper method. def _tpu_ordinal_function_impl(shard_index_in_host): # We put both enqueue/dequeue op at tpu.core(0) in each replica. replica = ctx.device_assignment.lookup_replicas( host_id, (0, 0, 0))[shard_index_in_host] return ctx.device_assignment.tpu_ordinal(replica=replica) - if ctx.device_assignment: + if ctx.model_parallelism_enabled: tpu_ordinal_function = _tpu_ordinal_function_impl else: tpu_ordinal_function = None def enqueue_ops_fn(): with ops.device(device): - num_of_shards_per_host = ctx.num_of_shards_per_host + num_of_replicas_per_host = ctx.num_of_replicas_per_host # Convert user input to features and labels. If the user returns a # dataset, it is initialized and the features and labels extracted via # `dataset.iterator.get_next()` @@ -994,7 +716,7 @@ def generate_per_host_enqueue_ops_fn_for_host( tuple_shapes=[t.shape for t in unsharded_tensor_list], shard_dimensions=batch_axis) captured_infeed_queue.capture(infeed_queue) - infeed_queue.set_number_of_shards(num_of_shards_per_host) + infeed_queue.set_number_of_shards(num_of_replicas_per_host) per_host_enqueue_ops = ( infeed_queue.split_inputs_and_generate_enqueue_ops( unsharded_tensor_list, @@ -1665,7 +1387,7 @@ class _OutfeedHostCall(object): # constraint it such that we have at most one outfeed dequeue and enqueue # per replica. tpu_device_placement_fn = self._ctx.tpu_device_placement_function - for i in xrange(self._ctx.num_shards): + for i in xrange(self._ctx.num_replicas): with ops.device(tpu_device_placement_fn(i)): outfeed_tensors = tpu_ops.outfeed_dequeue_tuple( dtypes=tensor_dtypes, shapes=tensor_shapes) @@ -1890,12 +1612,12 @@ class TPUEstimator(estimator_lib.Estimator): train_batch_size: An int representing the global training batch size. TPUEstimator transforms this global batch size to a per-shard batch size, as params['batch_size'], when calling `input_fn` and `model_fn`. - Cannot be `None` if `use_tpu` is `True`. Must be divisible by - `config.tpu_config.num_shards`. + Cannot be `None` if `use_tpu` is `True`. + Must be divisible by total number of replicas. eval_batch_size: An int representing evaluation batch size. - Must be divisible by `config.tpu_config.num_shards`. + Must be divisible by total number of replicas. predict_batch_size: An int representing the prediction batch size. - Must be divisible by `config.tpu_config.num_shards`. + Must be divisible by total number of replicas. batch_axis: A python tuple of int values describing how each tensor produced by the Estimator `input_fn` should be split across the TPU compute shards. For example, if your input_fn produced (images, labels) @@ -1919,45 +1641,24 @@ class TPUEstimator(estimator_lib.Estimator): _RESERVED_PARAMS_KEYS, params)) if use_tpu: + # Perform some very basic validations. More validations will be found in + # _TPUContext. if train_batch_size is None: raise ValueError('`train_batch_size` cannot be `None`') - if not isinstance(train_batch_size, int): - raise ValueError('`train_batch_size` must be an int') - if train_batch_size < 1: - raise ValueError('`train_batch_size` must be positive') - - # The specified batch size is the batch size for the entire computation. - # The input_fn and model_fn are called per-shard, so we want to calculate - # the per-shard batch size and pass that. - if train_batch_size % config.tpu_config.num_shards != 0: - raise ValueError( - 'train batch size {} must be divisible by number of shards {}' - .format(train_batch_size, config.tpu_config.num_shards)) + util_lib.check_positive_integer(train_batch_size, 'train_batch_size') if (not config.tpu_config.per_host_input_for_training and config.tpu_config.computation_shape): raise ValueError( - 'Model parallelism only supports per host input for training.') + 'Model parallelism only supports per host input for training. ' + 'Please adjust TPURunconfig.per_host_input_for_training.') if eval_batch_size is not None: - if not isinstance(eval_batch_size, int): - raise ValueError('`eval_batch_size` must be an int') - if eval_batch_size < 1: - raise ValueError('`eval_batch_size` must be positive') - if eval_batch_size % config.tpu_config.num_shards != 0: - raise ValueError( - 'eval batch size {} must be divisible by number of shards {}' - .format(eval_batch_size, config.tpu_config.num_shards)) + util_lib.check_positive_integer(eval_batch_size, 'eval_batch_size') if predict_batch_size is not None: - if not isinstance(predict_batch_size, int): - raise ValueError('`predict_batch_size` must be an int') - if predict_batch_size < 1: - raise ValueError('`predict_batch_size` must be positive') - if predict_batch_size % config.tpu_config.num_shards != 0: - raise ValueError( - 'predict batch size {} must be divisible by number of shards {}' - .format(predict_batch_size, config.tpu_config.num_shards)) + util_lib.check_positive_integer(predict_batch_size, + 'predict_batch_size') # Verifies the model_fn signature according to Estimator framework. estimator_lib._verify_model_fn_args(model_fn, params) # pylint: disable=protected-access @@ -1976,35 +1677,12 @@ class TPUEstimator(estimator_lib.Estimator): self._iterations_per_training_loop = ( self._config.tpu_config.iterations_per_loop) - if use_tpu and self._config.tpu_config.computation_shape: - try: - with tf_session.Session( - self._config.master, - config=config_pb2.ConfigProto( - operation_timeout_in_ms=_PINGING_MASTER_TIMEOUT_IN_MS)) as sess: - logging.info('Initializing TPU system to fetch topology for model ' - 'parallelism.') - topology = sess.run(tpu.initialize_system()) - device_assignment = tpu_device_assignment( - topology, - computation_shape=self._config.tpu_config.computation_shape, - num_replicas=self._config.tpu_config.num_shards) - logging.info('computation_shape: %s', - str(self._config.tpu_config.computation_shape)) - logging.info('num_replicas: %d', self._config.tpu_config.num_shards) - logging.info('device_assignment.topology.device_coordinates: %s', - str(device_assignment.topology.device_coordinates)) - logging.info('device_assignment.core_assignment: %s', - str(device_assignment.core_assignment)) - except errors.DeadlineExceededError: - raise ValueError( - 'Fail to connect master (%s). Please double check %s is ' - 'correct.' % (self._config.master, self._config.master)) - else: - device_assignment = None # All properties passed to _TPUContext are immutable. - self._ctx = _TPUContext(self._config, train_batch_size, eval_batch_size, - predict_batch_size, use_tpu, device_assignment) + # pylint: disable=protected-access + self._ctx = tpu_context._get_tpu_context( + self._config, train_batch_size, + eval_batch_size, predict_batch_size, + use_tpu) def _create_global_step(self, graph): """Creates a global step suitable for TPUs. @@ -2052,21 +1730,6 @@ class TPUEstimator(estimator_lib.Estimator): util_lib.check_positive_integer(steps, 'Eval steps') - # TODO(ylc): Support evaluating with model parallelism in different cluster. - if ctx.device_assignment and (self._config.evaluation_master != - self._config.master): - raise ValueError( - 'In the model-parallel case, both training and evaluation must run ' - 'in the same cluster.') - - if self._config.tpu_config.num_shards > 8: - raise NotImplementedError( - 'TPU evaluation is only supported with one host.') - - if self._ctx._eval_batch_size is None: # pylint: disable=protected-access - raise ValueError('`eval_batch_size` cannot be `None`' - 'if evaluate() is called on TPU.') - return [ evaluation._StopAfterNEvalsHook( # pylint: disable=protected-access num_evals=steps), @@ -2320,7 +1983,7 @@ def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): (loss,) = tpu.shard( multi_tpu_eval_steps_on_single_shard, inputs=[], - num_shards=ctx.num_shards, + num_shards=ctx.num_replicas, outputs_from_all_shards=False, device_assignment=ctx.device_assignment) @@ -2344,7 +2007,7 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): (loss,) = tpu.shard( multi_tpu_train_steps_on_single_shard, inputs=[], - num_shards=ctx.num_shards, + num_shards=ctx.num_replicas, outputs_from_all_shards=False, device_assignment=ctx.device_assignment) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py new file mode 100644 index 0000000..e003313 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py @@ -0,0 +1,139 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =================================================================== +"""TPU system metadata and associated tooling.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import re + +from tensorflow.contrib.tpu.python.tpu import tpu +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session as session_lib +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.platform import tf_logging as logging + +_PINGING_MASTER_TIMEOUT_IN_MS = 60 * 1000 # 1 min +_RETRY_TIMES = 120 +_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000 # 5 mins + +_TPU_DEVICE_REG = re.compile(r'.*task:(\d+)/.*device:TPU:(\d+)$') + +# _TPUSystemMetadata is used by TPUEstimator to hold TPU configuration, +# including num_cores and num_hosts. +_TPUSystemMetadata = collections.namedtuple('_TPUSystemMetadata', [ + 'num_cores', + 'num_hosts', + 'num_of_cores_per_host', + 'topology', + 'devices', +]) + + +def _query_tpu_system_metadata(master_address, query_topology=False): + """Automatically detects the TPU system metadata in the system.""" + tpu_core_count = 0 + devices = [] + device_dict = collections.defaultdict(list) + + retry_count = 1 + while True: + logging.info('Querying Tensorflow master (%s) for TPU system metadata.', + master_address) + try: + with ops.Graph().as_default(): + with session_lib.Session( + master_address, + config=config_pb2.ConfigProto( + operation_timeout_in_ms=_PINGING_MASTER_TIMEOUT_IN_MS)) as sess: + devices = sess.list_devices() + for device in devices: + match = _TPU_DEVICE_REG.match(device.name) + if match: + host_id = match.group(1) + core_id = match.group(2) + device_dict[host_id].append(core_id) + tpu_core_count += 1 + break + except errors.DeadlineExceededError: + msg = ('Fail to connect Tensorflow master. It could be the TPU worker is ' + 'not ready (still under scheduling) or Tensorflow ' + 'master address is correct: got (%s).' % + (master_address)) + + # TODO(xiejw): For local or grpc master we might not need retry logic + # here. + if retry_count <= _RETRY_TIMES: + logging.warning('%s', msg) + logging.warning('Retrying (%d/%d).', retry_count, _RETRY_TIMES) + retry_count += 1 + else: + raise ValueError(msg) + + num_of_cores_per_host = 0 + if tpu_core_count: + num_cores_per_host_set = set( + [len(core_ids) for core_ids in device_dict.values()]) + if len(num_cores_per_host_set) != 1: + raise RuntimeError( + 'TPU cores on each host is not same. This should not happen!. ' + 'devices: {}'.format(devices)) + num_of_cores_per_host = num_cores_per_host_set.pop() + + topology = None + if query_topology: + if not tpu_core_count: + raise RuntimeError( + 'Cannot find any TPU cores in the system (master address {}). ' + 'This usually means the master address is incorrect or the ' + 'TPU worker has some problems. Available devices: {}'.format( + master_address, devices)) + + topology = _obtain_topology(master_address) + + metadata = _TPUSystemMetadata( + num_cores=tpu_core_count, + num_hosts=len(device_dict), + num_of_cores_per_host=num_of_cores_per_host, + topology=topology, + devices=devices) + + msg = 'Found TPU system %s' if tpu_core_count else 'Failed to find TPU: %s' + logging.info(msg, metadata) + return metadata + + +def _obtain_topology(master_address): + try: + logging.info('Initializing TPU system (master: %s) to fetch topology ' + 'for model parallelism. This might take a while.', + master_address) + with ops.Graph().as_default(): + session_config = config_pb2.ConfigProto( + operation_timeout_in_ms=_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS) + with session_lib.Session( + master_address, config=session_config) as sess: + topology = sess.run(tpu.initialize_system()) + return topology + except errors.DeadlineExceededError: + raise ValueError( + 'Fail to initialize TPU system with master (%s). ' + 'Please double check the TPU system is functional.' % ( + master_address)) + + -- 2.7.4