Update global_step by default if the user specifies a host_call.
authorJonathan Hseu <jhseu@google.com>
Sat, 3 Feb 2018 00:35:10 +0000 (16:35 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 3 Feb 2018 00:39:02 +0000 (16:39 -0800)
PiperOrigin-RevId: 184352399

tensorflow/contrib/tpu/python/tpu/tpu_estimator.py

index b5082fc..56793f1 100644 (file)
@@ -63,6 +63,7 @@ from tensorflow.python.training import evaluation
 from tensorflow.python.training import session_run_hook
 from tensorflow.python.training import training
 from tensorflow.python.training import training_util
+from tensorflow.python.util import tf_inspect
 
 _INITIAL_LOSS = 1e7
 _ZERO_LOSS = 0.
@@ -484,7 +485,7 @@ class TPUEstimatorSpec(
     if self.eval_metrics is not None:
       host_calls['eval_metrics'] = self.eval_metrics
     if self.host_call is not None:
-      host_calls['host_call'] = self.host_call
+      host_calls['host_call'] = wrap_hostcall_with_global_step(self.host_call)
     host_call_ret = _OutfeedHostCall.create_cpu_hostcall(host_calls)
     eval_metric_ops = None
     if self.eval_metrics is not None:
@@ -1306,7 +1307,9 @@ class _ModelFnWrapper(object):
       if isinstance(estimator_spec, TPUEstimatorSpec):
         captured_scaffold_fn.capture(estimator_spec.scaffold_fn)
         if estimator_spec.host_call is not None:
-          host_call.record({'host_call': estimator_spec.host_call})
+          host_call.record({
+              'host_call': wrap_hostcall_with_global_step(
+                  estimator_spec.host_call)})
           host_call_outfeed_ops = host_call.create_enqueue_op()
       else:
         captured_scaffold_fn.capture(None)
@@ -1361,6 +1364,8 @@ class _ModelFnWrapper(object):
       to_record = {}
       to_record['eval_metrics'] = tpu_estimator_spec.eval_metrics
       if tpu_estimator_spec.host_call is not None:
+        # We assume that evaluate won't update global step, so we don't wrap
+        # this host_call.
         to_record['host_call'] = tpu_estimator_spec.host_call
       host_calls.record(to_record)
 
@@ -1503,11 +1508,14 @@ class _OutfeedHostCall(object):
         raise ValueError('{}[1] should be tuple or list, or dict.'.format(name))
 
       if isinstance(host_call[1], (tuple, list)):
+        fullargspec = tf_inspect.getfullargspec(host_call[0])
         fn_args = util.fn_args(host_call[0])
-        if len(host_call[1]) != len(fn_args):
+        # wrapped_hostcall_with_global_step uses varargs, so we allow that.
+        if fullargspec.varargs is None and len(host_call[1]) != len(fn_args):
           raise RuntimeError(
-              'In TPUEstimatorSpec.{}, length of tensors does not '
-              'match method args of metric_fn.'.format(name))
+              'In TPUEstimatorSpec.{}, length of tensors {} does not match '
+              'method args of the function, which takes {}.'.format(
+                  name, len(host_call[1]), len(fn_args)))
 
   @staticmethod
   def create_cpu_hostcall(host_calls):
@@ -1649,6 +1657,38 @@ class _OutfeedHostCall(object):
     return ret
 
 
+def wrap_hostcall_with_global_step(hostcall):
+  """Wrap the hostcall so that we update the global step upon every call."""
+  if hostcall is None:
+    return None
+  host_fn, tensors = hostcall
+
+  def global_step_host_fn(_global_step, *args, **kwargs):  # pylint: disable=invalid-name
+    # Note that we don't have any ordering here, so the graph may see a
+    # global_step that's off by 1.
+    state_ops.assign(
+        training.get_global_step(),
+        math_ops.cast(_global_step[0], dtypes.int64))
+    return host_fn(*args, **kwargs)
+  # Give the global step tensor a batch dimension. Reshape is not supported for
+  # int64, so we cast it to int32.
+  # TODO(jhseu): Remove the cast once int64 is supported.
+  global_step_tensor = array_ops.reshape(
+      math_ops.cast(training.get_global_step(), dtypes.int32), [1])
+  if isinstance(tensors, dict):
+    outfeed_tensors = {'_global_step': global_step_tensor}
+    outfeed_tensors.update(tensors)
+    return global_step_host_fn, outfeed_tensors
+  else:
+    fn_args = util.fn_args(host_fn)
+    if len(tensors) != len(fn_args):
+      raise RuntimeError(
+          'In TPUEstimatorSpec.host_call, length of tensors {} does not match '
+          'method args of the function, which takes {}.'.format(
+              len(tensors), len(fn_args)))
+    return global_step_host_fn, [global_step_tensor] + list(tensors)
+
+
 class _OutfeedHostCallHook(session_run_hook.SessionRunHook):
   """Hook to run host calls when use_tpu=False."""