[tf.data] Optimizations on make_csv_dataset internals.
authorRachel Lim <rachelim@google.com>
Thu, 29 Mar 2018 15:19:17 +0000 (08:19 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 29 Mar 2018 15:21:46 +0000 (08:21 -0700)
PiperOrigin-RevId: 190933143

tensorflow/contrib/data/python/ops/readers.py

index 95edca6..9a48aa0 100644 (file)
@@ -18,9 +18,11 @@ from __future__ import division
 from __future__ import print_function
 
 import csv
+from math import ceil
 
 import numpy as np
 
+from tensorflow.contrib.data.python.ops import batching
 from tensorflow.contrib.data.python.ops import interleave_ops
 from tensorflow.contrib.data.python.ops import shuffle_ops
 from tensorflow.python.data.ops import dataset_ops
@@ -176,6 +178,9 @@ def make_csv_dataset(
     shuffle_buffer_size=10000,
     shuffle_seed=None,
     prefetch_buffer_size=1,
+    num_parallel_reads=1,
+    num_parallel_parser_calls=2,
+    sloppy=False,
     default_float_type=dtypes.float32,
     num_rows_for_inference=100,
 ):
@@ -231,6 +236,15 @@ def make_csv_dataset(
     prefetch_buffer_size: An int specifying the number of feature batches to
       prefetch for performance improvement. Recommended value is the number of
       batches consumed per training step.
+    num_parallel_reads: Number of threads used to read CSV records from files.
+      If >1, the results will be interleaved.
+    num_parallel_parser_calls: Number of parallel invocations of the CSV parsing
+      function on CSV records.
+    sloppy: If `True`, reading performance will be improved at
+      the cost of non-deterministic ordering. If `False`, the order of elements
+      produced is deterministic prior to shuffling (elements are still
+      randomized if `shuffle=True`. Note that if the seed is set, then order
+      of elements after shuffling is deterministic). Defaults to `False`.
     default_float_type: Either `tf.float32` or `tf.float64`. If defaults are
       not provided, float-like strings are interpreted to be this type.
     num_rows_for_inference: Number of rows of a file to use for type inference
@@ -247,11 +261,16 @@ def make_csv_dataset(
   Raises:
     ValueError: If any of the arguments is malformed.
   """
-  filenames = _get_file_names(file_pattern, shuffle)
+  # Create dataset of all matching filenames
+  filenames = _get_file_names(file_pattern, False)
+  dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
+  if shuffle:
+    dataset = dataset.shuffle(len(filenames), shuffle_seed)
+
+  # Clean arguments; figure out column names and defaults
   if comment is not None and len(comment) != 1:
     raise ValueError("`comment` arg must be a single-character string or None")
 
-  # Clean arguments; figure out column names and defaults
   if column_names is None:
     if not header:
       raise ValueError("Cannot infer column names without a header line.")
@@ -272,7 +291,6 @@ def make_csv_dataset(
         filenames, len(column_names), field_delim, use_quote_delim, na_value,
         header, comment, default_float_type, num_rows_for_inference)
 
-  dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
   if label_name is not None and label_name not in column_names:
     raise ValueError("`label_name` provided must be one of the columns.")
 
@@ -311,16 +329,31 @@ def make_csv_dataset(
       return features, label
     return features
 
-  # TODO(rachelim): interleave records from files for better shuffling
-  dataset = dataset.flat_map(filename_to_dataset)
-  # TODO(rachelim): use fused shuffle_and_repeat for perf
-  if shuffle:
+  # Read files sequentially or in parallel
+  dataset = dataset.apply(
+      interleave_ops.parallel_interleave(
+          filename_to_dataset, cycle_length=num_parallel_reads, sloppy=sloppy))
+
+  if num_epochs != 1 and shuffle:
+    # Use shuffle_and_repeat for perf
+    dataset = dataset.apply(
+        shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs,
+                                       shuffle_seed))
+  elif shuffle:
     dataset = dataset.shuffle(shuffle_buffer_size, shuffle_seed)
-  if num_epochs != 1:
+  elif num_epochs != 1:
     dataset = dataset.repeat(num_epochs)
 
-  dataset = dataset.batch(batch_size)
-  dataset = dataset.map(decode_csv)
+  # Use map_and_batch for perf
+  # TODO(b/76425672): use num_parallel_calls for better performance tuning when
+  # that is added
+  dataset = dataset.apply(
+      batching.map_and_batch(
+          map_func=decode_csv,
+          batch_size=batch_size,
+          num_parallel_batches=int(
+              ceil(num_parallel_parser_calls / batch_size))))
+
   dataset = dataset.prefetch(prefetch_buffer_size)
   return dataset
 
@@ -416,12 +449,10 @@ def make_batched_features_dataset(file_pattern,
     `Tensor` or `SparseTensor` objects.
   """
   # Create dataset of all matching filenames
+  filenames = _get_file_names(file_pattern, False)
+  dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
   if shuffle:
-    dataset = dataset_ops.Dataset.list_files(file_pattern, shuffle=True)
-  else:
-    # TODO(b/73959787): Use Dataset.list_files() once ordering is deterministic.
-    filenames = _get_file_names(file_pattern, shuffle)
-    dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
+    dataset = dataset.shuffle(len(filenames), shuffle_seed)
 
   # Read `Example` records from files as tensor objects.
   if reader_args is None: