Add helpers to stream data from the GCE VM to a Cloud TPU.
authorBrennan Saeta <saeta@google.com>
Tue, 27 Feb 2018 04:21:07 +0000 (20:21 -0800)
committerGunhan Gulsoy <gunan@google.com>
Tue, 27 Feb 2018 22:33:33 +0000 (14:33 -0800)
PiperOrigin-RevId: 187122870

tensorflow/contrib/tpu/BUILD
tensorflow/contrib/tpu/python/tpu/datasets.py [new file with mode: 0644]
tensorflow/contrib/tpu/python/tpu/datasets_test.py [new file with mode: 0644]

index c48e84d..095b482 100644 (file)
@@ -163,6 +163,7 @@ py_library(
     ],
     srcs_version = "PY2AND3",
     deps = [
+        ":datasets",
         ":profiler",
         ":tpu_py",
         "//tensorflow/contrib/tpu/proto:topology_proto_py",
@@ -181,6 +182,33 @@ py_library(
     ],
 )
 
+py_library(
+    name = "datasets",
+    srcs = [
+        "python/tpu/datasets.py",
+    ],
+    srcs_version = "PY2AND3",
+    deps = [
+        "//tensorflow/contrib/data/python/ops:transformation_ops",
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:function",
+        "//tensorflow/python:functional_ops",
+        "//tensorflow/python/data/ops:dataset_ops",
+        "//tensorflow/python/data/ops:iterator_ops",
+        "//tensorflow/python/data/ops:readers",
+    ],
+)
+
+tf_py_test(
+    name = "datasets_test",
+    srcs = ["python/tpu/datasets_test.py"],
+    additional_deps = [
+        "//tensorflow/python:client_testlib",
+        ":datasets",
+    ],
+    grpc_enabled = True,
+)
+
 tf_py_test(
     name = "tpu_test",
     size = "small",
diff --git a/tensorflow/contrib/tpu/python/tpu/datasets.py b/tensorflow/contrib/tpu/python/tpu/datasets.py
new file mode 100644 (file)
index 0000000..29aea98
--- /dev/null
@@ -0,0 +1,192 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ======================================
+"""Library of Cloud TPU helper functions for data loading."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.ops import batching
+from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.data.ops import readers
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import functional_ops
+
+
+def _TextLineDataset(filename):
+  buffer_size = 8 * 1024 * 1024  # 8 MiB per file
+  dataset = readers.TextLineDataset(filename, buffer_size=buffer_size)
+  return dataset
+
+
+def _TFRecordDataset(filename):
+  buffer_size = 8 * 1024 * 1024  # 8 MiB per file
+  dataset = readers.TFRecordDataset(filename, buffer_size=buffer_size)
+  return dataset
+
+
+_FILETYPE_MAP = {
+    'tfrecord': _TFRecordDataset,
+    'textline': _TextLineDataset,
+    'text': _TextLineDataset,
+}
+
+
+def StreamingFilesDataset(files,
+                          filetype=None,
+                          file_reader_job=None,
+                          worker_job=None,
+                          num_epochs=None,
+                          filename_shuffle_buffer_size=None,
+                          num_parallel_reads=None,
+                          batch_transfer_size=None,
+                          sloppy=None):
+  """StreamingFilesDataset constructs a dataset to stream from workers (GCE VM).
+
+  Because Cloud TPUs are allocated over the network, a Cloud TPU cannot read
+  files local to your GCE VM. In order to train using files stored on your local
+  VM (e.g. on local SSD for extreme performance), use the StreamingFilesDataset
+  helper to generate a dataset to feed your Cloud TPU with files from your GCE
+  VM.
+
+  The resulting dataset may return an OutOfRangeError if there are no files
+  found as a result of the fileglob expansion.
+
+  Note: StreamingFilesDataset assumes that the session is using a
+  TPUClusterResolver and has therefore a worker and a coordinator job. File
+  loading will be done on the coordinator job.
+
+  Args:
+    files: A string glob to match files, or a `tf.data.Dataset` generating file
+      names.
+    filetype: A string (one of 'tfrecord', or 'textline') or a single-argument
+      TensorFlow function that when given a filename returns a dataset.
+    file_reader_job: An optional string that corresponds to the job that should
+      perform the file reads.
+    worker_job: An optional string that corresponds to the job that should
+      process the tensors (i.e. your GPU or TPU worker).
+    num_epochs: The number of epochs through the training set that should be
+      generated. By default, it will repeat infinitely.
+    filename_shuffle_buffer_size: An optional integer whose value controls the
+      shuffling of the file names. If you would like to read from the files in
+      the same order, set to 0 or False.
+    num_parallel_reads: An optional integer controlling the number of files to
+      read from concurrently. (Set to 1 for no parallelism.)
+    batch_transfer_size: An optional integer controlling the batching used to
+      amortize the remote function invocation overhead. Set to a very large
+      number to increase throughput. Set to a very small number to reduce memory
+      consumption. Set to False to skip batching.
+    sloppy: (Optional.) If `True`, read input data as fast as possible, without
+      maintaining a deterministic order. Defaults to `False`.
+  Returns:
+    A `tf.data.Dataset` with an infinite stream of elements generated by a
+    parallel interleaving of the set of files matched (or generated) by `files`
+    with a type is the output of the dataset specified by `filetype`.
+
+  Raises:
+    ValueError: if any argument is not of the expected type.
+  """
+  if filetype is None:
+    filetype = 'tfrecord'
+
+  if isinstance(filetype, str):
+    if filetype not in _FILETYPE_MAP:
+      raise ValueError('Unexpected filetype: %s' % filetype)
+    reader_fn = _FILETYPE_MAP[filetype]
+  elif callable(filetype):
+    reader_fn = filetype
+  else:
+    raise ValueError('filetype should be a string or a callable')
+
+  file_reader_job = file_reader_job or 'coordinator'
+
+  worker_job = worker_job or 'worker'
+
+  if filename_shuffle_buffer_size is None:
+    filename_shuffle_buffer_size = 4096
+
+  num_parallel_reads = num_parallel_reads or 8
+
+  if batch_transfer_size is None:
+    batch_transfer_size = 1024
+
+  if sloppy is None:
+    sloppy = False
+
+  with ops.device('/job:%s' % file_reader_job):
+    if isinstance(files, str):
+      source_dataset = dataset_ops.Dataset.list_files(files)
+    elif isinstance(files, dataset_ops.Dataset):
+      source_dataset = files
+    else:
+      raise ValueError('files was not a string or a dataset: %s' % files)
+
+    if filename_shuffle_buffer_size:
+      source_dataset = source_dataset.shuffle(
+          buffer_size=filename_shuffle_buffer_size)
+
+    # NOTE: We perform the `repeat` on the source dataset, because the output
+    # dataset does not currently have enough information to recreate an iterator
+    # over the source dataset when it reaches the end.
+    source_dataset = source_dataset.repeat(num_epochs)
+
+    source_dataset = source_dataset.apply(
+        interleave_ops.parallel_interleave(
+            reader_fn, cycle_length=num_parallel_reads, sloppy=sloppy))
+
+    if batch_transfer_size:
+      # Note: we can safely call batch_and_drop_remainder because we have an
+      # infinite stream of TFRecords.
+      source_dataset = source_dataset.apply(
+          batching.batch_and_drop_remainder(batch_transfer_size))
+
+    source_dataset = source_dataset.prefetch(1)
+
+    source_iterator = source_dataset.make_one_shot_iterator()
+    source_handle = source_iterator.string_handle()
+
+  @function.Defun(dtypes.string)
+  def LoadingFunc(h):
+    remote_iterator = iterator_ops.Iterator.from_string_handle(
+        h, source_dataset.output_types, source_dataset.output_shapes)
+    return remote_iterator.get_next()
+
+  def MapFn(unused_input):
+    return functional_ops.remote_call(
+        args=[source_handle],
+        Tout=[dtypes.string],
+        f=LoadingFunc,
+        target='/job:%s/replica:0/task:0/cpu:0' % file_reader_job)
+
+  with ops.device('/job:%s' % worker_job):
+    # TODO(saeta,mrry): Switch to using _GeneratorDataset.
+
+    # identity = lambda x: x
+    # dummy = constant_op.constant(0)
+    # output_dataset = dataset_ops._GeneratorDataset(dummy, identity, MapFn,
+    #                                                identity)
+
+    output_dataset = dataset_ops.Dataset.range(2).repeat().map(MapFn)
+    output_dataset = output_dataset.prefetch(1)
+
+    if batch_transfer_size:
+      # Undo the batching used during the transfer.
+      output_dataset = output_dataset.apply(batching.unbatch()).prefetch(1)
+
+  return output_dataset
diff --git a/tensorflow/contrib/tpu/python/tpu/datasets_test.py b/tensorflow/contrib/tpu/python/tpu/datasets_test.py
new file mode 100644 (file)
index 0000000..2c40797
--- /dev/null
@@ -0,0 +1,181 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TPU datasets tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.contrib.tpu.python.tpu import datasets
+from tensorflow.core.protobuf import cluster_pb2
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import readers
+from tensorflow.python.lib.io import python_io
+from tensorflow.python.platform import test
+from tensorflow.python.training import server_lib
+from tensorflow.python.util import compat
+
+_NUM_FILES = 10
+_NUM_ENTRIES = 200
+
+
+class DatasetsTest(test.TestCase):
+
+  def setUp(self):
+    super(DatasetsTest, self).setUp()
+    self._coord = server_lib.Server.create_local_server()
+    self._worker = server_lib.Server.create_local_server()
+
+    self._cluster_def = cluster_pb2.ClusterDef()
+    worker_job = self._cluster_def.job.add()
+    worker_job.name = 'worker'
+    worker_job.tasks[0] = self._worker.target[len('grpc://'):]
+    coord_job = self._cluster_def.job.add()
+    coord_job.name = 'coordinator'
+    coord_job.tasks[0] = self._coord.target[len('grpc://'):]
+
+    session_config = config_pb2.ConfigProto(cluster_def=self._cluster_def)
+
+    self._sess = session.Session(self._worker.target, config=session_config)
+
+  def testTextLineDataset(self):
+    all_contents = []
+    for i in range(_NUM_FILES):
+      filename = os.path.join(self.get_temp_dir(), 'text_line.%d.txt' % i)
+      contents = []
+      for j in range(_NUM_ENTRIES):
+        contents.append(compat.as_bytes('%d: %d' % (i, j)))
+      with open(filename, 'wb') as f:
+        f.write(b'\n'.join(contents))
+      all_contents.extend(contents)
+
+    dataset = datasets.StreamingFilesDataset(
+        os.path.join(self.get_temp_dir(), 'text_line.*.txt'), filetype='text')
+
+    iterator = dataset.make_initializable_iterator()
+    self._sess.run(iterator.initializer)
+    get_next = iterator.get_next()
+
+    retrieved_values = []
+    for _ in range(2 * len(all_contents)):
+      retrieved_values.append(compat.as_bytes(self._sess.run(get_next)))
+
+    self.assertEqual(set(all_contents), set(retrieved_values))
+
+  def testTFRecordDataset(self):
+    all_contents = []
+    for i in range(_NUM_FILES):
+      filename = os.path.join(self.get_temp_dir(), 'tf_record.%d' % i)
+      writer = python_io.TFRecordWriter(filename)
+      for j in range(_NUM_ENTRIES):
+        record = compat.as_bytes('Record %d of file %d' % (j, i))
+        writer.write(record)
+        all_contents.append(record)
+      writer.close()
+
+    dataset = datasets.StreamingFilesDataset(
+        os.path.join(self.get_temp_dir(), 'tf_record*'), filetype='tfrecord')
+
+    iterator = dataset.make_initializable_iterator()
+    self._sess.run(iterator.initializer)
+    get_next = iterator.get_next()
+
+    retrieved_values = []
+    for _ in range(2 * len(all_contents)):
+      retrieved_values.append(compat.as_bytes(self._sess.run(get_next)))
+
+    self.assertEqual(set(all_contents), set(retrieved_values))
+
+  def testTFRecordDatasetFromDataset(self):
+    filenames = []
+    all_contents = []
+    for i in range(_NUM_FILES):
+      filename = os.path.join(self.get_temp_dir(), 'tf_record.%d' % i)
+      filenames.append(filename)
+      writer = python_io.TFRecordWriter(filename)
+      for j in range(_NUM_ENTRIES):
+        record = compat.as_bytes('Record %d of file %d' % (j, i))
+        writer.write(record)
+        all_contents.append(record)
+      writer.close()
+
+    filenames = dataset_ops.Dataset.from_tensor_slices(filenames)
+
+    dataset = datasets.StreamingFilesDataset(filenames, filetype='tfrecord')
+
+    iterator = dataset.make_initializable_iterator()
+    self._sess.run(iterator.initializer)
+    get_next = iterator.get_next()
+
+    retrieved_values = []
+    for _ in range(2 * len(all_contents)):
+      retrieved_values.append(compat.as_bytes(self._sess.run(get_next)))
+
+    self.assertEqual(set(all_contents), set(retrieved_values))
+
+  def testArbitraryReaderFunc(self):
+
+    def MakeRecord(i, j):
+      return compat.as_bytes('%04d-%04d' % (i, j))
+
+    record_bytes = len(MakeRecord(10, 200))
+
+    all_contents = []
+    for i in range(_NUM_FILES):
+      filename = os.path.join(self.get_temp_dir(), 'fixed_length.%d' % i)
+      with open(filename, 'wb') as f:
+        for j in range(_NUM_ENTRIES):
+          record = MakeRecord(i, j)
+          f.write(record)
+          all_contents.append(record)
+
+    def FixedLengthFile(filename):
+      return readers.FixedLengthRecordDataset(filename, record_bytes)
+
+    dataset = datasets.StreamingFilesDataset(
+        os.path.join(self.get_temp_dir(), 'fixed_length*'),
+        filetype=FixedLengthFile)
+
+    iterator = dataset.make_initializable_iterator()
+    self._sess.run(iterator.initializer)
+    get_next = iterator.get_next()
+
+    retrieved_values = []
+    for _ in range(2 * len(all_contents)):
+      retrieved_values.append(compat.as_bytes(self._sess.run(get_next)))
+
+    self.assertEqual(set(all_contents), set(retrieved_values))
+
+  def testUnexpectedFiletypeString(self):
+    with self.assertRaises(ValueError):
+      datasets.StreamingFilesDataset(
+          os.path.join(self.get_temp_dir(), '*'), filetype='foo')
+
+  def testUnexpectedFiletypeType(self):
+    with self.assertRaises(ValueError):
+      datasets.StreamingFilesDataset(
+          os.path.join(self.get_temp_dir(), '*'), filetype=3)
+
+  def testUnexpectedFilesType(self):
+    with self.assertRaises(ValueError):
+      datasets.StreamingFilesDataset(123, filetype='tfrecord')
+
+
+if __name__ == '__main__':
+  test.main()