[tf.data] Changed internal implementation of `make_csv_dataset`, and removed argument...
authorRachel Lim <rachelim@google.com>
Fri, 18 May 2018 17:21:59 +0000 (10:21 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 18 May 2018 17:24:33 +0000 (10:24 -0700)
PiperOrigin-RevId: 197164167

tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
tensorflow/contrib/data/python/ops/readers.py

index 1fcb78a..e023719 100644 (file)
@@ -622,14 +622,12 @@ class MakeCsvDatasetTest(test.TestCase):
     f.close()
     return fn
 
-  def _create_file(self, fileno, header=True, comment=True):
+  def _create_file(self, fileno, header=True):
     rows = []
     if header:
       rows.append(self.COLUMNS)
     for recno in range(self._num_records):
       rows.append(self._csv_values(fileno, recno))
-      if comment:
-        rows.append("# Some comment goes here. Ignore me.")
     return self._write_file("csv_file%d.csv" % fileno, rows)
 
   def _create_files(self):
@@ -650,9 +648,7 @@ class MakeCsvDatasetTest(test.TestCase):
       shuffle=False,
       shuffle_seed=None,
       header=True,
-      comment="#",
       na_value="",
-      default_float_type=dtypes.float32,
   ):
     return readers.make_csv_dataset(
         filenames,
@@ -664,9 +660,7 @@ class MakeCsvDatasetTest(test.TestCase):
         shuffle=shuffle,
         shuffle_seed=shuffle_seed,
         header=header,
-        comment=comment,
         na_value=na_value,
-        default_float_type=default_float_type,
         select_columns=select_cols,
     )
 
@@ -788,29 +782,6 @@ class MakeCsvDatasetTest(test.TestCase):
             num_epochs=10,
             label_name=None)
 
-  def testMakeCSVDataset_withNoComments(self):
-    """Tests that datasets can be created from CSV files with no header line.
-    """
-    defaults = self.DEFAULTS
-    file_without_header = self._create_file(
-        len(self._test_filenames), comment=False)
-    with ops.Graph().as_default() as g:
-      with self.test_session(graph=g) as sess:
-        dataset = self._make_csv_dataset(
-            file_without_header,
-            defaults,
-            batch_size=2,
-            num_epochs=10,
-            comment=None,
-        )
-        self._verify_records(
-            sess,
-            dataset,
-            [len(self._test_filenames)],
-            batch_size=2,
-            num_epochs=10,
-        )
-
   def testMakeCSVDataset_withNoHeader(self):
     """Tests that datasets can be created from CSV files with no header line.
     """
@@ -878,7 +849,7 @@ class MakeCsvDatasetTest(test.TestCase):
 
     In that case, we should infer the types from the first N records.
     """
-    # Test that it works with standard test files (with comments, header, etc)
+    # Test that it works with standard test files (with header, etc)
     with ops.Graph().as_default() as g:
       with self.test_session(graph=g) as sess:
         dataset = self._make_csv_dataset(
@@ -891,7 +862,9 @@ class MakeCsvDatasetTest(test.TestCase):
             num_epochs=10,
             defaults=[[], [], [], [], [""]])
 
-    # Test on a deliberately tricky file
+  def testMakeCSVDataset_withTypeInferenceTricky(self):
+    # Test on a deliberately tricky file (type changes as we read more rows, and
+    # there are null values)
     fn = os.path.join(self.get_temp_dir(), "file.csv")
     expected_dtypes = [
         dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float32,
@@ -916,20 +889,29 @@ class MakeCsvDatasetTest(test.TestCase):
             column_names=None,
             label_name=None,
             na_value="NAN",
-            default_float_type=dtypes.float32,
         )
         features = dataset.make_one_shot_iterator().get_next()
         # Check that types match
         for i in range(len(expected_dtypes)):
+          print(features["col%d" % i].dtype, expected_dtypes[i])
           assert features["col%d" % i].dtype == expected_dtypes[i]
         for i in range(len(rows)):
           assert sess.run(features) == dict(zip(col_names, expected[i]))
 
-    # With float64 as default type for floats
+  def testMakeCSVDataset_withTypeInferenceAllTypes(self):
+    # Test that we make the correct inference for all types with fallthrough
+    fn = os.path.join(self.get_temp_dir(), "file.csv")
     expected_dtypes = [
-        dtypes.int32, dtypes.int64, dtypes.float64, dtypes.float64,
+        dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64,
         dtypes.string, dtypes.string
     ]
+    col_names = ["col%d" % i for i in range(len(expected_dtypes))]
+    rows = [[1, 2**31 + 1, 1.0, 4e40, "abc", ""]]
+    expected = [[
+        1, 2**31 + 1, 1.0, 4e40, "abc".encode("utf-8"), "".encode("utf-8")
+    ]]
+    self._write_file("file.csv", [col_names] + rows)
+
     with ops.Graph().as_default() as g:
       with self.test_session(graph=g) as sess:
         dataset = self._make_csv_dataset(
@@ -938,7 +920,6 @@ class MakeCsvDatasetTest(test.TestCase):
             column_names=None,
             label_name=None,
             na_value="NAN",
-            default_float_type=dtypes.float64,
         )
         features = dataset.make_one_shot_iterator().get_next()
         # Check that types match
index 2c57d11..75c31a9 100644 (file)
@@ -18,7 +18,6 @@ from __future__ import division
 from __future__ import print_function
 
 import csv
-from math import ceil
 
 import numpy as np
 
@@ -36,9 +35,7 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.lib.io import file_io
 from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import parsing_ops
-from tensorflow.python.ops import string_ops
 from tensorflow.python.platform import gfile
 from tensorflow.python.util import deprecation
 
@@ -70,7 +67,7 @@ def _is_valid_float(str_val, float_dtype):
     return False
 
 
-def _infer_type(str_val, na_value, prev_type, float_dtype):
+def _infer_type(str_val, na_value, prev_type):
   """Given a string, infers its tensor type.
 
   Infers the type of a value by picking the least 'permissive' type possible,
@@ -81,29 +78,33 @@ def _infer_type(str_val, na_value, prev_type, float_dtype):
     na_value: Additional string to recognize as a NA/NaN CSV value.
     prev_type: Type previously inferred based on values of this column that
       we've seen up till now.
-    float_dtype: Either `tf.float32` or `tf.float64`. Denotes what float type
-      to parse float strings as.
   Returns:
     Inferred dtype.
   """
   if str_val in ("", na_value):
+    # If the field is null, it gives no extra information about its type
     return prev_type
 
-  if _is_valid_int32(str_val) and prev_type in (None, dtypes.int32):
-    return dtypes.int32
+  type_list = [
+      dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string
+  ]  # list of types to try, ordered from least permissive to most
 
-  if _is_valid_int64(str_val) and prev_type in (None, dtypes.int32,
-                                                dtypes.int64):
-    return dtypes.int64
+  type_functions = [
+      _is_valid_int32,
+      _is_valid_int64,
+      lambda str_val: _is_valid_float(str_val, dtypes.float32),
+      lambda str_val: _is_valid_float(str_val, dtypes.float64),
+      lambda str_val: True,
+  ]  # Corresponding list of validation functions
 
-  if _is_valid_float(str_val, float_dtype) and prev_type != dtypes.string:
-    return float_dtype
+  for i in range(len(type_list)):
+    validation_fn = type_functions[i]
+    if validation_fn(str_val) and (prev_type is None or
+                                   prev_type in type_list[:i + 1]):
+      return type_list[i]
 
-  return dtypes.string
 
-
-def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header,
-                  comment):
+def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header):
   """Generator that yields rows of CSV file(s) in order."""
   for fn in filenames:
     with file_io.FileIO(fn, "r") as f:
@@ -115,9 +116,6 @@ def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header,
         next(rdr)  # Skip header lines
 
       for csv_row in rdr:
-        if comment is not None and csv_row[0].startswith(comment):
-          continue  # Skip comment lines
-
         if len(csv_row) != num_cols:
           raise ValueError(
               "Problem inferring types: CSV row has different number of fields "
@@ -126,22 +124,21 @@ def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header,
 
 
 def _infer_column_defaults(filenames, num_cols, field_delim, use_quote_delim,
-                           na_value, header, comment, float_dtype,
-                           num_rows_for_inference, select_columns):
+                           na_value, header, num_rows_for_inference,
+                           select_columns):
   """Infers column types from the first N valid CSV records of files."""
   if select_columns is None:
     select_columns = range(num_cols)
   inferred_types = [None] * len(select_columns)
 
   for i, csv_row in enumerate(
-      _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header,
-                    comment)):
+      _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header)):
     if num_rows_for_inference is not None and i >= num_rows_for_inference:
       break
 
     for j, col_index in enumerate(select_columns):
       inferred_types[j] = _infer_type(csv_row[col_index], na_value,
-                                      inferred_types[j], float_dtype)
+                                      inferred_types[j])
 
   # Replace None's with a default type
   inferred_types = [t or dtypes.string for t in inferred_types]
@@ -318,7 +315,6 @@ def make_csv_dataset(
     use_quote_delim=True,
     na_value="",
     header=True,
-    comment=None,
     num_epochs=None,
     shuffle=True,
     shuffle_buffer_size=10000,
@@ -327,7 +323,6 @@ def make_csv_dataset(
     num_parallel_reads=1,
     num_parallel_parser_calls=2,
     sloppy=False,
-    default_float_type=dtypes.float32,
     num_rows_for_inference=100,
 ):
   """Reads CSV files into a dataset.
@@ -381,9 +376,6 @@ def make_csv_dataset(
     header: A bool that indicates whether the first rows of provided CSV files
       correspond to header lines with column names, and should not be included
       in the data.
-    comment: An optional character string that marks lines that should not be
-      parsed as csv records. If this is provided, all lines that start with
-      this character will not be parsed.
     num_epochs: An int specifying the number of times this dataset is repeated.
       If None, cycles through the dataset forever.
     shuffle: A bool that indicates whether the input should be shuffled.
@@ -402,8 +394,6 @@ def make_csv_dataset(
       produced is deterministic prior to shuffling (elements are still
       randomized if `shuffle=True`. Note that if the seed is set, then order
       of elements after shuffling is deterministic). Defaults to `False`.
-    default_float_type: Either `tf.float32` or `tf.float64`. If defaults are
-      not provided, float-like strings are interpreted to be this type.
     num_rows_for_inference: Number of rows of a file to use for type inference
       if record_defaults is not provided. If None, reads all the rows of all
       the files. Defaults to 100.
@@ -425,8 +415,6 @@ def make_csv_dataset(
     dataset = dataset.shuffle(len(filenames), shuffle_seed)
 
   # Clean arguments; figure out column names and defaults
-  if comment is not None and len(comment) != 1:
-    raise ValueError("`comment` arg must be a single-character string or None")
 
   if column_names is None:
     if not header:
@@ -449,8 +437,7 @@ def make_csv_dataset(
     # construction time
     column_defaults = _infer_column_defaults(
         filenames, len(column_names), field_delim, use_quote_delim, na_value,
-        header, comment, default_float_type, num_rows_for_inference,
-        select_columns)
+        header, num_rows_for_inference, select_columns)
 
   if select_columns is not None and len(column_defaults) != len(select_columns):
     raise ValueError(
@@ -464,43 +451,33 @@ def make_csv_dataset(
   if label_name is not None and label_name not in column_names:
     raise ValueError("`label_name` provided must be one of the columns.")
 
-  # Define map and filter functions
-  def filter_fn(line):
-    return math_ops.not_equal(string_ops.substr(line, 0, 1), comment)
-
   def filename_to_dataset(filename):
-    ds = core_readers.TextLineDataset(filename)
-    if header:
-      ds = ds.skip(1)
-    if comment is not None:
-      ds = ds.filter(filter_fn)
-    return ds
+    return CsvDataset(
+        filename,
+        record_defaults=column_defaults,
+        field_delim=field_delim,
+        use_quote_delim=use_quote_delim,
+        na_value=na_value,
+        select_cols=select_columns,
+        header=header)
 
-  def decode_csv(line):
-    """Decodes CSV line into features.
+  def map_fn(*columns):
+    """Organizes columns into a features dictionary.
 
     Args:
-      line: String tensor corresponding to one csv record.
+      *columns: list of `Tensor`s corresponding to one csv record.
     Returns:
       A dictionary of feature names to values for that particular record. If
       label_name is provided, extracts the label feature to be returned as the
       second element of the tuple.
     """
-    columns = parsing_ops.decode_csv(
-        line,
-        column_defaults,
-        field_delim=field_delim,
-        use_quote_delim=use_quote_delim,
-        na_value=na_value,
-        select_cols=select_columns,
-    )
     features = dict(zip(column_names, columns))
     if label_name is not None:
       label = features.pop(label_name)
       return features, label
     return features
 
-  # Read files sequentially or in parallel
+  # Read files sequentially (if num_parallel_reads=1) or in parallel
   dataset = dataset.apply(
       interleave_ops.parallel_interleave(
           filename_to_dataset, cycle_length=num_parallel_reads, sloppy=sloppy))
@@ -508,17 +485,12 @@ def make_csv_dataset(
   dataset = _maybe_shuffle_and_repeat(
       dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
 
-  # Use map_and_batch for perf
-  # TODO(b/76425672): use num_parallel_calls for better performance tuning when
-  # that is added
-  dataset = dataset.apply(
-      batching.map_and_batch(
-          map_func=decode_csv,
-          batch_size=batch_size,
-          num_parallel_batches=int(
-              ceil(num_parallel_parser_calls / batch_size))))
-
+  # Apply batch before map for perf, because map has high overhead relative
+  # to the size of the computation in each map
+  dataset = dataset.batch(batch_size=batch_size)
+  dataset = dataset.map(map_fn, num_parallel_calls=num_parallel_parser_calls)
   dataset = dataset.prefetch(prefetch_buffer_size)
+
   return dataset