From 7b15f7a55dcd5e908211e86ec42b49136b1ccc25 Mon Sep 17 00:00:00 2001 From: Brennan Saeta Date: Mon, 26 Feb 2018 20:21:07 -0800 Subject: [PATCH] Add helpers to stream data from the GCE VM to a Cloud TPU. PiperOrigin-RevId: 187122870 --- tensorflow/contrib/tpu/BUILD | 28 +++ tensorflow/contrib/tpu/python/tpu/datasets.py | 192 +++++++++++++++++++++ tensorflow/contrib/tpu/python/tpu/datasets_test.py | 181 +++++++++++++++++++ 3 files changed, 401 insertions(+) create mode 100644 tensorflow/contrib/tpu/python/tpu/datasets.py create mode 100644 tensorflow/contrib/tpu/python/tpu/datasets_test.py diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index c48e84d..095b482 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -163,6 +163,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":datasets", ":profiler", ":tpu_py", "//tensorflow/contrib/tpu/proto:topology_proto_py", @@ -181,6 +182,33 @@ py_library( ], ) +py_library( + name = "datasets", + srcs = [ + "python/tpu/datasets.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:function", + "//tensorflow/python:functional_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python/data/ops:readers", + ], +) + +tf_py_test( + name = "datasets_test", + srcs = ["python/tpu/datasets_test.py"], + additional_deps = [ + "//tensorflow/python:client_testlib", + ":datasets", + ], + grpc_enabled = True, +) + tf_py_test( name = "tpu_test", size = "small", diff --git a/tensorflow/contrib/tpu/python/tpu/datasets.py b/tensorflow/contrib/tpu/python/tpu/datasets.py new file mode 100644 index 0000000..29aea98 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/datasets.py @@ -0,0 +1,192 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ====================================== +"""Library of Cloud TPU helper functions for data loading.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.data.ops import readers +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.ops import functional_ops + + +def _TextLineDataset(filename): + buffer_size = 8 * 1024 * 1024 # 8 MiB per file + dataset = readers.TextLineDataset(filename, buffer_size=buffer_size) + return dataset + + +def _TFRecordDataset(filename): + buffer_size = 8 * 1024 * 1024 # 8 MiB per file + dataset = readers.TFRecordDataset(filename, buffer_size=buffer_size) + return dataset + + +_FILETYPE_MAP = { + 'tfrecord': _TFRecordDataset, + 'textline': _TextLineDataset, + 'text': _TextLineDataset, +} + + +def StreamingFilesDataset(files, + filetype=None, + file_reader_job=None, + worker_job=None, + num_epochs=None, + filename_shuffle_buffer_size=None, + num_parallel_reads=None, + batch_transfer_size=None, + sloppy=None): + """StreamingFilesDataset constructs a dataset to stream from workers (GCE VM). + + Because Cloud TPUs are allocated over the network, a Cloud TPU cannot read + files local to your GCE VM. In order to train using files stored on your local + VM (e.g. on local SSD for extreme performance), use the StreamingFilesDataset + helper to generate a dataset to feed your Cloud TPU with files from your GCE + VM. + + The resulting dataset may return an OutOfRangeError if there are no files + found as a result of the fileglob expansion. + + Note: StreamingFilesDataset assumes that the session is using a + TPUClusterResolver and has therefore a worker and a coordinator job. File + loading will be done on the coordinator job. + + Args: + files: A string glob to match files, or a `tf.data.Dataset` generating file + names. + filetype: A string (one of 'tfrecord', or 'textline') or a single-argument + TensorFlow function that when given a filename returns a dataset. + file_reader_job: An optional string that corresponds to the job that should + perform the file reads. + worker_job: An optional string that corresponds to the job that should + process the tensors (i.e. your GPU or TPU worker). + num_epochs: The number of epochs through the training set that should be + generated. By default, it will repeat infinitely. + filename_shuffle_buffer_size: An optional integer whose value controls the + shuffling of the file names. If you would like to read from the files in + the same order, set to 0 or False. + num_parallel_reads: An optional integer controlling the number of files to + read from concurrently. (Set to 1 for no parallelism.) + batch_transfer_size: An optional integer controlling the batching used to + amortize the remote function invocation overhead. Set to a very large + number to increase throughput. Set to a very small number to reduce memory + consumption. Set to False to skip batching. + sloppy: (Optional.) If `True`, read input data as fast as possible, without + maintaining a deterministic order. Defaults to `False`. + Returns: + A `tf.data.Dataset` with an infinite stream of elements generated by a + parallel interleaving of the set of files matched (or generated) by `files` + with a type is the output of the dataset specified by `filetype`. + + Raises: + ValueError: if any argument is not of the expected type. + """ + if filetype is None: + filetype = 'tfrecord' + + if isinstance(filetype, str): + if filetype not in _FILETYPE_MAP: + raise ValueError('Unexpected filetype: %s' % filetype) + reader_fn = _FILETYPE_MAP[filetype] + elif callable(filetype): + reader_fn = filetype + else: + raise ValueError('filetype should be a string or a callable') + + file_reader_job = file_reader_job or 'coordinator' + + worker_job = worker_job or 'worker' + + if filename_shuffle_buffer_size is None: + filename_shuffle_buffer_size = 4096 + + num_parallel_reads = num_parallel_reads or 8 + + if batch_transfer_size is None: + batch_transfer_size = 1024 + + if sloppy is None: + sloppy = False + + with ops.device('/job:%s' % file_reader_job): + if isinstance(files, str): + source_dataset = dataset_ops.Dataset.list_files(files) + elif isinstance(files, dataset_ops.Dataset): + source_dataset = files + else: + raise ValueError('files was not a string or a dataset: %s' % files) + + if filename_shuffle_buffer_size: + source_dataset = source_dataset.shuffle( + buffer_size=filename_shuffle_buffer_size) + + # NOTE: We perform the `repeat` on the source dataset, because the output + # dataset does not currently have enough information to recreate an iterator + # over the source dataset when it reaches the end. + source_dataset = source_dataset.repeat(num_epochs) + + source_dataset = source_dataset.apply( + interleave_ops.parallel_interleave( + reader_fn, cycle_length=num_parallel_reads, sloppy=sloppy)) + + if batch_transfer_size: + # Note: we can safely call batch_and_drop_remainder because we have an + # infinite stream of TFRecords. + source_dataset = source_dataset.apply( + batching.batch_and_drop_remainder(batch_transfer_size)) + + source_dataset = source_dataset.prefetch(1) + + source_iterator = source_dataset.make_one_shot_iterator() + source_handle = source_iterator.string_handle() + + @function.Defun(dtypes.string) + def LoadingFunc(h): + remote_iterator = iterator_ops.Iterator.from_string_handle( + h, source_dataset.output_types, source_dataset.output_shapes) + return remote_iterator.get_next() + + def MapFn(unused_input): + return functional_ops.remote_call( + args=[source_handle], + Tout=[dtypes.string], + f=LoadingFunc, + target='/job:%s/replica:0/task:0/cpu:0' % file_reader_job) + + with ops.device('/job:%s' % worker_job): + # TODO(saeta,mrry): Switch to using _GeneratorDataset. + + # identity = lambda x: x + # dummy = constant_op.constant(0) + # output_dataset = dataset_ops._GeneratorDataset(dummy, identity, MapFn, + # identity) + + output_dataset = dataset_ops.Dataset.range(2).repeat().map(MapFn) + output_dataset = output_dataset.prefetch(1) + + if batch_transfer_size: + # Undo the batching used during the transfer. + output_dataset = output_dataset.apply(batching.unbatch()).prefetch(1) + + return output_dataset diff --git a/tensorflow/contrib/tpu/python/tpu/datasets_test.py b/tensorflow/contrib/tpu/python/tpu/datasets_test.py new file mode 100644 index 0000000..2c40797 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/datasets_test.py @@ -0,0 +1,181 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TPU datasets tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.tpu.python.tpu import datasets +from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import readers +from tensorflow.python.lib.io import python_io +from tensorflow.python.platform import test +from tensorflow.python.training import server_lib +from tensorflow.python.util import compat + +_NUM_FILES = 10 +_NUM_ENTRIES = 200 + + +class DatasetsTest(test.TestCase): + + def setUp(self): + super(DatasetsTest, self).setUp() + self._coord = server_lib.Server.create_local_server() + self._worker = server_lib.Server.create_local_server() + + self._cluster_def = cluster_pb2.ClusterDef() + worker_job = self._cluster_def.job.add() + worker_job.name = 'worker' + worker_job.tasks[0] = self._worker.target[len('grpc://'):] + coord_job = self._cluster_def.job.add() + coord_job.name = 'coordinator' + coord_job.tasks[0] = self._coord.target[len('grpc://'):] + + session_config = config_pb2.ConfigProto(cluster_def=self._cluster_def) + + self._sess = session.Session(self._worker.target, config=session_config) + + def testTextLineDataset(self): + all_contents = [] + for i in range(_NUM_FILES): + filename = os.path.join(self.get_temp_dir(), 'text_line.%d.txt' % i) + contents = [] + for j in range(_NUM_ENTRIES): + contents.append(compat.as_bytes('%d: %d' % (i, j))) + with open(filename, 'wb') as f: + f.write(b'\n'.join(contents)) + all_contents.extend(contents) + + dataset = datasets.StreamingFilesDataset( + os.path.join(self.get_temp_dir(), 'text_line.*.txt'), filetype='text') + + iterator = dataset.make_initializable_iterator() + self._sess.run(iterator.initializer) + get_next = iterator.get_next() + + retrieved_values = [] + for _ in range(2 * len(all_contents)): + retrieved_values.append(compat.as_bytes(self._sess.run(get_next))) + + self.assertEqual(set(all_contents), set(retrieved_values)) + + def testTFRecordDataset(self): + all_contents = [] + for i in range(_NUM_FILES): + filename = os.path.join(self.get_temp_dir(), 'tf_record.%d' % i) + writer = python_io.TFRecordWriter(filename) + for j in range(_NUM_ENTRIES): + record = compat.as_bytes('Record %d of file %d' % (j, i)) + writer.write(record) + all_contents.append(record) + writer.close() + + dataset = datasets.StreamingFilesDataset( + os.path.join(self.get_temp_dir(), 'tf_record*'), filetype='tfrecord') + + iterator = dataset.make_initializable_iterator() + self._sess.run(iterator.initializer) + get_next = iterator.get_next() + + retrieved_values = [] + for _ in range(2 * len(all_contents)): + retrieved_values.append(compat.as_bytes(self._sess.run(get_next))) + + self.assertEqual(set(all_contents), set(retrieved_values)) + + def testTFRecordDatasetFromDataset(self): + filenames = [] + all_contents = [] + for i in range(_NUM_FILES): + filename = os.path.join(self.get_temp_dir(), 'tf_record.%d' % i) + filenames.append(filename) + writer = python_io.TFRecordWriter(filename) + for j in range(_NUM_ENTRIES): + record = compat.as_bytes('Record %d of file %d' % (j, i)) + writer.write(record) + all_contents.append(record) + writer.close() + + filenames = dataset_ops.Dataset.from_tensor_slices(filenames) + + dataset = datasets.StreamingFilesDataset(filenames, filetype='tfrecord') + + iterator = dataset.make_initializable_iterator() + self._sess.run(iterator.initializer) + get_next = iterator.get_next() + + retrieved_values = [] + for _ in range(2 * len(all_contents)): + retrieved_values.append(compat.as_bytes(self._sess.run(get_next))) + + self.assertEqual(set(all_contents), set(retrieved_values)) + + def testArbitraryReaderFunc(self): + + def MakeRecord(i, j): + return compat.as_bytes('%04d-%04d' % (i, j)) + + record_bytes = len(MakeRecord(10, 200)) + + all_contents = [] + for i in range(_NUM_FILES): + filename = os.path.join(self.get_temp_dir(), 'fixed_length.%d' % i) + with open(filename, 'wb') as f: + for j in range(_NUM_ENTRIES): + record = MakeRecord(i, j) + f.write(record) + all_contents.append(record) + + def FixedLengthFile(filename): + return readers.FixedLengthRecordDataset(filename, record_bytes) + + dataset = datasets.StreamingFilesDataset( + os.path.join(self.get_temp_dir(), 'fixed_length*'), + filetype=FixedLengthFile) + + iterator = dataset.make_initializable_iterator() + self._sess.run(iterator.initializer) + get_next = iterator.get_next() + + retrieved_values = [] + for _ in range(2 * len(all_contents)): + retrieved_values.append(compat.as_bytes(self._sess.run(get_next))) + + self.assertEqual(set(all_contents), set(retrieved_values)) + + def testUnexpectedFiletypeString(self): + with self.assertRaises(ValueError): + datasets.StreamingFilesDataset( + os.path.join(self.get_temp_dir(), '*'), filetype='foo') + + def testUnexpectedFiletypeType(self): + with self.assertRaises(ValueError): + datasets.StreamingFilesDataset( + os.path.join(self.get_temp_dir(), '*'), filetype=3) + + def testUnexpectedFilesType(self): + with self.assertRaises(ValueError): + datasets.StreamingFilesDataset(123, filetype='tfrecord') + + +if __name__ == '__main__': + test.main() -- 2.7.4