From: Derek Murray Date: Tue, 3 Apr 2018 17:00:02 +0000 (-0700) Subject: [tf.data] Fix handling of nested structures in `tf.contrib.data.prefetch_to_device()`. X-Git-Tag: tflite-v0.1.7~39^2^2~91 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=cfc886ac6064a04c71dd6c52e8c21ebec91eae50;p=platform%2Fupstream%2Ftensorflow.git [tf.data] Fix handling of nested structures in `tf.contrib.data.prefetch_to_device()`. PiperOrigin-RevId: 191456191 --- diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py index 676959a..f2c57f9 100644 --- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py @@ -231,6 +231,37 @@ class StagingAreaOpsTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + def testPrefetchDictToDevice(self): + host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x}) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/cpu:1")) + + # NOTE(mrry): This device block creates the "host" dataset and iterator on + # /cpu:0, and ensures that the prefetching is across devices. In typical use + # this would not be necessary, because the GPU device would not support any + # of the dataset-related ops. + with ops.device("/cpu:0"): + iterator = device_dataset.make_one_shot_iterator() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + next_element = iterator.get_next() + self.assertEqual(dtypes.int64, next_element["a"].dtype) + self.assertEqual([], next_element["a"].shape) + + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + with self.test_session(config=worker_config) as sess: + for i in range(10): + self.assertEqual({"a": i}, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + def testPrefetchToDeviceGpu(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") @@ -248,5 +279,6 @@ class StagingAreaOpsTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py index 98651bb..554bfaa 100644 --- a/tensorflow/contrib/data/python/ops/prefetching_ops.py +++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py @@ -28,6 +28,7 @@ from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib # TODO(rohanj): Add a python class that constructs resource in the __init__ @@ -77,10 +78,24 @@ class _PrefetchToDeviceIterator(object): @function.Defun(dtypes.string) def _prefetch_fn(handle): + """Prefetches one element from `input_iterator`.""" remote_iterator = iterator_ops.Iterator.from_string_handle( handle, input_iterator.output_types, input_iterator.output_shapes, input_iterator.output_classes) - return remote_iterator.get_next() + ret = remote_iterator.get_next() + + # Convert any `SparseTensorValue`s to `SparseTensor`s. + ret = nest.pack_sequence_as(ret, [ + sparse_tensor_lib.SparseTensor.from_value(t) + if sparse_tensor_lib.is_sparse(t) else t for t in nest.flatten(ret) + ]) + + # Serialize any sparse tensors and convert result to tensors. + ret = nest.pack_sequence_as(ret, [ + ops.convert_to_tensor(t) + for t in nest.flatten(sparse.serialize_sparse_tensors(ret)) + ]) + return nest.flatten(ret) with ops.device(device): self._buffering_resource = function_buffering_resource(