Add document for TPUEstimate.predict, including limitations and example.
authorJianwei Xie <xiejw@google.com>
Thu, 8 Mar 2018 22:29:45 +0000 (14:29 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 8 Mar 2018 22:33:47 +0000 (14:33 -0800)
PiperOrigin-RevId: 188390287

tensorflow/contrib/tpu/python/tpu/tpu_estimator.py

index 33251f2..d918b0f 100644 (file)
@@ -1517,14 +1517,20 @@ class TPUEstimator(estimator_lib.Estimator):
   size when calling the `input_fn` and `model_fn`. Users should specify
   global batch size in constructor, and then get the batch size for each shard
   in `input_fn` and `model_fn` by `params['batch_size']`.
-  For training, `model_fn` gets per-core batch size; `input_fn` may get
-  per-core or per-host batch size depending on
-  `per_host_input_for_training` in `TPUConfig`.
-  For evaluation, `model_fn` gets per-core batch size and `input_fn` get
-  per-host batch size.
+
+  - For training, `model_fn` gets per-core batch size; `input_fn` may get
+    per-core or per-host batch size depending on `per_host_input_for_training`
+    in `TPUConfig` (See docstring for TPUConfig for details).
+
+  - For evaluation and prediction, `model_fn` gets per-core batch size and
+    `input_fn` get per-host batch size.
+
+  Evaluation
+  ==========
 
   `model_fn` should return `TPUEstimatorSpec`, which expects the `eval_metrics`
   for TPU evaluation.
+
   `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`, where
   `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. (See
   `TPUEstimatorSpec` for details).  `metric_fn` takes the `tensors` and returns
@@ -1536,12 +1542,17 @@ class TPUEstimator(estimator_lib.Estimator):
   `train_batch_size` or `eval_batch_size` unmodified as `params['batch_size']`.
 
   Current limitations:
+  --------------------
+
+  1. TPU evaluation only works on a single host (one TPU worker).
 
-  1. TPU evaluation only works on single host.
-  2. `input_fn` for evaluation should not throw OutOfRange error for all
-  evaluation steps and all batches should have the same size.
+  2. `input_fn` for evaluation should **NOT** raise an end-of-input exception
+     (`OutOfRangeError` or `StopIteration`). And all evaluation steps and all
+     batches should have the same size.
 
   Example (MNIST):
+  ----------------
+
   ```
   # The metric Fn which runs on CPU.
   def metric_fn(labels, logits):
@@ -1577,8 +1588,120 @@ class TPUEstimator(estimator_lib.Estimator):
           }))
   ```
 
-  Predict support on TPU is not yet implemented. So, `predict` and
-  `export_savedmodel` are executed on CPU, even if `use_tpu` is true.
+  Prediction
+  ==========
+
+  Prediction on TPU is an experimental feature to support large batch inference.
+  It is not designed for latency-critical system. In addition, due to some
+  usability issues, for prediction with small dataset, CPU `.predict`, i.e.,
+  creating a new `TPUEstimator` instance with `use_tpu=False`, might be more
+  convenient.
+
+  Note: In contrast to TPU training/evaluation, the `input_fn` for prediction
+  *should* raise an end-of-input exception (`OutOfRangeError` or
+  `StopIteration`), which serves as the stopping signal to `TPUEstimator`. To be
+  precise, the ops created by `input_fn` produce one batch of the data.
+  The `predict()` API processes one batch at a time. When reaching the end of
+  the data source, an end-of-input exception should be raised by one of these
+  operations. The user usually does not need to do this manually. As long as the
+  dataset is not repeated forever, the `tf.data` API will raise an end-of-input
+  exception automatically after the last batch has been produced.
+
+  Note: Estimator.predict returns a Python generator. Please consume all the
+  data from the generator so that TPUEstimator can shutdown the TPU system
+  properly for user.
+
+  Current limitations:
+  --------------------
+  1. TPU prediction only works on a single host (one TPU worker).
+
+  2. `input_fn` must return a `Dataset` instance rather than `features`. In
+  fact, .train() and .evaluate() also support Dataset as return value.
+
+  3. Each batch returned by `Dataset`'s iterator must have the *same static*
+     shape. This means two things:
+     - batch_size cannot be `None`
+     - the final batch must be padded by user to a full batch.
+
+  Example (MNIST):
+  ----------------
+  ```
+  height = 32
+  width = 32
+  total_examples = 100
+
+  def predict_input_fn(params):
+    batch_size = params['batch_size']
+
+    images = tf.random_uniform(
+        [total_examples, height, width, 3], minval=-1, maxval=1)
+
+    dataset = tf.data.Dataset.from_tensor_slices(images)
+    dataset = dataset.batch(batch_size)
+    dataset = dataset.map(lambda images: {'image': images})
+
+    def pad(tensor, missing_count):
+        # Pads out the batch dimension to the complete batch_size.
+        rank = len(tensor.shape)
+        assert rank > 0
+        padding = tf.stack([[0, missing_count]] + [[0, 0]] * (rank - 1))
+        padded_shape = (batch_size,) + tuple(tensor.shape[1:])
+        padded_tensor = tf.pad(tensor, padding)
+        padded_tensor.set_shape(padded_shape)
+        return padded_tensor
+
+    def pad_batch_if_incomplete(batch_features):
+      # Pads out the batch dimension for all features.
+      real_batch_size = tf.shape(batch_features["image"])[0]
+
+      missing_count = tf.constant(batch_size, tf.int32) - real_batch_size
+
+      padded_features = {
+          key: pad(tensor, missing_count)
+          for key, tensor in batch_features.iteritems()
+      }
+      padding_mask = tf.concat(
+          [
+              tf.zeros((real_batch_size, 1), dtype=tf.int32),
+              tf.ones((missing_count, 1), dtype=tf.int32)
+          ],
+          axis=0)
+      padding_mask.set_shape((batch_size, 1))
+      padded_features["is_padding"] = padding_mask
+      return padded_features
+
+    dataset = dataset.map(pad_batch_if_incomplete)
+
+    return dataset
+
+  def model_fn(features, labels, params, mode):
+     # Generate predictions, called 'output', from features['image']
+
+    if mode == tf.estimator.ModeKeys.PREDICT:
+      return tf.contrib.tpu.TPUEstimatorSpec(
+          mode=mode,
+          predictions={
+              'predictions': output,
+              'is_padding': features['is_padding']
+          })
+
+  tpu_est = TPUEstimator(
+      model_fn=model_fn,
+      ...,
+      predict_batch_size=16)
+
+  # Fully consume the generator so that TPUEstimator can shutdown the TPU
+  # system.
+  for item in tpu_est.predict(input_fn=input_fn):
+    # Filter out item if the `is_padding` is 1.
+    # Process the 'predictions'
+  ```
+
+  Exporting
+  =========
+
+  Exporting `SavedModel` support on TPU is not yet implemented. So,
+  `export_savedmodel` is executed on CPU, even if `use_tpu` is true.
   """
 
   def __init__(self,