[tf.data] Adding support for `tf.SparseTensor` into `tf.contrib.data.scan()`
authorJiri Simsa <jsimsa@google.com>
Mon, 30 Apr 2018 21:08:29 +0000 (14:08 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 30 Apr 2018 21:11:19 +0000 (14:11 -0700)
PiperOrigin-RevId: 194842266

tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
tensorflow/contrib/data/python/ops/scan_ops.py

index 1a97a84..f544b1c 100644 (file)
@@ -28,6 +28,7 @@ from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.platform import test
@@ -35,15 +36,19 @@ from tensorflow.python.platform import test
 
 class ScanDatasetTest(test.TestCase):
 
-  def _count(self, start, step):
-    return dataset_ops.Dataset.from_tensors(0).repeat(None).apply(
-        scan_ops.scan(start, lambda state, _: (state + step, state)))
+  def _counting_dataset(self, start, scan_fn):
+    return dataset_ops.Dataset.from_tensors(0).repeat().apply(
+        scan_ops.scan(start, scan_fn))
 
   def testCount(self):
+    def make_scan_fn(step):
+      return lambda state, _: (state + step, state)
+
     start = array_ops.placeholder(dtypes.int32, shape=[])
     step = array_ops.placeholder(dtypes.int32, shape=[])
     take = array_ops.placeholder(dtypes.int64, shape=[])
-    iterator = self._count(start, step).take(take).make_initializable_iterator()
+    iterator = self._counting_dataset(
+        start, make_scan_fn(step)).take(take).make_initializable_iterator()
     next_element = iterator.get_next()
 
     with self.test_session() as sess:
@@ -78,6 +83,37 @@ class ScanDatasetTest(test.TestCase):
     self.assertEqual(5, self.evaluate(next_element()))
     self.assertEqual(8, self.evaluate(next_element()))
 
+  def testSparseCount(self):
+    def _sparse(i):
+      return sparse_tensor.SparseTensorValue(
+          indices=np.array([[0, 0]]),
+          values=(i * np.array([1])),
+          dense_shape=np.array([1, 1]))
+
+    def make_scan_fn(step):
+      return lambda state, _: (_sparse(state.values[0] + step), state)
+
+    start = array_ops.placeholder(dtypes.int32, shape=[])
+    step = array_ops.placeholder(dtypes.int32, shape=[])
+    take = array_ops.placeholder(dtypes.int64, shape=[])
+    iterator = self._counting_dataset(
+        _sparse(start),
+        make_scan_fn(step)).take(take).make_initializable_iterator()
+    next_element = iterator.get_next()
+
+    with self.test_session() as sess:
+
+      for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10),
+                                            (10, 2, 10), (10, -1, 10),
+                                            (10, -2, 10)]:
+        sess.run(iterator.initializer,
+                 feed_dict={start: start_val, step: step_val, take: take_val})
+        for expected, _ in zip(
+            itertools.count(start_val, step_val), range(take_val)):
+          self.assertEqual(expected, sess.run(next_element).values[0])
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(next_element)
+
   def testChangingStateShape(self):
     # Test the fixed-point shape invariant calculations: start with
     # initial values with known shapes, and use a scan function that
index 60ef7ef..e911ad0 100644 (file)
@@ -24,6 +24,7 @@ from tensorflow.python.data.util import nest
 from tensorflow.python.data.util import sparse
 from tensorflow.python.framework import function
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.ops import gen_dataset_ops
 
 
@@ -36,18 +37,22 @@ class _ScanDataset(dataset_ops.Dataset):
     self._input_dataset = input_dataset
 
     with ops.name_scope("initial_state"):
+      # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+      # values to tensors.
       self._initial_state = nest.pack_sequence_as(initial_state, [
-          ops.convert_to_tensor(t, name="component_%d" % i)
+          sparse_tensor.SparseTensor.from_value(t)
+          if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(
+              t, name="component_%d" % i)
           for i, t in enumerate(nest.flatten(initial_state))
       ])
 
-    # Compute initial values for the state shapes and types based on
-    # the initial state. These will be refined by running
-    # `tf_scan_func` one or more times below.
-    # TODO(b/68937811): Allow the initial state to be a tf.SparseTensor.
+    # Compute initial values for the state classes, shapes and types based on
+    # the initial state. The shapes may be refined by running `tf_scan_func` one
+    # or more times below.
+    self._state_classes = sparse.get_classes(self._initial_state)
     self._state_shapes = nest.pack_sequence_as(
         self._initial_state,
-        [t.shape for t in nest.flatten(self._initial_state)])
+        [t.get_shape() for t in nest.flatten(self._initial_state)])
     self._state_types = nest.pack_sequence_as(
         self._initial_state,
         [t.dtype for t in nest.flatten(self._initial_state)])
@@ -62,67 +67,102 @@ class _ScanDataset(dataset_ops.Dataset):
     need_to_rerun = True
     while need_to_rerun:
 
-      flat_state_shapes = nest.flatten(self._state_shapes)
-      flat_state_types = nest.flatten(self._state_types)
-
-      # Create a list in which `tf_scan_func` will store the s
+      # Create a list in which `tf_scan_func` will store the new shapes.
       flat_new_state_shapes = []
 
-      @function.Defun(*(flat_state_types + nest.flatten(
-          sparse.as_dense_types(input_dataset.output_types,
-                                input_dataset.output_classes))))
+      @function.Defun(*(nest.flatten(
+          sparse.as_dense_types(
+              self._state_types, self._state_classes)) + nest.flatten(
+                  sparse.as_dense_types(input_dataset.output_types,
+                                        input_dataset.output_classes))))
       def tf_scan_func(*args):
         """A wrapper for Defun that facilitates shape inference."""
         # Pass in shape information from the state and input_dataset.
-        # TODO(b/69424092): Check that neither inputs nor outputs are sparse.
-        dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
-                                              input_dataset.output_classes)
-        for arg, shape in zip(args,
-                              flat_state_shapes + nest.flatten(dense_shapes)):
+        for arg, shape in zip(
+            args,
+            nest.flatten(
+                sparse.as_dense_shapes(self._state_shapes, self._state_classes))
+            + nest.flatten(
+                sparse.as_dense_shapes(input_dataset.output_shapes,
+                                       input_dataset.output_classes))):
           arg.set_shape(shape)
 
-        pivot = len(flat_state_shapes)
-        old_state = nest.pack_sequence_as(self._initial_state, args[:pivot])
-        input_value = nest.pack_sequence_as(input_dataset.output_types,
-                                            args[pivot:])
-
-        ret = scan_func(old_state, input_value)
+        pivot = len(nest.flatten(self._state_shapes))
+        print(self._state_classes)
+        nested_state_args = nest.pack_sequence_as(self._state_types,
+                                                  args[:pivot])
+        nested_state_args = sparse.deserialize_sparse_tensors(
+            nested_state_args, self._state_types, self._state_shapes,
+            self._state_classes)
+        print(input_dataset.output_classes)
+        nested_input_args = nest.pack_sequence_as(input_dataset.output_types,
+                                                  args[pivot:])
+        nested_input_args = sparse.deserialize_sparse_tensors(
+            nested_input_args, input_dataset.output_types,
+            input_dataset.output_shapes, input_dataset.output_classes)
+
+        ret = scan_func(nested_state_args, nested_input_args)
         if not isinstance(ret, collections.Sequence) or len(ret) != 2:
           raise TypeError("The scan function must return a pair comprising the "
                           "new state and the output value.")
+
+        # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+        # values to tensors.
+        ret = nest.pack_sequence_as(ret, [
+            sparse_tensor.SparseTensor.from_value(t)
+            if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
+            for t in nest.flatten(ret)
+        ])
         new_state, output_value = ret
 
-        flat_new_state = [
-            ops.convert_to_tensor(t) for t in nest.flatten(new_state)
-        ]
-        flat_output_value = [
-            ops.convert_to_tensor(t) for t in nest.flatten(output_value)
-        ]
+        # Extract and validate class information from the returned values.
+        for t, clazz in zip(
+            nest.flatten(new_state), nest.flatten(self._state_classes)):
+          if not isinstance(t, clazz):
+            raise TypeError(
+                "The element classes for the new state must match the initial "
+                "state. Expected %s; got %s." %
+                (self._state_classes,
+                 nest.pack_sequence_as(
+                     self._state_types,
+                     [type(t) for t in nest.flatten(new_state)])))
+        self._output_classes = sparse.get_classes(output_value)
 
         # Extract shape information from the returned values.
-        flat_new_state_shapes.extend([t.shape for t in flat_new_state])
+        flat_new_state_shapes.extend(
+            [t.get_shape() for t in nest.flatten(new_state)])
         self._output_shapes = nest.pack_sequence_as(
-            output_value, [t.shape for t in flat_output_value])
+            output_value, [t.get_shape() for t in nest.flatten(output_value)])
 
         # Extract and validate type information from the returned values.
-        for t, dtype in zip(flat_new_state, flat_state_types):
+        for t, dtype in zip(
+            nest.flatten(new_state), nest.flatten(self._state_types)):
           if t.dtype != dtype:
             raise TypeError(
                 "The element types for the new state must match the initial "
                 "state. Expected %s; got %s." %
-                (self._state_types, nest.pack_sequence_as(
-                    self._state_types, [t.dtype for t in flat_new_state])))
-        self._output_classes = nest.pack_sequence_as(
-            output_value, [ops.Tensor for _ in flat_output_value])
+                (self._state_types,
+                 nest.pack_sequence_as(
+                     self._state_types,
+                     [t.dtype for t in nest.flatten(new_state)])))
         self._output_types = nest.pack_sequence_as(
-            output_value, [t.dtype for t in flat_output_value])
-
-        return flat_new_state + flat_output_value
+            output_value, [t.dtype for t in nest.flatten(output_value)])
+
+        # Serialize any sparse tensors.
+        new_state = nest.pack_sequence_as(new_state, [
+            t for t in nest.flatten(sparse.serialize_sparse_tensors(new_state))
+        ])
+        output_value = nest.pack_sequence_as(output_value, [
+            t for t in nest.flatten(
+                sparse.serialize_sparse_tensors(output_value))
+        ])
+        return nest.flatten(new_state) + nest.flatten(output_value)
 
       # Use the private method that will execute `tf_scan_func` but delay
       # adding it to the graph in case we need to rerun the function.
       tf_scan_func._create_definition_if_needed()  # pylint: disable=protected-access
 
+      flat_state_shapes = nest.flatten(self._state_shapes)
       weakened_state_shapes = [
           original.most_specific_compatible_shape(new)
           for original, new in zip(flat_state_shapes, flat_new_state_shapes)
@@ -150,7 +190,7 @@ class _ScanDataset(dataset_ops.Dataset):
     input_t = self._input_dataset._as_variant_tensor()  # pylint: disable=protected-access
     return gen_dataset_ops.scan_dataset(
         input_t,
-        nest.flatten(self._initial_state),
+        nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)),
         self._scan_func.captured_inputs,
         f=self._scan_func,
         output_types=nest.flatten(