From 5a53c9b54d8781032ebf2cf26f93da3b2a33d1e4 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Thu, 12 Apr 2018 18:02:58 -0700 Subject: [PATCH] Reintroducing support for constants as outputs of tf.data.map(). This fixes a regression introduced by cl/176147440. PiperOrigin-RevId: 192702279 --- .../data/kernel_tests/map_dataset_op_test.py | 14 ++++++++ tensorflow/python/data/ops/dataset_ops.py | 42 +++++++++++----------- 2 files changed, 35 insertions(+), 21 deletions(-) diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py index 0791c61..1ad0b9d 100644 --- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py @@ -624,6 +624,20 @@ class MapDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testConstantOutput(self): + iterator = ( + dataset_ops.Dataset.range(10).map(lambda x: [x, "hello", 10]) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(10): + self.assertEqual((i, b"hello", 10), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + class MapDatasetBenchmark(test.Benchmark): diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index c28de3d..406f172 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -1155,10 +1155,12 @@ class _GeneratorDataset(Dataset): if isinstance(ret, list): ret = tuple(ret) - # Convert any `SparseTensorValue`s to `SparseTensor`s. + # Convert any `SparseTensorValue`s to `SparseTensor`s and all other + # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor_lib.SparseTensor.from_value(t) - if sparse_tensor_lib.is_sparse(t) else t for t in nest.flatten(ret) + if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(t) + for t in nest.flatten(ret) ]) self._state_classes = sparse.get_classes(ret) @@ -1167,11 +1169,9 @@ class _GeneratorDataset(Dataset): self._state_types = nest.pack_sequence_as( ret, [t.dtype for t in nest.flatten(ret)]) - # Serialize any sparse tensors and convert result to tensors. - ret = nest.pack_sequence_as(ret, [ - ops.convert_to_tensor(t) - for t in nest.flatten(sparse.serialize_sparse_tensors(ret)) - ]) + # Serialize any sparse tensors. + ret = nest.pack_sequence_as( + ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) return nest.flatten(ret) self._init_func = tf_init_func @@ -1214,10 +1214,12 @@ class _GeneratorDataset(Dataset): if isinstance(ret, list): ret = tuple(ret) - # Convert any `SparseTensorValue`s to `SparseTensor`s. + # Convert any `SparseTensorValue`s to `SparseTensor`s and all other + # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor_lib.SparseTensor.from_value(t) - if sparse_tensor_lib.is_sparse(t) else t for t in nest.flatten(ret) + if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(t) + for t in nest.flatten(ret) ]) self._output_classes = sparse.get_classes(ret) @@ -1226,11 +1228,9 @@ class _GeneratorDataset(Dataset): self._output_types = nest.pack_sequence_as( ret, [t.dtype for t in nest.flatten(ret)]) - # Serialize any sparse tensors and convert result to tensors. - ret = nest.pack_sequence_as(ret, [ - ops.convert_to_tensor(t) - for t in nest.flatten(sparse.serialize_sparse_tensors(ret)) - ]) + # Serialize any sparse tensors. + ret = nest.pack_sequence_as( + ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) return nest.flatten(ret) self._next_func = tf_next_func @@ -1816,10 +1816,12 @@ class MapDataset(Dataset): if isinstance(ret, list): ret = tuple(ret) - # Convert any `SparseTensorValue`s to `SparseTensor`s. + # Convert any `SparseTensorValue`s to `SparseTensor`s and all other + # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor_lib.SparseTensor.from_value(t) - if sparse_tensor_lib.is_sparse(t) else t for t in nest.flatten(ret) + if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(t) + for t in nest.flatten(ret) ]) self._output_classes = sparse.get_classes(ret) @@ -1828,11 +1830,9 @@ class MapDataset(Dataset): self._output_types = nest.pack_sequence_as( ret, [t.dtype for t in nest.flatten(ret)]) - # Serialize any sparse tensors and convert result to tensors. - ret = nest.pack_sequence_as(ret, [ - ops.convert_to_tensor(t) - for t in nest.flatten(sparse.serialize_sparse_tensors(ret)) - ]) + # Serialize any sparse tensors. + ret = nest.pack_sequence_as( + ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) return nest.flatten(ret) self._map_func = tf_map_func -- 2.7.4