From: Brennan Saeta Date: Thu, 29 Mar 2018 00:54:01 +0000 (-0700) Subject: TPU: Implement 3rd gen input pipeline config. X-Git-Tag: tflite-v0.1.7~67^2^2~25 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=628552228c76d2ee7f2eef4d56175a89941e3e1d;p=platform%2Fupstream%2Ftensorflow.git TPU: Implement 3rd gen input pipeline config. In this new configuration, we are able to drive a Cloud TPU at full device performance, and achieve over 3k images/sec on ResNet-50. The previous bottleneck was the un-pipeline-able split that occurred after the iterator.get_next() call. This split (when not splitting on the batch-major dimension) caused the training job to be single-threaded-CPU-bottlenecked, resulting in a performance of only ~2650 images/sec on ResNet-50. This latest input pipeline configuration requires the use of datasets. By requiring datasets, we gain the ability to call get_next() num_replicas times per host, and avoid the expensive split op. (Note: this also opens up potential future avenues for further optimization.) Despite this, we retain a lot of nice usability properties that per_host_v1 (aka input pipeline config v2) gave us. PiperOrigin-RevId: 190865741 --- diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py index 38b5ea2..cc1a7fd 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py @@ -35,10 +35,16 @@ _TF_CONFIG_ENV = run_config_lib._TF_CONFIG_ENV _SERVICE_KEY = run_config_lib._SERVICE_KEY _TPU_WORKER_JOB_NAME = 'tpu_worker_job_name' _NUM_CORES_PER_HOST = 8 - # pylint: enable=protected-access +class InputPipelineConfig(object): + r"""Please see the definition of these values in TPUConfig.""" + PER_SHARD_V1 = 1 + PER_HOST_V1 = 2 + PER_HOST_V2 = 3 + + # TODO(b/72511246) Provide a simplified api to configure model parallelism. class TPUConfig( collections.namedtuple('TPUConfig', [ @@ -68,13 +74,16 @@ class TPUConfig( partitioned across 4 cores which span two cores in both x and y coordinates. Please refer to @{tf.contrib.tpu.Topology} for the geometry of a TPU mesh. - per_host_input_for_training: If `True`, `input_fn` is invoked Per-Host - rather than Per-Core. With Per-Host input pipeline deployment, `input_fn` - is invoked once on each host. With Per-Core input pipeline deployment, it - is invoked once for each core. To be precise, with a global batch size - `train_batch_size` in `TPUEstimator` constructor, the batch size for each - shard is `train_batch_size` // #hosts. With Per-Core input pipeline - deployment, the shard batch size is `train_batch_size` // #cores. + per_host_input_for_training: If `True`, `PER_HOST_V1`, or `PER_HOST_V2`, + `input_fn` is invoked per-host rather than per-core. With per-host input + pipeline configuration, `input_fn` is invoked once on each host. With the + per-core input pipeline configuration, it is invoked once for each core. + With a global batch size `train_batch_size` in `TPUEstimator` constructor, + the batch size for each shard is `train_batch_size` // #hosts in the + `True` or `PER_HOST_V1` mode. In `PER_HOST_V2` mode, it is + `train_batch_size` // #cores. With the per-core input pipeline + configuration, the shard batch size is also `train_batch_size` // #cores. + Note: per_host_input_for_training==PER_SHARD_V1 only supports mode.TRAIN. tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred within TPUEstimator, however when using ClusterSpec propagation in more esoteric cluster configurations, you may need to specify the job name as a @@ -117,6 +126,13 @@ class TPUConfig( raise ValueError('computation_shape elements can only be 1 or 2; got ' 'computation_shape={}'.format(computation_shape)) + # per_host_input_for_training may be True, False, or integer in [1..3]. + # Map legacy values (True, False) to numeric values. + if per_host_input_for_training is False: + per_host_input_for_training = InputPipelineConfig.PER_SHARD_V1 + elif per_host_input_for_training is True: + per_host_input_for_training = InputPipelineConfig.PER_HOST_V1 + # Check initial_infeed_sleep_secs. if initial_infeed_sleep_secs: util_lib.check_positive_integer(initial_infeed_sleep_secs, diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py index 3bac2db..fbc1173 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py @@ -24,6 +24,7 @@ import copy import numpy as np from tensorflow.contrib.tpu.python.tpu import device_assignment as tpu_device_assignment +from tensorflow.contrib.tpu.python.tpu import tpu_config from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.platform import tf_logging as logging @@ -205,7 +206,13 @@ class _TPUContext(object): """Return true if input_fn is invoked per-core (other than per-host).""" mode = self._assert_mode() return (mode == model_fn_lib.ModeKeys.TRAIN and - not self._config.tpu_config.per_host_input_for_training) + (self._config.tpu_config.per_host_input_for_training is + tpu_config.InputPipelineConfig.PER_SHARD_V1)) + + def is_input_per_host_with_iterators(self): + """Return true if input_fn should be run in the per-host v2 config.""" + return (self._config.tpu_config.per_host_input_for_training is + tpu_config.InputPipelineConfig.PER_HOST_V2) def is_running_on_cpu(self, is_export_mode=False): """Determines whether the input_fn and model_fn should be invoked on CPU. @@ -271,7 +278,8 @@ class _TPUContext(object): return global_batch_size # On TPU - if self.is_input_sharded_per_core(): + if self.is_input_sharded_per_core() or ( + self.is_input_per_host_with_iterators()): # We prohibit per core input sharding for the model parallelism case, # therefore it is safe to use num_cores here. return global_batch_size // self.num_cores diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 152f8c8..fa56708 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -740,6 +740,61 @@ def generate_per_host_enqueue_ops_fn_for_host( return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset +def generate_per_host_v2_enqueue_ops_fn_for_host( + ctx, input_fn, inputs_structure_recorder, device, host_id): + """Generates infeed enqueue ops for per-host input_fn on a single host.""" + del host_id # unused + captured_infeed_queue = _CapturedObject() + hooks = [] + + with ops.device(device): + inputs = _Inputs.from_input_fn(input_fn()) + + is_dataset = inputs.is_dataset + if not is_dataset: + raise TypeError('`input_fn` must return a `Dataset` for the PER_HOST_V2 ' + 'input pipeline configuration.') + if ctx.mode == model_fn_lib.ModeKeys.PREDICT: + # TODO(b/XXX): Add predict support for PER_HOST_V2 + raise TypeError('Most PREDICT not yet supported in PER_HOST_V2 mode.') + + hooks.append(inputs.dataset_initializer_hook()) + + def enqueue_ops_fn(): + """Generates the per_host enqueue ops.""" + control_deps = [] + per_host_sharded_inputs = [] + num_replicas_per_host = ctx.num_of_replicas_per_host + with ops.device(device): + if not inputs.is_dataset: + raise TypeError('`input_fn` must return a `Dataset` for this mode.') + for _ in range(num_replicas_per_host): + # Use control dependencies to ensure a deterministic ordering. + with ops.control_dependencies(control_deps): + features, labels = inputs.features_and_labels() # Calls get_next() + + inputs_structure_recorder.validate_and_record_structure( + features, labels) + flattened_inputs = ( + inputs_structure_recorder.flatten_features_and_labels( + features, labels)) + + control_deps.extend(flattened_inputs) + per_host_sharded_inputs.append(flattened_inputs) + + infeed_queue = tpu_feed.InfeedQueue( + number_of_tuple_elements=len(per_host_sharded_inputs[0])) + captured_infeed_queue.capture(infeed_queue) + infeed_queue.set_configuration_from_sharded_input_tensors( + per_host_sharded_inputs) + + per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( + per_host_sharded_inputs, tpu_ordinal_function=ctx.tpu_ordinal_function) + return per_host_enqueue_ops + + return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset + + class _InputPipeline(object): """`_InputPipeline` handles invoking `input_fn` and piping to infeed queue. @@ -975,10 +1030,17 @@ class _InputPipeline(object): host_device = tpu_host_placement_fn(host_id=host_id) with ops.device(host_device): with ops.name_scope('input_pipeline_task%d' % (host_id)): - enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = ( - generate_per_host_enqueue_ops_fn_for_host( - self._ctx, self._input_fn, self._inputs_structure_recorder, - self._batch_axis, host_device, host_id)) + if self._ctx.is_input_per_host_with_iterators(): + enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = ( + generate_per_host_v2_enqueue_ops_fn_for_host( + self._ctx, self._input_fn, + self._inputs_structure_recorder, host_device, host_id)) + else: + enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = ( + generate_per_host_enqueue_ops_fn_for_host( + self._ctx, self._input_fn, + self._inputs_structure_recorder, self._batch_axis, + host_device, host_id)) all_hooks.extend(hooks) # NOTE(xiejw): We dispatch here based on the return type of the @@ -1724,7 +1786,7 @@ class TPUEstimator(estimator_lib.Estimator): labels to match up with the corresponding images. If None is supplied, and per_host_input_for_training is True, batches will be sharded based on the major dimension. If tpu_config.per_host_input_for_training is - False, batch_axis is ignored. + False or `PER_HOST_V2`, batch_axis is ignored. Raises: ValueError: `params` has reserved keys already. @@ -1744,7 +1806,8 @@ class TPUEstimator(estimator_lib.Estimator): raise ValueError('`train_batch_size` cannot be `None`') util_lib.check_positive_integer(train_batch_size, 'train_batch_size') - if (not config.tpu_config.per_host_input_for_training and + if (config.tpu_config.per_host_input_for_training is + tpu_config.InputPipelineConfig.PER_SHARD_V1 and config.tpu_config.computation_shape): raise ValueError( 'Model parallelism only supports per host input for training. ' @@ -2362,6 +2425,10 @@ class _Inputs(object): def features_and_labels(self): """Gets `features` and `labels`.""" if self.is_dataset: + if self._iterator is None: + raise RuntimeError('Internal error: Must call dataset_initializer_hook ' + 'before calling features_and_labels(). Please file ' + 'a bug!') return _Inputs._parse_inputs(self._iterator.get_next()) return (self._features, self._labels)