TPU: Implement 3rd gen input pipeline config.
authorBrennan Saeta <saeta@google.com>
Thu, 29 Mar 2018 00:54:01 +0000 (17:54 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 29 Mar 2018 00:56:45 +0000 (17:56 -0700)
In this new configuration, we are able to drive a Cloud TPU at full device performance, and achieve over 3k images/sec on ResNet-50. The previous bottleneck was the un-pipeline-able split that occurred after the iterator.get_next() call. This split (when not splitting on the batch-major dimension) caused the training job to be single-threaded-CPU-bottlenecked, resulting in a performance of only ~2650 images/sec on ResNet-50.

This latest input pipeline configuration requires the use of datasets. By requiring datasets, we gain the ability to call get_next() num_replicas times per host, and avoid the expensive split op. (Note: this also opens up potential future avenues for further optimization.) Despite this, we retain a lot of nice usability properties that per_host_v1 (aka input pipeline config v2) gave us.

PiperOrigin-RevId: 190865741

tensorflow/contrib/tpu/python/tpu/tpu_config.py
tensorflow/contrib/tpu/python/tpu/tpu_context.py
tensorflow/contrib/tpu/python/tpu/tpu_estimator.py

index 38b5ea2..cc1a7fd 100644 (file)
@@ -35,10 +35,16 @@ _TF_CONFIG_ENV = run_config_lib._TF_CONFIG_ENV
 _SERVICE_KEY = run_config_lib._SERVICE_KEY
 _TPU_WORKER_JOB_NAME = 'tpu_worker_job_name'
 _NUM_CORES_PER_HOST = 8
-
 # pylint: enable=protected-access
 
 
+class InputPipelineConfig(object):
+  r"""Please see the definition of these values in TPUConfig."""
+  PER_SHARD_V1 = 1
+  PER_HOST_V1 = 2
+  PER_HOST_V2 = 3
+
+
 # TODO(b/72511246) Provide a simplified api to configure model parallelism.
 class TPUConfig(
     collections.namedtuple('TPUConfig', [
@@ -68,13 +74,16 @@ class TPUConfig(
       partitioned across 4 cores which span two cores in both x and y
       coordinates.  Please refer to @{tf.contrib.tpu.Topology} for the
       geometry of a TPU mesh.
-    per_host_input_for_training: If `True`, `input_fn` is invoked Per-Host
-      rather than Per-Core. With Per-Host input pipeline deployment, `input_fn`
-      is invoked once on each host. With Per-Core input pipeline deployment, it
-      is invoked once for each core. To be precise, with a global batch size
-      `train_batch_size` in `TPUEstimator` constructor, the batch size for each
-      shard is `train_batch_size` // #hosts. With Per-Core input pipeline
-      deployment, the shard batch size is `train_batch_size` // #cores.
+    per_host_input_for_training: If `True`, `PER_HOST_V1`, or `PER_HOST_V2`,
+      `input_fn` is invoked per-host rather than per-core. With per-host input
+      pipeline configuration, `input_fn` is invoked once on each host. With the
+      per-core input pipeline configuration, it is invoked once for each core.
+      With a global batch size `train_batch_size` in `TPUEstimator` constructor,
+      the batch size for each shard is `train_batch_size` // #hosts in the
+      `True` or `PER_HOST_V1` mode. In `PER_HOST_V2` mode, it is
+      `train_batch_size` // #cores. With the per-core input pipeline
+      configuration, the shard batch size is also `train_batch_size` // #cores.
+      Note: per_host_input_for_training==PER_SHARD_V1 only supports mode.TRAIN.
     tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred
       within TPUEstimator, however when using ClusterSpec propagation in more
       esoteric cluster configurations, you may need to specify the job name as a
@@ -117,6 +126,13 @@ class TPUConfig(
         raise ValueError('computation_shape elements can only be 1 or 2; got '
                          'computation_shape={}'.format(computation_shape))
 
+    # per_host_input_for_training may be True, False, or integer in [1..3].
+    # Map legacy values (True, False) to numeric values.
+    if per_host_input_for_training is False:
+      per_host_input_for_training = InputPipelineConfig.PER_SHARD_V1
+    elif per_host_input_for_training is True:
+      per_host_input_for_training = InputPipelineConfig.PER_HOST_V1
+
     # Check initial_infeed_sleep_secs.
     if initial_infeed_sleep_secs:
       util_lib.check_positive_integer(initial_infeed_sleep_secs,
index 3bac2db..fbc1173 100644 (file)
@@ -24,6 +24,7 @@ import copy
 import numpy as np
 
 from tensorflow.contrib.tpu.python.tpu import device_assignment  as tpu_device_assignment
+from tensorflow.contrib.tpu.python.tpu import tpu_config
 from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
 from tensorflow.python.estimator import model_fn as model_fn_lib
 from tensorflow.python.platform import tf_logging as logging
@@ -205,7 +206,13 @@ class _TPUContext(object):
     """Return true if input_fn is invoked per-core (other than per-host)."""
     mode = self._assert_mode()
     return (mode == model_fn_lib.ModeKeys.TRAIN and
-            not self._config.tpu_config.per_host_input_for_training)
+            (self._config.tpu_config.per_host_input_for_training is
+             tpu_config.InputPipelineConfig.PER_SHARD_V1))
+
+  def is_input_per_host_with_iterators(self):
+    """Return true if input_fn should be run in the per-host v2 config."""
+    return (self._config.tpu_config.per_host_input_for_training is
+            tpu_config.InputPipelineConfig.PER_HOST_V2)
 
   def is_running_on_cpu(self, is_export_mode=False):
     """Determines whether the input_fn and model_fn should be invoked on CPU.
@@ -271,7 +278,8 @@ class _TPUContext(object):
       return global_batch_size
 
     # On TPU
-    if self.is_input_sharded_per_core():
+    if self.is_input_sharded_per_core() or (
+        self.is_input_per_host_with_iterators()):
       # We prohibit per core input sharding for the model parallelism case,
       # therefore it is safe to use num_cores here.
       return global_batch_size // self.num_cores
index 152f8c8..fa56708 100644 (file)
@@ -740,6 +740,61 @@ def generate_per_host_enqueue_ops_fn_for_host(
   return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset
 
 
+def generate_per_host_v2_enqueue_ops_fn_for_host(
+    ctx, input_fn, inputs_structure_recorder, device, host_id):
+  """Generates infeed enqueue ops for per-host input_fn on a single host."""
+  del host_id  # unused
+  captured_infeed_queue = _CapturedObject()
+  hooks = []
+
+  with ops.device(device):
+    inputs = _Inputs.from_input_fn(input_fn())
+
+    is_dataset = inputs.is_dataset
+    if not is_dataset:
+      raise TypeError('`input_fn` must return a `Dataset` for the PER_HOST_V2 '
+                      'input pipeline configuration.')
+    if ctx.mode == model_fn_lib.ModeKeys.PREDICT:
+      # TODO(b/XXX): Add predict support for PER_HOST_V2
+      raise TypeError('Most PREDICT not yet supported in PER_HOST_V2 mode.')
+
+    hooks.append(inputs.dataset_initializer_hook())
+
+  def enqueue_ops_fn():
+    """Generates the per_host enqueue ops."""
+    control_deps = []
+    per_host_sharded_inputs = []
+    num_replicas_per_host = ctx.num_of_replicas_per_host
+    with ops.device(device):
+      if not inputs.is_dataset:
+        raise TypeError('`input_fn` must return a `Dataset` for this mode.')
+      for _ in range(num_replicas_per_host):
+        # Use control dependencies to ensure a deterministic ordering.
+        with ops.control_dependencies(control_deps):
+          features, labels = inputs.features_and_labels()  # Calls get_next()
+
+        inputs_structure_recorder.validate_and_record_structure(
+            features, labels)
+        flattened_inputs = (
+            inputs_structure_recorder.flatten_features_and_labels(
+                features, labels))
+
+        control_deps.extend(flattened_inputs)
+        per_host_sharded_inputs.append(flattened_inputs)
+
+    infeed_queue = tpu_feed.InfeedQueue(
+        number_of_tuple_elements=len(per_host_sharded_inputs[0]))
+    captured_infeed_queue.capture(infeed_queue)
+    infeed_queue.set_configuration_from_sharded_input_tensors(
+        per_host_sharded_inputs)
+
+    per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
+        per_host_sharded_inputs, tpu_ordinal_function=ctx.tpu_ordinal_function)
+    return per_host_enqueue_ops
+
+  return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset
+
+
 class _InputPipeline(object):
   """`_InputPipeline` handles invoking `input_fn` and piping to infeed queue.
 
@@ -975,10 +1030,17 @@ class _InputPipeline(object):
         host_device = tpu_host_placement_fn(host_id=host_id)
         with ops.device(host_device):
           with ops.name_scope('input_pipeline_task%d' % (host_id)):
-            enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = (
-                generate_per_host_enqueue_ops_fn_for_host(
-                    self._ctx, self._input_fn, self._inputs_structure_recorder,
-                    self._batch_axis, host_device, host_id))
+            if self._ctx.is_input_per_host_with_iterators():
+              enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = (
+                  generate_per_host_v2_enqueue_ops_fn_for_host(
+                      self._ctx, self._input_fn,
+                      self._inputs_structure_recorder, host_device, host_id))
+            else:
+              enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = (
+                  generate_per_host_enqueue_ops_fn_for_host(
+                      self._ctx, self._input_fn,
+                      self._inputs_structure_recorder, self._batch_axis,
+                      host_device, host_id))
             all_hooks.extend(hooks)
 
             # NOTE(xiejw): We dispatch here based on the return type of the
@@ -1724,7 +1786,7 @@ class TPUEstimator(estimator_lib.Estimator):
         labels to match up with the corresponding images. If None is supplied,
         and per_host_input_for_training is True, batches will be sharded based
         on the major dimension. If tpu_config.per_host_input_for_training is
-        False, batch_axis is ignored.
+        False or `PER_HOST_V2`, batch_axis is ignored.
 
     Raises:
       ValueError: `params` has reserved keys already.
@@ -1744,7 +1806,8 @@ class TPUEstimator(estimator_lib.Estimator):
         raise ValueError('`train_batch_size` cannot be `None`')
       util_lib.check_positive_integer(train_batch_size, 'train_batch_size')
 
-      if (not config.tpu_config.per_host_input_for_training and
+      if (config.tpu_config.per_host_input_for_training is
+          tpu_config.InputPipelineConfig.PER_SHARD_V1 and
           config.tpu_config.computation_shape):
         raise ValueError(
             'Model parallelism only supports per host input for training. '
@@ -2362,6 +2425,10 @@ class _Inputs(object):
   def features_and_labels(self):
     """Gets `features` and `labels`."""
     if self.is_dataset:
+      if self._iterator is None:
+        raise RuntimeError('Internal error: Must call dataset_initializer_hook '
+                           'before calling features_and_labels(). Please file '
+                           'a bug!')
       return _Inputs._parse_inputs(self._iterator.get_next())
 
     return (self._features, self._labels)