StreamingFilesDataset fixes (#19413)
authorNand Dalal <nand@clarifai.com>
Mon, 21 May 2018 03:15:21 +0000 (22:15 -0500)
committerShanqing Cai <cais@google.com>
Mon, 21 May 2018 03:15:21 +0000 (23:15 -0400)
* use source_dataset.output_dtypes to yield correctly typed output dataset

* add test and fix issue introduced by 2a6c5998a239f41926ca295ac20bb595862fd5ff

tensorflow/contrib/tpu/python/tpu/datasets.py
tensorflow/contrib/tpu/python/tpu/datasets_test.py

index 2e472a2..d879170 100644 (file)
@@ -166,11 +166,21 @@ def StreamingFilesDataset(files,
     return remote_iterator.get_next()
 
   def MapFn(unused_input):
-    return functional_ops.remote_call(
+    if isinstance(source_dataset.output_types, dtypes.DType):
+      output_types = [source_dataset.output_types]
+    elif isinstance(source_dataset.output_types, (list, tuple)):
+      output_types = source_dataset.output_types
+    else:
+      raise ValueError('source dataset has invalid output types')
+    remote_calls = functional_ops.remote_call(
         args=[source_handle],
-        Tout=[dtypes.string],
+        Tout=output_types,
         f=LoadingFunc,
-        target='/job:%s/replica:0/task:0/cpu:0' % file_reader_job)[0]
+        target='/job:%s/replica:0/task:0/cpu:0' % file_reader_job)
+    if len(remote_calls) == 1:
+      return remote_calls[0]
+    else:
+      return remote_calls
 
   with ops.device('/job:%s' % worker_job):
     output_dataset = dataset_ops.Dataset.range(2).repeat().map(
index 918cf0e..b58d05e 100644 (file)
@@ -26,6 +26,8 @@ from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.client import session
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.ops import readers
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_shape
 from tensorflow.python.lib.io import python_io
 from tensorflow.python.platform import test
 from tensorflow.python.training import server_lib
@@ -162,6 +164,30 @@ class DatasetsTest(test.TestCase):
 
     self.assertEqual(set(all_contents), set(retrieved_values))
 
+  def testArbitraryReaderFuncFromDatasetGenerator(self):
+
+    def my_generator():
+      yield (1, [1] * 10)
+
+    def gen_dataset(dummy):
+      return dataset_ops.Dataset.from_generator(
+          my_generator, (dtypes.int64, dtypes.int64),
+          (tensor_shape.TensorShape([]), tensor_shape.TensorShape([10])))
+
+    dataset = datasets.StreamingFilesDataset(
+        dataset_ops.Dataset.range(10), filetype=gen_dataset)
+
+    iterator = dataset.make_initializable_iterator()
+    self._sess.run(iterator.initializer)
+    get_next = iterator.get_next()
+
+    retrieved_values = self._sess.run(get_next)
+
+    self.assertIsInstance(retrieved_values, (list, tuple))
+    self.assertEqual(len(retrieved_values), 2)
+    self.assertEqual(retrieved_values[0], 1)
+    self.assertItemsEqual(retrieved_values[1], [1] * 10)
+
   def testUnexpectedFiletypeString(self):
     with self.assertRaises(ValueError):
       datasets.StreamingFilesDataset(