from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training
from tensorflow.python.training import training_util
+from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
_INITIAL_LOSS = 1e7
raise TypeError(
'For mode PREDICT, `input_fn` must return `Dataset` instead of '
'`features` and `labels`.')
+ if batch_axis is not None:
+ raise TypeError('For mode PREDICT, batch_axis is not supported yet.')
inputs = _InputsWithStoppingSignals(
- dataset=inputs.dataset, batch_size=ctx.batch_size_for_input_fn)
+ dataset=inputs.dataset, batch_size=ctx.batch_size_for_input_fn,
+ add_padding=True)
if is_dataset:
hooks.append(inputs.dataset_initializer_hook())
2. `input_fn` must return a `Dataset` instance rather than `features`. In
fact, .train() and .evaluate() also support Dataset as return value.
- 3. Each batch returned by `Dataset`'s iterator must have the *same static*
- shape. This means two things:
- - batch_size cannot be `None`
- - the final batch must be padded by user to a full batch.
-
Example (MNIST):
----------------
```
[total_examples, height, width, 3], minval=-1, maxval=1)
dataset = tf.data.Dataset.from_tensor_slices(images)
- dataset = dataset.batch(batch_size)
dataset = dataset.map(lambda images: {'image': images})
- def pad(tensor, missing_count):
- # Pads out the batch dimension to the complete batch_size.
- rank = len(tensor.shape)
- assert rank > 0
- padding = tf.stack([[0, missing_count]] + [[0, 0]] * (rank - 1))
- padded_shape = (batch_size,) + tuple(tensor.shape[1:])
- padded_tensor = tf.pad(tensor, padding)
- padded_tensor.set_shape(padded_shape)
- return padded_tensor
-
- def pad_batch_if_incomplete(batch_features):
- # Pads out the batch dimension for all features.
- real_batch_size = tf.shape(batch_features["image"])[0]
-
- missing_count = tf.constant(batch_size, tf.int32) - real_batch_size
-
- padded_features = {
- key: pad(tensor, missing_count)
- for key, tensor in batch_features.iteritems()
- }
- padding_mask = tf.concat(
- [
- tf.zeros((real_batch_size, 1), dtype=tf.int32),
- tf.ones((missing_count, 1), dtype=tf.int32)
- ],
- axis=0)
- padding_mask.set_shape((batch_size, 1))
- padded_features["is_padding"] = padding_mask
- return padded_features
-
- dataset = dataset.map(pad_batch_if_incomplete)
-
+ dataset = dataset.batch(batch_size)
return dataset
def model_fn(features, labels, params, mode):
predictions, message=(
'The estimated size for TPUEstimatorSpec.predictions is too '
'large.'))
- stopping_signals = host_call_ret['signals']
+ signals = host_call_ret['signals']
with ops.control_dependencies(host_ops):
host_ops = [] # Empty, we do do not need it anymore.
scalar_stopping_signal = _StopSignals.as_scalar_stopping_signal(
- stopping_signals)
+ signals)
+ predictions = _PaddingSignals.slice_tensor_or_dict(
+ predictions, signals)
hooks = [
_StoppingPredictHook(scalar_stopping_signal),
return self._dataset
-# TODO(xiejw): Extend this to support final partial batch.
class _InputsWithStoppingSignals(_Inputs):
"""Inputs with `_StopSignals` inserted into the dataset."""
- def __init__(self, dataset, batch_size):
+ def __init__(self, dataset, batch_size, add_padding=False):
assert dataset is not None
user_provided_dataset = dataset.map(
_InputsWithStoppingSignals.insert_stopping_signal(
- stop=False, batch_size=batch_size))
+ stop=False, batch_size=batch_size, add_padding=add_padding))
final_batch_dataset = dataset.take(1).map(
_InputsWithStoppingSignals.insert_stopping_signal(
- stop=True, batch_size=batch_size))
+ stop=True, batch_size=batch_size, add_padding=add_padding))
dataset = user_provided_dataset.concatenate(final_batch_dataset).prefetch(2)
super(_InputsWithStoppingSignals, self).__init__(dataset=dataset)
return signals
@staticmethod
- def insert_stopping_signal(stop, batch_size):
+ def insert_stopping_signal(stop, batch_size, add_padding=False):
"""Inserts stopping_signal into dataset via _map_fn.
Here we change the data structure in the dataset, such that the return value
Args:
stop: bool, state of current stopping signals.
batch_size: int, batch size.
+ add_padding: bool, whether to pad the tensor to full batch size.
Returns:
A map_fn passed to dataset.map API.
args = args[0]
features, labels = _Inputs._parse_inputs(args)
new_input_dict = {}
- new_input_dict['features'] = features
- if labels is not None:
- new_input_dict['labels'] = labels
+
+ if add_padding:
+ padding_mask, features, labels = (
+ _PaddingSignals.pad_features_and_labels(
+ features, labels, batch_size))
+
+ new_input_dict['features'] = features
+ if labels is not None:
+ new_input_dict['labels'] = labels
+
+ else:
+ new_input_dict['features'] = features
+ if labels is not None:
+ new_input_dict['labels'] = labels
+ padding_mask = None
+
new_input_dict['signals'] = _StopSignals(
- stop=stop, batch_size=batch_size).as_dict()
+ stop=stop, batch_size=batch_size, padding_mask=padding_mask).as_dict()
+
return new_input_dict
return _map_fn
class _StopSignals(object):
"""Signals class holding all logic to handle TPU stopping condition."""
- NON_STOPPING_SIGNAL = 0.0
- STOPPING_SIGNAL = 1.0
+ NON_STOPPING_SIGNAL = False
+ STOPPING_SIGNAL = True
- def __init__(self, stop, batch_size):
+ def __init__(self, stop, batch_size, padding_mask=None):
self._stop = stop
self._batch_size = batch_size
+ self._padding_mask = padding_mask
def as_dict(self):
+ """Returns the signals as Python dict."""
shape = [self._batch_size, 1]
- dtype = dtypes.float32
+ dtype = dtypes.bool
if self._stop:
stopping = array_ops.ones(shape=shape, dtype=dtype)
else:
stopping = array_ops.zeros(shape=shape, dtype=dtype)
- return {'stopping': stopping}
+ signals = {'stopping': stopping}
+ if self._padding_mask is not None:
+ signals['padding_mask'] = self._padding_mask
+ return signals
@staticmethod
def as_scalar_stopping_signal(signals):
@staticmethod
def should_stop(scalar_stopping_signal):
- return scalar_stopping_signal >= _StopSignals.STOPPING_SIGNAL
+ if isinstance(scalar_stopping_signal, ops.Tensor):
+ # STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF
+ # way to express the bool check whether scalar_stopping_signal is True.
+ return math_ops.logical_and(
+ scalar_stopping_signal, _StopSignals.STOPPING_SIGNAL)
+ else:
+ # For non Tensor case, it is used in SessionRunHook. So, we cannot modify
+ # the graph anymore. Here, we use pure Python.
+ return bool(scalar_stopping_signal)
+
+
+class _PaddingSignals(object):
+ """Signals class holding all logic to handle padding."""
+
+ @staticmethod
+ def pad_features_and_labels(features, labels, batch_size):
+ """Pads out the batch dimension of features and labels."""
+ real_batch_size = array_ops.shape(
+ _PaddingSignals._find_any_tensor(features))[0]
+
+ batch_size_tensor = constant_op.constant(batch_size, dtypes.int32)
+
+ check_greater = check_ops.assert_greater_equal(
+ batch_size_tensor, real_batch_size,
+ data=(batch_size_tensor, real_batch_size),
+ message='The real batch size should not be greater than batch_size.')
+
+ with ops.control_dependencies([check_greater]):
+ missing_count = batch_size_tensor - real_batch_size
+
+ def pad_single_tensor(tensor):
+ """Pads out the batch dimension of a tensor to the complete batch_size."""
+ rank = len(tensor.shape)
+ assert rank > 0
+ padding = array_ops.stack([[0, missing_count]] + [[0, 0]] * (rank - 1))
+ padded_shape = (batch_size,) + tuple(tensor.shape[1:])
+ padded_tensor = array_ops.pad(tensor, padding)
+ padded_tensor.set_shape(padded_shape)
+ return padded_tensor
+
+ def nest_pad(tensor_or_dict):
+ return nest.map_structure(pad_single_tensor, tensor_or_dict)
+
+ features = nest_pad(features)
+ if labels is not None:
+ labels = nest_pad(labels)
+
+ padding_mask = _PaddingSignals._padding_mask(
+ real_batch_size, missing_count, batch_size)
+
+ return padding_mask, features, labels
+
+ @staticmethod
+ def slice_tensor_or_dict(tensor_or_dict, signals):
+ """Slice the real Tensors according to padding mask in signals."""
+
+ padding_mask = signals['padding_mask']
+ batch_size = array_ops.shape(padding_mask)[0]
+
+ def verify_batch_size(tensor):
+ check_batch_size = math_ops.equal(batch_size, tensor.shape[0])
+ with ops.control_dependencies([check_batch_size]):
+ return array_ops.identity(tensor)
+
+ def slice_single_tensor(tensor):
+ rank = len(tensor.shape)
+ assert rank > 0
+ real_batch_size = batch_size - math_ops.reduce_sum(padding_mask)
+ return verify_batch_size(tensor)[0:real_batch_size]
+
+ # As we split the Tensors to all TPU cores and concat them back, it is
+ # important to ensure the real data is placed before padded ones, i.e.,
+ # order is preserved. By that, the sliced padding mask should have all 0's.
+ # If this assertion failed, # the slice logic here would not hold.
+ sliced_padding_mask = slice_single_tensor(padding_mask)
+ assert_padding_mask = math_ops.equal(
+ math_ops.reduce_sum(sliced_padding_mask), 0)
+
+ with ops.control_dependencies([assert_padding_mask]):
+ should_stop = _StopSignals.should_stop(
+ _StopSignals.as_scalar_stopping_signal(signals))
+
+ is_full_batch = math_ops.equal(math_ops.reduce_sum(padding_mask), 0)
+
+ def slice_fn(tensor):
+ # If the current batch is full batch or part of stopping signals, we do
+ # not need to slice to save performance.
+ return control_flow_ops.cond(
+ math_ops.logical_or(should_stop, is_full_batch),
+ (lambda: verify_batch_size(tensor)),
+ (lambda: slice_single_tensor(tensor)))
+
+ return nest.map_structure(slice_fn, tensor_or_dict)
+
+ @staticmethod
+ def _find_any_tensor(batch_features):
+ tensors = [x for x in nest.flatten(batch_features)
+ if isinstance(x, ops.Tensor)]
+ if not tensors:
+ raise ValueError('Cannot find any Tensor in features dict.')
+ return tensors[0]
+
+ @staticmethod
+ def _padding_mask(real_batch_size, missing_count, batch_size):
+ padding_mask = array_ops.concat(
+ [
+ array_ops.zeros((real_batch_size,), dtype=dtypes.int32),
+ array_ops.ones((missing_count,), dtype=dtypes.int32)
+ ],
+ axis=0)
+ padding_mask.set_shape((batch_size,))
+ return padding_mask
class _SignalsHelper(object):
--- /dev/null
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TPU Estimator Signalling Tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.tpu.python.tpu import tpu_estimator
+from tensorflow.python import data as dataset_lib
+from tensorflow.python.client import session
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import test
+
+
+def make_input_fn(num_samples):
+ a = np.linspace(0, 100.0, num=num_samples)
+ b = np.reshape(np.array(a, dtype=np.float32), (len(a), 1))
+
+ def input_fn(params):
+ batch_size = params['batch_size']
+ da1 = dataset_lib.Dataset.from_tensor_slices(a)
+ da2 = dataset_lib.Dataset.from_tensor_slices(b)
+
+ dataset = dataset_lib.Dataset.zip((da1, da2))
+ dataset = dataset.map(lambda fa, fb: {'a': fa, 'b': fb})
+ dataset = dataset.batch(batch_size)
+ return dataset
+ return input_fn, (a, b)
+
+
+def make_input_fn_with_labels(num_samples):
+ a = np.linspace(0, 100.0, num=num_samples)
+ b = np.reshape(np.array(a, dtype=np.float32), (len(a), 1))
+
+ def input_fn(params):
+ batch_size = params['batch_size']
+ da1 = dataset_lib.Dataset.from_tensor_slices(a)
+ da2 = dataset_lib.Dataset.from_tensor_slices(b)
+
+ dataset = dataset_lib.Dataset.zip((da1, da2))
+ dataset = dataset.map(lambda fa, fb: ({'a': fa}, fb))
+ dataset = dataset.batch(batch_size)
+ return dataset
+ return input_fn, (a, b)
+
+
+class TPUEstimatorStoppingSignalsTest(test.TestCase):
+
+ def test_normal_output_without_signals(self):
+ num_samples = 4
+ batch_size = 2
+
+ params = {'batch_size': batch_size}
+ input_fn, (a, b) = make_input_fn(num_samples=num_samples)
+
+ with ops.Graph().as_default():
+ dataset = input_fn(params)
+ features = dataset.make_one_shot_iterator().get_next()
+
+ # With tf.data.Dataset.batch, the batch is None, i.e., dynamic shape.
+ self.assertIsNone(features['a'].shape.as_list()[0])
+
+ with session.Session() as sess:
+ result = sess.run(features)
+ self.assertAllEqual(a[:batch_size], result['a'])
+ self.assertAllEqual(b[:batch_size], result['b'])
+
+ # This run should work as num_samples / batch_size = 2.
+ result = sess.run(features)
+ self.assertAllEqual(a[batch_size:num_samples], result['a'])
+ self.assertAllEqual(b[batch_size:num_samples], result['b'])
+
+ with self.assertRaises(errors.OutOfRangeError):
+ # Given num_samples and batch_size, this run should fail.
+ sess.run(features)
+
+ def test_output_with_stopping_signals(self):
+ num_samples = 4
+ batch_size = 2
+
+ params = {'batch_size': batch_size}
+ input_fn, (a, b) = make_input_fn(num_samples=num_samples)
+
+ with ops.Graph().as_default():
+ dataset = input_fn(params)
+ inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size)
+ hook = inputs.dataset_initializer_hook()
+ features, _ = inputs.features_and_labels()
+ signals = inputs.signals()
+
+ # With tf.data.Dataset.batch, the batch is None, i.e., dynamic shape.
+ self.assertIsNone(features['a'].shape.as_list()[0])
+
+ with session.Session() as sess:
+ hook.begin()
+ hook.after_create_session(sess, coord=None)
+
+ result, evaluated_signals = sess.run([features, signals])
+ self.assertAllEqual(a[:batch_size], result['a'])
+ self.assertAllEqual(b[:batch_size], result['b'])
+ self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
+
+ # This run should work as num_samples / batch_size = 2.
+ result, evaluated_signals = sess.run([features, signals])
+ self.assertAllEqual(a[batch_size:num_samples], result['a'])
+ self.assertAllEqual(b[batch_size:num_samples], result['b'])
+ self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
+
+ # This run should work, *but* see STOP ('1') as signals
+ _, evaluated_signals = sess.run([features, signals])
+ self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(features)
+
+
+class TPUEstimatorStoppingSignalsWithPaddingTest(test.TestCase):
+
+ def test_num_samples_divisible_by_batch_size(self):
+ num_samples = 4
+ batch_size = 2
+
+ params = {'batch_size': batch_size}
+ input_fn, (a, b) = make_input_fn(num_samples=num_samples)
+
+ with ops.Graph().as_default():
+ dataset = input_fn(params)
+ inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size,
+ add_padding=True)
+ hook = inputs.dataset_initializer_hook()
+ features, _ = inputs.features_and_labels()
+ signals = inputs.signals()
+
+ # With padding, all shapes are static now.
+ self.assertEqual(batch_size, features['a'].shape.as_list()[0])
+
+ with session.Session() as sess:
+ hook.begin()
+ hook.after_create_session(sess, coord=None)
+
+ result, evaluated_signals = sess.run([features, signals])
+ self.assertAllEqual(a[:batch_size], result['a'])
+ self.assertAllEqual(b[:batch_size], result['b'])
+ self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
+ self.assertAllEqual([0.] * batch_size,
+ evaluated_signals['padding_mask'])
+
+ # This run should work as num_samples / batch_size = 2.
+ result, evaluated_signals = sess.run([features, signals])
+ self.assertAllEqual(a[batch_size:num_samples], result['a'])
+ self.assertAllEqual(b[batch_size:num_samples], result['b'])
+ self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
+ self.assertAllEqual([0.] * batch_size,
+ evaluated_signals['padding_mask'])
+
+ # This run should work, *but* see STOP ('1') as signals
+ _, evaluated_signals = sess.run([features, signals])
+ self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(features)
+
+ def test_num_samples_not_divisible_by_batch_size(self):
+ num_samples = 5
+ batch_size = 2
+
+ params = {'batch_size': batch_size}
+ input_fn, (a, b) = make_input_fn_with_labels(num_samples=num_samples)
+
+ with ops.Graph().as_default():
+ dataset = input_fn(params)
+ inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size,
+ add_padding=True)
+ hook = inputs.dataset_initializer_hook()
+ features, labels = inputs.features_and_labels()
+ signals = inputs.signals()
+
+ # With padding, all shapes are static.
+ self.assertEqual(batch_size, features['a'].shape.as_list()[0])
+
+ with session.Session() as sess:
+ hook.begin()
+ hook.after_create_session(sess, coord=None)
+
+ evaluated_features, evaluated_labels, evaluated_signals = (
+ sess.run([features, labels, signals]))
+ self.assertAllEqual(a[:batch_size], evaluated_features['a'])
+ self.assertAllEqual(b[:batch_size], evaluated_labels)
+ self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
+ self.assertAllEqual([0.] * batch_size,
+ evaluated_signals['padding_mask'])
+
+ # This run should work as num_samples / batch_size >= 2.
+ evaluated_features, evaluated_labels, evaluated_signals = (
+ sess.run([features, labels, signals]))
+ self.assertAllEqual(a[batch_size:2*batch_size], evaluated_features['a'])
+ self.assertAllEqual(b[batch_size:2*batch_size], evaluated_labels)
+ self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
+ self.assertAllEqual([0.] * batch_size,
+ evaluated_signals['padding_mask'])
+
+ # This is the final partial batch.
+ evaluated_features, evaluated_labels, evaluated_signals = (
+ sess.run([features, labels, signals]))
+ real_batch_size = num_samples % batch_size
+
+ # Assert the real part.
+ self.assertAllEqual(a[2*batch_size:num_samples],
+ evaluated_features['a'][:real_batch_size])
+ self.assertAllEqual(b[2*batch_size:num_samples],
+ evaluated_labels[:real_batch_size])
+ # Assert the padded part.
+ self.assertAllEqual([0.0] * (batch_size - real_batch_size),
+ evaluated_features['a'][real_batch_size:])
+ self.assertAllEqual([[0.0]] * (batch_size - real_batch_size),
+ evaluated_labels[real_batch_size:])
+
+ self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
+
+ padding = ([.0] * real_batch_size
+ + [1.] * (batch_size - real_batch_size))
+ self.assertAllEqual(padding, evaluated_signals['padding_mask'])
+
+ # This run should work, *but* see STOP ('1') as signals
+ _, evaluated_signals = sess.run([features, signals])
+ self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(features)
+
+ def test_slice(self):
+ num_samples = 3
+ batch_size = 2
+
+ params = {'batch_size': batch_size}
+ input_fn, (a, b) = make_input_fn(num_samples=num_samples)
+
+ with ops.Graph().as_default():
+ dataset = input_fn(params)
+ inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size,
+ add_padding=True)
+ hook = inputs.dataset_initializer_hook()
+ features, _ = inputs.features_and_labels()
+ signals = inputs.signals()
+
+ sliced_features = (
+ tpu_estimator._PaddingSignals.slice_tensor_or_dict(
+ features, signals))
+
+ with session.Session() as sess:
+ hook.begin()
+ hook.after_create_session(sess, coord=None)
+
+ result, evaluated_signals = sess.run([sliced_features, signals])
+ self.assertAllEqual(a[:batch_size], result['a'])
+ self.assertAllEqual(b[:batch_size], result['b'])
+ self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
+
+ # This is the final partial batch.
+ result, evaluated_signals = sess.run([sliced_features, signals])
+ self.assertEqual(1, len(result['a']))
+ self.assertAllEqual(a[batch_size:num_samples], result['a'])
+ self.assertAllEqual(b[batch_size:num_samples], result['b'])
+ self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
+
+ # This run should work, *but* see STOP ('1') as signals
+ _, evaluated_signals = sess.run([sliced_features, signals])
+ self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(sliced_features)
+
+
+if __name__ == '__main__':
+ test.main()