_SERVICE_KEY = run_config_lib._SERVICE_KEY
_TPU_WORKER_JOB_NAME = 'tpu_worker_job_name'
_NUM_CORES_PER_HOST = 8
-
# pylint: enable=protected-access
+class InputPipelineConfig(object):
+ r"""Please see the definition of these values in TPUConfig."""
+ PER_SHARD_V1 = 1
+ PER_HOST_V1 = 2
+ PER_HOST_V2 = 3
+
+
# TODO(b/72511246) Provide a simplified api to configure model parallelism.
class TPUConfig(
collections.namedtuple('TPUConfig', [
partitioned across 4 cores which span two cores in both x and y
coordinates. Please refer to @{tf.contrib.tpu.Topology} for the
geometry of a TPU mesh.
- per_host_input_for_training: If `True`, `input_fn` is invoked Per-Host
- rather than Per-Core. With Per-Host input pipeline deployment, `input_fn`
- is invoked once on each host. With Per-Core input pipeline deployment, it
- is invoked once for each core. To be precise, with a global batch size
- `train_batch_size` in `TPUEstimator` constructor, the batch size for each
- shard is `train_batch_size` // #hosts. With Per-Core input pipeline
- deployment, the shard batch size is `train_batch_size` // #cores.
+ per_host_input_for_training: If `True`, `PER_HOST_V1`, or `PER_HOST_V2`,
+ `input_fn` is invoked per-host rather than per-core. With per-host input
+ pipeline configuration, `input_fn` is invoked once on each host. With the
+ per-core input pipeline configuration, it is invoked once for each core.
+ With a global batch size `train_batch_size` in `TPUEstimator` constructor,
+ the batch size for each shard is `train_batch_size` // #hosts in the
+ `True` or `PER_HOST_V1` mode. In `PER_HOST_V2` mode, it is
+ `train_batch_size` // #cores. With the per-core input pipeline
+ configuration, the shard batch size is also `train_batch_size` // #cores.
+ Note: per_host_input_for_training==PER_SHARD_V1 only supports mode.TRAIN.
tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred
within TPUEstimator, however when using ClusterSpec propagation in more
esoteric cluster configurations, you may need to specify the job name as a
raise ValueError('computation_shape elements can only be 1 or 2; got '
'computation_shape={}'.format(computation_shape))
+ # per_host_input_for_training may be True, False, or integer in [1..3].
+ # Map legacy values (True, False) to numeric values.
+ if per_host_input_for_training is False:
+ per_host_input_for_training = InputPipelineConfig.PER_SHARD_V1
+ elif per_host_input_for_training is True:
+ per_host_input_for_training = InputPipelineConfig.PER_HOST_V1
+
# Check initial_infeed_sleep_secs.
if initial_infeed_sleep_secs:
util_lib.check_positive_integer(initial_infeed_sleep_secs,
import numpy as np
from tensorflow.contrib.tpu.python.tpu import device_assignment as tpu_device_assignment
+from tensorflow.contrib.tpu.python.tpu import tpu_config
from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.platform import tf_logging as logging
"""Return true if input_fn is invoked per-core (other than per-host)."""
mode = self._assert_mode()
return (mode == model_fn_lib.ModeKeys.TRAIN and
- not self._config.tpu_config.per_host_input_for_training)
+ (self._config.tpu_config.per_host_input_for_training is
+ tpu_config.InputPipelineConfig.PER_SHARD_V1))
+
+ def is_input_per_host_with_iterators(self):
+ """Return true if input_fn should be run in the per-host v2 config."""
+ return (self._config.tpu_config.per_host_input_for_training is
+ tpu_config.InputPipelineConfig.PER_HOST_V2)
def is_running_on_cpu(self, is_export_mode=False):
"""Determines whether the input_fn and model_fn should be invoked on CPU.
return global_batch_size
# On TPU
- if self.is_input_sharded_per_core():
+ if self.is_input_sharded_per_core() or (
+ self.is_input_per_host_with_iterators()):
# We prohibit per core input sharding for the model parallelism case,
# therefore it is safe to use num_cores here.
return global_batch_size // self.num_cores
return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset
+def generate_per_host_v2_enqueue_ops_fn_for_host(
+ ctx, input_fn, inputs_structure_recorder, device, host_id):
+ """Generates infeed enqueue ops for per-host input_fn on a single host."""
+ del host_id # unused
+ captured_infeed_queue = _CapturedObject()
+ hooks = []
+
+ with ops.device(device):
+ inputs = _Inputs.from_input_fn(input_fn())
+
+ is_dataset = inputs.is_dataset
+ if not is_dataset:
+ raise TypeError('`input_fn` must return a `Dataset` for the PER_HOST_V2 '
+ 'input pipeline configuration.')
+ if ctx.mode == model_fn_lib.ModeKeys.PREDICT:
+ # TODO(b/XXX): Add predict support for PER_HOST_V2
+ raise TypeError('Most PREDICT not yet supported in PER_HOST_V2 mode.')
+
+ hooks.append(inputs.dataset_initializer_hook())
+
+ def enqueue_ops_fn():
+ """Generates the per_host enqueue ops."""
+ control_deps = []
+ per_host_sharded_inputs = []
+ num_replicas_per_host = ctx.num_of_replicas_per_host
+ with ops.device(device):
+ if not inputs.is_dataset:
+ raise TypeError('`input_fn` must return a `Dataset` for this mode.')
+ for _ in range(num_replicas_per_host):
+ # Use control dependencies to ensure a deterministic ordering.
+ with ops.control_dependencies(control_deps):
+ features, labels = inputs.features_and_labels() # Calls get_next()
+
+ inputs_structure_recorder.validate_and_record_structure(
+ features, labels)
+ flattened_inputs = (
+ inputs_structure_recorder.flatten_features_and_labels(
+ features, labels))
+
+ control_deps.extend(flattened_inputs)
+ per_host_sharded_inputs.append(flattened_inputs)
+
+ infeed_queue = tpu_feed.InfeedQueue(
+ number_of_tuple_elements=len(per_host_sharded_inputs[0]))
+ captured_infeed_queue.capture(infeed_queue)
+ infeed_queue.set_configuration_from_sharded_input_tensors(
+ per_host_sharded_inputs)
+
+ per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
+ per_host_sharded_inputs, tpu_ordinal_function=ctx.tpu_ordinal_function)
+ return per_host_enqueue_ops
+
+ return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset
+
+
class _InputPipeline(object):
"""`_InputPipeline` handles invoking `input_fn` and piping to infeed queue.
host_device = tpu_host_placement_fn(host_id=host_id)
with ops.device(host_device):
with ops.name_scope('input_pipeline_task%d' % (host_id)):
- enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = (
- generate_per_host_enqueue_ops_fn_for_host(
- self._ctx, self._input_fn, self._inputs_structure_recorder,
- self._batch_axis, host_device, host_id))
+ if self._ctx.is_input_per_host_with_iterators():
+ enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = (
+ generate_per_host_v2_enqueue_ops_fn_for_host(
+ self._ctx, self._input_fn,
+ self._inputs_structure_recorder, host_device, host_id))
+ else:
+ enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = (
+ generate_per_host_enqueue_ops_fn_for_host(
+ self._ctx, self._input_fn,
+ self._inputs_structure_recorder, self._batch_axis,
+ host_device, host_id))
all_hooks.extend(hooks)
# NOTE(xiejw): We dispatch here based on the return type of the
labels to match up with the corresponding images. If None is supplied,
and per_host_input_for_training is True, batches will be sharded based
on the major dimension. If tpu_config.per_host_input_for_training is
- False, batch_axis is ignored.
+ False or `PER_HOST_V2`, batch_axis is ignored.
Raises:
ValueError: `params` has reserved keys already.
raise ValueError('`train_batch_size` cannot be `None`')
util_lib.check_positive_integer(train_batch_size, 'train_batch_size')
- if (not config.tpu_config.per_host_input_for_training and
+ if (config.tpu_config.per_host_input_for_training is
+ tpu_config.InputPipelineConfig.PER_SHARD_V1 and
config.tpu_config.computation_shape):
raise ValueError(
'Model parallelism only supports per host input for training. '
def features_and_labels(self):
"""Gets `features` and `labels`."""
if self.is_dataset:
+ if self._iterator is None:
+ raise RuntimeError('Internal error: Must call dataset_initializer_hook '
+ 'before calling features_and_labels(). Please file '
+ 'a bug!')
return _Inputs._parse_inputs(self._iterator.get_next())
return (self._features, self._labels)