Added function to Database interface to retrieve keys. Exposed a bug with LMDB itera...
authorKevin James Matzen <kmatzen@cs.cornell.edu>
Mon, 13 Oct 2014 01:36:59 +0000 (21:36 -0400)
committerKevin James Matzen <kmatzen@cs.cornell.edu>
Tue, 14 Oct 2014 23:31:30 +0000 (19:31 -0400)
include/caffe/database.hpp
include/caffe/leveldb_database.hpp
include/caffe/lmdb_database.hpp
src/caffe/leveldb_database.cpp
src/caffe/lmdb_database.cpp
src/caffe/test/test_database.cpp

index 2341149..8469e84 100644 (file)
@@ -33,6 +33,8 @@ class Database {
   virtual bool commit() = 0;
   virtual void close() = 0;
 
+  virtual void keys(vector<key_type>* keys) = 0;
+
   Database() { }
   virtual ~Database() { }
 
@@ -65,7 +67,8 @@ class Database {
 
     iterator(const iterator& other)
         : parent_(other.parent_),
-          state_(other.state_->clone()) { }
+          state_(other.state_ ? other.state_->clone()
+              : shared_ptr<DatabaseState>()) { }
 
     iterator& operator=(iterator copy) {
       copy.swap(*this);
@@ -86,12 +89,12 @@ class Database {
     }
 
     iterator& operator++() {
-      parent_->increment(state_);
+      parent_->increment(&state_);
       return *this;
     }
     iterator operator++(int) {
       iterator copy(*this);
-      parent_->increment(state_);
+      parent_->increment(&state_);
       return copy;
     }
 
@@ -117,7 +120,7 @@ class Database {
 
   virtual bool equal(shared_ptr<DatabaseState> state1,
       shared_ptr<DatabaseState> state2) const = 0;
-  virtual void increment(shared_ptr<DatabaseState> state) const = 0;
+  virtual void increment(shared_ptr<DatabaseState>* state) const = 0;
   virtual KV& dereference(
       shared_ptr<DatabaseState> state) const = 0;
 };
index 9c7f70e..48cf11e 100644 (file)
@@ -6,6 +6,7 @@
 
 #include <string>
 #include <utility>
+#include <vector>
 
 #include "caffe/common.hpp"
 #include "caffe/database.hpp"
@@ -20,6 +21,8 @@ class LeveldbDatabase : public Database {
   bool commit();
   void close();
 
+  void keys(vector<key_type>* keys);
+
   const_iterator begin() const;
   const_iterator cbegin() const;
   const_iterator end() const;
@@ -62,7 +65,7 @@ class LeveldbDatabase : public Database {
 
   bool equal(shared_ptr<DatabaseState> state1,
       shared_ptr<DatabaseState> state2) const;
-  void increment(shared_ptr<DatabaseState> state) const;
+  void increment(shared_ptr<DatabaseState>* state) const;
   Database::KV& dereference(shared_ptr<DatabaseState> state) const;
 
   shared_ptr<leveldb::DB> db_;
index 107d936..a8ce60d 100644 (file)
@@ -3,6 +3,7 @@
 
 #include <string>
 #include <utility>
+#include <vector>
 
 #include "lmdb.h"
 
@@ -24,6 +25,8 @@ class LmdbDatabase : public Database {
   bool commit();
   void close();
 
+  void keys(vector<key_type>* keys);
+
   const_iterator begin() const;
   const_iterator cbegin() const;
   const_iterator end() const;
@@ -66,7 +69,7 @@ class LmdbDatabase : public Database {
 
   bool equal(shared_ptr<DatabaseState> state1,
       shared_ptr<DatabaseState> state2) const;
-  void increment(shared_ptr<DatabaseState> state) const;
+  void increment(shared_ptr<DatabaseState>* state) const;
   Database::KV& dereference(shared_ptr<DatabaseState> state) const;
 
   MDB_env* env_;
index 2061b8e..ad98f9e 100644 (file)
@@ -1,5 +1,6 @@
 #include <string>
 #include <utility>
+#include <vector>
 
 #include "caffe/leveldb_database.hpp"
 
@@ -115,6 +116,16 @@ void LeveldbDatabase::close() {
   db_.reset();
 }
 
+void LeveldbDatabase::keys(vector<key_type>* keys) {
+  LOG(INFO) << "LevelDB: Keys";
+
+  keys->clear();
+  for (Database::const_iterator iter = begin(); iter != end(); ++iter) {
+    LOG(INFO) << "KEY";
+    keys->push_back(iter->key);
+  }
+}
+
 LeveldbDatabase::const_iterator LeveldbDatabase::begin() const {
   CHECK_NOTNULL(db_.get());
   shared_ptr<leveldb::Iterator> iter(db_->NewIterator(leveldb::ReadOptions()));
@@ -122,13 +133,16 @@ LeveldbDatabase::const_iterator LeveldbDatabase::begin() const {
   if (!iter->Valid()) {
     iter.reset();
   }
-  shared_ptr<DatabaseState> state(new LeveldbState(db_, iter));
+
+  shared_ptr<DatabaseState> state;
+  if (iter) {
+    state.reset(new LeveldbState(db_, iter));
+  }
   return const_iterator(this, state);
 }
 
 LeveldbDatabase::const_iterator LeveldbDatabase::end() const {
-  shared_ptr<leveldb::Iterator> iter;
-  shared_ptr<DatabaseState> state(new LeveldbState(db_, iter));
+  shared_ptr<DatabaseState> state;
   return const_iterator(this, state);
 }
 
@@ -143,25 +157,20 @@ bool LeveldbDatabase::equal(shared_ptr<DatabaseState> state1,
   shared_ptr<LeveldbState> leveldb_state1 =
       boost::dynamic_pointer_cast<LeveldbState>(state1);
 
-  CHECK_NOTNULL(leveldb_state1.get());
-
   shared_ptr<LeveldbState> leveldb_state2 =
       boost::dynamic_pointer_cast<LeveldbState>(state2);
 
-  CHECK_NOTNULL(leveldb_state2.get());
-
-  CHECK(!leveldb_state1->iter_ || leveldb_state1->iter_->Valid());
-  CHECK(!leveldb_state2->iter_ || leveldb_state2->iter_->Valid());
+  LOG(INFO) << leveldb_state1 << " " << leveldb_state2;
 
   // The KV store doesn't really have any sort of ordering,
   // so while we can do a sequential scan over the collection,
   // we can't really use subranges.
-  return !leveldb_state1->iter_ && !leveldb_state2->iter_;
+  return !leveldb_state1 && !leveldb_state2;
 }
 
-void LeveldbDatabase::increment(shared_ptr<DatabaseState> state) const {
+void LeveldbDatabase::increment(shared_ptr<DatabaseState>* state) const {
   shared_ptr<LeveldbState> leveldb_state =
-      boost::dynamic_pointer_cast<LeveldbState>(state);
+      boost::dynamic_pointer_cast<LeveldbState>(*state);
 
   CHECK_NOTNULL(leveldb_state.get());
 
@@ -172,7 +181,7 @@ void LeveldbDatabase::increment(shared_ptr<DatabaseState> state) const {
 
   iter->Next();
   if (!iter->Valid()) {
-    iter.reset();
+    state->reset();
   }
 }
 
index eb22d15..aabb733 100644 (file)
@@ -2,6 +2,7 @@
 
 #include <string>
 #include <utility>
+#include <vector>
 
 #include "caffe/lmdb_database.hpp"
 
@@ -168,22 +169,43 @@ void LmdbDatabase::close() {
   }
 }
 
+void LmdbDatabase::keys(vector<key_type>* keys) {
+  LOG(INFO) << "LMDB: Keys";
+
+  keys->clear();
+  for (Database::const_iterator iter = begin(); iter != end(); ++iter) {
+    keys->push_back(iter->key);
+  }
+}
+
 LmdbDatabase::const_iterator LmdbDatabase::begin() const {
-  MDB_cursor* cursor;
   int retval;
-  retval = mdb_cursor_open(txn_, dbi_, &cursor);
+
+  MDB_txn* iter_txn;
+
+  retval = mdb_txn_begin(env_, NULL, MDB_RDONLY, &iter_txn);
+  CHECK_EQ(MDB_SUCCESS, retval) << "mdb_txn_begin failed "
+      << mdb_strerror(retval);
+
+  MDB_cursor* cursor;
+  retval = mdb_cursor_open(iter_txn, dbi_, &cursor);
   CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval);
   MDB_val key;
   MDB_val val;
   retval = mdb_cursor_get(cursor, &key, &val, MDB_FIRST);
-  CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval);
 
-  shared_ptr<DatabaseState> state(new LmdbState(cursor, txn_, &dbi_));
+  CHECK(MDB_SUCCESS == retval || MDB_NOTFOUND == retval)
+      << mdb_strerror(retval);
+
+  shared_ptr<DatabaseState> state;
+  if (MDB_SUCCESS == retval) {
+    state.reset(new LmdbState(cursor, iter_txn, &dbi_));
+  }
   return const_iterator(this, state);
 }
 
 LmdbDatabase::const_iterator LmdbDatabase::end() const {
-  shared_ptr<DatabaseState> state(new LmdbState(NULL, txn_, &dbi_));
+  shared_ptr<DatabaseState> state;
   return const_iterator(this, state);
 }
 
@@ -195,22 +217,18 @@ bool LmdbDatabase::equal(shared_ptr<DatabaseState> state1,
   shared_ptr<LmdbState> lmdb_state1 =
       boost::dynamic_pointer_cast<LmdbState>(state1);
 
-  CHECK_NOTNULL(lmdb_state1.get());
-
   shared_ptr<LmdbState> lmdb_state2 =
       boost::dynamic_pointer_cast<LmdbState>(state2);
 
-  CHECK_NOTNULL(lmdb_state2.get());
-
   // The KV store doesn't really have any sort of ordering,
   // so while we can do a sequential scan over the collection,
   // we can't really use subranges.
-  return !lmdb_state1->cursor_ && !lmdb_state2->cursor_;
+  return !lmdb_state1 && !lmdb_state2;
 }
 
-void LmdbDatabase::increment(shared_ptr<DatabaseState> state) const {
+void LmdbDatabase::increment(shared_ptr<DatabaseState>* state) const {
   shared_ptr<LmdbState> lmdb_state =
-      boost::dynamic_pointer_cast<LmdbState>(state);
+      boost::dynamic_pointer_cast<LmdbState>(*state);
 
   CHECK_NOTNULL(lmdb_state.get());
 
@@ -223,7 +241,7 @@ void LmdbDatabase::increment(shared_ptr<DatabaseState> state) const {
   int retval = mdb_cursor_get(cursor, &key, &val, MDB_NEXT);
   if (MDB_NOTFOUND == retval) {
     mdb_cursor_close(cursor);
-    cursor = NULL;
+    state->reset();
   } else {
     CHECK_EQ(MDB_SUCCESS, retval) << mdb_strerror(retval);
   }
index ad1c786..9c56910 100644 (file)
@@ -112,6 +112,56 @@ TYPED_TEST(DatabaseTest, TestReadWriteDoesntExistLevelDBPasses) {
   database->close();
 }
 
+TYPED_TEST(DatabaseTest, TestKeysLevelDB) {
+  string name = this->DBName();
+  shared_ptr<Database> database = DatabaseFactory("leveldb");
+  EXPECT_TRUE(database->open(name, Database::New));
+
+  Database::key_type key1 = this->TestKey();
+  Database::value_type value1 = this->TestValue();
+
+  EXPECT_TRUE(database->put(key1, value1));
+
+  Database::key_type key2 = this->TestAltKey();
+  Database::value_type value2 = this->TestAltValue();
+
+  EXPECT_TRUE(database->put(key2, value2));
+
+  EXPECT_TRUE(database->commit());
+
+  vector<Database::key_type> keys;
+  database->keys(&keys);
+
+  EXPECT_EQ(2, keys.size());
+
+  EXPECT_TRUE(this->BufferEq(keys.at(0), key1) ||
+      this->BufferEq(keys.at(0), key2));
+  EXPECT_TRUE(this->BufferEq(keys.at(1), key1) ||
+      this->BufferEq(keys.at(2), key2));
+  EXPECT_FALSE(this->BufferEq(keys.at(0), keys.at(1)));
+}
+
+TYPED_TEST(DatabaseTest, TestKeysNoCommitLevelDB) {
+  string name = this->DBName();
+  shared_ptr<Database> database = DatabaseFactory("leveldb");
+  EXPECT_TRUE(database->open(name, Database::New));
+
+  Database::key_type key1 = this->TestKey();
+  Database::value_type value1 = this->TestValue();
+
+  EXPECT_TRUE(database->put(key1, value1));
+
+  Database::key_type key2 = this->TestAltKey();
+  Database::value_type value2 = this->TestAltValue();
+
+  EXPECT_TRUE(database->put(key2, value2));
+
+  vector<Database::key_type> keys;
+  database->keys(&keys);
+
+  EXPECT_EQ(0, keys.size());
+}
+
 TYPED_TEST(DatabaseTest, TestIteratorsLevelDB) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("leveldb");
@@ -451,6 +501,57 @@ TYPED_TEST(DatabaseTest, TestReadWriteDoesntExistLMDBPasses) {
   database->close();
 }
 
+TYPED_TEST(DatabaseTest, TestKeysLMDB) {
+  string name = this->DBName();
+  shared_ptr<Database> database = DatabaseFactory("lmdb");
+  EXPECT_TRUE(database->open(name, Database::New));
+
+  Database::key_type key1 = this->TestKey();
+  Database::value_type value1 = this->TestValue();
+
+  EXPECT_TRUE(database->put(key1, value1));
+
+  Database::key_type key2 = this->TestAltKey();
+  Database::value_type value2 = this->TestAltValue();
+
+  EXPECT_TRUE(database->put(key2, value2));
+
+  EXPECT_TRUE(database->commit());
+
+  vector<Database::key_type> keys;
+  database->keys(&keys);
+
+  EXPECT_EQ(2, keys.size());
+
+  EXPECT_TRUE(this->BufferEq(keys.at(0), key1) ||
+      this->BufferEq(keys.at(0), key2));
+  EXPECT_TRUE(this->BufferEq(keys.at(1), key1) ||
+      this->BufferEq(keys.at(2), key2));
+  EXPECT_FALSE(this->BufferEq(keys.at(0), keys.at(1)));
+}
+
+TYPED_TEST(DatabaseTest, TestKeysNoCommitLMDB) {
+  string name = this->DBName();
+  shared_ptr<Database> database = DatabaseFactory("lmdb");
+  EXPECT_TRUE(database->open(name, Database::New));
+
+  Database::key_type key1 = this->TestKey();
+  Database::value_type value1 = this->TestValue();
+
+  EXPECT_TRUE(database->put(key1, value1));
+
+  Database::key_type key2 = this->TestAltKey();
+  Database::value_type value2 = this->TestAltValue();
+
+  EXPECT_TRUE(database->put(key2, value2));
+
+  vector<Database::key_type> keys;
+  database->keys(&keys);
+
+  EXPECT_EQ(0, keys.size());
+}
+
+
 TYPED_TEST(DatabaseTest, TestIteratorsLMDB) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("lmdb");