Add reservoir sampling to DB summary writer
authorJustine Tunney <jart@google.com>
Fri, 12 Jan 2018 00:08:50 +0000 (16:08 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 12 Jan 2018 00:12:46 +0000 (16:12 -0800)
This thing is kind of cool. It's able to turn a 350mB event log into a
35mB SQLite file at 80mBps with one Macbook core. Best of all, this was
accomplished using a normalized schema without the embedded protos.

PiperOrigin-RevId: 181676380

12 files changed:
tensorflow/contrib/summary/summary_ops_test.py
tensorflow/contrib/tensorboard/db/BUILD
tensorflow/contrib/tensorboard/db/schema.cc
tensorflow/contrib/tensorboard/db/summary_db_writer.cc
tensorflow/contrib/tensorboard/db/summary_db_writer.h
tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
tensorflow/core/kernels/BUILD
tensorflow/core/kernels/data/sql/sqlite_query_connection.cc
tensorflow/core/kernels/summary_kernels.cc
tensorflow/core/lib/db/BUILD
tensorflow/core/lib/db/sqlite.cc
tensorflow/core/lib/db/sqlite.h

index 4ef03434b76ee04ce1bb0bd09c27a46db115bab3..dfaa4182bb867cc03480320eaf1804da36206655 100644 (file)
@@ -18,12 +18,14 @@ from __future__ import print_function
 
 import tempfile
 
+import numpy as np
 import six
 
 from tensorflow.contrib.summary import summary_ops
 from tensorflow.contrib.summary import summary_test_util
 from tensorflow.core.framework import graph_pb2
 from tensorflow.core.framework import node_def_pb2
+from tensorflow.core.framework import types_pb2
 from tensorflow.python.eager import function
 from tensorflow.python.eager import test
 from tensorflow.python.framework import dtypes
@@ -37,6 +39,23 @@ from tensorflow.python.training import training_util
 get_all = summary_test_util.get_all
 get_one = summary_test_util.get_one
 
+_NUMPY_NUMERIC_TYPES = {
+    types_pb2.DT_HALF: np.float16,
+    types_pb2.DT_FLOAT: np.float32,
+    types_pb2.DT_DOUBLE: np.float64,
+    types_pb2.DT_INT8: np.int8,
+    types_pb2.DT_INT16: np.int16,
+    types_pb2.DT_INT32: np.int32,
+    types_pb2.DT_INT64: np.int64,
+    types_pb2.DT_UINT8: np.uint8,
+    types_pb2.DT_UINT16: np.uint16,
+    types_pb2.DT_UINT32: np.uint32,
+    types_pb2.DT_UINT64: np.uint64,
+    types_pb2.DT_COMPLEX64: np.complex64,
+    types_pb2.DT_COMPLEX128: np.complex128,
+    types_pb2.DT_BOOL: np.bool_,
+}
+
 
 class TargetTest(test_util.TensorFlowTestCase):
 
@@ -154,8 +173,9 @@ class DbTest(summary_test_util.SummaryDbTest):
       with writer.as_default():
         self.assertEqual(5, adder(int64(2), int64(3)).numpy())
 
-    six.assertCountEqual(self, [1, 1, 1],
-                         get_all(self.db, 'SELECT step FROM Tensors'))
+    six.assertCountEqual(
+        self, [1, 1, 1],
+        get_all(self.db, 'SELECT step FROM Tensors WHERE dtype IS NOT NULL'))
     six.assertCountEqual(self, ['x', 'y', 'sum'],
                          get_all(self.db, 'SELECT tag_name FROM Tags'))
     x_id = get_one(self.db, 'SELECT tag_id FROM Tags WHERE tag_name = "x"')
@@ -166,8 +186,9 @@ class DbTest(summary_test_util.SummaryDbTest):
       with writer.as_default():
         self.assertEqual(9, adder(int64(4), int64(5)).numpy())
 
-    six.assertCountEqual(self, [1, 1, 1, 2, 2, 2],
-                         get_all(self.db, 'SELECT step FROM Tensors'))
+    six.assertCountEqual(
+        self, [1, 1, 1, 2, 2, 2],
+        get_all(self.db, 'SELECT step FROM Tensors WHERE dtype IS NOT NULL'))
     six.assertCountEqual(self, [x_id, y_id, sum_id],
                          get_all(self.db, 'SELECT tag_id FROM Tags'))
     self.assertEqual(2, get_tensor(self.db, x_id, 1))
@@ -212,9 +233,15 @@ class DbTest(summary_test_util.SummaryDbTest):
 
 
 def get_tensor(db, tag_id, step):
-  return get_one(
-      db, 'SELECT tensor FROM Tensors WHERE tag_id = ? AND step = ?', tag_id,
-      step)
+  cursor = db.execute(
+      'SELECT dtype, shape, data FROM Tensors WHERE series = ? AND step = ?',
+      (tag_id, step))
+  dtype, shape, data = cursor.fetchone()
+  assert dtype in _NUMPY_NUMERIC_TYPES
+  buf = np.frombuffer(data, dtype=_NUMPY_NUMERIC_TYPES[dtype])
+  if not shape:
+    return buf[0]
+  return buf.reshape([int(i) for i in shape.split(',')])
 
 
 def int64(x):
index 3a3402c59b1ed5d3c4e97674cb7b5ba8f44b6601..4c9cc4ccd6e93151618d203a104217c90ad9a526 100644 (file)
@@ -5,12 +5,13 @@ package(default_visibility = ["//tensorflow:internal"])
 
 licenses(["notice"])  # Apache 2.0
 
-load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_copts")
 
 cc_library(
     name = "schema",
     srcs = ["schema.cc"],
     hdrs = ["schema.h"],
+    copts = tf_copts(),
     deps = [
         "//tensorflow/core:lib",
         "//tensorflow/core/lib/db:sqlite",
@@ -32,8 +33,10 @@ cc_library(
     name = "summary_db_writer",
     srcs = ["summary_db_writer.cc"],
     hdrs = ["summary_db_writer.h"],
+    copts = tf_copts(),
     deps = [
         ":schema",
+        "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:protos_all_cc",
@@ -47,6 +50,7 @@ tf_cc_test(
     size = "small",
     srcs = ["summary_db_writer_test.cc"],
     deps = [
+        ":schema",
         ":summary_db_writer",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
index 2cd00876f8b95c04c5c131b1c683a6956b934bc4..6ccd386dc0f6da65e3ae1e5016670e0d56c7bc53 100644 (file)
@@ -22,8 +22,7 @@ namespace {
 Status Run(Sqlite* db, const char* sql) {
   SqliteStatement stmt;
   TF_RETURN_IF_ERROR(db->Prepare(sql, &stmt));
-  TF_RETURN_IF_ERROR(stmt.StepAndReset());
-  return Status::OK();
+  return stmt.StepAndReset();
 }
 
 }  // namespace
@@ -38,37 +37,34 @@ Status SetupTensorboardSqliteDb(Sqlite* db) {
   db->PrepareOrDie("PRAGMA user_version=0").StepAndResetOrDie();
   Status s;
 
-  // Creates Ids table.
+  // Ids identify resources.
   //
-  // This table must be used to randomly allocate Permanent IDs for
-  // all top-level tables, in order to maintain an invariant where
-  // foo_id != bar_id for all IDs of any two tables.
+  // This table can be used to efficiently generate Permanent IDs in
+  // conjunction with a random number generator. Unlike rowids these
+  // IDs safe to use in URLs and unique across tables.
   //
-  // A row should only be deleted from this table if it can be
-  // guaranteed that it exists absolutely nowhere else in the entire
-  // system.
+  // Within any given system, there can't be any foo_id == bar_id for
+  // all rows of any two (Foos, Bars) tables. A row should only be
+  // deleted from this table if there's a very high level of confidence
+  // it exists nowhere else in the system.
   //
   // Fields:
-  //   id: An ID that was allocated globally. This must be in the
-  //     range [1,2**47). 0 is assigned the same meaning as NULL and
-  //     shouldn't be stored; 2**63-1 is reserved for statically
-  //     allocating space in a page to UPDATE later; and all other
-  //     int64 values are reserved for future use.
+  //   id: The system-wide ID. This must be in the range [1,2**47). 0
+  //     is assigned the same meaning as NULL and shouldn't be stored
+  //     and all other int64 values are reserved for future use. Please
+  //     note that id is also the rowid.
   s.Update(Run(db, R"sql(
     CREATE TABLE IF NOT EXISTS Ids (
       id INTEGER PRIMARY KEY
     )
   )sql"));
 
-  // Creates Descriptions table.
-  //
-  // This table allows TensorBoard to associate Markdown text with any
-  // object in the database that has a Permanent ID.
+  // Descriptions are Markdown text that can be associated with any
+  // resource that has a Permanent ID.
   //
   // Fields:
-  //   id: The Permanent ID of the associated object. This is also the
-  //     SQLite rowid.
-  //   description: Arbitrary Markdown text.
+  //   id: The foo_id of the associated row in Foos.
+  //   description: Arbitrary NUL-terminated Markdown text.
   s.Update(Run(db, R"sql(
     CREATE TABLE IF NOT EXISTS Descriptions (
       id INTEGER PRIMARY KEY,
@@ -76,121 +72,136 @@ Status SetupTensorboardSqliteDb(Sqlite* db) {
     )
   )sql"));
 
-  // Creates Tensors table.
+  // Tensors are 0..n-dimensional numbers or strings.
   //
   // Fields:
-  //   rowid: Ephemeral b-tree ID dictating locality.
-  //   tag_id: ID of associated Tag.
+  //   rowid: Ephemeral b-tree ID.
+  //   series: The Permanent ID of a different resource, e.g. tag_id. A
+  //     tensor will be vacuumed if no series == foo_id exists for all
+  //     rows of all Foos. When series is NULL this tensor may serve
+  //     undefined purposes. This field should be set on placeholders.
+  //   step: Arbitrary number to uniquely order tensors within series.
+  //     The meaning of step is undefined when series is NULL. This may
+  //     be set on placeholders to prepopulate index pages.
   //   computed_time: Float UNIX timestamp with microsecond precision.
   //     In the old summaries system that uses FileWriter, this is the
   //     wall time around when tf.Session.run finished. In the new
   //     summaries system, it is the wall time of when the tensor was
   //     computed. On systems with monotonic clocks, it is calculated
   //     by adding the monotonic run duration to Run.started_time.
-  //     This field is not indexed because, in practice, it should be
-  //     ordered the same or nearly the same as TensorIndex, so local
-  //     insertion sort might be more suitable.
-  //   step: User-supplied number, ordering this tensor in Tag.
-  //     If NULL then the Tag must have only one Tensor.
-  //   tensor: Can be an INTEGER (DT_INT64), FLOAT (DT_DOUBLE), or
-  //     BLOB. The structure of a BLOB is currently undefined, but in
-  //     essence it is a Snappy tf.TensorProto that spills over into
-  //     TensorChunks.
+  //   dtype: The tensorflow::DataType ID. For example, DT_INT64 is 9.
+  //     When NULL or 0 this must be treated as a placeholder row that
+  //     does not officially exist.
+  //   shape: A comma-delimited list of int64 >=0 values representing
+  //     length of each dimension in the tensor. This must be a valid
+  //     shape. That means no -1 values and, in the case of numeric
+  //     tensors, length(data) == product(shape) * sizeof(dtype). Empty
+  //     means this is a scalar a.k.a. 0-dimensional tensor.
+  //   data: Little-endian raw tensor memory. If dtype is DT_STRING and
+  //     shape is empty, the nullness of this field indicates whether or
+  //     not it contains the tensor contents; otherwise TensorStrings
+  //     must be queried. If dtype is NULL then ZEROBLOB can be used on
+  //     this field to reserve row space to be updated later.
   s.Update(Run(db, R"sql(
     CREATE TABLE IF NOT EXISTS Tensors (
       rowid INTEGER PRIMARY KEY,
-      tag_id INTEGER NOT NULL,
-      computed_time REAL,
+      series INTEGER,
       step INTEGER,
-      tensor BLOB
+      dtype INTEGER,
+      computed_time REAL,
+      shape TEXT,
+      data BLOB
     )
   )sql"));
 
-  // Uniquely indexes (tag_id, step) on Tensors table.
   s.Update(Run(db, R"sql(
-    CREATE UNIQUE INDEX IF NOT EXISTS TensorIndex
-    ON Tensors (tag_id, step)
+    CREATE UNIQUE INDEX IF NOT EXISTS
+      TensorSeriesStepIndex
+    ON
+      Tensors (series, step)
+    WHERE
+      series IS NOT NULL
+      AND step IS NOT NULL
   )sql"));
 
-  // Creates TensorChunks table.
+  // TensorStrings are the flat contents of 1..n dimensional DT_STRING
+  // Tensors.
   //
-  // This table can be used to split up a tensor across many rows,
-  // which has the advantage of not slowing down table scans on the
-  // main table, allowing asynchronous fetching, minimizing copying,
-  // and preventing large buffers from being allocated.
+  // The number of rows associated with a Tensor must be equal to the
+  // product of its Tensors.shape.
   //
   // Fields:
-  //   rowid: Ephemeral b-tree ID dictating locality.
-  //   tag_id: ID of associated Tag.
-  //   step: Same as corresponding Tensors.step.
-  //   sequence: 1-indexed sequence number for ordering chunks. Please
-  //     note that the 0th index is Tensors.tensor.
-  //   chunk: Bytes of next chunk in tensor.
+  //   rowid: Ephemeral b-tree ID.
+  //   tensor_rowid: References Tensors.rowid.
+  //   idx: Index in flattened tensor, starting at 0.
+  //   data: The string value at a particular index. NUL characters are
+  //     permitted.
   s.Update(Run(db, R"sql(
-    CREATE TABLE IF NOT EXISTS TensorChunks (
+    CREATE TABLE IF NOT EXISTS TensorStrings (
       rowid INTEGER PRIMARY KEY,
-      tag_id INTEGER NOT NULL,
-      step INTEGER,
-      sequence INTEGER,
-      chunk BLOB
+      tensor_rowid INTEGER NOT NULL,
+      idx INTEGER NOT NULL,
+      data BLOB
     )
   )sql"));
 
-  // Uniquely indexes (tag_id, step, sequence) on TensorChunks table.
   s.Update(Run(db, R"sql(
-    CREATE UNIQUE INDEX IF NOT EXISTS TensorChunkIndex
-    ON TensorChunks (tag_id, step, sequence)
+    CREATE UNIQUE INDEX IF NOT EXISTS TensorStringIndex
+    ON TensorStrings (tensor_rowid, idx)
   )sql"));
 
-  // Creates Tags table.
+  // Tags are series of Tensors.
   //
   // Fields:
-  //   rowid: Ephemeral b-tree ID dictating locality.
+  //   rowid: Ephemeral b-tree ID.
   //   tag_id: The Permanent ID of the Tag.
   //   run_id: Optional ID of associated Run.
-  //   tag_name: The tag field in summary.proto, unique across Run.
   //   inserted_time: Float UNIX timestamp with µs precision. This is
   //     always the wall time of when the row was inserted into the
   //     DB. It may be used as a hint for an archival job.
+  //   tag_name: The tag field in summary.proto, unique across Run.
   //   display_name: Optional for GUI and defaults to tag_name.
   //   plugin_name: Arbitrary TensorBoard plugin name for dispatch.
   //   plugin_data: Arbitrary data that plugin wants.
+  //
+  // TODO(jart): Maybe there should be a Plugins table?
   s.Update(Run(db, R"sql(
     CREATE TABLE IF NOT EXISTS Tags (
       rowid INTEGER PRIMARY KEY,
       run_id INTEGER,
       tag_id INTEGER NOT NULL,
-      tag_name TEXT,
       inserted_time DOUBLE,
+      tag_name TEXT,
       display_name TEXT,
       plugin_name TEXT,
       plugin_data BLOB
     )
   )sql"));
 
-  // Uniquely indexes tag_id on Tags table.
   s.Update(Run(db, R"sql(
     CREATE UNIQUE INDEX IF NOT EXISTS TagIdIndex
     ON Tags (tag_id)
   )sql"));
 
-  // Uniquely indexes (run_id, tag_name) on Tags table.
   s.Update(Run(db, R"sql(
-    CREATE UNIQUE INDEX IF NOT EXISTS TagNameIndex
-    ON Tags (run_id, tag_name)
-    WHERE tag_name IS NOT NULL
+    CREATE UNIQUE INDEX IF NOT EXISTS
+      TagRunNameIndex
+    ON
+      Tags (run_id, tag_name)
+    WHERE
+      run_id IS NOT NULL
+      AND tag_name IS NOT NULL
   )sql"));
 
-  // Creates Runs table.
+  // Runs are groups of Tags.
   //
-  // This table stores information about Runs. Each row usually
-  // represents a single attempt at training or testing a TensorFlow
-  // model, with a given set of hyper-parameters, whose summaries are
-  // written out to a single event logs directory with a monotonic step
-  // counter.
+  // Each Run usually represents a single attempt at training or testing
+  // a TensorFlow model, with a given set of hyper-parameters, whose
+  // summaries are written out to a single event logs directory with a
+  // monotonic step counter.
   //
   // Fields:
-  //   rowid: Ephemeral b-tree ID dictating locality.
+  //   rowid: Ephemeral b-tree ID.
   //   run_id: The Permanent ID of the Run. This has a 1:1 mapping
   //     with a SummaryWriter instance. If two writers spawn for a
   //     given (user_name, run_name, run_name) then each should
@@ -199,8 +210,8 @@ Status SetupTensorboardSqliteDb(Sqlite* db) {
   //     previous invocations will then enter limbo, where they may be
   //     accessible for certain operations, but should be garbage
   //     collected eventually.
-  //   experiment_id: Optional ID of associated Experiment.
   //   run_name: User-supplied string, unique across Experiment.
+  //   experiment_id: Optional ID of associated Experiment.
   //   inserted_time: Float UNIX timestamp with µs precision. This is
   //     always the time the row was inserted into the database. It
   //     does not change.
@@ -215,40 +226,33 @@ Status SetupTensorboardSqliteDb(Sqlite* db) {
   //     SummaryWriter resource that created this run was destroyed.
   //     Once this value becomes non-NULL a Run and its Tags and
   //     Tensors should be regarded as immutable.
-  //   graph_id: ID of associated Graphs row.
   s.Update(Run(db, R"sql(
     CREATE TABLE IF NOT EXISTS Runs (
       rowid INTEGER PRIMARY KEY,
       experiment_id INTEGER,
       run_id INTEGER NOT NULL,
-      run_name TEXT,
       inserted_time REAL,
       started_time REAL,
       finished_time REAL,
-      graph_id INTEGER
+      run_name TEXT
     )
   )sql"));
 
-  // Uniquely indexes run_id on Runs table.
   s.Update(Run(db, R"sql(
     CREATE UNIQUE INDEX IF NOT EXISTS RunIdIndex
     ON Runs (run_id)
   )sql"));
 
-  // Uniquely indexes (experiment_id, run_name) on Runs table.
   s.Update(Run(db, R"sql(
     CREATE UNIQUE INDEX IF NOT EXISTS RunNameIndex
     ON Runs (experiment_id, run_name)
     WHERE run_name IS NOT NULL
   )sql"));
 
-  // Creates Experiments table.
-  //
-  // This table stores information about experiments, which are sets of
-  // runs.
+  // Experiments are groups of Runs.
   //
   // Fields:
-  //   rowid: Ephemeral b-tree ID dictating locality.
+  //   rowid: Ephemeral b-tree ID.
   //   user_id: Optional ID of associated User.
   //   experiment_id: The Permanent ID of the Experiment.
   //   experiment_name: User-supplied string, unique across User.
@@ -259,34 +263,39 @@ Status SetupTensorboardSqliteDb(Sqlite* db) {
   //     the MIN(experiment.started_time, run.started_time) of each
   //     Run added to the database, including Runs which have since
   //     been overwritten.
+  //   is_watching: A boolean indicating if someone is actively
+  //     looking at this Experiment in the TensorBoard GUI. Tensor
+  //     writers that do reservoir sampling can query this value to
+  //     decide if they want the "keep last" behavior. This improves
+  //     the performance of long running training while allowing low
+  //     latency feedback in TensorBoard.
   s.Update(Run(db, R"sql(
     CREATE TABLE IF NOT EXISTS Experiments (
       rowid INTEGER PRIMARY KEY,
       user_id INTEGER,
       experiment_id INTEGER NOT NULL,
-      experiment_name TEXT,
       inserted_time REAL,
-      started_time REAL
+      started_time REAL,
+      is_watching INTEGER,
+      experiment_name TEXT
     )
   )sql"));
 
-  // Uniquely indexes experiment_id on Experiments table.
   s.Update(Run(db, R"sql(
     CREATE UNIQUE INDEX IF NOT EXISTS ExperimentIdIndex
     ON Experiments (experiment_id)
   )sql"));
 
-  // Uniquely indexes (user_id, experiment_name) on Experiments table.
   s.Update(Run(db, R"sql(
     CREATE UNIQUE INDEX IF NOT EXISTS ExperimentNameIndex
     ON Experiments (user_id, experiment_name)
     WHERE experiment_name IS NOT NULL
   )sql"));
 
-  // Creates Users table.
+  // Users are people who love TensorBoard.
   //
   // Fields:
-  //   rowid: Ephemeral b-tree ID dictating locality.
+  //   rowid: Ephemeral b-tree ID.
   //   user_id: The Permanent ID of the User.
   //   user_name: Unique user name.
   //   email: Optional unique email address.
@@ -297,61 +306,66 @@ Status SetupTensorboardSqliteDb(Sqlite* db) {
     CREATE TABLE IF NOT EXISTS Users (
       rowid INTEGER PRIMARY KEY,
       user_id INTEGER NOT NULL,
+      inserted_time REAL,
       user_name TEXT,
-      email TEXT,
-      inserted_time REAL
+      email TEXT
     )
   )sql"));
 
-  // Uniquely indexes user_id on Users table.
   s.Update(Run(db, R"sql(
     CREATE UNIQUE INDEX IF NOT EXISTS UserIdIndex
     ON Users (user_id)
   )sql"));
 
-  // Uniquely indexes user_name on Users table.
   s.Update(Run(db, R"sql(
     CREATE UNIQUE INDEX IF NOT EXISTS UserNameIndex
     ON Users (user_name)
     WHERE user_name IS NOT NULL
   )sql"));
 
-  // Uniquely indexes email on Users table.
   s.Update(Run(db, R"sql(
     CREATE UNIQUE INDEX IF NOT EXISTS UserEmailIndex
     ON Users (email)
     WHERE email IS NOT NULL
   )sql"));
 
-  // Creates Graphs table.
+  // Graphs define how Tensors flowed in Runs.
   //
   // Fields:
-  //   rowid: Ephemeral b-tree ID dictating locality.
+  //   rowid: Ephemeral b-tree ID.
+  //   run_id: The Permanent ID of the associated Run. Only one Graph
+  //     can be associated with a Run.
   //   graph_id: The Permanent ID of the Graph.
   //   inserted_time: Float UNIX timestamp with µs precision. This is
   //     always the wall time of when the row was inserted into the
   //     DB. It may be used as a hint for an archival job.
-  //   node_def: Contains Snappy tf.GraphDef proto. All fields will be
-  //     cleared except those not expressed in SQL.
+  //   node_def: Contains tf.GraphDef proto. All fields will be cleared
+  //     except those not expressed in SQL.
   s.Update(Run(db, R"sql(
     CREATE TABLE IF NOT EXISTS Graphs (
       rowid INTEGER PRIMARY KEY,
+      run_id INTEGER,
       graph_id INTEGER NOT NULL,
       inserted_time REAL,
       graph_def BLOB
     )
   )sql"));
 
-  // Uniquely indexes graph_id on Graphs table.
   s.Update(Run(db, R"sql(
     CREATE UNIQUE INDEX IF NOT EXISTS GraphIdIndex
     ON Graphs (graph_id)
   )sql"));
 
-  // Creates Nodes table.
+  s.Update(Run(db, R"sql(
+    CREATE UNIQUE INDEX IF NOT EXISTS GraphRunIndex
+    ON Graphs (run_id)
+    WHERE run_id IS NOT NULL
+  )sql"));
+
+  // Nodes are the vertices in Graphs.
   //
   // Fields:
-  //   rowid: Ephemeral b-tree ID dictating locality.
+  //   rowid: Ephemeral b-tree ID.
   //   graph_id: The Permanent ID of the associated Graph.
   //   node_id: ID for this node. This is more like a 0-index within
   //     the Graph. Please note indexes are allowed to be removed.
@@ -361,8 +375,10 @@ Status SetupTensorboardSqliteDb(Sqlite* db) {
   //     node_def.name proto field must not be cleared.
   //   op: Copied from tf.NodeDef proto.
   //   device: Copied from tf.NodeDef proto.
-  //   node_def: Contains Snappy tf.NodeDef proto. All fields will be
-  //     cleared except those not expressed in SQL.
+  //   node_def: Contains tf.NodeDef proto. All fields will be cleared
+  //     except those not expressed in SQL.
+  //
+  // TODO(jart): Make separate tables for op and device strings.
   s.Update(Run(db, R"sql(
     CREATE TABLE IF NOT EXISTS Nodes (
       rowid INTEGER PRIMARY KEY,
@@ -375,32 +391,35 @@ Status SetupTensorboardSqliteDb(Sqlite* db) {
     )
   )sql"));
 
-  // Uniquely indexes (graph_id, node_id) on Nodes table.
   s.Update(Run(db, R"sql(
     CREATE UNIQUE INDEX IF NOT EXISTS NodeIdIndex
     ON Nodes (graph_id, node_id)
   )sql"));
 
-  // Uniquely indexes (graph_id, node_name) on Nodes table.
   s.Update(Run(db, R"sql(
     CREATE UNIQUE INDEX IF NOT EXISTS NodeNameIndex
     ON Nodes (graph_id, node_name)
     WHERE node_name IS NOT NULL
   )sql"));
 
-  // Creates NodeInputs table.
+  // NodeInputs are directed edges between Nodes in Graphs.
   //
   // Fields:
-  //   rowid: Ephemeral b-tree ID dictating locality.
+  //   rowid: Ephemeral b-tree ID.
   //   graph_id: The Permanent ID of the associated Graph.
   //   node_id: Index of Node in question. This can be considered the
   //     'to' vertex.
   //   idx: Used for ordering inputs on a given Node.
   //   input_node_id: Nodes.node_id of the corresponding input node.
   //     This can be considered the 'from' vertex.
+  //   input_node_idx: Since a Node can output multiple Tensors, this
+  //     is the integer index of which of those outputs is our input.
+  //     NULL is treated as 0.
   //   is_control: If non-zero, indicates this input is a controlled
   //     dependency, which means this isn't an edge through which
   //     tensors flow. NULL means 0.
+  //
+  // TODO(jart): Rename to NodeEdges.
   s.Update(Run(db, R"sql(
     CREATE TABLE IF NOT EXISTS NodeInputs (
       rowid INTEGER PRIMARY KEY,
@@ -408,11 +427,11 @@ Status SetupTensorboardSqliteDb(Sqlite* db) {
       node_id INTEGER NOT NULL,
       idx INTEGER NOT NULL,
       input_node_id INTEGER NOT NULL,
+      input_node_idx INTEGER,
       is_control INTEGER
     )
   )sql"));
 
-  // Uniquely indexes (graph_id, node_id, idx) on NodeInputs table.
   s.Update(Run(db, R"sql(
     CREATE UNIQUE INDEX IF NOT EXISTS NodeInputsIndex
     ON NodeInputs (graph_id, node_id, idx)
index 44887930c15a02b36399bdb073c609bc52de7f8d..889ac43415b62a2979babe98252fb78534a52f0a 100644 (file)
@@ -14,17 +14,37 @@ limitations under the License.
 ==============================================================================*/
 #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h"
 
-#include "tensorflow/contrib/tensorboard/db/schema.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/register_types.h"
 #include "tensorflow/core/framework/summary.pb.h"
 #include "tensorflow/core/lib/core/stringpiece.h"
 #include "tensorflow/core/lib/db/sqlite.h"
 #include "tensorflow/core/lib/random/random.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
-#include "tensorflow/core/platform/fingerprint.h"
 #include "tensorflow/core/util/event.pb.h"
 
+// TODO(jart): Break this up into multiple files with excellent unit tests.
+// TODO(jart): Make decision to write in separate op.
+// TODO(jart): Add really good busy handling.
+
+// clang-format off
+#define CALL_SUPPORTED_TYPES(m) \
+  TF_CALL_string(m)             \
+  TF_CALL_half(m)               \
+  TF_CALL_float(m)              \
+  TF_CALL_double(m)             \
+  TF_CALL_complex64(m)          \
+  TF_CALL_complex128(m)         \
+  TF_CALL_int8(m)               \
+  TF_CALL_int16(m)              \
+  TF_CALL_int32(m)              \
+  TF_CALL_int64(m)              \
+  TF_CALL_uint8(m)              \
+  TF_CALL_uint16(m)             \
+  TF_CALL_uint32(m)             \
+  TF_CALL_uint64(m)
+// clang-format on
+
 namespace tensorflow {
 namespace {
 
@@ -33,115 +53,145 @@ const uint64 kIdTiers[] = {
     0x7fffffULL,        // 23-bit (3 bytes on disk)
     0x7fffffffULL,      // 31-bit (4 bytes on disk)
     0x7fffffffffffULL,  // 47-bit (5 bytes on disk)
-                        // Remaining bits reserved for future use.
+                        // remaining bits for future use
 };
 const int kMaxIdTier = sizeof(kIdTiers) / sizeof(uint64);
 const int kIdCollisionDelayMicros = 10;
 const int kMaxIdCollisions = 21;  // sum(2**i*10µs for i in range(21))~=21s
 const int64 kAbsent = 0LL;
-const int64 kReserved = 0x7fffffffffffffffLL;
 
-double GetWallTime(Env* env) {
+const char* kScalarPluginName = "scalars";
+const char* kImagePluginName = "images";
+const char* kAudioPluginName = "audio";
+const char* kHistogramPluginName = "histograms";
+
+const int kScalarSlots = 10000;
+const int kImageSlots = 10;
+const int kAudioSlots = 10;
+const int kHistogramSlots = 1;
+const int kTensorSlots = 10;
+
+const int64 kReserveMinBytes = 32;
+const double kReserveMultiplier = 1.5;
+
+// Flush is a misnomer because what we're actually doing is having lots
+// of commits inside any SqliteTransaction that writes potentially
+// hundreds of megs but doesn't need the transaction to maintain its
+// invariants. This ensures the WAL read penalty is small and might
+// allow writers in other processes a chance to schedule.
+const uint64 kFlushBytes = 1024 * 1024;
+
+double DoubleTime(uint64 micros) {
   // TODO(@jart): Follow precise definitions for time laid out in schema.
   // TODO(@jart): Use monotonic clock from gRPC codebase.
-  return static_cast<double>(env->NowMicros()) / 1.0e6;
+  return static_cast<double>(micros) / 1.0e6;
 }
 
-Status Serialize(const protobuf::MessageLite& proto, string* output) {
-  output->clear();
-  if (!proto.SerializeToString(output)) {
-    return errors::DataLoss("SerializeToString failed");
+string StringifyShape(const TensorShape& shape) {
+  string result;
+  bool first = true;
+  for (const auto& dim : shape) {
+    if (first) {
+      first = false;
+    } else {
+      strings::StrAppend(&result, ",");
+    }
+    strings::StrAppend(&result, dim.size);
   }
-  return Status::OK();
+  return result;
 }
 
-Status BindProto(SqliteStatement* stmt, int parameter,
-                 const protobuf::MessageLite& proto) {
-  string serialized;
-  TF_RETURN_IF_ERROR(Serialize(proto, &serialized));
-  stmt->BindBlob(parameter, serialized);
+Status CheckSupportedType(const Tensor& t) {
+#define CASE(T)                  \
+  case DataTypeToEnum<T>::value: \
+    break;
+  switch (t.dtype()) {
+    CALL_SUPPORTED_TYPES(CASE)
+    default:
+      return errors::Unimplemented(DataTypeString(t.dtype()),
+                                   " tensors unsupported on platform");
+  }
   return Status::OK();
+#undef CASE
 }
 
-Status BindTensor(SqliteStatement* stmt, int parameter, const Tensor& t) {
-  // TODO(@jart): Make portable between little and big endian systems.
-  // TODO(@jart): Use TensorChunks with minimal copying for big tensors.
-  // TODO(@jart): Add field to indicate encoding.
-  TensorProto p;
-  t.AsProtoTensorContent(&p);
-  return BindProto(stmt, parameter, p);
-}
-
-// Tries to fudge shape and dtype to something with smaller storage.
-Status CoerceScalar(const Tensor& t, Tensor* out) {
+Tensor AsScalar(const Tensor& t) {
+  Tensor t2{t.dtype(), {}};
+#define CASE(T)                        \
+  case DataTypeToEnum<T>::value:       \
+    t2.scalar<T>()() = t.flat<T>()(0); \
+    break;
   switch (t.dtype()) {
-    case DT_DOUBLE:
-      *out = t;
-      break;
-    case DT_INT64:
-      *out = t;
-      break;
-    case DT_FLOAT:
-      *out = {DT_DOUBLE, {}};
-      out->scalar<double>()() = t.scalar<float>()();
-      break;
-    case DT_HALF:
-      *out = {DT_DOUBLE, {}};
-      out->scalar<double>()() = static_cast<double>(t.scalar<Eigen::half>()());
-      break;
-    case DT_INT32:
-      *out = {DT_INT64, {}};
-      out->scalar<int64>()() = t.scalar<int32>()();
-      break;
-    case DT_INT16:
-      *out = {DT_INT64, {}};
-      out->scalar<int64>()() = t.scalar<int16>()();
-      break;
-    case DT_INT8:
-      *out = {DT_INT64, {}};
-      out->scalar<int64>()() = t.scalar<int8>()();
-      break;
-    case DT_UINT32:
-      *out = {DT_INT64, {}};
-      out->scalar<int64>()() = t.scalar<uint32>()();
-      break;
-    case DT_UINT16:
-      *out = {DT_INT64, {}};
-      out->scalar<int64>()() = t.scalar<uint16>()();
-      break;
-    case DT_UINT8:
-      *out = {DT_INT64, {}};
-      out->scalar<int64>()() = t.scalar<uint8>()();
-      break;
+    CALL_SUPPORTED_TYPES(CASE)
     default:
-      return errors::Unimplemented("Scalar summary for dtype ",
-                                   DataTypeString(t.dtype()),
-                                   " is not supported.");
+      t2 = {DT_FLOAT, {}};
+      t2.scalar<float>()() = NAN;
+      break;
   }
-  return Status::OK();
+  return t2;
+#undef CASE
+}
+
+void PatchPluginName(SummaryMetadata* metadata, const char* name) {
+  if (metadata->plugin_data().plugin_name().empty()) {
+    metadata->mutable_plugin_data()->set_plugin_name(name);
+  }
+}
+
+int GetSlots(const Tensor& t, const SummaryMetadata& metadata) {
+  if (metadata.plugin_data().plugin_name() == kScalarPluginName) {
+    return kScalarSlots;
+  } else if (metadata.plugin_data().plugin_name() == kImagePluginName) {
+    return kImageSlots;
+  } else if (metadata.plugin_data().plugin_name() == kAudioPluginName) {
+    return kAudioSlots;
+  } else if (metadata.plugin_data().plugin_name() == kHistogramPluginName) {
+    return kHistogramSlots;
+  } else if (t.dims() == 0 && t.dtype() != DT_STRING) {
+    return kScalarSlots;
+  } else {
+    return kTensorSlots;
+  }
+}
+
+Status SetDescription(Sqlite* db, int64 id, const StringPiece& markdown) {
+  const char* sql = R"sql(
+    INSERT OR REPLACE INTO Descriptions (id, description) VALUES (?, ?)
+  )sql";
+  SqliteStatement insert_desc;
+  TF_RETURN_IF_ERROR(db->Prepare(sql, &insert_desc));
+  insert_desc.BindInt(1, id);
+  insert_desc.BindText(2, markdown);
+  return insert_desc.StepAndReset();
 }
 
-/// \brief Generates unique IDs randomly in the [1,2**63-2] range.
+/// \brief Generates unique IDs randomly in the [1,2**63-1] range.
 ///
 /// This class starts off generating IDs in the [1,2**23-1] range,
 /// because it's human friendly and occupies 4 bytes max on disk with
 /// SQLite's zigzag varint encoding. Then, each time a collision
 /// happens, the random space is increased by 8 bits.
 ///
-/// This class uses exponential back-off so writes will slow down as
-/// the ID space becomes exhausted.
+/// This class uses exponential back-off so writes gradually slow down
+/// as IDs become exhausted but reads are still possible.
+///
+/// This class is thread safe.
 class IdAllocator {
  public:
-  IdAllocator(Env* env, Sqlite* db)
-      : env_{env},
-        inserter_{db->PrepareOrDie("INSERT INTO Ids (id) VALUES (?)")} {}
+  IdAllocator(Env* env, Sqlite* db) : env_{env}, db_{db} {
+    DCHECK(env_ != nullptr);
+    DCHECK(db_ != nullptr);
+  }
 
-  Status CreateNewId(int64* id) {
+  Status CreateNewId(int64* id) LOCKS_EXCLUDED(mu_) {
+    mutex_lock lock(mu_);
     Status s;
+    SqliteStatement stmt;
+    TF_RETURN_IF_ERROR(db_->Prepare("INSERT INTO Ids (id) VALUES (?)", &stmt));
     for (int i = 0; i < kMaxIdCollisions; ++i) {
       int64 tid = MakeRandomId();
-      inserter_.BindInt(1, tid);
-      s = inserter_.StepAndReset();
+      stmt.BindInt(1, tid);
+      s = stmt.StepAndReset();
       if (s.ok()) {
         *id = tid;
         break;
@@ -167,34 +217,38 @@ class IdAllocator {
   }
 
  private:
-  int64 MakeRandomId() {
+  int64 MakeRandomId() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
     int64 id = static_cast<int64>(random::New64() & kIdTiers[tier_]);
     if (id == kAbsent) ++id;
-    if (id == kReserved) --id;
     return id;
   }
 
-  Env* env_;
-  SqliteStatement inserter_;
-  int tier_ = 0;
+  mutex mu_;
+  Env* const env_;
+  Sqlite* const db_;
+  int tier_ GUARDED_BY(mu_) = 0;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(IdAllocator);
 };
 
-class GraphSaver {
+class GraphWriter {
  public:
-  static Status Save(Env* env, Sqlite* db, IdAllocator* id_allocator,
-                     GraphDef* graph, int64* graph_id) {
-    TF_RETURN_IF_ERROR(id_allocator->CreateNewId(graph_id));
-    GraphSaver saver{env, db, graph, *graph_id};
+  static Status Save(Sqlite* db, SqliteTransaction* txn, IdAllocator* ids,
+                     GraphDef* graph, uint64 now, int64 run_id, int64* graph_id)
+      SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) {
+    TF_RETURN_IF_ERROR(ids->CreateNewId(graph_id));
+    GraphWriter saver{db, txn, graph, now, *graph_id};
     saver.MapNameToNodeId();
-    TF_RETURN_IF_ERROR(saver.SaveNodeInputs());
-    TF_RETURN_IF_ERROR(saver.SaveNodes());
-    TF_RETURN_IF_ERROR(saver.SaveGraph());
+    TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveNodeInputs(), "SaveNodeInputs");
+    TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveNodes(), "SaveNodes");
+    TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveGraph(run_id), "SaveGraph");
     return Status::OK();
   }
 
  private:
-  GraphSaver(Env* env, Sqlite* db, GraphDef* graph, int64 graph_id)
-      : env_(env), db_(db), graph_(graph), graph_id_(graph_id) {}
+  GraphWriter(Sqlite* db, SqliteTransaction* txn, GraphDef* graph, uint64 now,
+              int64 graph_id)
+      : db_(db), txn_(txn), graph_(graph), now_(now), graph_id_(graph_id) {}
 
   void MapNameToNodeId() {
     size_t toto = static_cast<size_t>(graph_->node_size());
@@ -209,161 +263,193 @@ class GraphSaver {
   }
 
   Status SaveNodeInputs() {
-    auto insert = db_->PrepareOrDie(R"sql(
-      INSERT INTO NodeInputs (graph_id, node_id, idx, input_node_id, is_control)
-      VALUES (?, ?, ?, ?, ?)
-    )sql");
+    const char* sql = R"sql(
+      INSERT INTO NodeInputs (
+        graph_id,
+        node_id,
+        idx,
+        input_node_id,
+        input_node_idx,
+        is_control
+      ) VALUES (?, ?, ?, ?, ?, ?)
+    )sql";
+    SqliteStatement insert;
+    TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert));
     for (int node_id = 0; node_id < graph_->node_size(); ++node_id) {
       const NodeDef& node = graph_->node(node_id);
       for (int idx = 0; idx < node.input_size(); ++idx) {
         StringPiece name = node.input(idx);
-        insert.BindInt(1, graph_id_);
-        insert.BindInt(2, node_id);
-        insert.BindInt(3, idx);
+        int64 input_node_id;
+        int64 input_node_idx = 0;
+        int64 is_control = 0;
+        size_t i = name.rfind(':');
+        if (i != StringPiece::npos) {
+          if (!strings::safe_strto64(name.substr(i + 1, name.size() - i - 1),
+                                     &input_node_idx)) {
+            return errors::DataLoss("Bad NodeDef.input: ", name);
+          }
+          name.remove_suffix(name.size() - i);
+        }
         if (!name.empty() && name[0] == '^') {
           name.remove_prefix(1);
-          insert.BindInt(5, 1);
+          is_control = 1;
         }
         auto e = name_to_node_id_.find(name);
         if (e == name_to_node_id_.end()) {
           return errors::DataLoss("Could not find node: ", name);
         }
-        insert.BindInt(4, e->second);
+        input_node_id = e->second;
+        insert.BindInt(1, graph_id_);
+        insert.BindInt(2, node_id);
+        insert.BindInt(3, idx);
+        insert.BindInt(4, input_node_id);
+        insert.BindInt(5, input_node_idx);
+        insert.BindInt(6, is_control);
+        unflushed_bytes_ += insert.size();
         TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node.name(),
                                         " -> ", name);
+        TF_RETURN_IF_ERROR(MaybeFlush());
       }
     }
     return Status::OK();
   }
 
   Status SaveNodes() {
-    auto insert = db_->PrepareOrDie(R"sql(
-      INSERT INTO Nodes (graph_id, node_id, node_name, op, device, node_def)
-      VALUES (?, ?, ?, ?, ?, snap(?))
-    )sql");
+    const char* sql = R"sql(
+      INSERT INTO Nodes (
+        graph_id,
+        node_id,
+        node_name,
+        op,
+        device,
+        node_def)
+      VALUES (?, ?, ?, ?, ?, ?)
+    )sql";
+    SqliteStatement insert;
+    TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert));
     for (int node_id = 0; node_id < graph_->node_size(); ++node_id) {
       NodeDef* node = graph_->mutable_node(node_id);
       insert.BindInt(1, graph_id_);
       insert.BindInt(2, node_id);
       insert.BindText(3, node->name());
+      insert.BindText(4, node->op());
+      insert.BindText(5, node->device());
       node->clear_name();
-      if (!node->op().empty()) {
-        insert.BindText(4, node->op());
-        node->clear_op();
-      }
-      if (!node->device().empty()) {
-        insert.BindText(5, node->device());
-        node->clear_device();
-      }
+      node->clear_op();
+      node->clear_device();
       node->clear_input();
-      TF_RETURN_IF_ERROR(BindProto(&insert, 6, *node));
+      string node_def;
+      if (node->SerializeToString(&node_def)) {
+        insert.BindBlobUnsafe(6, node_def);
+      }
+      unflushed_bytes_ += insert.size();
       TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node->name());
+      TF_RETURN_IF_ERROR(MaybeFlush());
     }
     return Status::OK();
   }
 
-  Status SaveGraph() {
-    auto insert = db_->PrepareOrDie(R"sql(
-      INSERT INTO Graphs (graph_id, inserted_time, graph_def)
-      VALUES (?, ?, snap(?))
-    )sql");
-    insert.BindInt(1, graph_id_);
-    insert.BindDouble(2, GetWallTime(env_));
+  Status SaveGraph(int64 run_id) {
+    const char* sql = R"sql(
+      INSERT OR REPLACE INTO Graphs (
+        run_id,
+        graph_id,
+        inserted_time,
+        graph_def
+      ) VALUES (?, ?, ?, ?)
+    )sql";
+    SqliteStatement insert;
+    TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert));
+    if (run_id != kAbsent) insert.BindInt(1, run_id);
+    insert.BindInt(2, graph_id_);
+    insert.BindDouble(3, DoubleTime(now_));
     graph_->clear_node();
-    TF_RETURN_IF_ERROR(BindProto(&insert, 3, *graph_));
+    string graph_def;
+    if (graph_->SerializeToString(&graph_def)) {
+      insert.BindBlobUnsafe(4, graph_def);
+    }
     return insert.StepAndReset();
   }
 
-  Env* env_;
-  Sqlite* db_;
-  GraphDef* graph_;
-  int64 graph_id_;
+  Status MaybeFlush() {
+    if (unflushed_bytes_ >= kFlushBytes) {
+      TF_RETURN_WITH_CONTEXT_IF_ERROR(txn_->Commit(), "flushing ",
+                                      unflushed_bytes_, " bytes");
+      unflushed_bytes_ = 0;
+    }
+    return Status::OK();
+  }
+
+  Sqlite* const db_;
+  SqliteTransaction* const txn_;
+  uint64 unflushed_bytes_ = 0;
+  GraphDef* const graph_;
+  const uint64 now_;
+  const int64 graph_id_;
   std::vector<string> name_copies_;
   std::unordered_map<StringPiece, int64, StringPieceHasher> name_to_node_id_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(GraphWriter);
 };
 
-class RunWriter {
+/// \brief Run metadata manager.
+///
+/// This class gives us Tag IDs we can pass to SeriesWriter. In order
+/// to do that, rows are created in the Ids, Tags, Runs, Experiments,
+/// and Users tables.
+///
+/// This class is thread safe.
+class RunMetadata {
  public:
-  RunWriter(Env* env, Sqlite* db, const string& experiment_name,
-            const string& run_name, const string& user_name)
-      : env_{env},
-        db_{db},
-        id_allocator_{env_, db_},
+  RunMetadata(IdAllocator* ids, const string& experiment_name,
+              const string& run_name, const string& user_name)
+      : ids_{ids},
         experiment_name_{experiment_name},
         run_name_{run_name},
-        user_name_{user_name},
-        insert_tensor_{db_->PrepareOrDie(R"sql(
-          INSERT OR REPLACE INTO Tensors (tag_id, step, computed_time, tensor)
-          VALUES (?, ?, ?, snap(?))
-        )sql")} {
-    db_->Ref();
+        user_name_{user_name} {
+    DCHECK(ids_ != nullptr);
   }
 
-  ~RunWriter() {
-    if (run_id_ != kAbsent) {
-      auto update = db_->PrepareOrDie(R"sql(
-        UPDATE Runs SET finished_time = ? WHERE run_id = ?
-      )sql");
-      update.BindDouble(1, GetWallTime(env_));
-      update.BindInt(2, run_id_);
-      Status s = update.StepAndReset();
-      if (!s.ok()) {
-        LOG(ERROR) << "Failed to set Runs[" << run_id_
-                   << "].finish_time: " << s.ToString();
-      }
-    }
-    db_->Unref();
-  }
+  const string& experiment_name() { return experiment_name_; }
+  const string& run_name() { return run_name_; }
+  const string& user_name() { return user_name_; }
 
-  Status InsertTensor(int64 tag_id, int64 step, double computed_time,
-                      Tensor t) {
-    insert_tensor_.BindInt(1, tag_id);
-    insert_tensor_.BindInt(2, step);
-    insert_tensor_.BindDouble(3, computed_time);
-    if (t.shape().dims() == 0 && t.dtype() == DT_INT64) {
-      insert_tensor_.BindInt(4, t.scalar<int64>()());
-    } else if (t.shape().dims() == 0 && t.dtype() == DT_DOUBLE) {
-      insert_tensor_.BindDouble(4, t.scalar<double>()());
-    } else {
-      TF_RETURN_IF_ERROR(BindTensor(&insert_tensor_, 4, t));
-    }
-    return insert_tensor_.StepAndReset();
+  int64 run_id() LOCKS_EXCLUDED(mu_) {
+    mutex_lock lock(mu_);
+    return run_id_;
   }
 
-  Status InsertGraph(std::unique_ptr<GraphDef> g, double computed_time) {
-    TF_RETURN_IF_ERROR(InitializeRun(computed_time));
+  Status SetGraph(Sqlite* db, uint64 now, double computed_time,
+                  std::unique_ptr<GraphDef> g) SQLITE_TRANSACTIONS_EXCLUDED(*db)
+      LOCKS_EXCLUDED(mu_) {
+    int64 run_id;
+    {
+      mutex_lock lock(mu_);
+      TF_RETURN_IF_ERROR(InitializeRun(db, now, computed_time));
+      run_id = run_id_;
+    }
     int64 graph_id;
+    SqliteTransaction txn(*db);  // only to increase performance
     TF_RETURN_IF_ERROR(
-        GraphSaver::Save(env_, db_, &id_allocator_, g.get(), &graph_id));
-    if (run_id_ != kAbsent) {
-      auto set =
-          db_->PrepareOrDie("UPDATE Runs SET graph_id = ? WHERE run_id = ?");
-      set.BindInt(1, graph_id);
-      set.BindInt(2, run_id_);
-      TF_RETURN_IF_ERROR(set.StepAndReset());
-    }
-    return Status::OK();
+        GraphWriter::Save(db, &txn, ids_, g.get(), now, run_id, &graph_id));
+    return txn.Commit();
   }
 
-  Status GetTagId(double computed_time, const string& tag_name,
-                  const SummaryMetadata& metadata, int64* tag_id) {
-    TF_RETURN_IF_ERROR(InitializeRun(computed_time));
+  Status GetTagId(Sqlite* db, uint64 now, double computed_time,
+                  const string& tag_name, int64* tag_id,
+                  const SummaryMetadata& metadata) LOCKS_EXCLUDED(mu_) {
+    mutex_lock lock(mu_);
+    TF_RETURN_IF_ERROR(InitializeRun(db, now, computed_time));
     auto e = tag_ids_.find(tag_name);
     if (e != tag_ids_.end()) {
       *tag_id = e->second;
       return Status::OK();
     }
-    TF_RETURN_IF_ERROR(id_allocator_.CreateNewId(tag_id));
+    TF_RETURN_IF_ERROR(ids_->CreateNewId(tag_id));
     tag_ids_[tag_name] = *tag_id;
-    if (!metadata.summary_description().empty()) {
-      SqliteStatement insert_description = db_->PrepareOrDie(R"sql(
-        INSERT INTO Descriptions (id, description) VALUES (?, ?)
-      )sql");
-      insert_description.BindInt(1, *tag_id);
-      insert_description.BindText(2, metadata.summary_description());
-      TF_RETURN_IF_ERROR(insert_description.StepAndReset());
-    }
-    SqliteStatement insert = db_->PrepareOrDie(R"sql(
+    TF_RETURN_IF_ERROR(
+        SetDescription(db, *tag_id, metadata.summary_description()));
+    const char* sql = R"sql(
       INSERT INTO Tags (
         run_id,
         tag_id,
@@ -372,30 +458,54 @@ class RunWriter {
         display_name,
         plugin_name,
         plugin_data
-      ) VALUES (?, ?, ?, ?, ?, ?, ?)
-    )sql");
-    if (run_id_ != kAbsent) insert.BindInt(1, run_id_);
-    insert.BindInt(2, *tag_id);
-    insert.BindText(3, tag_name);
-    insert.BindDouble(4, GetWallTime(env_));
-    if (!metadata.display_name().empty()) {
-      insert.BindText(5, metadata.display_name());
-    }
-    if (!metadata.plugin_data().plugin_name().empty()) {
-      insert.BindText(6, metadata.plugin_data().plugin_name());
-    }
-    if (!metadata.plugin_data().content().empty()) {
-      insert.BindBlob(7, metadata.plugin_data().content());
-    }
+      ) VALUES (
+        :run_id,
+        :tag_id,
+        :tag_name,
+        :inserted_time,
+        :display_name,
+        :plugin_name,
+        :plugin_data
+      )
+    )sql";
+    SqliteStatement insert;
+    TF_RETURN_IF_ERROR(db->Prepare(sql, &insert));
+    if (run_id_ != kAbsent) insert.BindInt(":run_id", run_id_);
+    insert.BindInt(":tag_id", *tag_id);
+    insert.BindTextUnsafe(":tag_name", tag_name);
+    insert.BindDouble(":inserted_time", DoubleTime(now));
+    insert.BindTextUnsafe(":display_name", metadata.display_name());
+    insert.BindTextUnsafe(":plugin_name", metadata.plugin_data().plugin_name());
+    insert.BindBlobUnsafe(":plugin_data", metadata.plugin_data().content());
     return insert.StepAndReset();
   }
 
+  Status GetIsWatching(Sqlite* db, bool* is_watching)
+      SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) {
+    mutex_lock lock(mu_);
+    if (experiment_id_ == kAbsent) {
+      *is_watching = true;
+      return Status::OK();
+    }
+    const char* sql = R"sql(
+      SELECT is_watching FROM Experiments WHERE experiment_id = ?
+    )sql";
+    SqliteStatement stmt;
+    TF_RETURN_IF_ERROR(db->Prepare(sql, &stmt));
+    stmt.BindInt(1, experiment_id_);
+    TF_RETURN_IF_ERROR(stmt.StepOnce());
+    *is_watching = stmt.ColumnInt(0) != 0;
+    return Status::OK();
+  }
+
  private:
-  Status InitializeUser() {
+  Status InitializeUser(Sqlite* db, uint64 now) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
     if (user_id_ != kAbsent || user_name_.empty()) return Status::OK();
-    SqliteStatement get = db_->PrepareOrDie(R"sql(
+    const char* get_sql = R"sql(
       SELECT user_id FROM Users WHERE user_name = ?
-    )sql");
+    )sql";
+    SqliteStatement get;
+    TF_RETURN_IF_ERROR(db->Prepare(get_sql, &get));
     get.BindText(1, user_name_);
     bool is_done;
     TF_RETURN_IF_ERROR(get.Step(&is_done));
@@ -403,22 +513,29 @@ class RunWriter {
       user_id_ = get.ColumnInt(0);
       return Status::OK();
     }
-    TF_RETURN_IF_ERROR(id_allocator_.CreateNewId(&user_id_));
-    SqliteStatement insert = db_->PrepareOrDie(R"sql(
-      INSERT INTO Users (user_id, user_name, inserted_time) VALUES (?, ?, ?)
-    )sql");
+    TF_RETURN_IF_ERROR(ids_->CreateNewId(&user_id_));
+    const char* insert_sql = R"sql(
+      INSERT INTO Users (
+        user_id,
+        user_name,
+        inserted_time
+      ) VALUES (?, ?, ?)
+    )sql";
+    SqliteStatement insert;
+    TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert));
     insert.BindInt(1, user_id_);
     insert.BindText(2, user_name_);
-    insert.BindDouble(3, GetWallTime(env_));
+    insert.BindDouble(3, DoubleTime(now));
     TF_RETURN_IF_ERROR(insert.StepAndReset());
     return Status::OK();
   }
 
-  Status InitializeExperiment(double computed_time) {
+  Status InitializeExperiment(Sqlite* db, uint64 now, double computed_time)
+      EXCLUSIVE_LOCKS_REQUIRED(mu_) {
     if (experiment_name_.empty()) return Status::OK();
     if (experiment_id_ == kAbsent) {
-      TF_RETURN_IF_ERROR(InitializeUser());
-      SqliteStatement get = db_->PrepareOrDie(R"sql(
+      TF_RETURN_IF_ERROR(InitializeUser(db, now));
+      const char* get_sql = R"sql(
         SELECT
           experiment_id,
           started_time
@@ -427,7 +544,9 @@ class RunWriter {
         WHERE
           user_id IS ?
           AND experiment_name = ?
-      )sql");
+      )sql";
+      SqliteStatement get;
+      TF_RETURN_IF_ERROR(db->Prepare(get_sql, &get));
       if (user_id_ != kAbsent) get.BindInt(1, user_id_);
       get.BindText(2, experiment_name_);
       bool is_done;
@@ -436,30 +555,41 @@ class RunWriter {
         experiment_id_ = get.ColumnInt(0);
         experiment_started_time_ = get.ColumnInt(1);
       } else {
-        TF_RETURN_IF_ERROR(id_allocator_.CreateNewId(&experiment_id_));
+        TF_RETURN_IF_ERROR(ids_->CreateNewId(&experiment_id_));
         experiment_started_time_ = computed_time;
-        SqliteStatement insert = db_->PrepareOrDie(R"sql(
+        const char* insert_sql = R"sql(
           INSERT INTO Experiments (
             user_id,
             experiment_id,
             experiment_name,
             inserted_time,
-            started_time
-          ) VALUES (?, ?, ?, ?, ?)
-        )sql");
+            started_time,
+            is_watching
+          ) VALUES (?, ?, ?, ?, ?, ?)
+        )sql";
+        SqliteStatement insert;
+        TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert));
         if (user_id_ != kAbsent) insert.BindInt(1, user_id_);
         insert.BindInt(2, experiment_id_);
         insert.BindText(3, experiment_name_);
-        insert.BindDouble(4, GetWallTime(env_));
+        insert.BindDouble(4, DoubleTime(now));
         insert.BindDouble(5, computed_time);
+        insert.BindInt(6, 0);
         TF_RETURN_IF_ERROR(insert.StepAndReset());
       }
     }
     if (computed_time < experiment_started_time_) {
       experiment_started_time_ = computed_time;
-      SqliteStatement update = db_->PrepareOrDie(R"sql(
-        UPDATE Experiments SET started_time = ? WHERE experiment_id = ?
-      )sql");
+      const char* update_sql = R"sql(
+        UPDATE
+          Experiments
+        SET
+          started_time = ?
+        WHERE
+          experiment_id = ?
+      )sql";
+      SqliteStatement update;
+      TF_RETURN_IF_ERROR(db->Prepare(update_sql, &update));
       update.BindDouble(1, computed_time);
       update.BindInt(2, experiment_id_);
       TF_RETURN_IF_ERROR(update.StepAndReset());
@@ -467,13 +597,14 @@ class RunWriter {
     return Status::OK();
   }
 
-  Status InitializeRun(double computed_time) {
+  Status InitializeRun(Sqlite* db, uint64 now, double computed_time)
+      EXCLUSIVE_LOCKS_REQUIRED(mu_) {
     if (run_name_.empty()) return Status::OK();
-    TF_RETURN_IF_ERROR(InitializeExperiment(computed_time));
+    TF_RETURN_IF_ERROR(InitializeExperiment(db, now, computed_time));
     if (run_id_ == kAbsent) {
-      TF_RETURN_IF_ERROR(id_allocator_.CreateNewId(&run_id_));
+      TF_RETURN_IF_ERROR(ids_->CreateNewId(&run_id_));
       run_started_time_ = computed_time;
-      SqliteStatement insert = db_->PrepareOrDie(R"sql(
+      const char* insert_sql = R"sql(
         INSERT OR REPLACE INTO Runs (
           experiment_id,
           run_id,
@@ -481,19 +612,28 @@ class RunWriter {
           inserted_time,
           started_time
         ) VALUES (?, ?, ?, ?, ?)
-      )sql");
+      )sql";
+      SqliteStatement insert;
+      TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert));
       if (experiment_id_ != kAbsent) insert.BindInt(1, experiment_id_);
       insert.BindInt(2, run_id_);
       insert.BindText(3, run_name_);
-      insert.BindDouble(4, GetWallTime(env_));
+      insert.BindDouble(4, DoubleTime(now));
       insert.BindDouble(5, computed_time);
       TF_RETURN_IF_ERROR(insert.StepAndReset());
     }
     if (computed_time < run_started_time_) {
       run_started_time_ = computed_time;
-      SqliteStatement update = db_->PrepareOrDie(R"sql(
-        UPDATE Runs SET started_time = ? WHERE run_id = ?
-      )sql");
+      const char* update_sql = R"sql(
+        UPDATE
+          Runs
+        SET
+          started_time = ?
+        WHERE
+          run_id = ?
+      )sql";
+      SqliteStatement update;
+      TF_RETURN_IF_ERROR(db->Prepare(update_sql, &update));
       update.BindDouble(1, computed_time);
       update.BindInt(2, run_id_);
       TF_RETURN_IF_ERROR(update.StepAndReset());
@@ -501,79 +641,400 @@ class RunWriter {
     return Status::OK();
   }
 
-  Env* env_;
-  Sqlite* db_;
-  IdAllocator id_allocator_;
+  mutex mu_;
+  IdAllocator* const ids_;
   const string experiment_name_;
   const string run_name_;
   const string user_name_;
-  int64 experiment_id_ = kAbsent;
-  int64 run_id_ = kAbsent;
-  int64 user_id_ = kAbsent;
-  std::unordered_map<string, int64> tag_ids_;
-  double experiment_started_time_ = 0.0;
-  double run_started_time_ = 0.0;
-  SqliteStatement insert_tensor_;
+  int64 experiment_id_ GUARDED_BY(mu_) = kAbsent;
+  int64 run_id_ GUARDED_BY(mu_) = kAbsent;
+  int64 user_id_ GUARDED_BY(mu_) = kAbsent;
+  double experiment_started_time_ GUARDED_BY(mu_) = 0.0;
+  double run_started_time_ GUARDED_BY(mu_) = 0.0;
+  std::unordered_map<string, int64> tag_ids_ GUARDED_BY(mu_);
+
+  TF_DISALLOW_COPY_AND_ASSIGN(RunMetadata);
 };
 
+/// \brief Tensor writer for a single series, e.g. Tag.
+///
+/// This class can be used to write an infinite stream of Tensors to the
+/// database in a fixed block of contiguous disk space. This is
+/// accomplished using Algorithm R reservoir sampling.
+///
+/// The reservoir consists of a fixed number of rows, which are inserted
+/// using ZEROBLOB upon receiving the first sample, which is used to
+/// predict how big the other ones are likely to be. This is done
+/// transactionally in a way that tries to be mindful of other processes
+/// that might be trying to access the same DB.
+///
+/// Once the reservoir fills up, rows are replaced at random, and writes
+/// gradually become no-ops. This allows long training to go fast
+/// without configuration. The exception is when someone is actually
+/// looking at TensorBoard. When that happens, the "keep last" behavior
+/// is turned on and Append() will always result in a write.
+///
+/// If no one is watching training, this class still holds on to the
+/// most recent "dangling" Tensor, so if Finish() is called, the most
+/// recent training state can be written to disk.
+///
+/// The randomly selected sampling points should be consistent across
+/// multiple instances.
+///
+/// This class is thread safe.
+class SeriesWriter {
+ public:
+  SeriesWriter(int64 series, int slots, RunMetadata* meta)
+      : series_{series},
+        slots_{slots},
+        meta_{meta},
+        rng_{std::mt19937_64::default_seed} {
+    DCHECK(series_ > 0);
+    DCHECK(slots_ > 0);
+  }
+
+  Status Append(Sqlite* db, int64 step, uint64 now, double computed_time,
+                Tensor t) SQLITE_TRANSACTIONS_EXCLUDED(*db)
+      LOCKS_EXCLUDED(mu_) {
+    mutex_lock lock(mu_);
+    if (rowids_.empty()) {
+      Status s = Reserve(db, t);
+      if (!s.ok()) {
+        rowids_.clear();
+        return s;
+      }
+    }
+    DCHECK(rowids_.size() == slots_);
+    int64 rowid;
+    size_t i = count_;
+    if (i < slots_) {
+      rowid = last_rowid_ = rowids_[i];
+    } else {
+      i = rng_() % (i + 1);
+      if (i < slots_) {
+        rowid = last_rowid_ = rowids_[i];
+      } else {
+        bool keep_last;
+        TF_RETURN_IF_ERROR(meta_->GetIsWatching(db, &keep_last));
+        if (!keep_last) {
+          ++count_;
+          dangling_tensor_.reset(new Tensor(std::move(t)));
+          dangling_step_ = step;
+          dangling_computed_time_ = computed_time;
+          return Status::OK();
+        }
+        rowid = last_rowid_;
+      }
+    }
+    Status s = Write(db, rowid, step, computed_time, t);
+    if (s.ok()) {
+      ++count_;
+      dangling_tensor_.reset();
+    }
+    return s;
+  }
+
+  Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db)
+      LOCKS_EXCLUDED(mu_) {
+    mutex_lock lock(mu_);
+    // Short runs: Delete unused pre-allocated Tensors.
+    if (count_ < rowids_.size()) {
+      SqliteTransaction txn(*db);
+      const char* sql = R"sql(
+        DELETE FROM Tensors WHERE rowid = ?
+      )sql";
+      SqliteStatement deleter;
+      TF_RETURN_IF_ERROR(db->Prepare(sql, &deleter));
+      for (size_t i = count_; i < rowids_.size(); ++i) {
+        deleter.BindInt(1, rowids_[i]);
+        TF_RETURN_IF_ERROR(deleter.StepAndReset());
+      }
+      TF_RETURN_IF_ERROR(txn.Commit());
+      rowids_.clear();
+    }
+    // Long runs: Make last sample be the very most recent one.
+    if (dangling_tensor_) {
+      DCHECK(last_rowid_ != kAbsent);
+      TF_RETURN_IF_ERROR(Write(db, last_rowid_, dangling_step_,
+                               dangling_computed_time_, *dangling_tensor_));
+      dangling_tensor_.reset();
+    }
+    return Status::OK();
+  }
+
+ private:
+  Status Write(Sqlite* db, int64 rowid, int64 step, double computed_time,
+               const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) {
+    if (t.dtype() == DT_STRING) {
+      if (t.dims() == 0) {
+        return Update(db, step, computed_time, t, t.scalar<string>()(), rowid);
+      } else {
+        SqliteTransaction txn(*db);
+        TF_RETURN_IF_ERROR(
+            Update(db, step, computed_time, t, StringPiece(), rowid));
+        TF_RETURN_IF_ERROR(UpdateNdString(db, t, rowid));
+        return txn.Commit();
+      }
+    } else {
+      return Update(db, step, computed_time, t, t.tensor_data(), rowid);
+    }
+  }
+
+  Status Update(Sqlite* db, int64 step, double computed_time, const Tensor& t,
+                const StringPiece& data, int64 rowid) {
+    // TODO(jart): How can we ensure reservoir fills on replace?
+    const char* sql = R"sql(
+      UPDATE OR REPLACE
+        Tensors
+      SET
+        step = ?,
+        computed_time = ?,
+        dtype = ?,
+        shape = ?,
+        data = ?
+      WHERE
+        rowid = ?
+    )sql";
+    SqliteStatement stmt;
+    TF_RETURN_IF_ERROR(db->Prepare(sql, &stmt));
+    stmt.BindInt(1, step);
+    stmt.BindDouble(2, computed_time);
+    stmt.BindInt(3, t.dtype());
+    stmt.BindText(4, StringifyShape(t.shape()));
+    stmt.BindBlobUnsafe(5, data);
+    stmt.BindInt(6, rowid);
+    TF_RETURN_IF_ERROR(stmt.StepAndReset());
+    return Status::OK();
+  }
+
+  Status UpdateNdString(Sqlite* db, const Tensor& t, int64 tensor_rowid)
+      SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) {
+    DCHECK_EQ(t.dtype(), DT_STRING);
+    DCHECK_GT(t.dims(), 0);
+    const char* deleter_sql = R"sql(
+      DELETE FROM TensorStrings WHERE tensor_rowid = ?
+    )sql";
+    SqliteStatement deleter;
+    TF_RETURN_IF_ERROR(db->Prepare(deleter_sql, &deleter));
+    deleter.BindInt(1, tensor_rowid);
+    TF_RETURN_WITH_CONTEXT_IF_ERROR(deleter.StepAndReset(), tensor_rowid);
+    const char* inserter_sql = R"sql(
+      INSERT INTO TensorStrings (
+        tensor_rowid,
+        idx,
+        data
+      ) VALUES (?, ?, ?)
+    )sql";
+    SqliteStatement inserter;
+    TF_RETURN_IF_ERROR(db->Prepare(inserter_sql, &inserter));
+    auto flat = t.flat<string>();
+    for (int64 i = 0; i < flat.size(); ++i) {
+      inserter.BindInt(1, tensor_rowid);
+      inserter.BindInt(2, i);
+      inserter.BindBlobUnsafe(3, flat(i));
+      TF_RETURN_WITH_CONTEXT_IF_ERROR(inserter.StepAndReset(), "i=", i);
+    }
+    return Status::OK();
+  }
+
+  Status Reserve(Sqlite* db, const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db)
+      EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+    SqliteTransaction txn(*db);  // only for performance
+    unflushed_bytes_ = 0;
+    if (t.dtype() == DT_STRING) {
+      if (t.dims() == 0) {
+        TF_RETURN_IF_ERROR(ReserveData(db, &txn, t.scalar<string>()().size()));
+      } else {
+        TF_RETURN_IF_ERROR(ReserveTensors(db, &txn, kReserveMinBytes));
+      }
+    } else {
+      TF_RETURN_IF_ERROR(ReserveData(db, &txn, t.tensor_data().size()));
+    }
+    return txn.Commit();
+  }
+
+  Status ReserveData(Sqlite* db, SqliteTransaction* txn, size_t size)
+      SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db)
+          EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+    int64 space =
+        static_cast<int64>(static_cast<double>(size) * kReserveMultiplier);
+    if (space < kReserveMinBytes) space = kReserveMinBytes;
+    return ReserveTensors(db, txn, space);
+  }
+
+  Status ReserveTensors(Sqlite* db, SqliteTransaction* txn,
+                        int64 reserved_bytes)
+      SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db)
+          EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+    const char* sql = R"sql(
+      INSERT INTO Tensors (
+        series,
+        data
+      ) VALUES (?, ZEROBLOB(?))
+    )sql";
+    SqliteStatement insert;
+    TF_RETURN_IF_ERROR(db->Prepare(sql, &insert));
+    // TODO(jart): Maybe preallocate index pages by setting step. This
+    //             is tricky because UPDATE OR REPLACE can have a side
+    //             effect of deleting preallocated rows.
+    for (int64 i = 0; i < slots_; ++i) {
+      insert.BindInt(1, series_);
+      insert.BindInt(2, reserved_bytes);
+      TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), "i=", i);
+      rowids_.push_back(db->last_insert_rowid());
+      unflushed_bytes_ += reserved_bytes;
+      TF_RETURN_IF_ERROR(MaybeFlush(db, txn));
+    }
+    return Status::OK();
+  }
+
+  Status MaybeFlush(Sqlite* db, SqliteTransaction* txn)
+      SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db)
+          EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+    if (unflushed_bytes_ >= kFlushBytes) {
+      TF_RETURN_WITH_CONTEXT_IF_ERROR(txn->Commit(), "flushing ",
+                                      unflushed_bytes_, " bytes");
+      unflushed_bytes_ = 0;
+    }
+    return Status::OK();
+  }
+
+  mutex mu_;
+  const int64 series_;
+  const int slots_;
+  RunMetadata* const meta_;
+  std::mt19937_64 rng_ GUARDED_BY(mu_);
+  uint64 count_ GUARDED_BY(mu_) = 0;
+  int64 last_rowid_ GUARDED_BY(mu_) = kAbsent;
+  std::vector<int64> rowids_ GUARDED_BY(mu_);
+  uint64 unflushed_bytes_ GUARDED_BY(mu_) = 0;
+  std::unique_ptr<Tensor> dangling_tensor_ GUARDED_BY(mu_);
+  int64 dangling_step_ GUARDED_BY(mu_) = 0;
+  double dangling_computed_time_ GUARDED_BY(mu_) = 0.0;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(SeriesWriter);
+};
+
+/// \brief Tensor writer for a single Run.
+///
+/// This class farms out tensors to SeriesWriter instances. It also
+/// keeps track of whether or not someone is watching the TensorBoard
+/// GUI, so it can avoid writes when possible.
+///
+/// This class is thread safe.
+class RunWriter {
+ public:
+  explicit RunWriter(RunMetadata* meta) : meta_{meta} {}
+
+  Status Append(Sqlite* db, int64 tag_id, int64 step, uint64 now,
+                double computed_time, Tensor t, int slots)
+      SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) {
+    SeriesWriter* writer = GetSeriesWriter(tag_id, slots);
+    return writer->Append(db, step, now, computed_time, std::move(t));
+  }
+
+  Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db)
+      LOCKS_EXCLUDED(mu_) {
+    mutex_lock lock(mu_);
+    if (series_writers_.empty()) return Status::OK();
+    for (auto i = series_writers_.begin(); i != series_writers_.end(); ++i) {
+      if (!i->second) continue;
+      TF_RETURN_WITH_CONTEXT_IF_ERROR(i->second->Finish(db),
+                                      "finish tag_id=", i->first);
+      i->second.reset();
+    }
+    return Status::OK();
+  }
+
+ private:
+  SeriesWriter* GetSeriesWriter(int64 tag_id, int slots) LOCKS_EXCLUDED(mu_) {
+    mutex_lock sl(mu_);
+    auto spot = series_writers_.find(tag_id);
+    if (spot == series_writers_.end()) {
+      SeriesWriter* writer = new SeriesWriter(tag_id, slots, meta_);
+      series_writers_[tag_id].reset(writer);
+      return writer;
+    } else {
+      return spot->second.get();
+    }
+  }
+
+  mutex mu_;
+  RunMetadata* const meta_;
+  std::unordered_map<int64, std::unique_ptr<SeriesWriter>> series_writers_
+      GUARDED_BY(mu_);
+
+  TF_DISALLOW_COPY_AND_ASSIGN(RunWriter);
+};
+
+/// \brief SQLite implementation of SummaryWriterInterface.
+///
+/// This class is thread safe.
 class SummaryDbWriter : public SummaryWriterInterface {
  public:
-  SummaryDbWriter(Env* env, Sqlite* db,
-                  const string& experiment_name, const string& run_name,
-                  const string& user_name)
-      : env_{env},
-        run_writer_{env, db, experiment_name, run_name, user_name} {}
-  ~SummaryDbWriter() override {}
+  SummaryDbWriter(Env* env, Sqlite* db, const string& experiment_name,
+                  const string& run_name, const string& user_name)
+      : SummaryWriterInterface(),
+        env_{env},
+        db_{db},
+        ids_{env_, db_},
+        meta_{&ids_, experiment_name, run_name, user_name},
+        run_{&meta_} {
+    DCHECK(env_ != nullptr);
+    db_->Ref();
+  }
+
+  ~SummaryDbWriter() override {
+    core::ScopedUnref unref(db_);
+    Status s = run_.Finish(db_);
+    if (!s.ok()) {
+      // TODO(jart): Retry on transient errors here.
+      LOG(ERROR) << s.ToString();
+    }
+    int64 run_id = meta_.run_id();
+    if (run_id == kAbsent) return;
+    const char* sql = R"sql(
+      UPDATE Runs SET finished_time = ? WHERE run_id = ?
+    )sql";
+    SqliteStatement update;
+    s = db_->Prepare(sql, &update);
+    if (s.ok()) {
+      update.BindDouble(1, DoubleTime(env_->NowMicros()));
+      update.BindInt(2, run_id);
+      s = update.StepAndReset();
+    }
+    if (!s.ok()) {
+      LOG(ERROR) << "Failed to set Runs[" << run_id
+                 << "].finish_time: " << s.ToString();
+    }
+  }
 
   Status Flush() override { return Status::OK(); }
 
   Status WriteTensor(int64 global_step, Tensor t, const string& tag,
                      const string& serialized_metadata) override {
-    mutex_lock ml(mu_);
+    TF_RETURN_IF_ERROR(CheckSupportedType(t));
     SummaryMetadata metadata;
-    if (!serialized_metadata.empty()) {
-      metadata.ParseFromString(serialized_metadata);
+    if (!metadata.ParseFromString(serialized_metadata)) {
+      return errors::InvalidArgument("Bad serialized_metadata");
     }
-    double now = GetWallTime(env_);
-    int64 tag_id;
-    TF_RETURN_IF_ERROR(run_writer_.GetTagId(now, tag, metadata, &tag_id));
-    return run_writer_.InsertTensor(tag_id, global_step, now, t);
+    return Write(global_step, t, tag, metadata);
   }
 
   Status WriteScalar(int64 global_step, Tensor t, const string& tag) override {
-    Tensor t2;
-    TF_RETURN_IF_ERROR(CoerceScalar(t, &t2));
-    // TODO(jart): Generate scalars plugin metadata on this value.
-    return WriteTensor(global_step, std::move(t2), tag, "");
+    TF_RETURN_IF_ERROR(CheckSupportedType(t));
+    SummaryMetadata metadata;
+    PatchPluginName(&metadata, kScalarPluginName);
+    return Write(global_step, AsScalar(t), tag, metadata);
   }
 
   Status WriteGraph(int64 global_step, std::unique_ptr<GraphDef> g) override {
-    mutex_lock ml(mu_);
-    return run_writer_.InsertGraph(std::move(g), GetWallTime(env_));
+    uint64 now = env_->NowMicros();
+    return meta_.SetGraph(db_, now, DoubleTime(now), std::move(g));
   }
 
   Status WriteEvent(std::unique_ptr<Event> e) override {
-    switch (e->what_case()) {
-      case Event::WhatCase::kSummary: {
-        mutex_lock ml(mu_);
-        Status s;
-        for (const auto& value : e->summary().value()) {
-          s.Update(WriteSummary(e.get(), value));
-        }
-        return s;
-      }
-      case Event::WhatCase::kGraphDef: {
-        mutex_lock ml(mu_);
-        std::unique_ptr<GraphDef> graph{new GraphDef};
-        if (!ParseProtoUnlimited(graph.get(), e->graph_def())) {
-          return errors::DataLoss("parse event.graph_def failed");
-        }
-        return run_writer_.InsertGraph(std::move(graph), e->wall_time());
-      }
-      default:
-        // TODO(@jart): Handle other stuff.
-        return Status::OK();
-    }
+    return MigrateEvent(std::move(e));
   }
 
   Status WriteHistogram(int64 global_step, Tensor t,
@@ -600,26 +1061,165 @@ class SummaryDbWriter : public SummaryWriterInterface {
   string DebugString() override { return "SummaryDbWriter"; }
 
  private:
-  Status WriteSummary(const Event* e, const Summary::Value& summary)
-      EXCLUSIVE_LOCKS_REQUIRED(mu_) {
-    switch (summary.value_case()) {
-      case Summary::Value::ValueCase::kSimpleValue: {
-        int64 tag_id;
-        TF_RETURN_IF_ERROR(run_writer_.GetTagId(e->wall_time(), summary.tag(),
-                                                summary.metadata(), &tag_id));
-        Tensor t{DT_DOUBLE, {}};
-        t.scalar<double>()() = summary.simple_value();
-        return run_writer_.InsertTensor(tag_id, e->step(), e->wall_time(), t);
+  Status Write(int64 step, const Tensor& t, const string& tag,
+               const SummaryMetadata& metadata) {
+    uint64 now = env_->NowMicros();
+    double computed_time = DoubleTime(now);
+    int64 tag_id;
+    TF_RETURN_IF_ERROR(
+        meta_.GetTagId(db_, now, computed_time, tag, &tag_id, metadata));
+    TF_RETURN_WITH_CONTEXT_IF_ERROR(
+        run_.Append(db_, tag_id, step, now, computed_time, t,
+                    GetSlots(t, metadata)),
+        meta_.user_name(), "/", meta_.experiment_name(), "/", meta_.run_name(),
+        "/", tag, "@", step);
+    return Status::OK();
+  }
+
+  Status MigrateEvent(std::unique_ptr<Event> e) {
+    switch (e->what_case()) {
+      case Event::WhatCase::kSummary: {
+        uint64 now = env_->NowMicros();
+        auto summaries = e->mutable_summary();
+        for (int i = 0; i < summaries->value_size(); ++i) {
+          Summary::Value* value = summaries->mutable_value(i);
+          TF_RETURN_WITH_CONTEXT_IF_ERROR(
+              MigrateSummary(e.get(), value, now), meta_.user_name(), "/",
+              meta_.experiment_name(), "/", meta_.run_name(), "/", value->tag(),
+              "@", e->step());
+        }
+        break;
       }
+      case Event::WhatCase::kGraphDef:
+        TF_RETURN_WITH_CONTEXT_IF_ERROR(
+            MigrateGraph(e.get(), e->graph_def()), meta_.user_name(), "/",
+            meta_.experiment_name(), "/", meta_.run_name(), "/__graph__@",
+            e->step());
+        break;
       default:
-        // TODO(@jart): Handle the rest.
-        return Status::OK();
+        // TODO(@jart): Handle other stuff.
+        break;
     }
+    return Status::OK();
   }
 
-  mutex mu_;
-  Env* env_;
-  RunWriter run_writer_ GUARDED_BY(mu_);
+  Status MigrateGraph(const Event* e, const string& graph_def) {
+    uint64 now = env_->NowMicros();
+    std::unique_ptr<GraphDef> graph{new GraphDef};
+    if (!ParseProtoUnlimited(graph.get(), graph_def)) {
+      return errors::InvalidArgument("bad proto");
+    }
+    return meta_.SetGraph(db_, now, e->wall_time(), std::move(graph));
+  }
+
+  Status MigrateSummary(const Event* e, Summary::Value* s, uint64 now) {
+    switch (s->value_case()) {
+      case Summary::Value::ValueCase::kTensor:
+        TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateTensor(e, s, now), "tensor");
+        break;
+      case Summary::Value::ValueCase::kSimpleValue:
+        TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateScalar(e, s, now), "scalar");
+        break;
+      case Summary::Value::ValueCase::kHisto:
+        TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateHistogram(e, s, now), "histo");
+        break;
+      case Summary::Value::ValueCase::kImage:
+        TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateImage(e, s, now), "image");
+        break;
+      case Summary::Value::ValueCase::kAudio:
+        TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateAudio(e, s, now), "audio");
+        break;
+      default:
+        break;
+    }
+    return Status::OK();
+  }
+
+  Status MigrateTensor(const Event* e, Summary::Value* s, uint64 now) {
+    Tensor t;
+    if (!t.FromProto(s->tensor())) return errors::InvalidArgument("bad proto");
+    TF_RETURN_IF_ERROR(CheckSupportedType(t));
+    int64 tag_id;
+    TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
+                                      &tag_id, s->metadata()));
+    return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t,
+                       GetSlots(t, s->metadata()));
+  }
+
+  // TODO(jart): Refactor Summary -> Tensor logic into separate file.
+
+  Status MigrateScalar(const Event* e, Summary::Value* s, uint64 now) {
+    // See tensorboard/plugins/scalar/summary.py and data_compat.py
+    Tensor t{DT_FLOAT, {}};
+    t.scalar<float>()() = s->simple_value();
+    int64 tag_id;
+    PatchPluginName(s->mutable_metadata(), kScalarPluginName);
+    TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
+                                      &tag_id, s->metadata()));
+    return run_.Append(db_, tag_id, e->step(), now, e->wall_time(),
+                       std::move(t), kScalarSlots);
+  }
+
+  Status MigrateHistogram(const Event* e, Summary::Value* s, uint64 now) {
+    const HistogramProto& histo = s->histo();
+    int k = histo.bucket_size();
+    if (k != histo.bucket_limit_size()) {
+      return errors::InvalidArgument("size mismatch");
+    }
+    // See tensorboard/plugins/histogram/summary.py and data_compat.py
+    Tensor t{DT_DOUBLE, {k, 3}};
+    auto data = t.flat<double>();
+    for (int i = 0; i < k; ++i) {
+      double left_edge = ((i - 1 >= 0) ? histo.bucket_limit(i - 1)
+                                       : std::numeric_limits<double>::min());
+      double right_edge = ((i + 1 < k) ? histo.bucket_limit(i + 1)
+                                       : std::numeric_limits<double>::max());
+      data(i + 0) = left_edge;
+      data(i + 1) = right_edge;
+      data(i + 2) = histo.bucket(i);
+    }
+    int64 tag_id;
+    PatchPluginName(s->mutable_metadata(), kHistogramPluginName);
+    TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
+                                      &tag_id, s->metadata()));
+    return run_.Append(db_, tag_id, e->step(), now, e->wall_time(),
+                       std::move(t), kHistogramSlots);
+  }
+
+  Status MigrateImage(const Event* e, Summary::Value* s, uint64 now) {
+    // See tensorboard/plugins/image/summary.py and data_compat.py
+    Tensor t{DT_STRING, {3}};
+    auto img = s->mutable_image();
+    t.flat<string>()(0) = strings::StrCat(img->width());
+    t.flat<string>()(1) = strings::StrCat(img->height());
+    t.flat<string>()(2) = std::move(*img->mutable_encoded_image_string());
+    int64 tag_id;
+    PatchPluginName(s->mutable_metadata(), kImagePluginName);
+    TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
+                                      &tag_id, s->metadata()));
+    return run_.Append(db_, tag_id, e->step(), now, e->wall_time(),
+                       std::move(t), kImageSlots);
+  }
+
+  Status MigrateAudio(const Event* e, Summary::Value* s, uint64 now) {
+    // See tensorboard/plugins/audio/summary.py and data_compat.py
+    Tensor t{DT_STRING, {1, 2}};
+    auto wav = s->mutable_audio();
+    t.flat<string>()(0) = std::move(*wav->mutable_encoded_audio_string());
+    t.flat<string>()(1) = "";
+    int64 tag_id;
+    PatchPluginName(s->mutable_metadata(), kAudioPluginName);
+    TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
+                                      &tag_id, s->metadata()));
+    return run_.Append(db_, tag_id, e->step(), now, e->wall_time(),
+                       std::move(t), kAudioSlots);
+  }
+
+  Env* const env_;
+  Sqlite* const db_;
+  IdAllocator ids_;
+  RunMetadata meta_;
+  RunWriter run_;
 };
 
 }  // namespace
@@ -627,8 +1227,6 @@ class SummaryDbWriter : public SummaryWriterInterface {
 Status CreateSummaryDbWriter(Sqlite* db, const string& experiment_name,
                              const string& run_name, const string& user_name,
                              Env* env, SummaryWriterInterface** result) {
-  *result = nullptr;
-  TF_RETURN_IF_ERROR(SetupTensorboardSqliteDb(db));
   *result = new SummaryDbWriter(env, db, experiment_name, run_name, user_name);
   return Status::OK();
 }
index 5a3de195de8c1d94dfc1f153cea4eb02f3258e1b..746da1533b157bf7b2be5c85ada8b61ba224cc3e 100644 (file)
@@ -19,16 +19,15 @@ limitations under the License.
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/db/sqlite.h"
 #include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
 
 /// \brief Creates SQLite SummaryWriterInterface.
 ///
 /// This can be used to write tensors from the execution graph directly
-/// to a database. The schema will be created automatically, but only
-/// if necessary. Entries in the Users, Experiments, and Runs tables
-/// will be created automatically if they don't already exist.
+/// to a database. The schema must be created beforehand. Entries in
+/// Users, Experiments, and Runs tables will be created automatically
+/// if they don't already exist.
 ///
 /// Please note that the type signature of this function may change in
 /// the future if support for other DBs is added to core.
index 68444c35be216119206765b00527efe86dc66d4d..29b8063218de72aac1a73bbfb440e75fcdd5013f 100644 (file)
@@ -14,6 +14,8 @@ limitations under the License.
 ==============================================================================*/
 #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h"
 
+#include "tensorflow/contrib/tensorboard/db/schema.h"
+#include "tensorflow/core/framework/function.pb.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/framework/summary.pb.h"
@@ -27,8 +29,6 @@ limitations under the License.
 namespace tensorflow {
 namespace {
 
-const float kTolerance = 1e-5;
-
 Tensor MakeScalarInt64(int64 x) {
   Tensor t(DT_INT64, TensorShape({}));
   t.scalar<int64>()() = x;
@@ -50,6 +50,7 @@ class SummaryDbWriterTest : public ::testing::Test {
  protected:
   void SetUp() override {
     TF_ASSERT_OK(Sqlite::Open(":memory:", SQLITE_OPEN_READWRITE, &db_));
+    TF_ASSERT_OK(SetupTensorboardSqliteDb(db_));
   }
 
   void TearDown() override {
@@ -138,7 +139,7 @@ TEST_F(SummaryDbWriterTest, TensorsWritten_RowsGetInitialized) {
   ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Experiments"));
   ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Runs"));
   ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags"));
-  ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
+  ASSERT_EQ(10000LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
 
   int64 user_id = QueryInt("SELECT user_id FROM Users");
   int64 experiment_id = QueryInt("SELECT experiment_id FROM Experiments");
@@ -170,17 +171,13 @@ TEST_F(SummaryDbWriterTest, TensorsWritten_RowsGetInitialized) {
   EXPECT_EQ("plugin_data", QueryString("SELECT plugin_data FROM Tags"));
   EXPECT_EQ("description", QueryString("SELECT description FROM Descriptions"));
 
-  EXPECT_EQ(tag_id, QueryInt("SELECT tag_id FROM Tensors WHERE step = 1"));
+  EXPECT_EQ(tag_id, QueryInt("SELECT series FROM Tensors WHERE step = 1"));
   EXPECT_EQ(0.023,
             QueryDouble("SELECT computed_time FROM Tensors WHERE step = 1"));
-  EXPECT_FALSE(
-      QueryString("SELECT tensor FROM Tensors WHERE step = 1").empty());
 
-  EXPECT_EQ(tag_id, QueryInt("SELECT tag_id FROM Tensors WHERE step = 2"));
+  EXPECT_EQ(tag_id, QueryInt("SELECT series FROM Tensors WHERE step = 2"));
   EXPECT_EQ(0.046,
             QueryDouble("SELECT computed_time FROM Tensors WHERE step = 2"));
-  EXPECT_FALSE(
-      QueryString("SELECT tensor FROM Tensors WHERE step = 2").empty());
 }
 
 TEST_F(SummaryDbWriterTest, EmptyParentNames_NoParentsCreated) {
@@ -191,7 +188,7 @@ TEST_F(SummaryDbWriterTest, EmptyParentNames_NoParentsCreated) {
   ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Experiments"));
   ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Runs"));
   ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags"));
-  ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
+  ASSERT_EQ(10000LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
 }
 
 TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) {
@@ -208,33 +205,24 @@ TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) {
   TF_ASSERT_OK(writer_->WriteEvent(std::move(e)));
   TF_ASSERT_OK(writer_->Flush());
   ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tags"));
-  ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
+  ASSERT_EQ(20000LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
   int64 tag1_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = 'π'");
   int64 tag2_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = 'φ'");
   EXPECT_GT(tag1_id, 0LL);
   EXPECT_GT(tag2_id, 0LL);
   EXPECT_EQ(123.456, QueryDouble(strings::StrCat(
-                         "SELECT computed_time FROM Tensors WHERE tag_id = ",
+                         "SELECT computed_time FROM Tensors WHERE series = ",
                          tag1_id, " AND step = 7")));
   EXPECT_EQ(123.456, QueryDouble(strings::StrCat(
-                         "SELECT computed_time FROM Tensors WHERE tag_id = ",
+                         "SELECT computed_time FROM Tensors WHERE series = ",
                          tag2_id, " AND step = 7")));
-  EXPECT_NEAR(3.14,
-              QueryDouble(strings::StrCat(
-                  "SELECT tensor FROM Tensors WHERE tag_id = ", tag1_id,
-                  " AND step = 7")),
-              kTolerance);  // Summary::simple_value is float
-  EXPECT_NEAR(1.61,
-              QueryDouble(strings::StrCat(
-                  "SELECT tensor FROM Tensors WHERE tag_id = ", tag2_id,
-                  " AND step = 7")),
-              kTolerance);
 }
 
 TEST_F(SummaryDbWriterTest, WriteGraph) {
   TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "R", "", &env_, &writer_));
   env_.AdvanceByMillis(23);
   GraphDef graph;
+  graph.mutable_library()->add_gradient()->set_function_name("funk");
   NodeDef* node = graph.add_node();
   node->set_name("x");
   node->set_op("Placeholder");
@@ -260,11 +248,17 @@ TEST_F(SummaryDbWriterTest, WriteGraph) {
   ASSERT_EQ(4LL, QueryInt("SELECT COUNT(*) FROM Nodes"));
   ASSERT_EQ(3LL, QueryInt("SELECT COUNT(*) FROM NodeInputs"));
 
+  ASSERT_EQ(QueryInt("SELECT run_id FROM Runs"),
+            QueryInt("SELECT run_id FROM Graphs"));
+
   int64 graph_id = QueryInt("SELECT graph_id FROM Graphs");
   EXPECT_GT(graph_id, 0LL);
-  EXPECT_EQ(graph_id, QueryInt("SELECT graph_id FROM Runs"));
   EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Graphs"));
-  EXPECT_FALSE(QueryString("SELECT graph_def FROM Graphs").empty());
+
+  GraphDef graph2;
+  graph2.ParseFromString(QueryString("SELECT graph_def FROM Graphs"));
+  EXPECT_EQ(0, graph2.node_size());
+  EXPECT_EQ("funk", graph2.library().gradient(0).function_name());
 
   EXPECT_EQ("x", QueryString("SELECT node_name FROM Nodes WHERE node_id = 0"));
   EXPECT_EQ("y", QueryString("SELECT node_name FROM Nodes WHERE node_id = 1"));
@@ -307,33 +301,6 @@ TEST_F(SummaryDbWriterTest, WriteGraph) {
   EXPECT_EQ(1LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 2"));
 }
 
-TEST_F(SummaryDbWriterTest, WriteScalarInt32_CoercesToInt64) {
-  TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_));
-  Tensor t(DT_INT32, {});
-  t.scalar<int32>()() = -17;
-  TF_ASSERT_OK(writer_->WriteScalar(1, t, "t"));
-  TF_ASSERT_OK(writer_->Flush());
-  ASSERT_EQ(-17LL, QueryInt("SELECT tensor FROM Tensors"));
-}
-
-TEST_F(SummaryDbWriterTest, WriteScalarInt8_CoercesToInt64) {
-  TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_));
-  Tensor t(DT_INT8, {});
-  t.scalar<int8>()() = static_cast<int8>(-17);
-  TF_ASSERT_OK(writer_->WriteScalar(1, t, "t"));
-  TF_ASSERT_OK(writer_->Flush());
-  ASSERT_EQ(-17LL, QueryInt("SELECT tensor FROM Tensors"));
-}
-
-TEST_F(SummaryDbWriterTest, WriteScalarUint8_CoercesToInt64) {
-  TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_));
-  Tensor t(DT_UINT8, {});
-  t.scalar<uint8>()() = static_cast<uint8>(254);
-  TF_ASSERT_OK(writer_->WriteScalar(1, t, "t"));
-  TF_ASSERT_OK(writer_->Flush());
-  ASSERT_EQ(254LL, QueryInt("SELECT tensor FROM Tensors"));
-}
-
 TEST_F(SummaryDbWriterTest, UsesIdsTable) {
   SummaryMetadata metadata;
   TF_ASSERT_OK(CreateSummaryDbWriter(db_, "mad-science", "train", "jart", &env_,
index 72a857d01adcb0f9e4022b2b4a829b81572b5cda..878faf261bff022e86d693a848e1e314c4bd8b3a 100644 (file)
@@ -5930,6 +5930,7 @@ tf_kernel_library(
     srcs = ["summary_kernels.cc"],
     deps = [
         ":summary_interface",
+        "//tensorflow/contrib/tensorboard/db:schema",
         "//tensorflow/contrib/tensorboard/db:summary_db_writer",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
index 924bf27452b8a10851bbe6a991aace153f500982..029a0aab97290e30783e415274323a1e43f9740b 100644 (file)
@@ -34,9 +34,8 @@ Status SqliteQueryConnection::Open(const string& data_source_name,
     return errors::FailedPrecondition(
         "Failed to open query connection: Connection already opened.");
   }
-  TF_RETURN_IF_ERROR(Sqlite::Open(data_source_name,
-                                  SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE,
-                                  &db_));
+  TF_RETURN_IF_ERROR(Sqlite::Open(
+      data_source_name, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, &db_));
   query_ = query;
   output_types_ = output_types;
   return Status::OK();
@@ -87,6 +86,7 @@ void SqliteQueryConnection::FillTensorWithResultSetEntry(
 #define INT_CASE(T) CASE(T, ColumnInt)
 #define DOUBLE_CASE(T) CASE(T, ColumnDouble)
 #define STRING_CASE(T) CASE(T, ColumnString)
+  // clang-format off
   switch (data_type) {
     TF_CALL_int8(INT_CASE)
     TF_CALL_uint8(INT_CASE)
@@ -102,13 +102,13 @@ void SqliteQueryConnection::FillTensorWithResultSetEntry(
     case DT_BOOL:
       tensor->scalar<bool>()() = stmt_.ColumnInt(column_index) != 0;
       break;
-      // Error preemptively thrown by SqlDatasetOp::MakeDataset in this case.
-    default: {
+    // Error preemptively thrown by SqlDatasetOp::MakeDataset in this case.
+    default:
       LOG(FATAL)
           << "Use of unsupported TensorFlow data type by 'SqlQueryConnection': "
           << DataTypeString(data_type) << ".";
-    }
   }
+  // clang-format on
 }
 
 }  // namespace sql
index 2ecde70b519a4b9cd67a1f9f03880c59a4fb22aa..a815f540b10a3d6bc4cc98f39d72796c85734d84 100644 (file)
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
+#include "tensorflow/contrib/tensorboard/db/schema.h"
 #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/op_kernel.h"
@@ -70,6 +71,7 @@ class CreateSummaryDbWriterOp : public OpKernel {
                                      SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE,
                                      &db));
     core::ScopedUnref unref(db);
+    OP_REQUIRES_OK(ctx, SetupTensorboardSqliteDb(db));
     OP_REQUIRES_OK(
         ctx, CreateSummaryDbWriter(db, experiment_name,
                                    run_name, user_name, ctx->env(), &s));
index 31c7d0c1b68c7df7fd3c804b687edb2587bdfa9b..9ff87e8d66d2575966c703a896ac9ff0bc51661a 100644 (file)
@@ -5,12 +5,13 @@ package(default_visibility = ["//tensorflow:internal"])
 
 licenses(["notice"])  # Apache 2.0
 
-load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_copts")
 
 cc_library(
     name = "sqlite",
     srcs = ["sqlite.cc"],
     hdrs = ["sqlite.h"],
+    copts = tf_copts(),
     deps = [
         ":snapfn",
         "//tensorflow/core:lib",
@@ -22,7 +23,7 @@ cc_library(
 cc_library(
     name = "snapfn",
     srcs = ["snapfn.cc"],
-    copts = ["-DSQLITE_OMIT_LOAD_EXTENSION"],
+    copts = tf_copts() + ["-DSQLITE_OMIT_LOAD_EXTENSION"],
     linkstatic = 1,
     deps = [
         "@org_sqlite",
index 0683e8133d7fe77059d8b54484aebfc1e732a80e..cb6943379d4ebe38c79ba9097d4c3183c7b8c205 100644 (file)
@@ -173,11 +173,6 @@ Status Sqlite::Prepare(const StringPiece& sql, SqliteStatement* stmt) {
   return Status::OK();
 }
 
-SqliteStatement::~SqliteStatement() {
-  if (stmt_ != nullptr) sqlite3_finalize(stmt_);
-  if (db_ != nullptr) db_->Unref();
-}
-
 Status SqliteStatement::Step(bool* is_done) {
   DCHECK(stmt_ != nullptr);
   if (TF_PREDICT_FALSE(bind_error_ != SQLITE_OK)) {
index 2aa82560b955d5f9193fc56e0e12bf1b75e805c8..0faa458f1d692a103099d5b05d0400944ffdaad7 100644 (file)
@@ -105,7 +105,7 @@ class LOCKABLE Sqlite : public core::RefCounted {
   }
 
   /// \brief Returns rowid assigned to last successful insert.
-  int64 last_insert_row_id() const EXCLUSIVE_LOCKS_REQUIRED(this) {
+  int64 last_insert_rowid() const EXCLUSIVE_LOCKS_REQUIRED(this) {
     return sqlite3_last_insert_rowid(db_);
   }
 
@@ -151,7 +151,10 @@ class SqliteStatement {
   ///
   /// This can take milliseconds if it was blocking the Sqlite
   /// connection object from being freed.
-  ~SqliteStatement();
+  ~SqliteStatement() {
+    sqlite3_finalize(stmt_);
+    if (db_ != nullptr) db_->Unref();
+  }
 
   /// \brief Returns true if statement is initialized.
   explicit operator bool() const { return stmt_ != nullptr; }
@@ -432,6 +435,10 @@ class SCOPED_LOCKABLE SqliteTransaction {
   TF_DISALLOW_COPY_AND_ASSIGN(SqliteTransaction);
 };
 
+#define SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(...) \
+  EXCLUSIVE_LOCKS_REQUIRED(__VA_ARGS__)
+#define SQLITE_TRANSACTIONS_EXCLUDED(...) LOCKS_EXCLUDED(__VA_ARGS__)
+
 inline SqliteStatement Sqlite::PrepareOrDie(const StringPiece& sql) {
   SqliteStatement stmt;
   TF_CHECK_OK(Prepare(sql, &stmt));