Detect unknown batch size in predictions dict
authorJianwei Xie <xiejw@google.com>
Tue, 22 May 2018 19:36:35 +0000 (12:36 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 22 May 2018 19:39:28 +0000 (12:39 -0700)
PiperOrigin-RevId: 197606059

tensorflow/contrib/tpu/python/tpu/tpu_estimator.py

index 77d117b..f0c7564 100644 (file)
@@ -1264,13 +1264,11 @@ class _ModelFnWrapper(object):
             'estimator_spec used by TPU prediction must have type'
             '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec)))
 
+      self._verify_tpu_spec_predictions(tpu_estimator_spec.predictions)
+
       captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn)
       to_record = {}
       identity_fn = lambda **kwargs: kwargs
-      # TODO(xiejw): Adds validation for prediction dictionrary.
-      # TODO(xiejw): Adds support for single tensor as predictions.
-      if not isinstance(tpu_estimator_spec.predictions, dict):
-        raise TypeError('TPUEstimatorSpec.predictions must be dict of Tensors.')
       to_record['predictions'] = [identity_fn, tpu_estimator_spec.predictions]
       to_record['signals'] = [identity_fn, stopping_signals]
       if tpu_estimator_spec.host_call is not None:
@@ -1282,6 +1280,21 @@ class _ModelFnWrapper(object):
 
     return predict_step, host_calls, captured_scaffold_fn
 
+  def _verify_tpu_spec_predictions(self, predictions):
+    """Validates TPUEstimatorSpec.predictions dict."""
+    # TODO(xiejw): Adds validation for prediction dictionrary.
+    # TODO(xiejw): Adds support for single tensor as predictions.
+    if not isinstance(predictions, dict):
+      raise TypeError('TPUEstimatorSpec.predictions must be dict of Tensors.')
+
+    for (key, tensor) in predictions.items():
+      if tensor.shape[0].value is None:
+        raise ValueError(
+            'The tensor with key ({}) in TPUEstimatorSpec.predictions has '
+            'dynamic shape (should be static). Tensor: {}'.format(
+                key, tensor))
+    return predictions
+
   def _call_model_fn(self, features, labels, is_export_mode=False):
     """Calls the model_fn with required parameters."""
     model_fn_args = function_utils.fn_args(self._model_fn)