Reintroducing support for constants as outputs of tf.data.map(). This fixes a regress...
authorJiri Simsa <jsimsa@google.com>
Fri, 13 Apr 2018 01:02:58 +0000 (18:02 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 13 Apr 2018 01:05:14 +0000 (18:05 -0700)
PiperOrigin-RevId: 192702279

tensorflow/python/data/kernel_tests/map_dataset_op_test.py
tensorflow/python/data/ops/dataset_ops.py

index 0791c61..1ad0b9d 100644 (file)
@@ -624,6 +624,20 @@ class MapDatasetTest(test.TestCase):
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(get_next)
 
+  def testConstantOutput(self):
+    iterator = (
+        dataset_ops.Dataset.range(10).map(lambda x: [x, "hello", 10])
+        .make_initializable_iterator())
+    init_op = iterator.initializer
+    get_next = iterator.get_next()
+
+    with self.test_session() as sess:
+      sess.run(init_op)
+      for i in range(10):
+        self.assertEqual((i, b"hello", 10), sess.run(get_next))
+      with self.assertRaises(errors.OutOfRangeError):
+        sess.run(get_next)
+
 
 class MapDatasetBenchmark(test.Benchmark):
 
index c28de3d..406f172 100644 (file)
@@ -1155,10 +1155,12 @@ class _GeneratorDataset(Dataset):
       if isinstance(ret, list):
         ret = tuple(ret)
 
-      # Convert any `SparseTensorValue`s to `SparseTensor`s.
+      # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+      # values to tensors.
       ret = nest.pack_sequence_as(ret, [
           sparse_tensor_lib.SparseTensor.from_value(t)
-          if sparse_tensor_lib.is_sparse(t) else t for t in nest.flatten(ret)
+          if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(t)
+          for t in nest.flatten(ret)
       ])
 
       self._state_classes = sparse.get_classes(ret)
@@ -1167,11 +1169,9 @@ class _GeneratorDataset(Dataset):
       self._state_types = nest.pack_sequence_as(
           ret, [t.dtype for t in nest.flatten(ret)])
 
-      # Serialize any sparse tensors and convert result to tensors.
-      ret = nest.pack_sequence_as(ret, [
-          ops.convert_to_tensor(t)
-          for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
-      ])
+      # Serialize any sparse tensors.
+      ret = nest.pack_sequence_as(
+          ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
       return nest.flatten(ret)
 
     self._init_func = tf_init_func
@@ -1214,10 +1214,12 @@ class _GeneratorDataset(Dataset):
       if isinstance(ret, list):
         ret = tuple(ret)
 
-      # Convert any `SparseTensorValue`s to `SparseTensor`s.
+      # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+      # values to tensors.
       ret = nest.pack_sequence_as(ret, [
           sparse_tensor_lib.SparseTensor.from_value(t)
-          if sparse_tensor_lib.is_sparse(t) else t for t in nest.flatten(ret)
+          if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(t)
+          for t in nest.flatten(ret)
       ])
 
       self._output_classes = sparse.get_classes(ret)
@@ -1226,11 +1228,9 @@ class _GeneratorDataset(Dataset):
       self._output_types = nest.pack_sequence_as(
           ret, [t.dtype for t in nest.flatten(ret)])
 
-      # Serialize any sparse tensors and convert result to tensors.
-      ret = nest.pack_sequence_as(ret, [
-          ops.convert_to_tensor(t)
-          for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
-      ])
+      # Serialize any sparse tensors.
+      ret = nest.pack_sequence_as(
+          ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
       return nest.flatten(ret)
 
     self._next_func = tf_next_func
@@ -1816,10 +1816,12 @@ class MapDataset(Dataset):
       if isinstance(ret, list):
         ret = tuple(ret)
 
-      # Convert any `SparseTensorValue`s to `SparseTensor`s.
+      # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+      # values to tensors.
       ret = nest.pack_sequence_as(ret, [
           sparse_tensor_lib.SparseTensor.from_value(t)
-          if sparse_tensor_lib.is_sparse(t) else t for t in nest.flatten(ret)
+          if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(t)
+          for t in nest.flatten(ret)
       ])
 
       self._output_classes = sparse.get_classes(ret)
@@ -1828,11 +1830,9 @@ class MapDataset(Dataset):
       self._output_types = nest.pack_sequence_as(
           ret, [t.dtype for t in nest.flatten(ret)])
 
-      # Serialize any sparse tensors and convert result to tensors.
-      ret = nest.pack_sequence_as(ret, [
-          ops.convert_to_tensor(t)
-          for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
-      ])
+      # Serialize any sparse tensors.
+      ret = nest.pack_sequence_as(
+          ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
       return nest.flatten(ret)
 
     self._map_func = tf_map_func