From: Jianwei Xie Date: Tue, 20 Mar 2018 15:32:30 +0000 (-0700) Subject: Fixed the bug that the export code triggers the TPU validation. X-Git-Tag: tflite-v0.1.7~145^2^2~13 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=4a4c13788634e73f3c1bd01abd142a607c2fd253;p=platform%2Fupstream%2Ftensorflow.git Fixed the bug that the export code triggers the TPU validation. PiperOrigin-RevId: 189745966 --- diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 5a8fa04..f61f6bb 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -1044,8 +1044,8 @@ class _ModelFnWrapper(object): self._params = params self._ctx = ctx - def call_without_tpu(self, features, labels): - return self._call_model_fn(features, labels) + def call_without_tpu(self, features, labels, is_export_mode): + return self._call_model_fn(features, labels, is_export_mode=is_export_mode) def convert_to_single_tpu_train_step(self, dequeue_fn): """Converts user provided model_fn` as a single train step on TPU. @@ -1204,7 +1204,7 @@ class _ModelFnWrapper(object): return predict_step, host_calls, captured_scaffold_fn - def _call_model_fn(self, features, labels, is_export_mode=True): + def _call_model_fn(self, features, labels, is_export_mode=False): """Calls the model_fn with required parameters.""" model_fn_args = util.fn_args(self._model_fn) kwargs = {} @@ -1230,7 +1230,11 @@ class _ModelFnWrapper(object): 'required by TPUEstimator to pass batch size as ' 'params[\'batch_size\']'.format(self._model_fn)) - batch_size_for_model_fn = self._ctx.batch_size_for_model_fn + if is_export_mode: + batch_size_for_model_fn = None + else: + batch_size_for_model_fn = self._ctx.batch_size_for_model_fn + if batch_size_for_model_fn is not None: params[_BATCH_SIZE_KEY] = batch_size_for_model_fn @@ -1778,6 +1782,8 @@ class TPUEstimator(estimator_lib.Estimator): eval_batch_size, predict_batch_size, use_tpu) + self._is_input_fn_invoked = None + def _create_global_step(self, graph): """Creates a global step suitable for TPUs. @@ -1860,6 +1866,9 @@ class TPUEstimator(estimator_lib.Estimator): if 'mode' in input_fn_args: kwargs['mode'] = mode + # Records the fact input_fn has been invoked. + self._is_input_fn_invoked = True + with self._ctx.with_mode(mode) as ctx: # Setting the batch size in params first. This helps user to have same # input_fn for use_tpu=True/False. @@ -1907,15 +1916,24 @@ class TPUEstimator(estimator_lib.Estimator): with self._ctx.with_mode(mode) as ctx: model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx) - # For export_savedmodel, input_fn is never passed to Estimator. So, - # if features is callable, it means it is the input_fn passed by - # TPUEstimator._call_input_fn. Then we can know if the mode == PREDICT, - # it implies, it is the .predict API, not export_savedmodel API. - is_export_mode = not callable(features) + if mode != model_fn_lib.ModeKeys.PREDICT: + is_export_mode = False + else: + # For export_savedmodel, input_fn is never passed to Estimator. So, by + # checking the self._is_input_fn_invoked bit, we can know, given the + # mode == PREDICT, it is the .predict API, not export_savedmodel API. + if self._is_input_fn_invoked: + is_export_mode = False + else: + is_export_mode = True + + # Clear the bit. + self._is_input_fn_invoked = None if ctx.is_running_on_cpu(is_export_mode=is_export_mode): logging.info('Running %s on CPU', mode) - return model_fn_wrapper.call_without_tpu(features, labels) + return model_fn_wrapper.call_without_tpu( + features, labels, is_export_mode=is_export_mode) assert labels is None, '`labels` passed to `model_fn` must be `None`.' # TPUEstimator._call_input_fn passes `input_fn` as features to here.