add assert_element_shape method for tf.contrib.data (#17480)
authorYan Facai (颜发才) <facai.yan@gmail.com>
Fri, 6 Apr 2018 02:51:12 +0000 (10:51 +0800)
committerDerek Murray <derek.murray@gmail.com>
Fri, 6 Apr 2018 02:51:12 +0000 (19:51 -0700)
* ENH: add assert_element_shape method

* CLN: add indentation

* ENH: raise exception when wrong shape is given

* CLN: fix too long line

tensorflow/contrib/data/__init__.py
tensorflow/contrib/data/python/kernel_tests/BUILD
tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
tensorflow/contrib/data/python/ops/BUILD
tensorflow/contrib/data/python/ops/batching.py

index 1704831..125260b 100644 (file)
@@ -25,6 +25,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview.
 @@Counter
 @@SqlDataset
 
+@@assert_element_shape
 @@batch_and_drop_remainder
 @@bucket_by_sequence_length
 @@dense_to_sparse_batch
@@ -55,6 +56,7 @@ from __future__ import print_function
 
 # pylint: disable=unused-import
 
+from tensorflow.contrib.data.python.ops.batching import assert_element_shape
 from tensorflow.contrib.data.python.ops.batching import batch_and_drop_remainder
 from tensorflow.contrib.data.python.ops.batching import dense_to_sparse_batch
 from tensorflow.contrib.data.python.ops.batching import map_and_batch
index c8699e0..7270d53 100644 (file)
@@ -22,6 +22,7 @@ py_test(
         "//tensorflow/python:dtypes",
         "//tensorflow/python:errors",
         "//tensorflow/python:math_ops",
+        "//tensorflow/python:script_ops",
         "//tensorflow/python:sparse_tensor",
         "//tensorflow/python:string_ops",
         "//tensorflow/python:tensor_shape",
index 75482f6..413d873 100644 (file)
@@ -28,8 +28,10 @@ from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import script_ops
 from tensorflow.python.ops import string_ops
 from tensorflow.python.platform import test
 
@@ -579,5 +581,73 @@ class PaddedBatchDatasetSerializationTest(
                         lambda: build_dataset(seq_lens2), 8)
 
 
+class RestructuredDatasetTest(test.TestCase):
+
+  def test_assert_element_shape(self):
+
+    def create_unknown_shape_dataset(x):
+      return script_ops.py_func(lambda _: (np.ones(2, dtype=np.float32),
+                                           np.zeros((3, 4), dtype=np.int32)),
+                                [x],
+                                [dtypes.float32, dtypes.int32])
+
+    dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset)
+    unknown_shapes = (tensor_shape.TensorShape(None),
+                      tensor_shape.TensorShape(None))
+    self.assertEqual(unknown_shapes, dataset.output_shapes)
+
+    expected_shapes = (tensor_shape.TensorShape(2),
+                       tensor_shape.TensorShape((3, 4)))
+    result = dataset.apply(batching.assert_element_shape(expected_shapes))
+    self.assertEqual(expected_shapes, result.output_shapes)
+
+    iterator = result.make_initializable_iterator()
+    init_op = iterator.initializer
+    get_next = iterator.get_next()
+    with self.test_session() as sess:
+      sess.run(init_op)
+      for _ in range(5):
+        sess.run(get_next)
+      with self.assertRaises(errors.OutOfRangeError):
+        sess.run(get_next)
+
+  def test_assert_wrong_element_shape(self):
+
+    def create_dataset(_):
+      return (array_ops.ones(2, dtype=dtypes.float32),
+              array_ops.zeros((3, 4), dtype=dtypes.int32))
+
+    dataset = dataset_ops.Dataset.range(3).map(create_dataset)
+    wrong_shapes = (tensor_shape.TensorShape(2),
+                    tensor_shape.TensorShape((3, 10)))
+    with self.assertRaises(ValueError):
+      dataset.apply(batching.assert_element_shape(wrong_shapes))
+
+  def test_assert_wrong_element_shape_on_unknown_shape_dataset(self):
+
+    def create_unknown_shape_dataset(x):
+      return script_ops.py_func(lambda _: (np.ones(2, dtype=np.float32),
+                                           np.zeros((3, 4), dtype=np.int32)),
+                                [x],
+                                [dtypes.float32, dtypes.int32])
+
+    dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset)
+    unknown_shapes = (tensor_shape.TensorShape(None),
+                      tensor_shape.TensorShape(None))
+    self.assertEqual(unknown_shapes, dataset.output_shapes)
+
+    wrong_shapes = (tensor_shape.TensorShape(2),
+                    tensor_shape.TensorShape((3, 10)))
+    iterator = (
+        dataset.apply(batching.assert_element_shape(wrong_shapes))
+        .make_initializable_iterator())
+    init_op = iterator.initializer
+    get_next = iterator.get_next()
+    with self.test_session() as sess:
+      sess.run(init_op)
+      with self.assertRaises(errors.InvalidArgumentError):
+        sess.run(get_next)
+
+
 if __name__ == "__main__":
   test.main()
index 236792b..a1a5c9e 100644 (file)
@@ -119,6 +119,7 @@ py_library(
     deps = [
         ":contrib_op_loader",
         ":gen_dataset_ops",
+        "//tensorflow/contrib/framework:framework_py",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:control_flow_ops",
         "//tensorflow/python:dataset_ops_gen",
index a212adf..1eba010 100644 (file)
@@ -17,6 +17,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.contrib.framework import with_shape
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.util import nest
 from tensorflow.python.data.util import sparse
@@ -345,6 +346,45 @@ class _RestructuredDataset(dataset_ops.Dataset):
     return self._output_shapes
 
 
+def assert_element_shape(expected_shapes):
+  """Assert the shape of this `Dataset`.
+
+  ```python
+  shapes = [tf.TensorShape([16, 256]), tf.TensorShape(None)]
+  result = dataset.apply(tf.contrib.data.assert_element_shape(shapes))
+  print(result.output_shapes)  # ==> "((16, 256), <unknown>)"
+  ```
+
+  If dataset shapes and expected_shape, are fully defined, assert they match.
+  Otherwise, add assert op that will validate the shapes when tensors are
+  evaluated, and set shapes on tensors, respectively.
+
+  Args:
+    expected_shapes: A nested structure of `tf.TensorShape` objects.
+
+  Returns:
+    A `Dataset` transformation function, which can be passed to
+    @{tf.data.Dataset.apply}
+  """
+
+  def _check_shape(*elements):
+    flatten_tensors = nest.flatten(elements)
+    flatten_shapes = nest.flatten(expected_shapes)
+    checked_tensors = [with_shape(shape, tensor)
+                       for shape, tensor in zip(flatten_shapes,
+                                                flatten_tensors)]
+    return nest.pack_sequence_as(elements, checked_tensors)
+
+  def _apply_fn(dataset):
+    return _RestructuredDataset(
+        dataset.map(_check_shape),
+        dataset.output_types,
+        output_shapes=expected_shapes,
+        output_classes=dataset.output_classes)
+
+  return _apply_fn
+
+
 class _MapAndBatchDataset(dataset_ops.MapDataset):
   """A `Dataset` that maps a function over a batch of elements."""