return remote_iterator.get_next()
def MapFn(unused_input):
- return functional_ops.remote_call(
+ if isinstance(source_dataset.output_types, dtypes.DType):
+ output_types = [source_dataset.output_types]
+ elif isinstance(source_dataset.output_types, (list, tuple)):
+ output_types = source_dataset.output_types
+ else:
+ raise ValueError('source dataset has invalid output types')
+ remote_calls = functional_ops.remote_call(
args=[source_handle],
- Tout=[dtypes.string],
+ Tout=output_types,
f=LoadingFunc,
- target='/job:%s/replica:0/task:0/cpu:0' % file_reader_job)[0]
+ target='/job:%s/replica:0/task:0/cpu:0' % file_reader_job)
+ if len(remote_calls) == 1:
+ return remote_calls[0]
+ else:
+ return remote_calls
with ops.device('/job:%s' % worker_job):
output_dataset = dataset_ops.Dataset.range(2).repeat().map(
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.lib.io import python_io
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
self.assertEqual(set(all_contents), set(retrieved_values))
+ def testArbitraryReaderFuncFromDatasetGenerator(self):
+
+ def my_generator():
+ yield (1, [1] * 10)
+
+ def gen_dataset(dummy):
+ return dataset_ops.Dataset.from_generator(
+ my_generator, (dtypes.int64, dtypes.int64),
+ (tensor_shape.TensorShape([]), tensor_shape.TensorShape([10])))
+
+ dataset = datasets.StreamingFilesDataset(
+ dataset_ops.Dataset.range(10), filetype=gen_dataset)
+
+ iterator = dataset.make_initializable_iterator()
+ self._sess.run(iterator.initializer)
+ get_next = iterator.get_next()
+
+ retrieved_values = self._sess.run(get_next)
+
+ self.assertIsInstance(retrieved_values, (list, tuple))
+ self.assertEqual(len(retrieved_values), 2)
+ self.assertEqual(retrieved_values[0], 1)
+ self.assertItemsEqual(retrieved_values[1], [1] * 10)
+
def testUnexpectedFiletypeString(self):
with self.assertRaises(ValueError):
datasets.StreamingFilesDataset(