From: Jianwei Xie Date: Tue, 22 May 2018 19:36:35 +0000 (-0700) Subject: Detect unknown batch size in predictions dict X-Git-Tag: upstream/v1.9.0_rc1~38^2~4^2~204 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=a4c9efe6a5bf143f844b1cffbdc839c399620b9b;p=platform%2Fupstream%2Ftensorflow.git Detect unknown batch size in predictions dict PiperOrigin-RevId: 197606059 --- diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 77d117b..f0c7564 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -1264,13 +1264,11 @@ class _ModelFnWrapper(object): 'estimator_spec used by TPU prediction must have type' '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec))) + self._verify_tpu_spec_predictions(tpu_estimator_spec.predictions) + captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn) to_record = {} identity_fn = lambda **kwargs: kwargs - # TODO(xiejw): Adds validation for prediction dictionrary. - # TODO(xiejw): Adds support for single tensor as predictions. - if not isinstance(tpu_estimator_spec.predictions, dict): - raise TypeError('TPUEstimatorSpec.predictions must be dict of Tensors.') to_record['predictions'] = [identity_fn, tpu_estimator_spec.predictions] to_record['signals'] = [identity_fn, stopping_signals] if tpu_estimator_spec.host_call is not None: @@ -1282,6 +1280,21 @@ class _ModelFnWrapper(object): return predict_step, host_calls, captured_scaffold_fn + def _verify_tpu_spec_predictions(self, predictions): + """Validates TPUEstimatorSpec.predictions dict.""" + # TODO(xiejw): Adds validation for prediction dictionrary. + # TODO(xiejw): Adds support for single tensor as predictions. + if not isinstance(predictions, dict): + raise TypeError('TPUEstimatorSpec.predictions must be dict of Tensors.') + + for (key, tensor) in predictions.items(): + if tensor.shape[0].value is None: + raise ValueError( + 'The tensor with key ({}) in TPUEstimatorSpec.predictions has ' + 'dynamic shape (should be static). Tensor: {}'.format( + key, tensor)) + return predictions + def _call_model_fn(self, features, labels, is_export_mode=False): """Calls the model_fn with required parameters.""" model_fn_args = function_utils.fn_args(self._model_fn)