[tf.data] Support tf.SparseTensor components in tf.contrib.data.get_single_element().
authorDerek Murray <mrry@google.com>
Wed, 14 Mar 2018 02:58:22 +0000 (19:58 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 14 Mar 2018 03:02:20 +0000 (20:02 -0700)
PiperOrigin-RevId: 188970548

tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
tensorflow/contrib/data/python/ops/get_single_element.py

index 32ea44f..87b7c6d 100644 (file)
@@ -22,6 +22,7 @@ from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.ops import array_ops
 from tensorflow.python.platform import test
 
@@ -33,17 +34,25 @@ class GetSingleElementTest(test.TestCase):
     take_value = array_ops.placeholder_with_default(
         constant_op.constant(1, dtype=dtypes.int64), shape=[])
 
+    def make_sparse(x):
+      x_1d = array_ops.reshape(x, [1])
+      x_2d = array_ops.reshape(x, [1, 1])
+      return sparse_tensor.SparseTensor(x_2d, x_1d, x_1d)
+
     dataset = (dataset_ops.Dataset.range(100)
                .skip(skip_value)
-               .map(lambda x: x * x)
+               .map(lambda x: (x * x, make_sparse(x)))
                .take(take_value))
 
     element = get_single_element.get_single_element(dataset)
 
     with self.test_session() as sess:
-      self.assertEqual(0, sess.run(element, feed_dict={skip_value: 0}))
-      self.assertEqual(25, sess.run(element, feed_dict={skip_value: 5}))
-      self.assertEqual(100, sess.run(element, feed_dict={skip_value: 10}))
+      for x in [0, 5, 10]:
+        dense_val, sparse_val = sess.run(element, feed_dict={skip_value: x})
+        self.assertEqual(x * x, dense_val)
+        self.assertAllEqual([[x]], sparse_val.indices)
+        self.assertAllEqual([x], sparse_val.values)
+        self.assertAllEqual([x], sparse_val.dense_shape)
 
       with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                    "Dataset was empty."):
index a817b45..3a07df5 100644 (file)
@@ -19,6 +19,7 @@ from __future__ import print_function
 
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
 from tensorflow.python.ops import gen_dataset_ops
 
 
@@ -59,9 +60,14 @@ def get_single_element(dataset):
   """
   if not isinstance(dataset, dataset_ops.Dataset):
     raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
-  return nest.pack_sequence_as(
-      dataset.output_types,
-      gen_dataset_ops.dataset_to_single_element(
+
+  nested_ret = nest.pack_sequence_as(
+      dataset.output_types, gen_dataset_ops.dataset_to_single_element(
           dataset._as_variant_tensor(),  # pylint: disable=protected-access
-          output_types=nest.flatten(dataset.output_types),
-          output_shapes=nest.flatten(dataset.output_shapes)))
+          output_types=nest.flatten(sparse.as_dense_types(
+              dataset.output_types, dataset.output_classes)),
+          output_shapes=nest.flatten(sparse.as_dense_shapes(
+              dataset.output_shapes, dataset.output_classes))))
+  return sparse.deserialize_sparse_tensors(
+      nested_ret, dataset.output_types, dataset.output_shapes,
+      dataset.output_classes)