Remove reservoir sampling from SummaryDbWriter
authorJustine Tunney <jart@google.com>
Tue, 22 May 2018 22:30:02 +0000 (15:30 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 22 May 2018 22:32:36 +0000 (15:32 -0700)
PiperOrigin-RevId: 197634162

tensorflow/contrib/tensorboard/db/summary_db_writer.cc
tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc

index 6590d6f..d5d8e41 100644 (file)
@@ -14,6 +14,8 @@ limitations under the License.
 ==============================================================================*/
 #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h"
 
+#include <deque>
+
 #include "tensorflow/contrib/tensorboard/db/summary_converter.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/node_def.pb.h"
@@ -66,14 +68,9 @@ const char* kImagePluginName = "images";
 const char* kAudioPluginName = "audio";
 const char* kHistogramPluginName = "histograms";
 
-const int kScalarSlots = 10000;
-const int kImageSlots = 10;
-const int kAudioSlots = 10;
-const int kHistogramSlots = 1;
-const int kTensorSlots = 10;
-
 const int64 kReserveMinBytes = 32;
 const double kReserveMultiplier = 1.5;
+const int64 kPreallocateRows = 1000;
 
 // Flush is a misnomer because what we're actually doing is having lots
 // of commits inside any SqliteTransaction that writes potentially
@@ -139,22 +136,6 @@ void PatchPluginName(SummaryMetadata* metadata, const char* name) {
   }
 }
 
-int GetSlots(const Tensor& t, const SummaryMetadata& metadata) {
-  if (metadata.plugin_data().plugin_name() == kScalarPluginName) {
-    return kScalarSlots;
-  } else if (metadata.plugin_data().plugin_name() == kImagePluginName) {
-    return kImageSlots;
-  } else if (metadata.plugin_data().plugin_name() == kAudioPluginName) {
-    return kAudioSlots;
-  } else if (metadata.plugin_data().plugin_name() == kHistogramPluginName) {
-    return kHistogramSlots;
-  } else if (t.dims() == 0 && t.dtype() != DT_STRING) {
-    return kScalarSlots;
-  } else {
-    return kTensorSlots;
-  }
-}
-
 Status SetDescription(Sqlite* db, int64 id, const StringPiece& markdown) {
   const char* sql = R"sql(
     INSERT OR REPLACE INTO Descriptions (id, description) VALUES (?, ?)
@@ -481,24 +462,6 @@ class RunMetadata {
     return insert.StepAndReset();
   }
 
-  Status GetIsWatching(Sqlite* db, bool* is_watching)
-      SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) {
-    mutex_lock lock(mu_);
-    if (experiment_id_ == kAbsent) {
-      *is_watching = true;
-      return Status::OK();
-    }
-    const char* sql = R"sql(
-      SELECT is_watching FROM Experiments WHERE experiment_id = ?
-    )sql";
-    SqliteStatement stmt;
-    TF_RETURN_IF_ERROR(db->Prepare(sql, &stmt));
-    stmt.BindInt(1, experiment_id_);
-    TF_RETURN_IF_ERROR(stmt.StepOnce());
-    *is_watching = stmt.ColumnInt(0) != 0;
-    return Status::OK();
-  }
-
  private:
   Status InitializeUser(Sqlite* db, uint64 now) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
     if (user_id_ != kAbsent || user_name_.empty()) return Status::OK();
@@ -659,43 +622,15 @@ class RunMetadata {
 
 /// \brief Tensor writer for a single series, e.g. Tag.
 ///
-/// This class can be used to write an infinite stream of Tensors to the
-/// database in a fixed block of contiguous disk space. This is
-/// accomplished using Algorithm R reservoir sampling.
-///
-/// The reservoir consists of a fixed number of rows, which are inserted
-/// using ZEROBLOB upon receiving the first sample, which is used to
-/// predict how big the other ones are likely to be. This is done
-/// transactionally in a way that tries to be mindful of other processes
-/// that might be trying to access the same DB.
-///
-/// Once the reservoir fills up, rows are replaced at random, and writes
-/// gradually become no-ops. This allows long training to go fast
-/// without configuration. The exception is when someone is actually
-/// looking at TensorBoard. When that happens, the "keep last" behavior
-/// is turned on and Append() will always result in a write.
-///
-/// If no one is watching training, this class still holds on to the
-/// most recent "dangling" Tensor, so if Finish() is called, the most
-/// recent training state can be written to disk.
-///
-/// The randomly selected sampling points should be consistent across
-/// multiple instances.
-///
 /// This class is thread safe.
 class SeriesWriter {
  public:
-  SeriesWriter(int64 series, int slots, RunMetadata* meta)
-      : series_{series},
-        slots_{slots},
-        meta_{meta},
-        rng_{std::mt19937_64::default_seed} {
+  SeriesWriter(int64 series, RunMetadata* meta) : series_{series}, meta_{meta} {
     DCHECK(series_ > 0);
-    DCHECK(slots_ > 0);
   }
 
   Status Append(Sqlite* db, int64 step, uint64 now, double computed_time,
-                Tensor t) SQLITE_TRANSACTIONS_EXCLUDED(*db)
+                const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db)
       LOCKS_EXCLUDED(mu_) {
     mutex_lock lock(mu_);
     if (rowids_.empty()) {
@@ -705,41 +640,20 @@ class SeriesWriter {
         return s;
       }
     }
-    DCHECK(rowids_.size() == slots_);
-    int64 rowid;
-    size_t i = count_;
-    if (i < slots_) {
-      rowid = last_rowid_ = rowids_[i];
-    } else {
-      i = rng_() % (i + 1);
-      if (i < slots_) {
-        rowid = last_rowid_ = rowids_[i];
-      } else {
-        bool keep_last;
-        TF_RETURN_IF_ERROR(meta_->GetIsWatching(db, &keep_last));
-        if (!keep_last) {
-          ++count_;
-          dangling_tensor_.reset(new Tensor(std::move(t)));
-          dangling_step_ = step;
-          dangling_computed_time_ = computed_time;
-          return Status::OK();
-        }
-        rowid = last_rowid_;
-      }
-    }
+    int64 rowid = rowids_.front();
     Status s = Write(db, rowid, step, computed_time, t);
     if (s.ok()) {
       ++count_;
-      dangling_tensor_.reset();
     }
+    rowids_.pop_front();
     return s;
   }
 
   Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db)
       LOCKS_EXCLUDED(mu_) {
     mutex_lock lock(mu_);
-    // Short runs: Delete unused pre-allocated Tensors.
-    if (count_ < rowids_.size()) {
+    // Delete unused pre-allocated Tensors.
+    if (!rowids_.empty()) {
       SqliteTransaction txn(*db);
       const char* sql = R"sql(
         DELETE FROM Tensors WHERE rowid = ?
@@ -747,19 +661,13 @@ class SeriesWriter {
       SqliteStatement deleter;
       TF_RETURN_IF_ERROR(db->Prepare(sql, &deleter));
       for (size_t i = count_; i < rowids_.size(); ++i) {
-        deleter.BindInt(1, rowids_[i]);
+        deleter.BindInt(1, rowids_.front());
         TF_RETURN_IF_ERROR(deleter.StepAndReset());
+        rowids_.pop_front();
       }
       TF_RETURN_IF_ERROR(txn.Commit());
       rowids_.clear();
     }
-    // Long runs: Make last sample be the very most recent one.
-    if (dangling_tensor_) {
-      DCHECK(last_rowid_ != kAbsent);
-      TF_RETURN_IF_ERROR(Write(db, last_rowid_, dangling_step_,
-                               dangling_computed_time_, *dangling_tensor_));
-      dangling_tensor_.reset();
-    }
     return Status::OK();
   }
 
@@ -783,7 +691,6 @@ class SeriesWriter {
 
   Status Update(Sqlite* db, int64 step, double computed_time, const Tensor& t,
                 const StringPiece& data, int64 rowid) {
-    // TODO(jart): How can we ensure reservoir fills on replace?
     const char* sql = R"sql(
       UPDATE OR REPLACE
         Tensors
@@ -878,7 +785,7 @@ class SeriesWriter {
     // TODO(jart): Maybe preallocate index pages by setting step. This
     //             is tricky because UPDATE OR REPLACE can have a side
     //             effect of deleting preallocated rows.
-    for (int64 i = 0; i < slots_; ++i) {
+    for (int64 i = 0; i < kPreallocateRows; ++i) {
       insert.BindInt(1, series_);
       insert.BindInt(2, reserved_bytes);
       TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), "i=", i);
@@ -902,16 +809,10 @@ class SeriesWriter {
 
   mutex mu_;
   const int64 series_;
-  const int slots_;
   RunMetadata* const meta_;
-  std::mt19937_64 rng_ GUARDED_BY(mu_);
   uint64 count_ GUARDED_BY(mu_) = 0;
-  int64 last_rowid_ GUARDED_BY(mu_) = kAbsent;
-  std::vector<int64> rowids_ GUARDED_BY(mu_);
+  std::deque<int64> rowids_ GUARDED_BY(mu_);
   uint64 unflushed_bytes_ GUARDED_BY(mu_) = 0;
-  std::unique_ptr<Tensor> dangling_tensor_ GUARDED_BY(mu_);
-  int64 dangling_step_ GUARDED_BY(mu_) = 0;
-  double dangling_computed_time_ GUARDED_BY(mu_) = 0.0;
 
   TF_DISALLOW_COPY_AND_ASSIGN(SeriesWriter);
 };
@@ -928,10 +829,10 @@ class RunWriter {
   explicit RunWriter(RunMetadata* meta) : meta_{meta} {}
 
   Status Append(Sqlite* db, int64 tag_id, int64 step, uint64 now,
-                double computed_time, Tensor t, int slots)
+                double computed_time, const Tensor& t)
       SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) {
-    SeriesWriter* writer = GetSeriesWriter(tag_id, slots);
-    return writer->Append(db, step, now, computed_time, std::move(t));
+    SeriesWriter* writer = GetSeriesWriter(tag_id);
+    return writer->Append(db, step, now, computed_time, t);
   }
 
   Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db)
@@ -948,11 +849,11 @@ class RunWriter {
   }
 
  private:
-  SeriesWriter* GetSeriesWriter(int64 tag_id, int slots) LOCKS_EXCLUDED(mu_) {
+  SeriesWriter* GetSeriesWriter(int64 tag_id) LOCKS_EXCLUDED(mu_) {
     mutex_lock sl(mu_);
     auto spot = series_writers_.find(tag_id);
     if (spot == series_writers_.end()) {
-      SeriesWriter* writer = new SeriesWriter(tag_id, slots, meta_);
+      SeriesWriter* writer = new SeriesWriter(tag_id, meta_);
       series_writers_[tag_id].reset(writer);
       return writer;
     } else {
@@ -1082,8 +983,7 @@ class SummaryDbWriter : public SummaryWriterInterface {
     TF_RETURN_IF_ERROR(
         meta_.GetTagId(db_, now, computed_time, tag, &tag_id, metadata));
     TF_RETURN_WITH_CONTEXT_IF_ERROR(
-        run_.Append(db_, tag_id, step, now, computed_time, t,
-                    GetSlots(t, metadata)),
+        run_.Append(db_, tag_id, step, now, computed_time, t),
         meta_.user_name(), "/", meta_.experiment_name(), "/", meta_.run_name(),
         "/", tag, "@", step);
     return Status::OK();
@@ -1155,8 +1055,7 @@ class SummaryDbWriter : public SummaryWriterInterface {
     int64 tag_id;
     TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
                                       &tag_id, s->metadata()));
-    return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t,
-                       GetSlots(t, s->metadata()));
+    return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t);
   }
 
   // TODO(jart): Refactor Summary -> Tensor logic into separate file.
@@ -1169,8 +1068,7 @@ class SummaryDbWriter : public SummaryWriterInterface {
     PatchPluginName(s->mutable_metadata(), kScalarPluginName);
     TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
                                       &tag_id, s->metadata()));
-    return run_.Append(db_, tag_id, e->step(), now, e->wall_time(),
-                       std::move(t), kScalarSlots);
+    return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t);
   }
 
   Status MigrateHistogram(const Event* e, Summary::Value* s, uint64 now) {
@@ -1195,8 +1093,7 @@ class SummaryDbWriter : public SummaryWriterInterface {
     PatchPluginName(s->mutable_metadata(), kHistogramPluginName);
     TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
                                       &tag_id, s->metadata()));
-    return run_.Append(db_, tag_id, e->step(), now, e->wall_time(),
-                       std::move(t), kHistogramSlots);
+    return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t);
   }
 
   Status MigrateImage(const Event* e, Summary::Value* s, uint64 now) {
@@ -1210,8 +1107,7 @@ class SummaryDbWriter : public SummaryWriterInterface {
     PatchPluginName(s->mutable_metadata(), kImagePluginName);
     TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
                                       &tag_id, s->metadata()));
-    return run_.Append(db_, tag_id, e->step(), now, e->wall_time(),
-                       std::move(t), kImageSlots);
+    return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t);
   }
 
   Status MigrateAudio(const Event* e, Summary::Value* s, uint64 now) {
@@ -1224,8 +1120,7 @@ class SummaryDbWriter : public SummaryWriterInterface {
     PatchPluginName(s->mutable_metadata(), kAudioPluginName);
     TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
                                       &tag_id, s->metadata()));
-    return run_.Append(db_, tag_id, e->step(), now, e->wall_time(),
-                       std::move(t), kAudioSlots);
+    return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t);
   }
 
   Env* const env_;
index 29b8063..c34b676 100644 (file)
@@ -139,7 +139,7 @@ TEST_F(SummaryDbWriterTest, TensorsWritten_RowsGetInitialized) {
   ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Experiments"));
   ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Runs"));
   ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags"));
-  ASSERT_EQ(10000LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
+  ASSERT_EQ(1000LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
 
   int64 user_id = QueryInt("SELECT user_id FROM Users");
   int64 experiment_id = QueryInt("SELECT experiment_id FROM Experiments");
@@ -188,7 +188,7 @@ TEST_F(SummaryDbWriterTest, EmptyParentNames_NoParentsCreated) {
   ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Experiments"));
   ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Runs"));
   ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags"));
-  ASSERT_EQ(10000LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
+  ASSERT_EQ(1000LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
 }
 
 TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) {
@@ -205,7 +205,7 @@ TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) {
   TF_ASSERT_OK(writer_->WriteEvent(std::move(e)));
   TF_ASSERT_OK(writer_->Flush());
   ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tags"));
-  ASSERT_EQ(20000LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
+  ASSERT_EQ(2000LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
   int64 tag1_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = 'π'");
   int64 tag2_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = 'φ'");
   EXPECT_GT(tag1_id, 0LL);