self._params = params
self._ctx = ctx
- def call_without_tpu(self, features, labels):
- return self._call_model_fn(features, labels)
+ def call_without_tpu(self, features, labels, is_export_mode):
+ return self._call_model_fn(features, labels, is_export_mode=is_export_mode)
def convert_to_single_tpu_train_step(self, dequeue_fn):
"""Converts user provided model_fn` as a single train step on TPU.
return predict_step, host_calls, captured_scaffold_fn
- def _call_model_fn(self, features, labels, is_export_mode=True):
+ def _call_model_fn(self, features, labels, is_export_mode=False):
"""Calls the model_fn with required parameters."""
model_fn_args = util.fn_args(self._model_fn)
kwargs = {}
'required by TPUEstimator to pass batch size as '
'params[\'batch_size\']'.format(self._model_fn))
- batch_size_for_model_fn = self._ctx.batch_size_for_model_fn
+ if is_export_mode:
+ batch_size_for_model_fn = None
+ else:
+ batch_size_for_model_fn = self._ctx.batch_size_for_model_fn
+
if batch_size_for_model_fn is not None:
params[_BATCH_SIZE_KEY] = batch_size_for_model_fn
eval_batch_size, predict_batch_size,
use_tpu)
+ self._is_input_fn_invoked = None
+
def _create_global_step(self, graph):
"""Creates a global step suitable for TPUs.
if 'mode' in input_fn_args:
kwargs['mode'] = mode
+ # Records the fact input_fn has been invoked.
+ self._is_input_fn_invoked = True
+
with self._ctx.with_mode(mode) as ctx:
# Setting the batch size in params first. This helps user to have same
# input_fn for use_tpu=True/False.
with self._ctx.with_mode(mode) as ctx:
model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx)
- # For export_savedmodel, input_fn is never passed to Estimator. So,
- # if features is callable, it means it is the input_fn passed by
- # TPUEstimator._call_input_fn. Then we can know if the mode == PREDICT,
- # it implies, it is the .predict API, not export_savedmodel API.
- is_export_mode = not callable(features)
+ if mode != model_fn_lib.ModeKeys.PREDICT:
+ is_export_mode = False
+ else:
+ # For export_savedmodel, input_fn is never passed to Estimator. So, by
+ # checking the self._is_input_fn_invoked bit, we can know, given the
+ # mode == PREDICT, it is the .predict API, not export_savedmodel API.
+ if self._is_input_fn_invoked:
+ is_export_mode = False
+ else:
+ is_export_mode = True
+
+ # Clear the bit.
+ self._is_input_fn_invoked = None
if ctx.is_running_on_cpu(is_export_mode=is_export_mode):
logging.info('Running %s on CPU', mode)
- return model_fn_wrapper.call_without_tpu(features, labels)
+ return model_fn_wrapper.call_without_tpu(
+ features, labels, is_export_mode=is_export_mode)
assert labels is None, '`labels` passed to `model_fn` must be `None`.'
# TPUEstimator._call_input_fn passes `input_fn` as features to here.