[tf.data] Fix handling of nested structures in `tf.contrib.data.prefetch_to_device()`.
authorDerek Murray <mrry@google.com>
Tue, 3 Apr 2018 17:00:02 +0000 (10:00 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 3 Apr 2018 17:05:51 +0000 (10:05 -0700)
PiperOrigin-RevId: 191456191

tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
tensorflow/contrib/data/python/ops/prefetching_ops.py

index 676959a..f2c57f9 100644 (file)
@@ -231,6 +231,37 @@ class StagingAreaOpsTest(test.TestCase):
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(next_element)
 
+  def testPrefetchDictToDevice(self):
+    host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
+    device_dataset = host_dataset.apply(
+        prefetching_ops.prefetch_to_device("/cpu:1"))
+
+    # NOTE(mrry): This device block creates the "host" dataset and iterator on
+    # /cpu:0, and ensures that the prefetching is across devices. In typical use
+    # this would not be necessary, because the GPU device would not support any
+    # of the dataset-related ops.
+    with ops.device("/cpu:0"):
+      iterator = device_dataset.make_one_shot_iterator()
+
+    self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+    self.assertEqual(host_dataset.output_types, iterator.output_types)
+    self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+    self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+    self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+    self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+    next_element = iterator.get_next()
+    self.assertEqual(dtypes.int64, next_element["a"].dtype)
+    self.assertEqual([], next_element["a"].shape)
+
+    worker_config = config_pb2.ConfigProto()
+    worker_config.device_count["CPU"] = 2
+    with self.test_session(config=worker_config) as sess:
+      for i in range(10):
+        self.assertEqual({"a": i}, sess.run(next_element))
+      with self.assertRaises(errors.OutOfRangeError):
+        sess.run(next_element)
+
   def testPrefetchToDeviceGpu(self):
     if not test_util.is_gpu_available():
       self.skipTest("No GPU available")
@@ -248,5 +279,6 @@ class StagingAreaOpsTest(test.TestCase):
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(next_element)
 
+
 if __name__ == "__main__":
   test.main()
index 98651bb..554bfaa 100644 (file)
@@ -28,6 +28,7 @@ from tensorflow.python.data.util import sparse
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import function
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
 
 
 # TODO(rohanj): Add a python class that constructs resource in the __init__
@@ -77,10 +78,24 @@ class _PrefetchToDeviceIterator(object):
 
     @function.Defun(dtypes.string)
     def _prefetch_fn(handle):
+      """Prefetches one element from `input_iterator`."""
       remote_iterator = iterator_ops.Iterator.from_string_handle(
           handle, input_iterator.output_types, input_iterator.output_shapes,
           input_iterator.output_classes)
-      return remote_iterator.get_next()
+      ret = remote_iterator.get_next()
+
+      # Convert any `SparseTensorValue`s to `SparseTensor`s.
+      ret = nest.pack_sequence_as(ret, [
+          sparse_tensor_lib.SparseTensor.from_value(t)
+          if sparse_tensor_lib.is_sparse(t) else t for t in nest.flatten(ret)
+      ])
+
+      # Serialize any sparse tensors and convert result to tensors.
+      ret = nest.pack_sequence_as(ret, [
+          ops.convert_to_tensor(t)
+          for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
+      ])
+      return nest.flatten(ret)
 
     with ops.device(device):
       self._buffering_resource = function_buffering_resource(