From a2e1b4dcbd0ecd310efa2eb258dcbdbcf942af86 Mon Sep 17 00:00:00 2001 From: Nand Dalal Date: Sun, 20 May 2018 22:15:21 -0500 Subject: [PATCH] StreamingFilesDataset fixes (#19413) * use source_dataset.output_dtypes to yield correctly typed output dataset * add test and fix issue introduced by 2a6c5998a239f41926ca295ac20bb595862fd5ff --- tensorflow/contrib/tpu/python/tpu/datasets.py | 16 ++++++++++--- tensorflow/contrib/tpu/python/tpu/datasets_test.py | 26 ++++++++++++++++++++++ 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/tpu/python/tpu/datasets.py b/tensorflow/contrib/tpu/python/tpu/datasets.py index 2e472a2..d879170 100644 --- a/tensorflow/contrib/tpu/python/tpu/datasets.py +++ b/tensorflow/contrib/tpu/python/tpu/datasets.py @@ -166,11 +166,21 @@ def StreamingFilesDataset(files, return remote_iterator.get_next() def MapFn(unused_input): - return functional_ops.remote_call( + if isinstance(source_dataset.output_types, dtypes.DType): + output_types = [source_dataset.output_types] + elif isinstance(source_dataset.output_types, (list, tuple)): + output_types = source_dataset.output_types + else: + raise ValueError('source dataset has invalid output types') + remote_calls = functional_ops.remote_call( args=[source_handle], - Tout=[dtypes.string], + Tout=output_types, f=LoadingFunc, - target='/job:%s/replica:0/task:0/cpu:0' % file_reader_job)[0] + target='/job:%s/replica:0/task:0/cpu:0' % file_reader_job) + if len(remote_calls) == 1: + return remote_calls[0] + else: + return remote_calls with ops.device('/job:%s' % worker_job): output_dataset = dataset_ops.Dataset.range(2).repeat().map( diff --git a/tensorflow/contrib/tpu/python/tpu/datasets_test.py b/tensorflow/contrib/tpu/python/tpu/datasets_test.py index 918cf0e..b58d05e 100644 --- a/tensorflow/contrib/tpu/python/tpu/datasets_test.py +++ b/tensorflow/contrib/tpu/python/tpu/datasets_test.py @@ -26,6 +26,8 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape from tensorflow.python.lib.io import python_io from tensorflow.python.platform import test from tensorflow.python.training import server_lib @@ -162,6 +164,30 @@ class DatasetsTest(test.TestCase): self.assertEqual(set(all_contents), set(retrieved_values)) + def testArbitraryReaderFuncFromDatasetGenerator(self): + + def my_generator(): + yield (1, [1] * 10) + + def gen_dataset(dummy): + return dataset_ops.Dataset.from_generator( + my_generator, (dtypes.int64, dtypes.int64), + (tensor_shape.TensorShape([]), tensor_shape.TensorShape([10]))) + + dataset = datasets.StreamingFilesDataset( + dataset_ops.Dataset.range(10), filetype=gen_dataset) + + iterator = dataset.make_initializable_iterator() + self._sess.run(iterator.initializer) + get_next = iterator.get_next() + + retrieved_values = self._sess.run(get_next) + + self.assertIsInstance(retrieved_values, (list, tuple)) + self.assertEqual(len(retrieved_values), 2) + self.assertEqual(retrieved_values[0], 1) + self.assertItemsEqual(retrieved_values[1], [1] * 10) + def testUnexpectedFiletypeString(self): with self.assertRaises(ValueError): datasets.StreamingFilesDataset( -- 2.7.4