Keep column order in make_csv_dataset.
authorMark Daoust <markdaoust@google.com>
Wed, 23 May 2018 17:01:15 +0000 (10:01 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 23 May 2018 17:06:20 +0000 (10:06 -0700)
PiperOrigin-RevId: 197742412

tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
tensorflow/contrib/data/python/ops/readers.py

index 641a389..f9f11a1 100644 (file)
@@ -308,6 +308,23 @@ class CsvDatasetOpTest(test.TestCase):
         record_defaults=record_defaults,
     )
 
+  def testMakeCsvDataset_fieldOrder(self):
+    data = [[
+        '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19',
+        '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19'
+    ]]
+    file_path = self.setup_files(data)
+
+    with ops.Graph().as_default() as g:
+      ds = readers.make_csv_dataset(
+          file_path, batch_size=1, shuffle=False, num_epochs=1)
+      next_batch = ds.make_one_shot_iterator().get_next()
+
+    with self.test_session(graph=g) as sess:
+      result = list(sess.run(next_batch).values())
+
+    self.assertEqual(result, sorted(result))
+
 
 class CsvDatasetBenchmark(test.Benchmark):
   """Benchmarks for the various ways of creating a dataset from CSV files.
index 75c31a9..f938153 100644 (file)
@@ -17,6 +17,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import collections
 import csv
 
 import numpy as np
@@ -467,11 +468,11 @@ def make_csv_dataset(
     Args:
       *columns: list of `Tensor`s corresponding to one csv record.
     Returns:
-      A dictionary of feature names to values for that particular record. If
+      An OrderedDict of feature names to values for that particular record. If
       label_name is provided, extracts the label feature to be returned as the
       second element of the tuple.
     """
-    features = dict(zip(column_names, columns))
+    features = collections.OrderedDict(zip(column_names, columns))
     if label_name is not None:
       label = features.pop(label_name)
       return features, label