From: Derek Murray Date: Wed, 14 Mar 2018 02:58:22 +0000 (-0700) Subject: [tf.data] Support tf.SparseTensor components in tf.contrib.data.get_single_element(). X-Git-Tag: tflite-v0.1.7~193^2~1^2~21 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c218b09d4f0c52f526be1481d47cdc3c4d005f5b;p=platform%2Fupstream%2Ftensorflow.git [tf.data] Support tf.SparseTensor components in tf.contrib.data.get_single_element(). PiperOrigin-RevId: 188970548 --- diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py index 32ea44f..87b7c6d 100644 --- a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py @@ -22,6 +22,7 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -33,17 +34,25 @@ class GetSingleElementTest(test.TestCase): take_value = array_ops.placeholder_with_default( constant_op.constant(1, dtype=dtypes.int64), shape=[]) + def make_sparse(x): + x_1d = array_ops.reshape(x, [1]) + x_2d = array_ops.reshape(x, [1, 1]) + return sparse_tensor.SparseTensor(x_2d, x_1d, x_1d) + dataset = (dataset_ops.Dataset.range(100) .skip(skip_value) - .map(lambda x: x * x) + .map(lambda x: (x * x, make_sparse(x))) .take(take_value)) element = get_single_element.get_single_element(dataset) with self.test_session() as sess: - self.assertEqual(0, sess.run(element, feed_dict={skip_value: 0})) - self.assertEqual(25, sess.run(element, feed_dict={skip_value: 5})) - self.assertEqual(100, sess.run(element, feed_dict={skip_value: 10})) + for x in [0, 5, 10]: + dense_val, sparse_val = sess.run(element, feed_dict={skip_value: x}) + self.assertEqual(x * x, dense_val) + self.assertAllEqual([[x]], sparse_val.indices) + self.assertAllEqual([x], sparse_val.values) + self.assertAllEqual([x], sparse_val.dense_shape) with self.assertRaisesRegexp(errors.InvalidArgumentError, "Dataset was empty."): diff --git a/tensorflow/contrib/data/python/ops/get_single_element.py b/tensorflow/contrib/data/python/ops/get_single_element.py index a817b45..3a07df5 100644 --- a/tensorflow/contrib/data/python/ops/get_single_element.py +++ b/tensorflow/contrib/data/python/ops/get_single_element.py @@ -19,6 +19,7 @@ from __future__ import print_function from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse from tensorflow.python.ops import gen_dataset_ops @@ -59,9 +60,14 @@ def get_single_element(dataset): """ if not isinstance(dataset, dataset_ops.Dataset): raise TypeError("`dataset` must be a `tf.data.Dataset` object.") - return nest.pack_sequence_as( - dataset.output_types, - gen_dataset_ops.dataset_to_single_element( + + nested_ret = nest.pack_sequence_as( + dataset.output_types, gen_dataset_ops.dataset_to_single_element( dataset._as_variant_tensor(), # pylint: disable=protected-access - output_types=nest.flatten(dataset.output_types), - output_shapes=nest.flatten(dataset.output_shapes))) + output_types=nest.flatten(sparse.as_dense_types( + dataset.output_types, dataset.output_classes)), + output_shapes=nest.flatten(sparse.as_dense_shapes( + dataset.output_shapes, dataset.output_classes)))) + return sparse.deserialize_sparse_tensors( + nested_ret, dataset.output_types, dataset.output_shapes, + dataset.output_classes)