[tf.data] better benchmarking code in tests for measuring improvements to csv parsing
authorRachel Lim <rachelim@google.com>
Tue, 29 May 2018 21:22:18 +0000 (14:22 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 29 May 2018 21:24:46 +0000 (14:24 -0700)
PiperOrigin-RevId: 198457501

tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py

index f9f11a1..8c138c7 100644 (file)
@@ -19,6 +19,7 @@ from __future__ import division
 from __future__ import print_function
 
 import os
+import string
 import tempfile
 import time
 
@@ -329,67 +330,93 @@ class CsvDatasetOpTest(test.TestCase):
 class CsvDatasetBenchmark(test.Benchmark):
   """Benchmarks for the various ways of creating a dataset from CSV files.
   """
+  FLOAT_VAL = '1.23456E12'
+  STR_VAL = string.ascii_letters * 10
 
-  def _setUp(self):
+  def _setUp(self, str_val):
     # Since this isn't test.TestCase, have to manually create a test dir
     gfile.MakeDirs(googletest.GetTempDir())
     self._temp_dir = tempfile.mkdtemp(dir=googletest.GetTempDir())
 
     self._num_cols = [4, 64, 256]
-    self._batch_size = 500
+    self._num_per_iter = 5000
     self._filenames = []
     for n in self._num_cols:
       fn = os.path.join(self._temp_dir, 'file%d.csv' % n)
       with open(fn, 'w') as f:
-        # Just write 10 rows and use `repeat`...
-        row = ','.join(['1.23456E12' for _ in range(n)])
-        f.write('\n'.join([row for _ in range(10)]))
+        # Just write 100 rows and use `repeat`... Assumes the cost
+        # of creating an iterator is not significant
+        row = ','.join([str_val for _ in range(n)])
+        f.write('\n'.join([row for _ in range(100)]))
       self._filenames.append(fn)
 
   def _tearDown(self):
     gfile.DeleteRecursively(self._temp_dir)
 
   def _runBenchmark(self, dataset, num_cols, prefix):
-    next_element = dataset.make_one_shot_iterator().get_next()
-    with session.Session() as sess:
-      for _ in range(5):
-        sess.run(next_element)
-      deltas = []
-      for _ in range(10):
+    dataset = dataset.skip(self._num_per_iter - 1)
+    deltas = []
+    for _ in range(10):
+      next_element = dataset.make_one_shot_iterator().get_next()
+      with session.Session() as sess:
         start = time.time()
+        # NOTE: This depends on the underlying implementation of skip, to have
+        # the net effect of calling `GetNext` num_per_iter times on the
+        # input dataset. We do it this way (instead of a python for loop, or
+        # batching N inputs in one iter) so that the overhead from session.run
+        # or batch doesn't dominate. If we eventually optimize skip, this has
+        # to change.
         sess.run(next_element)
         end = time.time()
-        deltas.append(end - start)
-    median_wall_time = np.median(deltas) / 100
+      deltas.append(end - start)
+    # Median wall time per CSV record read and decoded
+    median_wall_time = np.median(deltas) / self._num_per_iter
     print('%s num_cols: %d Median wall time: %f' % (prefix, num_cols,
                                                     median_wall_time))
     self.report_benchmark(
-        iters=self._batch_size,
+        iters=self._num_per_iter,
         wall_time=median_wall_time,
         name='%s_with_cols_%d' % (prefix, num_cols))
 
-  def benchmarkBatchThenMap(self):
-    self._setUp()
+  def benchmarkMapWithFloats(self):
+    self._setUp(self.FLOAT_VAL)
     for i in range(len(self._filenames)):
       num_cols = self._num_cols[i]
       kwargs = {'record_defaults': [[0.0]] * num_cols}
       dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
       dataset = dataset.map(lambda l: gen_parsing_ops.decode_csv(l, **kwargs))  # pylint: disable=cell-var-from-loop
-      dataset = dataset.batch(self._batch_size)
-      self._runBenchmark(dataset, num_cols, 'csv_map_then_batch')
+      self._runBenchmark(dataset, num_cols, 'csv_float_map_decode_csv')
+    self._tearDown()
+
+  def benchmarkMapWithStrings(self):
+    self._setUp(self.STR_VAL)
+    for i in range(len(self._filenames)):
+      num_cols = self._num_cols[i]
+      kwargs = {'record_defaults': [['']] * num_cols}
+      dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
+      dataset = dataset.map(lambda l: gen_parsing_ops.decode_csv(l, **kwargs))  # pylint: disable=cell-var-from-loop
+      self._runBenchmark(dataset, num_cols, 'csv_strings_map_decode_csv')
     self._tearDown()
 
-  def benchmarkCsvDataset(self):
-    self._setUp()
+  def benchmarkCsvDatasetWithFloats(self):
+    self._setUp(self.FLOAT_VAL)
     for i in range(len(self._filenames)):
       num_cols = self._num_cols[i]
       kwargs = {'record_defaults': [[0.0]] * num_cols}
       dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
       dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat()  # pylint: disable=cell-var-from-loop
-      dataset = dataset.batch(self._batch_size)
-      self._runBenchmark(dataset, num_cols, 'csv_fused_dataset')
+      self._runBenchmark(dataset, num_cols, 'csv_float_fused_dataset')
     self._tearDown()
 
+  def benchmarkCsvDatasetWithStrings(self):
+    self._setUp(self.STR_VAL)
+    for i in range(len(self._filenames)):
+      num_cols = self._num_cols[i]
+      kwargs = {'record_defaults': [['']] * num_cols}
+      dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
+      dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat()  # pylint: disable=cell-var-from-loop
+      self._runBenchmark(dataset, num_cols, 'csv_strings_fused_dataset')
+    self._tearDown()
 
 if __name__ == '__main__':
   test.main()