'loss',
'train_op',
'eval_metrics',
- 'export_outputs'])):
+ 'export_outputs',
+ 'scaffold_fn'])):
"""Ops and objects returned from a `model_fn` and passed to `TPUEstimator`.
See `EstimatorSpec` for `mode`, 'predictions, 'loss', 'train_op', and
'export_outputs`.
- TPU evaluation expects a slightly different signature from the
+ For evaluation, `eval_metrics `is a tuple of `metric_fn` and `tensors`, where
+ `metric_fn` runs on CPU to generate metrics and `tensors` represents the
+ `Tensor`s transferred from TPU system to CPU host and passed to `metric_fn`.
+ To be precise, TPU evaluation expects a slightly different signature from the
${tf.estimator.Estimator}. While `EstimatorSpec.eval_metric_ops` expects a
dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`.
The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The
to the `metric_fn` if `tensors` is list or keyword arguments if `tensors` is
dict. `metric_fn` takes the `tensors` and returns a dict from metric string
name to the result of calling a metric function, namely a `(metric_tensor,
- update_op)` tuple.
+ update_op)` tuple. See `TPUEstimator` for MNIST example how to specify the
+ `eval_metrics`.
- See `TPUEstimator` for MNIST example how to specify the `eval_metrics`.
+ `scaffold_fn` is a function running on CPU to generate the `Scaffold`. This
+ function should not capture any Tensors in `model_fn`.
"""
def __new__(cls,
loss=None,
train_op=None,
eval_metrics=None,
- export_outputs=None):
+ export_outputs=None,
+ scaffold_fn=None):
"""Creates a validated `TPUEstimatorSpec` instance."""
if eval_metrics is not None:
_EvalMetrics.validate(eval_metrics)
loss=loss,
train_op=train_op,
eval_metrics=eval_metrics,
- export_outputs=export_outputs)
+ export_outputs=export_outputs,
+ scaffold_fn=scaffold_fn)
def as_estimator_spec(self):
"""Creates an equivalent `EstimatorSpec` used by CPU train/eval."""
eval_metric_ops = _EvalMetrics.to_metric_metric_ops_for_cpu(
self.eval_metrics)
+ scaffold = self.scaffold_fn() if self.scaffold_fn else None
return model_fn_lib.EstimatorSpec(mode=self.mode,
predictions=self.predictions,
loss=self.loss,
train_op=self.train_op,
eval_metric_ops=eval_metric_ops,
- export_outputs=self.export_outputs)
+ export_outputs=self.export_outputs,
+ scaffold=scaffold)
class _InfeedOutfeedThreadBaseController(object):
def generate_per_core_enqueue_ops_fn_for_host(
ctx, input_fn, inputs_structure_recorder):
"""Generates infeed enqueue ops for per-core input_fn on a single host."""
- infeed_queue_holder = {'instance': None}
+ captured_infeed_queue = _CapturedObject()
def enqueue_ops_fn():
"""A fn returns enqueue_ops."""
infeed_queue = tpu_feed.InfeedQueue(
number_of_tuple_elements=len(per_host_sharded_inputs[0]))
- infeed_queue_holder['instance'] = infeed_queue
+ captured_infeed_queue.capture(infeed_queue)
infeed_queue.set_configuration_from_sharded_input_tensors(
per_host_sharded_inputs)
per_host_sharded_inputs,
tpu_ordinal_function=ctx.tpu_ordinal_function)
return per_host_enqueue_ops
- return enqueue_ops_fn, (lambda: infeed_queue_holder['instance'])
+ return enqueue_ops_fn, captured_infeed_queue
def generate_per_host_enqueue_ops_fn_for_host(
ctx, input_fn, inputs_structure_recorder, batch_axis, device):
"""Generates infeed enqueue ops for per-host input_fn on a single host."""
- infeed_queue_holder = {'instance': None}
+ captured_infeed_queue = _CapturedObject()
def enqueue_ops_fn():
with ops.device(device):
tuple_types=[t.dtype for t in unsharded_tensor_list],
tuple_shapes=[t.shape for t in unsharded_tensor_list],
shard_dimensions=batch_axis)
- infeed_queue_holder['instance'] = infeed_queue
+ captured_infeed_queue.capture(infeed_queue)
infeed_queue.set_number_of_shards(num_cores_per_host)
per_host_enqueue_ops = (
unsharded_tensor_list,
placement_function=lambda x: device))
return per_host_enqueue_ops
- return enqueue_ops_fn, (lambda: infeed_queue_holder['instance'])
+ return enqueue_ops_fn, captured_infeed_queue
class _InputPipeline(object):
host_device = tpu_host_placement_fn(host_id=host_id)
with ops.device(host_device):
with ops.name_scope('input_pipeline_task%d' % (host_id)):
- enqueue_ops_fn, infeed_queue_getter = (
+ enqueue_ops_fn, captured_infeed_queue = (
generate_per_core_enqueue_ops_fn_for_host(
self._ctx, self._input_fn, self._inputs_structure_recorder))
else:
enqueue_ops.append(enqueue_ops_fn())
# Infeed_queue_getter must be called after enqueue_ops_fn is called.
- infeed_queues.append(infeed_queue_getter())
+ infeed_queues.append(captured_infeed_queue.get())
else:
for host_id in range(num_hosts):
host_device = tpu_host_placement_fn(host_id=host_id)
with ops.device(host_device):
with ops.name_scope('input_pipeline_task%d' % (host_id)):
- enqueue_ops_fn, infeed_queue_getter = (
+ enqueue_ops_fn, captured_infeed_queue = (
generate_per_host_enqueue_ops_fn_for_host(
self._ctx, self._input_fn, self._inputs_structure_recorder,
self._batch_axis, host_device))
device=host_device, op_fn=enqueue_ops_fn))
else:
enqueue_ops.append(enqueue_ops_fn())
- infeed_queues.append(infeed_queue_getter())
+ infeed_queues.append(captured_infeed_queue.get())
# infeed_queue is used to generate dequeue ops. The only thing it uses for
# dequeue is dtypes and types. So, any one can be used. Here, grab the
# first one.
A Fn representing the train step for TPU.
"""
+ captured_scaffold_fn = _CapturedObject()
+
def train_step(loss):
"""Training step function for use inside a while loop."""
del loss # unused; required in function signature.
estimator_spec = self._verify_estimator_spec(
self._call_model_fn(features, labels))
loss, train_op = estimator_spec.loss, estimator_spec.train_op
+
+ if isinstance(estimator_spec, TPUEstimatorSpec):
+ captured_scaffold_fn.capture(estimator_spec.scaffold_fn)
+ else:
+ captured_scaffold_fn.capture(None)
+
with ops.control_dependencies([train_op]):
return array_ops.identity(loss)
- return train_step
+ return train_step, captured_scaffold_fn
def convert_to_single_tpu_eval_step(self, dequeue_fn):
"""Converts user provided model_fn` as a single eval step on TPU.
step for TPU. and eval_metrics is an `_EvalMetrics` instance.
"""
eval_metrics = _EvalMetrics(self._ctx)
+ captured_scaffold_fn = _CapturedObject()
def eval_step(total_loss):
"""Evaluation step function for use inside a while loop."""
'`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec)))
loss = tpu_estimator_spec.loss
+ captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn)
eval_metrics.record(tpu_estimator_spec)
outfeed_ops = tpu_ops.outfeed_enqueue_tuple(eval_metrics.outfeed_tensors)
with ops.control_dependencies([outfeed_ops]):
return math_ops.add(total_loss, loss)
- return eval_step, eval_metrics
+ return eval_step, eval_metrics, captured_scaffold_fn
def _call_model_fn(self, features, labels):
"""Calls the model_fn with required parameters."""
raise ValueError(err_msg.format('training_hooks'))
if estimator_spec.evaluation_hooks:
raise ValueError(err_msg.format('evaluation_hooks'))
+
+ if estimator_spec.scaffold:
+ logging.warning('EstimatorSpec.Scaffold is ignored by TPU train/eval. '
+ 'Please use TPUEstimatorSpec.')
return estimator_spec
input_holders.generate_infeed_enqueue_ops_and_dequeue_fn())
if mode == model_fn_lib.ModeKeys.TRAIN:
- loss = _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn)
+ loss, scaffold = (
+ _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn))
hooks = [
TPUInfeedOutfeedSessionHook(ctx, enqueue_ops),
training.LoggingTensorHook(
mode,
loss=loss,
training_hooks=hooks,
- train_op=control_flow_ops.group(*update_ops))
+ train_op=control_flow_ops.group(*update_ops),
+ scaffold=scaffold)
# Now eval.
- total_loss, eval_metric_ops = _eval_on_tpu_system(
+ total_loss, eval_metric_ops, scaffold = _eval_on_tpu_system(
ctx, model_fn_wrapper, dequeue_fn)
iterations_per_loop_var = _create_or_get_iterations_per_loop()
mean_loss = math_ops.div(
mode,
loss=mean_loss,
evaluation_hooks=hooks,
- eval_metric_ops=eval_metric_ops)
+ eval_metric_ops=eval_metric_ops,
+ scaffold=scaffold)
return _model_fn
num_cores = ctx.num_cores
iterations_per_loop_var = _create_or_get_iterations_per_loop()
- single_tpu_eval_step, eval_metric_ops = (
+ single_tpu_eval_step, eval_metric_ops, captured_scaffold_fn = (
model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn))
def multi_tpu_eval_steps_on_single_shard():
inputs=[],
num_shards=num_cores,
outputs_from_all_shards=False)
- return loss, eval_metric_ops
+
+ scaffold = _get_scaffold(captured_scaffold_fn)
+ return loss, eval_metric_ops, scaffold
def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
num_cores = ctx.num_cores
iterations_per_loop_var = _create_or_get_iterations_per_loop()
- single_tpu_train_step = model_fn_wrapper.convert_to_single_tpu_train_step(
- dequeue_fn)
+ single_tpu_train_step, captured_scaffold_fn = (
+ model_fn_wrapper.convert_to_single_tpu_train_step(dequeue_fn))
def multi_tpu_train_steps_on_single_shard():
return training_loop.repeat(
inputs=[],
num_shards=num_cores,
outputs_from_all_shards=False)
- return loss
+
+ scaffold = _get_scaffold(captured_scaffold_fn)
+ return loss, scaffold
def _wrap_computation_in_while_loop(device, op_fn):
'CrossShardOptimizer must be used for model training on TPUs.')
+class _CapturedObject(object):
+ """A placeholder to capture an object.
+
+ This is useful when we need to capture a Python object in the Tensorflow
+ control flow body function and use it outside the control flow.
+ """
+
+ def __init__(self):
+ self._object = None
+ self._captured = False
+
+ def capture(self, o):
+ if self._captured:
+ raise RuntimeError(
+ 'InternalError: Object can be captured only. Please file bug .')
+
+ self._captured = True
+ self._object = o
+
+ def get(self):
+ if not self._captured:
+ raise RuntimeError(
+ 'InternalError: Object is not captured properly before `get`. '
+ 'Please file bug .')
+ return self._object
+
+
+def _get_scaffold(captured_scaffold_fn):
+ """Retrieves the Scaffold from `captured_scaffold_fn`."""
+ with _CapturingContext(message='Inside scaffold_fn'):
+ scaffold_fn = captured_scaffold_fn.get()
+ if scaffold_fn:
+ scaffold = scaffold_fn()
+ if scaffold is None:
+ raise ValueError(
+ 'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed')
+ else:
+ scaffold = None
+
+ if scaffold:
+ wrapped_finalize = scaffold.finalize
+ def _finalize():
+ with _CapturingContext('Inside Scaffold.finalize'):
+ wrapped_finalize()
+ scaffold.finalize = _finalize
+ return scaffold
+
+
+class _CapturingContext(control_flow_ops.ControlFlowContext):
+ """Tracks references to Tensors defined in TPU replication."""
+
+ def __init__(self, message):
+ control_flow_ops.ControlFlowContext.__init__(self)
+ self._message = message
+
+ def AddOp(self, op): # pylint: disable=invalid-name
+ for c in op.inputs:
+ if tpu._TPU_REPLICATE_ATTR in c.op.node_def.attr: # pylint: disable=protected-access
+ raise ValueError(
+ '{}: Op {} depends on TPU computation {}, '
+ 'which is not allowed.'.format(self._message, op, c))
+
+ def __enter__(self):
+ # pylint: disable=protected-access
+ self._g = ops.get_default_graph()
+ self._old = self._g._get_control_flow_context()
+ self._g._set_control_flow_context(self)
+ # pylint: enable=protected-access
+
+ def __exit__(self, _, __, ___): # pylint: disable=invalid-name
+ self._g._set_control_flow_context(self._old) # pylint: disable=protected-access
+
+