From 896a6d74959c02b5c41087f96e77ef166fe484e3 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Wed, 23 May 2018 10:01:15 -0700 Subject: [PATCH] Keep column order in make_csv_dataset. PiperOrigin-RevId: 197742412 --- .../data/python/kernel_tests/csv_dataset_op_test.py | 17 +++++++++++++++++ tensorflow/contrib/data/python/ops/readers.py | 5 +++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py index 641a389..f9f11a1 100644 --- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py @@ -308,6 +308,23 @@ class CsvDatasetOpTest(test.TestCase): record_defaults=record_defaults, ) + def testMakeCsvDataset_fieldOrder(self): + data = [[ + '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19', + '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19' + ]] + file_path = self.setup_files(data) + + with ops.Graph().as_default() as g: + ds = readers.make_csv_dataset( + file_path, batch_size=1, shuffle=False, num_epochs=1) + next_batch = ds.make_one_shot_iterator().get_next() + + with self.test_session(graph=g) as sess: + result = list(sess.run(next_batch).values()) + + self.assertEqual(result, sorted(result)) + class CsvDatasetBenchmark(test.Benchmark): """Benchmarks for the various ways of creating a dataset from CSV files. diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 75c31a9..f938153 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import csv import numpy as np @@ -467,11 +468,11 @@ def make_csv_dataset( Args: *columns: list of `Tensor`s corresponding to one csv record. Returns: - A dictionary of feature names to values for that particular record. If + An OrderedDict of feature names to values for that particular record. If label_name is provided, extracts the label feature to be returned as the second element of the tuple. """ - features = dict(zip(column_names, columns)) + features = collections.OrderedDict(zip(column_names, columns)) if label_name is not None: label = features.pop(label_name) return features, label -- 2.7.4