from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
take_value = array_ops.placeholder_with_default(
constant_op.constant(1, dtype=dtypes.int64), shape=[])
+ def make_sparse(x):
+ x_1d = array_ops.reshape(x, [1])
+ x_2d = array_ops.reshape(x, [1, 1])
+ return sparse_tensor.SparseTensor(x_2d, x_1d, x_1d)
+
dataset = (dataset_ops.Dataset.range(100)
.skip(skip_value)
- .map(lambda x: x * x)
+ .map(lambda x: (x * x, make_sparse(x)))
.take(take_value))
element = get_single_element.get_single_element(dataset)
with self.test_session() as sess:
- self.assertEqual(0, sess.run(element, feed_dict={skip_value: 0}))
- self.assertEqual(25, sess.run(element, feed_dict={skip_value: 5}))
- self.assertEqual(100, sess.run(element, feed_dict={skip_value: 10}))
+ for x in [0, 5, 10]:
+ dense_val, sparse_val = sess.run(element, feed_dict={skip_value: x})
+ self.assertEqual(x * x, dense_val)
+ self.assertAllEqual([[x]], sparse_val.indices)
+ self.assertAllEqual([x], sparse_val.values)
+ self.assertAllEqual([x], sparse_val.dense_shape)
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"Dataset was empty."):
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
from tensorflow.python.ops import gen_dataset_ops
"""
if not isinstance(dataset, dataset_ops.Dataset):
raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
- return nest.pack_sequence_as(
- dataset.output_types,
- gen_dataset_ops.dataset_to_single_element(
+
+ nested_ret = nest.pack_sequence_as(
+ dataset.output_types, gen_dataset_ops.dataset_to_single_element(
dataset._as_variant_tensor(), # pylint: disable=protected-access
- output_types=nest.flatten(dataset.output_types),
- output_shapes=nest.flatten(dataset.output_shapes)))
+ output_types=nest.flatten(sparse.as_dense_types(
+ dataset.output_types, dataset.output_classes)),
+ output_shapes=nest.flatten(sparse.as_dense_shapes(
+ dataset.output_shapes, dataset.output_classes))))
+ return sparse.deserialize_sparse_tensors(
+ nested_ret, dataset.output_types, dataset.output_shapes,
+ dataset.output_classes)