'estimator_spec used by TPU prediction must have type'
'`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec)))
+ self._verify_tpu_spec_predictions(tpu_estimator_spec.predictions)
+
captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn)
to_record = {}
identity_fn = lambda **kwargs: kwargs
- # TODO(xiejw): Adds validation for prediction dictionrary.
- # TODO(xiejw): Adds support for single tensor as predictions.
- if not isinstance(tpu_estimator_spec.predictions, dict):
- raise TypeError('TPUEstimatorSpec.predictions must be dict of Tensors.')
to_record['predictions'] = [identity_fn, tpu_estimator_spec.predictions]
to_record['signals'] = [identity_fn, stopping_signals]
if tpu_estimator_spec.host_call is not None:
return predict_step, host_calls, captured_scaffold_fn
+ def _verify_tpu_spec_predictions(self, predictions):
+ """Validates TPUEstimatorSpec.predictions dict."""
+ # TODO(xiejw): Adds validation for prediction dictionrary.
+ # TODO(xiejw): Adds support for single tensor as predictions.
+ if not isinstance(predictions, dict):
+ raise TypeError('TPUEstimatorSpec.predictions must be dict of Tensors.')
+
+ for (key, tensor) in predictions.items():
+ if tensor.shape[0].value is None:
+ raise ValueError(
+ 'The tensor with key ({}) in TPUEstimatorSpec.predictions has '
+ 'dynamic shape (should be static). Tensor: {}'.format(
+ key, tensor))
+ return predictions
+
def _call_model_fn(self, features, labels, is_export_mode=False):
"""Calls the model_fn with required parameters."""
model_fn_args = function_utils.fn_args(self._model_fn)