Scaffold support in TPUEstimator.
authorJianwei Xie <xiejw@google.com>
Fri, 15 Dec 2017 16:11:45 +0000 (08:11 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 15 Dec 2017 16:15:23 +0000 (08:15 -0800)
PiperOrigin-RevId: 179194460

tensorflow/contrib/tpu/python/tpu/tpu.py
tensorflow/contrib/tpu/python/tpu/tpu_estimator.py

index 7fb8a33698fdd2b37f42464e934331de65904bfe..24596bdb0af66314d402a7f6e21a8f00ca06dfbe 100644 (file)
@@ -52,6 +52,8 @@ _NOT_IMPLEMENTED_OPS = set([
     "TensorSummaryV2",
     ])
 
+_TPU_REPLICATE_ATTR = "_tpu_replicate"
+
 
 def _tpu_system_device_name(job):
   """Returns the device name for the TPU_SYSTEM device of `job`."""
@@ -138,9 +140,9 @@ class TPUReplicateContext(control_flow_ops.ControlFlowContext):
           "Non-resource Variables are not supported inside TPU computations "
           "(operator name: %s)" % op.name)
     # pylint: enable=protected-access
-    if "_tpu_replicate" in op.node_def.attr:
+    if _TPU_REPLICATE_ATTR in op.node_def.attr:
       raise ValueError("TPU computations cannot be nested")
-    op.node_def.attr["_tpu_replicate"].s = self._name
+    op.node_def.attr[_TPU_REPLICATE_ATTR].s = self._name
     op.graph.prevent_feeding(op)
     op.graph.prevent_fetching(op)
 
index e324948be55b72dc0555dbd49de5b39bf7a09282..d66abd7b66fcf98286ea50b2548199eaa812e353 100644 (file)
@@ -365,13 +365,17 @@ class TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [
     'loss',
     'train_op',
     'eval_metrics',
-    'export_outputs'])):
+    'export_outputs',
+    'scaffold_fn'])):
   """Ops and objects returned from a `model_fn` and passed to `TPUEstimator`.
 
   See `EstimatorSpec` for `mode`, 'predictions, 'loss', 'train_op', and
   'export_outputs`.
 
-  TPU evaluation expects a slightly different signature from the
+  For evaluation, `eval_metrics `is a tuple of `metric_fn` and `tensors`, where
+  `metric_fn` runs on CPU to generate metrics and `tensors` represents the
+  `Tensor`s transferred from TPU system to CPU host and passed to `metric_fn`.
+  To be precise, TPU evaluation expects a slightly different signature from the
   ${tf.estimator.Estimator}. While `EstimatorSpec.eval_metric_ops` expects a
   dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`.
   The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The
@@ -382,9 +386,11 @@ class TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [
   to the `metric_fn` if `tensors` is list or keyword arguments if `tensors` is
   dict. `metric_fn` takes the `tensors` and returns a dict from metric string
   name to the result of calling a metric function, namely a `(metric_tensor,
-  update_op)` tuple.
+  update_op)` tuple. See `TPUEstimator` for MNIST example how to specify the
+  `eval_metrics`.
 
-  See `TPUEstimator` for MNIST example how to specify the `eval_metrics`.
+  `scaffold_fn` is a function running on CPU to generate the `Scaffold`. This
+  function should not capture any Tensors in `model_fn`.
   """
 
   def __new__(cls,
@@ -393,7 +399,8 @@ class TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [
               loss=None,
               train_op=None,
               eval_metrics=None,
-              export_outputs=None):
+              export_outputs=None,
+              scaffold_fn=None):
     """Creates a validated `TPUEstimatorSpec` instance."""
     if eval_metrics is not None:
       _EvalMetrics.validate(eval_metrics)
@@ -403,18 +410,21 @@ class TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [
                                                 loss=loss,
                                                 train_op=train_op,
                                                 eval_metrics=eval_metrics,
-                                                export_outputs=export_outputs)
+                                                export_outputs=export_outputs,
+                                                scaffold_fn=scaffold_fn)
 
   def as_estimator_spec(self):
     """Creates an equivalent `EstimatorSpec` used by CPU train/eval."""
     eval_metric_ops = _EvalMetrics.to_metric_metric_ops_for_cpu(
         self.eval_metrics)
+    scaffold = self.scaffold_fn() if self.scaffold_fn else None
     return model_fn_lib.EstimatorSpec(mode=self.mode,
                                       predictions=self.predictions,
                                       loss=self.loss,
                                       train_op=self.train_op,
                                       eval_metric_ops=eval_metric_ops,
-                                      export_outputs=self.export_outputs)
+                                      export_outputs=self.export_outputs,
+                                      scaffold=scaffold)
 
 
 class _InfeedOutfeedThreadBaseController(object):
@@ -679,7 +689,7 @@ class _SetEvalIterationsHook(session_run_hook.SessionRunHook):
 def generate_per_core_enqueue_ops_fn_for_host(
     ctx, input_fn, inputs_structure_recorder):
   """Generates infeed enqueue ops for per-core input_fn on a single host."""
-  infeed_queue_holder = {'instance': None}
+  captured_infeed_queue = _CapturedObject()
 
   def enqueue_ops_fn():
     """A fn returns enqueue_ops."""
@@ -702,7 +712,7 @@ def generate_per_core_enqueue_ops_fn_for_host(
 
     infeed_queue = tpu_feed.InfeedQueue(
         number_of_tuple_elements=len(per_host_sharded_inputs[0]))
-    infeed_queue_holder['instance'] = infeed_queue
+    captured_infeed_queue.capture(infeed_queue)
     infeed_queue.set_configuration_from_sharded_input_tensors(
         per_host_sharded_inputs)
 
@@ -710,13 +720,13 @@ def generate_per_core_enqueue_ops_fn_for_host(
         per_host_sharded_inputs,
         tpu_ordinal_function=ctx.tpu_ordinal_function)
     return per_host_enqueue_ops
-  return enqueue_ops_fn, (lambda: infeed_queue_holder['instance'])
+  return enqueue_ops_fn, captured_infeed_queue
 
 
 def generate_per_host_enqueue_ops_fn_for_host(
     ctx, input_fn, inputs_structure_recorder, batch_axis, device):
   """Generates infeed enqueue ops for per-host input_fn on a single host."""
-  infeed_queue_holder = {'instance': None}
+  captured_infeed_queue = _CapturedObject()
 
   def enqueue_ops_fn():
     with ops.device(device):
@@ -736,7 +746,7 @@ def generate_per_host_enqueue_ops_fn_for_host(
           tuple_types=[t.dtype for t in unsharded_tensor_list],
           tuple_shapes=[t.shape for t in unsharded_tensor_list],
           shard_dimensions=batch_axis)
-      infeed_queue_holder['instance'] = infeed_queue
+      captured_infeed_queue.capture(infeed_queue)
       infeed_queue.set_number_of_shards(num_cores_per_host)
 
       per_host_enqueue_ops = (
@@ -744,7 +754,7 @@ def generate_per_host_enqueue_ops_fn_for_host(
               unsharded_tensor_list,
               placement_function=lambda x: device))
       return per_host_enqueue_ops
-  return enqueue_ops_fn, (lambda: infeed_queue_holder['instance'])
+  return enqueue_ops_fn, captured_infeed_queue
 
 
 class _InputPipeline(object):
@@ -934,7 +944,7 @@ class _InputPipeline(object):
         host_device = tpu_host_placement_fn(host_id=host_id)
         with ops.device(host_device):
           with ops.name_scope('input_pipeline_task%d' % (host_id)):
-            enqueue_ops_fn, infeed_queue_getter = (
+            enqueue_ops_fn, captured_infeed_queue = (
                 generate_per_core_enqueue_ops_fn_for_host(
                     self._ctx, self._input_fn, self._inputs_structure_recorder))
 
@@ -944,14 +954,14 @@ class _InputPipeline(object):
             else:
               enqueue_ops.append(enqueue_ops_fn())
             # Infeed_queue_getter must be called after enqueue_ops_fn is called.
-            infeed_queues.append(infeed_queue_getter())
+            infeed_queues.append(captured_infeed_queue.get())
 
     else:
       for host_id in range(num_hosts):
         host_device = tpu_host_placement_fn(host_id=host_id)
         with ops.device(host_device):
           with ops.name_scope('input_pipeline_task%d' % (host_id)):
-            enqueue_ops_fn, infeed_queue_getter = (
+            enqueue_ops_fn, captured_infeed_queue = (
                 generate_per_host_enqueue_ops_fn_for_host(
                     self._ctx, self._input_fn, self._inputs_structure_recorder,
                     self._batch_axis, host_device))
@@ -961,7 +971,7 @@ class _InputPipeline(object):
                   device=host_device, op_fn=enqueue_ops_fn))
             else:
               enqueue_ops.append(enqueue_ops_fn())
-            infeed_queues.append(infeed_queue_getter())
+            infeed_queues.append(captured_infeed_queue.get())
     # infeed_queue is used to generate dequeue ops. The only thing it uses for
     # dequeue is dtypes and types. So, any one can be used. Here, grab the
     # first one.
@@ -1029,6 +1039,8 @@ class _ModelFnWrapper(object):
       A Fn representing the train step for TPU.
     """
 
+    captured_scaffold_fn = _CapturedObject()
+
     def train_step(loss):
       """Training step function for use inside a while loop."""
       del loss  # unused; required in function signature.
@@ -1037,9 +1049,15 @@ class _ModelFnWrapper(object):
       estimator_spec = self._verify_estimator_spec(
           self._call_model_fn(features, labels))
       loss, train_op = estimator_spec.loss, estimator_spec.train_op
+
+      if isinstance(estimator_spec, TPUEstimatorSpec):
+        captured_scaffold_fn.capture(estimator_spec.scaffold_fn)
+      else:
+        captured_scaffold_fn.capture(None)
+
       with ops.control_dependencies([train_op]):
         return array_ops.identity(loss)
-    return train_step
+    return train_step, captured_scaffold_fn
 
   def convert_to_single_tpu_eval_step(self, dequeue_fn):
     """Converts user provided model_fn` as a single eval step on TPU.
@@ -1068,6 +1086,7 @@ class _ModelFnWrapper(object):
       step for TPU. and eval_metrics is an `_EvalMetrics` instance.
     """
     eval_metrics = _EvalMetrics(self._ctx)
+    captured_scaffold_fn = _CapturedObject()
 
     def eval_step(total_loss):
       """Evaluation step function for use inside a while loop."""
@@ -1080,12 +1099,13 @@ class _ModelFnWrapper(object):
             '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec)))
 
       loss = tpu_estimator_spec.loss
+      captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn)
       eval_metrics.record(tpu_estimator_spec)
       outfeed_ops = tpu_ops.outfeed_enqueue_tuple(eval_metrics.outfeed_tensors)
 
       with ops.control_dependencies([outfeed_ops]):
         return math_ops.add(total_loss, loss)
-    return eval_step, eval_metrics
+    return eval_step, eval_metrics, captured_scaffold_fn
 
   def _call_model_fn(self, features, labels):
     """Calls the model_fn with required parameters."""
@@ -1139,6 +1159,10 @@ class _ModelFnWrapper(object):
       raise ValueError(err_msg.format('training_hooks'))
     if estimator_spec.evaluation_hooks:
       raise ValueError(err_msg.format('evaluation_hooks'))
+
+    if estimator_spec.scaffold:
+      logging.warning('EstimatorSpec.Scaffold is ignored by TPU train/eval. '
+                      'Please use TPUEstimatorSpec.')
     return estimator_spec
 
 
@@ -1607,7 +1631,8 @@ class TPUEstimator(estimator_lib.Estimator):
             input_holders.generate_infeed_enqueue_ops_and_dequeue_fn())
 
         if mode == model_fn_lib.ModeKeys.TRAIN:
-          loss = _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn)
+          loss, scaffold = (
+              _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn))
           hooks = [
               TPUInfeedOutfeedSessionHook(ctx, enqueue_ops),
               training.LoggingTensorHook(
@@ -1626,10 +1651,11 @@ class TPUEstimator(estimator_lib.Estimator):
               mode,
               loss=loss,
               training_hooks=hooks,
-              train_op=control_flow_ops.group(*update_ops))
+              train_op=control_flow_ops.group(*update_ops),
+              scaffold=scaffold)
 
         # Now eval.
-        total_loss, eval_metric_ops = _eval_on_tpu_system(
+        total_loss, eval_metric_ops, scaffold = _eval_on_tpu_system(
             ctx, model_fn_wrapper, dequeue_fn)
         iterations_per_loop_var = _create_or_get_iterations_per_loop()
         mean_loss = math_ops.div(
@@ -1660,7 +1686,8 @@ class TPUEstimator(estimator_lib.Estimator):
             mode,
             loss=mean_loss,
             evaluation_hooks=hooks,
-            eval_metric_ops=eval_metric_ops)
+            eval_metric_ops=eval_metric_ops,
+            scaffold=scaffold)
     return _model_fn
 
 
@@ -1669,7 +1696,7 @@ def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
   num_cores = ctx.num_cores
   iterations_per_loop_var = _create_or_get_iterations_per_loop()
 
-  single_tpu_eval_step, eval_metric_ops = (
+  single_tpu_eval_step, eval_metric_ops, captured_scaffold_fn = (
       model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn))
 
   def multi_tpu_eval_steps_on_single_shard():
@@ -1682,7 +1709,9 @@ def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
                       inputs=[],
                       num_shards=num_cores,
                       outputs_from_all_shards=False)
-  return loss, eval_metric_ops
+
+  scaffold = _get_scaffold(captured_scaffold_fn)
+  return loss, eval_metric_ops, scaffold
 
 
 def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
@@ -1690,8 +1719,8 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
   num_cores = ctx.num_cores
   iterations_per_loop_var = _create_or_get_iterations_per_loop()
 
-  single_tpu_train_step = model_fn_wrapper.convert_to_single_tpu_train_step(
-      dequeue_fn)
+  single_tpu_train_step, captured_scaffold_fn = (
+      model_fn_wrapper.convert_to_single_tpu_train_step(dequeue_fn))
 
   def multi_tpu_train_steps_on_single_shard():
     return training_loop.repeat(
@@ -1704,7 +1733,9 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
                       inputs=[],
                       num_shards=num_cores,
                       outputs_from_all_shards=False)
-  return loss
+
+  scaffold = _get_scaffold(captured_scaffold_fn)
+  return loss, scaffold
 
 
 def _wrap_computation_in_while_loop(device, op_fn):
@@ -1740,3 +1771,76 @@ def _validate_tpu_training_graph():
         'CrossShardOptimizer must be used for model training on TPUs.')
 
 
+class _CapturedObject(object):
+  """A placeholder to capture an object.
+
+  This is useful when we need to capture a Python object in the Tensorflow
+  control flow body function and use it outside the control flow.
+  """
+
+  def __init__(self):
+    self._object = None
+    self._captured = False
+
+  def capture(self, o):
+    if self._captured:
+      raise RuntimeError(
+          'InternalError: Object can be captured only. Please file bug .')
+
+    self._captured = True
+    self._object = o
+
+  def get(self):
+    if not self._captured:
+      raise RuntimeError(
+          'InternalError: Object is not captured properly before `get`. '
+          'Please file bug .')
+    return self._object
+
+
+def _get_scaffold(captured_scaffold_fn):
+  """Retrieves the Scaffold from `captured_scaffold_fn`."""
+  with _CapturingContext(message='Inside scaffold_fn'):
+    scaffold_fn = captured_scaffold_fn.get()
+    if scaffold_fn:
+      scaffold = scaffold_fn()
+      if scaffold is None:
+        raise ValueError(
+            'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed')
+    else:
+      scaffold = None
+
+  if scaffold:
+    wrapped_finalize = scaffold.finalize
+    def _finalize():
+      with _CapturingContext('Inside Scaffold.finalize'):
+        wrapped_finalize()
+    scaffold.finalize = _finalize
+  return scaffold
+
+
+class _CapturingContext(control_flow_ops.ControlFlowContext):
+  """Tracks references to Tensors defined in TPU replication."""
+
+  def __init__(self, message):
+    control_flow_ops.ControlFlowContext.__init__(self)
+    self._message = message
+
+  def AddOp(self, op):  # pylint: disable=invalid-name
+    for c in op.inputs:
+      if tpu._TPU_REPLICATE_ATTR in c.op.node_def.attr:  # pylint: disable=protected-access
+        raise ValueError(
+            '{}: Op {} depends on TPU computation {}, '
+            'which is not allowed.'.format(self._message, op, c))
+
+  def __enter__(self):
+    # pylint: disable=protected-access
+    self._g = ops.get_default_graph()
+    self._old = self._g._get_control_flow_context()
+    self._g._set_control_flow_context(self)
+    # pylint: enable=protected-access
+
+  def __exit__(self, _, __, ___):  # pylint: disable=invalid-name
+    self._g._set_control_flow_context(self._old)  # pylint: disable=protected-access
+
+