Adds final partial batch support for TPUEstimator.predict.
authorJianwei Xie <xiejw@google.com>
Tue, 20 Mar 2018 03:06:26 +0000 (20:06 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 20 Mar 2018 03:10:51 +0000 (20:10 -0700)
PiperOrigin-RevId: 189683528

tensorflow/contrib/tpu/BUILD
tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py [new file with mode: 0644]

index ed930e4..eea19e9 100644 (file)
@@ -271,6 +271,17 @@ tf_py_test(
     ],
 )
 
+tf_py_test(
+    name = "tpu_estimator_signals_test",
+    size = "small",
+    srcs = ["python/tpu/tpu_estimator_signals_test.py"],
+    additional_deps = [
+        ":tpu_estimator",
+        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_test_lib",
+    ],
+)
+
 filegroup(
     name = "all_files",
     srcs = glob(
index 32f15e6..5a8fa04 100644 (file)
@@ -49,6 +49,7 @@ from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import math_ops
@@ -62,6 +63,7 @@ from tensorflow.python.training import evaluation
 from tensorflow.python.training import session_run_hook
 from tensorflow.python.training import training
 from tensorflow.python.training import training_util
+from tensorflow.python.util import nest
 from tensorflow.python.util import tf_inspect
 
 _INITIAL_LOSS = 1e7
@@ -678,8 +680,11 @@ def generate_per_host_enqueue_ops_fn_for_host(
         raise TypeError(
             'For mode PREDICT, `input_fn` must return `Dataset` instead of '
             '`features` and `labels`.')
+      if batch_axis is not None:
+        raise TypeError('For mode PREDICT, batch_axis is not supported yet.')
       inputs = _InputsWithStoppingSignals(
-          dataset=inputs.dataset, batch_size=ctx.batch_size_for_input_fn)
+          dataset=inputs.dataset, batch_size=ctx.batch_size_for_input_fn,
+          add_padding=True)
 
     if is_dataset:
       hooks.append(inputs.dataset_initializer_hook())
@@ -1620,11 +1625,6 @@ class TPUEstimator(estimator_lib.Estimator):
   2. `input_fn` must return a `Dataset` instance rather than `features`. In
   fact, .train() and .evaluate() also support Dataset as return value.
 
-  3. Each batch returned by `Dataset`'s iterator must have the *same static*
-     shape. This means two things:
-     - batch_size cannot be `None`
-     - the final batch must be padded by user to a full batch.
-
   Example (MNIST):
   ----------------
   ```
@@ -1639,41 +1639,9 @@ class TPUEstimator(estimator_lib.Estimator):
         [total_examples, height, width, 3], minval=-1, maxval=1)
 
     dataset = tf.data.Dataset.from_tensor_slices(images)
-    dataset = dataset.batch(batch_size)
     dataset = dataset.map(lambda images: {'image': images})
 
-    def pad(tensor, missing_count):
-        # Pads out the batch dimension to the complete batch_size.
-        rank = len(tensor.shape)
-        assert rank > 0
-        padding = tf.stack([[0, missing_count]] + [[0, 0]] * (rank - 1))
-        padded_shape = (batch_size,) + tuple(tensor.shape[1:])
-        padded_tensor = tf.pad(tensor, padding)
-        padded_tensor.set_shape(padded_shape)
-        return padded_tensor
-
-    def pad_batch_if_incomplete(batch_features):
-      # Pads out the batch dimension for all features.
-      real_batch_size = tf.shape(batch_features["image"])[0]
-
-      missing_count = tf.constant(batch_size, tf.int32) - real_batch_size
-
-      padded_features = {
-          key: pad(tensor, missing_count)
-          for key, tensor in batch_features.iteritems()
-      }
-      padding_mask = tf.concat(
-          [
-              tf.zeros((real_batch_size, 1), dtype=tf.int32),
-              tf.ones((missing_count, 1), dtype=tf.int32)
-          ],
-          axis=0)
-      padding_mask.set_shape((batch_size, 1))
-      padded_features["is_padding"] = padding_mask
-      return padded_features
-
-    dataset = dataset.map(pad_batch_if_incomplete)
-
+    dataset = dataset.batch(batch_size)
     return dataset
 
   def model_fn(features, labels, params, mode):
@@ -2089,12 +2057,14 @@ class TPUEstimator(estimator_lib.Estimator):
             predictions, message=(
                 'The estimated size for TPUEstimatorSpec.predictions is too '
                 'large.'))
-        stopping_signals = host_call_ret['signals']
+        signals = host_call_ret['signals']
 
         with ops.control_dependencies(host_ops):
           host_ops = []  # Empty, we do do not need it anymore.
           scalar_stopping_signal = _StopSignals.as_scalar_stopping_signal(
-              stopping_signals)
+              signals)
+          predictions = _PaddingSignals.slice_tensor_or_dict(
+              predictions, signals)
 
         hooks = [
             _StoppingPredictHook(scalar_stopping_signal),
@@ -2389,20 +2359,19 @@ class _Inputs(object):
     return self._dataset
 
 
-# TODO(xiejw): Extend this to support final partial batch.
 class _InputsWithStoppingSignals(_Inputs):
   """Inputs with `_StopSignals` inserted into the dataset."""
 
-  def __init__(self, dataset, batch_size):
+  def __init__(self, dataset, batch_size, add_padding=False):
 
     assert dataset is not None
 
     user_provided_dataset = dataset.map(
         _InputsWithStoppingSignals.insert_stopping_signal(
-            stop=False, batch_size=batch_size))
+            stop=False, batch_size=batch_size, add_padding=add_padding))
     final_batch_dataset = dataset.take(1).map(
         _InputsWithStoppingSignals.insert_stopping_signal(
-            stop=True, batch_size=batch_size))
+            stop=True, batch_size=batch_size, add_padding=add_padding))
     dataset = user_provided_dataset.concatenate(final_batch_dataset).prefetch(2)
 
     super(_InputsWithStoppingSignals, self).__init__(dataset=dataset)
@@ -2432,7 +2401,7 @@ class _InputsWithStoppingSignals(_Inputs):
     return signals
 
   @staticmethod
-  def insert_stopping_signal(stop, batch_size):
+  def insert_stopping_signal(stop, batch_size, add_padding=False):
     """Inserts stopping_signal into dataset via _map_fn.
 
     Here we change the data structure in the dataset, such that the return value
@@ -2443,6 +2412,7 @@ class _InputsWithStoppingSignals(_Inputs):
     Args:
       stop: bool, state of current stopping signals.
       batch_size: int, batch size.
+      add_padding: bool, whether to pad the tensor to full batch size.
 
     Returns:
       A map_fn passed to dataset.map API.
@@ -2456,11 +2426,25 @@ class _InputsWithStoppingSignals(_Inputs):
         args = args[0]
       features, labels = _Inputs._parse_inputs(args)
       new_input_dict = {}
-      new_input_dict['features'] = features
-      if labels is not None:
-        new_input_dict['labels'] = labels
+
+      if add_padding:
+        padding_mask, features, labels = (
+            _PaddingSignals.pad_features_and_labels(
+                features, labels, batch_size))
+
+        new_input_dict['features'] = features
+        if labels is not None:
+          new_input_dict['labels'] = labels
+
+      else:
+        new_input_dict['features'] = features
+        if labels is not None:
+          new_input_dict['labels'] = labels
+        padding_mask = None
+
       new_input_dict['signals'] = _StopSignals(
-          stop=stop, batch_size=batch_size).as_dict()
+          stop=stop, batch_size=batch_size, padding_mask=padding_mask).as_dict()
+
       return new_input_dict
 
     return _map_fn
@@ -2469,23 +2453,28 @@ class _InputsWithStoppingSignals(_Inputs):
 class _StopSignals(object):
   """Signals class holding all logic to handle TPU stopping condition."""
 
-  NON_STOPPING_SIGNAL = 0.0
-  STOPPING_SIGNAL = 1.0
+  NON_STOPPING_SIGNAL = False
+  STOPPING_SIGNAL = True
 
-  def __init__(self, stop, batch_size):
+  def __init__(self, stop, batch_size, padding_mask=None):
     self._stop = stop
     self._batch_size = batch_size
+    self._padding_mask = padding_mask
 
   def as_dict(self):
+    """Returns the signals as Python dict."""
     shape = [self._batch_size, 1]
-    dtype = dtypes.float32
+    dtype = dtypes.bool
 
     if self._stop:
       stopping = array_ops.ones(shape=shape, dtype=dtype)
     else:
       stopping = array_ops.zeros(shape=shape, dtype=dtype)
 
-    return {'stopping': stopping}
+    signals = {'stopping': stopping}
+    if self._padding_mask is not None:
+      signals['padding_mask'] = self._padding_mask
+    return signals
 
   @staticmethod
   def as_scalar_stopping_signal(signals):
@@ -2493,7 +2482,118 @@ class _StopSignals(object):
 
   @staticmethod
   def should_stop(scalar_stopping_signal):
-    return scalar_stopping_signal >= _StopSignals.STOPPING_SIGNAL
+    if isinstance(scalar_stopping_signal, ops.Tensor):
+      # STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF
+      # way to express the bool check whether scalar_stopping_signal is True.
+      return math_ops.logical_and(
+          scalar_stopping_signal, _StopSignals.STOPPING_SIGNAL)
+    else:
+      # For non Tensor case, it is used in SessionRunHook. So, we cannot modify
+      # the graph anymore. Here, we use pure Python.
+      return bool(scalar_stopping_signal)
+
+
+class _PaddingSignals(object):
+  """Signals class holding all logic to handle padding."""
+
+  @staticmethod
+  def pad_features_and_labels(features, labels, batch_size):
+    """Pads out the batch dimension of features and labels."""
+    real_batch_size = array_ops.shape(
+        _PaddingSignals._find_any_tensor(features))[0]
+
+    batch_size_tensor = constant_op.constant(batch_size, dtypes.int32)
+
+    check_greater = check_ops.assert_greater_equal(
+        batch_size_tensor, real_batch_size,
+        data=(batch_size_tensor, real_batch_size),
+        message='The real batch size should not be greater than batch_size.')
+
+    with ops.control_dependencies([check_greater]):
+      missing_count = batch_size_tensor - real_batch_size
+
+    def pad_single_tensor(tensor):
+      """Pads out the batch dimension of a tensor to the complete batch_size."""
+      rank = len(tensor.shape)
+      assert rank > 0
+      padding = array_ops.stack([[0, missing_count]] + [[0, 0]] * (rank - 1))
+      padded_shape = (batch_size,) + tuple(tensor.shape[1:])
+      padded_tensor = array_ops.pad(tensor, padding)
+      padded_tensor.set_shape(padded_shape)
+      return padded_tensor
+
+    def nest_pad(tensor_or_dict):
+      return nest.map_structure(pad_single_tensor, tensor_or_dict)
+
+    features = nest_pad(features)
+    if labels is not None:
+      labels = nest_pad(labels)
+
+    padding_mask = _PaddingSignals._padding_mask(
+        real_batch_size, missing_count, batch_size)
+
+    return padding_mask, features, labels
+
+  @staticmethod
+  def slice_tensor_or_dict(tensor_or_dict, signals):
+    """Slice the real Tensors according to padding mask in signals."""
+
+    padding_mask = signals['padding_mask']
+    batch_size = array_ops.shape(padding_mask)[0]
+
+    def verify_batch_size(tensor):
+      check_batch_size = math_ops.equal(batch_size, tensor.shape[0])
+      with ops.control_dependencies([check_batch_size]):
+        return array_ops.identity(tensor)
+
+    def slice_single_tensor(tensor):
+      rank = len(tensor.shape)
+      assert rank > 0
+      real_batch_size = batch_size - math_ops.reduce_sum(padding_mask)
+      return verify_batch_size(tensor)[0:real_batch_size]
+
+    # As we split the Tensors to all TPU cores and concat them back, it is
+    # important to ensure the real data is placed before padded ones, i.e.,
+    # order is preserved. By that, the sliced padding mask should have all 0's.
+    # If this assertion failed, # the slice logic here would not hold.
+    sliced_padding_mask = slice_single_tensor(padding_mask)
+    assert_padding_mask = math_ops.equal(
+        math_ops.reduce_sum(sliced_padding_mask), 0)
+
+    with ops.control_dependencies([assert_padding_mask]):
+      should_stop = _StopSignals.should_stop(
+          _StopSignals.as_scalar_stopping_signal(signals))
+
+    is_full_batch = math_ops.equal(math_ops.reduce_sum(padding_mask), 0)
+
+    def slice_fn(tensor):
+      # If the current batch is full batch or part of stopping signals, we do
+      # not need to slice to save performance.
+      return control_flow_ops.cond(
+          math_ops.logical_or(should_stop, is_full_batch),
+          (lambda: verify_batch_size(tensor)),
+          (lambda: slice_single_tensor(tensor)))
+
+    return nest.map_structure(slice_fn, tensor_or_dict)
+
+  @staticmethod
+  def _find_any_tensor(batch_features):
+    tensors = [x for x in nest.flatten(batch_features)
+               if isinstance(x, ops.Tensor)]
+    if not tensors:
+      raise ValueError('Cannot find any Tensor in features dict.')
+    return tensors[0]
+
+  @staticmethod
+  def _padding_mask(real_batch_size, missing_count, batch_size):
+    padding_mask = array_ops.concat(
+        [
+            array_ops.zeros((real_batch_size,), dtype=dtypes.int32),
+            array_ops.ones((missing_count,), dtype=dtypes.int32)
+        ],
+        axis=0)
+    padding_mask.set_shape((batch_size,))
+    return padding_mask
 
 
 class _SignalsHelper(object):
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py
new file mode 100644 (file)
index 0000000..3e90957
--- /dev/null
@@ -0,0 +1,291 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TPU Estimator Signalling Tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.tpu.python.tpu import tpu_estimator
+from tensorflow.python import data as dataset_lib
+from tensorflow.python.client import session
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import test
+
+
+def make_input_fn(num_samples):
+  a = np.linspace(0, 100.0, num=num_samples)
+  b = np.reshape(np.array(a, dtype=np.float32), (len(a), 1))
+
+  def input_fn(params):
+    batch_size = params['batch_size']
+    da1 = dataset_lib.Dataset.from_tensor_slices(a)
+    da2 = dataset_lib.Dataset.from_tensor_slices(b)
+
+    dataset = dataset_lib.Dataset.zip((da1, da2))
+    dataset = dataset.map(lambda fa, fb: {'a': fa, 'b': fb})
+    dataset = dataset.batch(batch_size)
+    return dataset
+  return input_fn, (a, b)
+
+
+def make_input_fn_with_labels(num_samples):
+  a = np.linspace(0, 100.0, num=num_samples)
+  b = np.reshape(np.array(a, dtype=np.float32), (len(a), 1))
+
+  def input_fn(params):
+    batch_size = params['batch_size']
+    da1 = dataset_lib.Dataset.from_tensor_slices(a)
+    da2 = dataset_lib.Dataset.from_tensor_slices(b)
+
+    dataset = dataset_lib.Dataset.zip((da1, da2))
+    dataset = dataset.map(lambda fa, fb: ({'a': fa}, fb))
+    dataset = dataset.batch(batch_size)
+    return dataset
+  return input_fn, (a, b)
+
+
+class TPUEstimatorStoppingSignalsTest(test.TestCase):
+
+  def test_normal_output_without_signals(self):
+    num_samples = 4
+    batch_size = 2
+
+    params = {'batch_size': batch_size}
+    input_fn, (a, b) = make_input_fn(num_samples=num_samples)
+
+    with ops.Graph().as_default():
+      dataset = input_fn(params)
+      features = dataset.make_one_shot_iterator().get_next()
+
+      # With tf.data.Dataset.batch, the batch is None, i.e., dynamic shape.
+      self.assertIsNone(features['a'].shape.as_list()[0])
+
+      with session.Session() as sess:
+        result = sess.run(features)
+        self.assertAllEqual(a[:batch_size], result['a'])
+        self.assertAllEqual(b[:batch_size], result['b'])
+
+        # This run should work as num_samples / batch_size = 2.
+        result = sess.run(features)
+        self.assertAllEqual(a[batch_size:num_samples], result['a'])
+        self.assertAllEqual(b[batch_size:num_samples], result['b'])
+
+        with self.assertRaises(errors.OutOfRangeError):
+          # Given num_samples and batch_size, this run should fail.
+          sess.run(features)
+
+  def test_output_with_stopping_signals(self):
+    num_samples = 4
+    batch_size = 2
+
+    params = {'batch_size': batch_size}
+    input_fn, (a, b) = make_input_fn(num_samples=num_samples)
+
+    with ops.Graph().as_default():
+      dataset = input_fn(params)
+      inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size)
+      hook = inputs.dataset_initializer_hook()
+      features, _ = inputs.features_and_labels()
+      signals = inputs.signals()
+
+      # With tf.data.Dataset.batch, the batch is None, i.e., dynamic shape.
+      self.assertIsNone(features['a'].shape.as_list()[0])
+
+      with session.Session() as sess:
+        hook.begin()
+        hook.after_create_session(sess, coord=None)
+
+        result, evaluated_signals = sess.run([features, signals])
+        self.assertAllEqual(a[:batch_size], result['a'])
+        self.assertAllEqual(b[:batch_size], result['b'])
+        self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
+
+        # This run should work as num_samples / batch_size = 2.
+        result, evaluated_signals = sess.run([features, signals])
+        self.assertAllEqual(a[batch_size:num_samples], result['a'])
+        self.assertAllEqual(b[batch_size:num_samples], result['b'])
+        self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
+
+        # This run should work, *but* see STOP ('1') as signals
+        _, evaluated_signals = sess.run([features, signals])
+        self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
+
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(features)
+
+
+class TPUEstimatorStoppingSignalsWithPaddingTest(test.TestCase):
+
+  def test_num_samples_divisible_by_batch_size(self):
+    num_samples = 4
+    batch_size = 2
+
+    params = {'batch_size': batch_size}
+    input_fn, (a, b) = make_input_fn(num_samples=num_samples)
+
+    with ops.Graph().as_default():
+      dataset = input_fn(params)
+      inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size,
+                                                        add_padding=True)
+      hook = inputs.dataset_initializer_hook()
+      features, _ = inputs.features_and_labels()
+      signals = inputs.signals()
+
+      # With padding, all shapes are static now.
+      self.assertEqual(batch_size, features['a'].shape.as_list()[0])
+
+      with session.Session() as sess:
+        hook.begin()
+        hook.after_create_session(sess, coord=None)
+
+        result, evaluated_signals = sess.run([features, signals])
+        self.assertAllEqual(a[:batch_size], result['a'])
+        self.assertAllEqual(b[:batch_size], result['b'])
+        self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
+        self.assertAllEqual([0.] * batch_size,
+                            evaluated_signals['padding_mask'])
+
+        # This run should work as num_samples / batch_size = 2.
+        result, evaluated_signals = sess.run([features, signals])
+        self.assertAllEqual(a[batch_size:num_samples], result['a'])
+        self.assertAllEqual(b[batch_size:num_samples], result['b'])
+        self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
+        self.assertAllEqual([0.] * batch_size,
+                            evaluated_signals['padding_mask'])
+
+        # This run should work, *but* see STOP ('1') as signals
+        _, evaluated_signals = sess.run([features, signals])
+        self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
+
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(features)
+
+  def test_num_samples_not_divisible_by_batch_size(self):
+    num_samples = 5
+    batch_size = 2
+
+    params = {'batch_size': batch_size}
+    input_fn, (a, b) = make_input_fn_with_labels(num_samples=num_samples)
+
+    with ops.Graph().as_default():
+      dataset = input_fn(params)
+      inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size,
+                                                        add_padding=True)
+      hook = inputs.dataset_initializer_hook()
+      features, labels = inputs.features_and_labels()
+      signals = inputs.signals()
+
+      # With padding, all shapes are static.
+      self.assertEqual(batch_size, features['a'].shape.as_list()[0])
+
+      with session.Session() as sess:
+        hook.begin()
+        hook.after_create_session(sess, coord=None)
+
+        evaluated_features, evaluated_labels, evaluated_signals = (
+            sess.run([features, labels, signals]))
+        self.assertAllEqual(a[:batch_size], evaluated_features['a'])
+        self.assertAllEqual(b[:batch_size], evaluated_labels)
+        self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
+        self.assertAllEqual([0.] * batch_size,
+                            evaluated_signals['padding_mask'])
+
+        # This run should work as num_samples / batch_size >= 2.
+        evaluated_features, evaluated_labels, evaluated_signals = (
+            sess.run([features, labels, signals]))
+        self.assertAllEqual(a[batch_size:2*batch_size], evaluated_features['a'])
+        self.assertAllEqual(b[batch_size:2*batch_size], evaluated_labels)
+        self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
+        self.assertAllEqual([0.] * batch_size,
+                            evaluated_signals['padding_mask'])
+
+        # This is the final partial batch.
+        evaluated_features, evaluated_labels, evaluated_signals = (
+            sess.run([features, labels, signals]))
+        real_batch_size = num_samples % batch_size
+
+        # Assert the real part.
+        self.assertAllEqual(a[2*batch_size:num_samples],
+                            evaluated_features['a'][:real_batch_size])
+        self.assertAllEqual(b[2*batch_size:num_samples],
+                            evaluated_labels[:real_batch_size])
+        # Assert the padded part.
+        self.assertAllEqual([0.0] * (batch_size - real_batch_size),
+                            evaluated_features['a'][real_batch_size:])
+        self.assertAllEqual([[0.0]] * (batch_size - real_batch_size),
+                            evaluated_labels[real_batch_size:])
+
+        self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
+
+        padding = ([.0] * real_batch_size
+                   + [1.] * (batch_size - real_batch_size))
+        self.assertAllEqual(padding, evaluated_signals['padding_mask'])
+
+        # This run should work, *but* see STOP ('1') as signals
+        _, evaluated_signals = sess.run([features, signals])
+        self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
+
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(features)
+
+  def test_slice(self):
+    num_samples = 3
+    batch_size = 2
+
+    params = {'batch_size': batch_size}
+    input_fn, (a, b) = make_input_fn(num_samples=num_samples)
+
+    with ops.Graph().as_default():
+      dataset = input_fn(params)
+      inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size,
+                                                        add_padding=True)
+      hook = inputs.dataset_initializer_hook()
+      features, _ = inputs.features_and_labels()
+      signals = inputs.signals()
+
+      sliced_features = (
+          tpu_estimator._PaddingSignals.slice_tensor_or_dict(
+              features, signals))
+
+      with session.Session() as sess:
+        hook.begin()
+        hook.after_create_session(sess, coord=None)
+
+        result, evaluated_signals = sess.run([sliced_features, signals])
+        self.assertAllEqual(a[:batch_size], result['a'])
+        self.assertAllEqual(b[:batch_size], result['b'])
+        self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
+
+        # This is the final partial batch.
+        result, evaluated_signals = sess.run([sliced_features, signals])
+        self.assertEqual(1, len(result['a']))
+        self.assertAllEqual(a[batch_size:num_samples], result['a'])
+        self.assertAllEqual(b[batch_size:num_samples], result['b'])
+        self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
+
+        # This run should work, *but* see STOP ('1') as signals
+        _, evaluated_signals = sess.run([sliced_features, signals])
+        self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
+
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(sliced_features)
+
+
+if __name__ == '__main__':
+  test.main()