Fixed the bug that the export code triggers the TPU validation.
authorJianwei Xie <xiejw@google.com>
Tue, 20 Mar 2018 15:32:30 +0000 (08:32 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 20 Mar 2018 15:36:54 +0000 (08:36 -0700)
PiperOrigin-RevId: 189745966

tensorflow/contrib/tpu/python/tpu/tpu_estimator.py

index 5a8fa04..f61f6bb 100644 (file)
@@ -1044,8 +1044,8 @@ class _ModelFnWrapper(object):
     self._params = params
     self._ctx = ctx
 
-  def call_without_tpu(self, features, labels):
-    return self._call_model_fn(features, labels)
+  def call_without_tpu(self, features, labels, is_export_mode):
+    return self._call_model_fn(features, labels, is_export_mode=is_export_mode)
 
   def convert_to_single_tpu_train_step(self, dequeue_fn):
     """Converts user provided model_fn` as a single train step on TPU.
@@ -1204,7 +1204,7 @@ class _ModelFnWrapper(object):
 
     return predict_step, host_calls, captured_scaffold_fn
 
-  def _call_model_fn(self, features, labels, is_export_mode=True):
+  def _call_model_fn(self, features, labels, is_export_mode=False):
     """Calls the model_fn with required parameters."""
     model_fn_args = util.fn_args(self._model_fn)
     kwargs = {}
@@ -1230,7 +1230,11 @@ class _ModelFnWrapper(object):
                        'required by TPUEstimator to pass batch size as '
                        'params[\'batch_size\']'.format(self._model_fn))
 
-    batch_size_for_model_fn = self._ctx.batch_size_for_model_fn
+    if is_export_mode:
+      batch_size_for_model_fn = None
+    else:
+      batch_size_for_model_fn = self._ctx.batch_size_for_model_fn
+
     if batch_size_for_model_fn is not None:
       params[_BATCH_SIZE_KEY] = batch_size_for_model_fn
 
@@ -1778,6 +1782,8 @@ class TPUEstimator(estimator_lib.Estimator):
         eval_batch_size, predict_batch_size,
         use_tpu)
 
+    self._is_input_fn_invoked = None
+
   def _create_global_step(self, graph):
     """Creates a global step suitable for TPUs.
 
@@ -1860,6 +1866,9 @@ class TPUEstimator(estimator_lib.Estimator):
     if 'mode' in input_fn_args:
       kwargs['mode'] = mode
 
+    # Records the fact input_fn has been invoked.
+    self._is_input_fn_invoked = True
+
     with self._ctx.with_mode(mode) as ctx:
       # Setting the batch size in params first. This helps user to have same
       # input_fn for use_tpu=True/False.
@@ -1907,15 +1916,24 @@ class TPUEstimator(estimator_lib.Estimator):
       with self._ctx.with_mode(mode) as ctx:
         model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx)
 
-        # For export_savedmodel, input_fn is never passed to Estimator. So,
-        # if features is callable, it means it is the input_fn passed by
-        # TPUEstimator._call_input_fn. Then we can know if the mode == PREDICT,
-        # it implies, it is the .predict API, not export_savedmodel API.
-        is_export_mode = not callable(features)
+        if mode != model_fn_lib.ModeKeys.PREDICT:
+          is_export_mode = False
+        else:
+          # For export_savedmodel, input_fn is never passed to Estimator. So, by
+          # checking the self._is_input_fn_invoked bit, we can know, given the
+          # mode == PREDICT, it is the .predict API, not export_savedmodel API.
+          if self._is_input_fn_invoked:
+            is_export_mode = False
+          else:
+            is_export_mode = True
+
+        # Clear the bit.
+        self._is_input_fn_invoked = None
 
         if ctx.is_running_on_cpu(is_export_mode=is_export_mode):
           logging.info('Running %s on CPU', mode)
-          return model_fn_wrapper.call_without_tpu(features, labels)
+          return model_fn_wrapper.call_without_tpu(
+              features, labels, is_export_mode=is_export_mode)
 
         assert labels is None, '`labels` passed to `model_fn` must be `None`.'
         # TPUEstimator._call_input_fn passes `input_fn` as features to here.