from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import script_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
lambda: build_dataset(seq_lens2), 8)
+class RestructuredDatasetTest(test.TestCase):
+
+ def test_assert_element_shape(self):
+
+ def create_unknown_shape_dataset(x):
+ return script_ops.py_func(lambda _: (np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
+ [x],
+ [dtypes.float32, dtypes.int32])
+
+ dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset)
+ unknown_shapes = (tensor_shape.TensorShape(None),
+ tensor_shape.TensorShape(None))
+ self.assertEqual(unknown_shapes, dataset.output_shapes)
+
+ expected_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 4)))
+ result = dataset.apply(batching.assert_element_shape(expected_shapes))
+ self.assertEqual(expected_shapes, result.output_shapes)
+
+ iterator = result.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.test_session() as sess:
+ sess.run(init_op)
+ for _ in range(5):
+ sess.run(get_next)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def test_assert_wrong_element_shape(self):
+
+ def create_dataset(_):
+ return (array_ops.ones(2, dtype=dtypes.float32),
+ array_ops.zeros((3, 4), dtype=dtypes.int32))
+
+ dataset = dataset_ops.Dataset.range(3).map(create_dataset)
+ wrong_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 10)))
+ with self.assertRaises(ValueError):
+ dataset.apply(batching.assert_element_shape(wrong_shapes))
+
+ def test_assert_wrong_element_shape_on_unknown_shape_dataset(self):
+
+ def create_unknown_shape_dataset(x):
+ return script_ops.py_func(lambda _: (np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
+ [x],
+ [dtypes.float32, dtypes.int32])
+
+ dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset)
+ unknown_shapes = (tensor_shape.TensorShape(None),
+ tensor_shape.TensorShape(None))
+ self.assertEqual(unknown_shapes, dataset.output_shapes)
+
+ wrong_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 10)))
+ iterator = (
+ dataset.apply(batching.assert_element_shape(wrong_shapes))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.test_session() as sess:
+ sess.run(init_op)
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(get_next)
+
+
if __name__ == "__main__":
test.main()
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.framework import with_shape
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
return self._output_shapes
+def assert_element_shape(expected_shapes):
+ """Assert the shape of this `Dataset`.
+
+ ```python
+ shapes = [tf.TensorShape([16, 256]), tf.TensorShape(None)]
+ result = dataset.apply(tf.contrib.data.assert_element_shape(shapes))
+ print(result.output_shapes) # ==> "((16, 256), <unknown>)"
+ ```
+
+ If dataset shapes and expected_shape, are fully defined, assert they match.
+ Otherwise, add assert op that will validate the shapes when tensors are
+ evaluated, and set shapes on tensors, respectively.
+
+ Args:
+ expected_shapes: A nested structure of `tf.TensorShape` objects.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ @{tf.data.Dataset.apply}
+ """
+
+ def _check_shape(*elements):
+ flatten_tensors = nest.flatten(elements)
+ flatten_shapes = nest.flatten(expected_shapes)
+ checked_tensors = [with_shape(shape, tensor)
+ for shape, tensor in zip(flatten_shapes,
+ flatten_tensors)]
+ return nest.pack_sequence_as(elements, checked_tensors)
+
+ def _apply_fn(dataset):
+ return _RestructuredDataset(
+ dataset.map(_check_shape),
+ dataset.output_types,
+ output_shapes=expected_shapes,
+ output_classes=dataset.output_classes)
+
+ return _apply_fn
+
+
class _MapAndBatchDataset(dataset_ops.MapDataset):
"""A `Dataset` that maps a function over a batch of elements."""