From 86adab02897a4ec4403f1106ba68fffb4f802085 Mon Sep 17 00:00:00 2001 From: Shivani Agrawal Date: Wed, 9 May 2018 12:15:11 -0700 Subject: [PATCH] [tf.data] Saveable iterator for SqlDataset. PiperOrigin-RevId: 196009176 --- tensorflow/contrib/data/python/kernel_tests/BUILD | 1 + .../python/kernel_tests/sql_dataset_op_test.py | 28 ++++++- tensorflow/core/kernels/data/sql_dataset_ops.cc | 89 ++++++++++++++++++---- 3 files changed, 101 insertions(+), 17 deletions(-) diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 7643c2a..9855688 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -407,6 +407,7 @@ py_test( srcs = ["sql_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py index e26cef8..4148add 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py @@ -22,6 +22,7 @@ import os import sqlite3 +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import readers from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -29,7 +30,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class SqlDatasetTest(test.TestCase): +class SqlDatasetTestBase(test.TestCase): def _createSqlDataset(self, output_types, num_repeats=1): dataset = readers.SqlDataset(self.driver_name, self.data_source_name, @@ -92,6 +93,9 @@ class SqlDatasetTest(test.TestCase): conn.commit() conn.close() + +class SqlDatasetTest(SqlDatasetTestBase): + # Test that SqlDataset can read from a database table. def testReadResultSet(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, @@ -652,5 +656,27 @@ class SqlDatasetTest(test.TestCase): sess.run(get_next) +class SqlDatasetSerializationTest( + SqlDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, num_repeats): + data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite") + driver_name = array_ops.placeholder_with_default( + array_ops.constant("sqlite", dtypes.string), shape=[]) + query = ("SELECT first_name, last_name, motto FROM students ORDER BY " + "first_name DESC") + output_types = (dtypes.string, dtypes.string, dtypes.string) + return readers.SqlDataset(driver_name, data_source_name, query, + output_types).repeat(num_repeats) + + def testSQLSaveable(self): + num_repeats = 4 + num_outputs = num_repeats * 2 + self.run_core_tests(lambda: self._build_dataset(num_repeats), + lambda: self._build_dataset(num_repeats // 2), + num_outputs) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/core/kernels/data/sql_dataset_ops.cc b/tensorflow/core/kernels/data/sql_dataset_ops.cc index d50e9c9..634b3c2 100644 --- a/tensorflow/core/kernels/data/sql_dataset_ops.cc +++ b/tensorflow/core/kernels/data/sql_dataset_ops.cc @@ -70,17 +70,19 @@ class SqlDatasetOp : public DatasetOpKernel { "The set of supported databases is: {'sqlite'}.", driver_name.c_str()))); - *output = new Dataset(driver_name, data_source_name, query, output_types_, - output_shapes_); + *output = new Dataset(ctx, driver_name, data_source_name, query, + output_types_, output_shapes_); } private: - class Dataset : public DatasetBase { + class Dataset : public GraphDatasetBase { public: - Dataset(const string& driver_name, const string& data_source_name, - const string& query, const DataTypeVector& output_types, + Dataset(OpKernelContext* ctx, const string& driver_name, + const string& data_source_name, const string& query, + const DataTypeVector& output_types, const std::vector& output_shapes) - : driver_name_(driver_name), + : GraphDatasetBase(ctx), + driver_name_(driver_name), data_source_name_(data_source_name), query_(query), output_types_(output_types), @@ -102,6 +104,21 @@ class SqlDatasetOp : public DatasetOpKernel { string DebugString() override { return "SqlDatasetOp::Dataset"; } + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + Node* driver_name_node; + TF_RETURN_IF_ERROR(b->AddScalar(driver_name_, &driver_name_node)); + Node* data_source_name_node; + TF_RETURN_IF_ERROR( + b->AddScalar(data_source_name_, &data_source_name_node)); + Node* query_node; + TF_RETURN_IF_ERROR(b->AddScalar(query_, &query_node)); + TF_RETURN_IF_ERROR(b->AddDataset( + this, {driver_name_node, data_source_name_node, query_node}, output)); + return Status::OK(); + } + private: class Iterator : public DatasetIterator { public: @@ -121,22 +138,62 @@ class SqlDatasetOp : public DatasetOpKernel { bool* end_of_sequence) override { mutex_lock l(mu_); if (!query_connection_initialized_) { - query_connection_initialized_ = true; - query_connection_ = sql::DriverManager::CreateQueryConnection( - dataset()->driver_name_); - Status s = query_connection_->Open(dataset()->data_source_name_, - dataset()->query_, - dataset()->output_types_); - if (!s.ok()) { - LOG(WARNING) << "Failed to connect to database: " << s; - return s; - } + TF_RETURN_IF_ERROR(InitializeQueryConnection()); } + next_calls_++; return query_connection_->GetNext(ctx, out_tensors, end_of_sequence); } + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + if (query_connection_initialized_) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("next_calls"), next_calls_)); + } + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + if (reader->Contains(full_name("next_calls"))) { + TF_RETURN_IF_ERROR(InitializeQueryConnection()); + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("next_calls"), &next_calls_)); + int64 rem_next_calls = next_calls_; + std::vector out_tensors; + bool end_of_sequence = false; + while (rem_next_calls--) { + TF_RETURN_IF_ERROR(query_connection_->GetNext(ctx, &out_tensors, + &end_of_sequence)); + out_tensors.clear(); + } + } else { + query_connection_initialized_ = false; + } + return Status::OK(); + } + private: + Status InitializeQueryConnection() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + query_connection_initialized_ = true; + query_connection_ = + sql::DriverManager::CreateQueryConnection(dataset()->driver_name_); + Status s = query_connection_->Open(dataset()->data_source_name_, + dataset()->query_, + dataset()->output_types_); + next_calls_ = 0; + if (!s.ok()) { + LOG(WARNING) << "Failed to connect to database: " << s; + return s; + } + return Status::OK(); + } + mutex mu_; + // TODO(shivaniagrawal): explore ways to seek into a SQLite databases. + int64 next_calls_ GUARDED_BY(mu_) = 0; std::unique_ptr query_connection_ GUARDED_BY(mu_); bool query_connection_initialized_ GUARDED_BY(mu_) = false; }; -- 2.7.4