From 08e2cdedbbe4d7178597d6dc1479eccfcce3b0d5 Mon Sep 17 00:00:00 2001 From: Kevin James Matzen Date: Sun, 12 Oct 2014 21:36:59 -0400 Subject: [PATCH] Added function to Database interface to retrieve keys. Exposed a bug with LMDB iterators. Fix the bug and updated how invalid iterators are represented. --- include/caffe/database.hpp | 11 ++-- include/caffe/leveldb_database.hpp | 5 +- include/caffe/lmdb_database.hpp | 5 +- src/caffe/leveldb_database.cpp | 35 ++++++++----- src/caffe/lmdb_database.cpp | 44 +++++++++++----- src/caffe/test/test_database.cpp | 101 +++++++++++++++++++++++++++++++++++++ 6 files changed, 169 insertions(+), 32 deletions(-) diff --git a/include/caffe/database.hpp b/include/caffe/database.hpp index 2341149..8469e84 100644 --- a/include/caffe/database.hpp +++ b/include/caffe/database.hpp @@ -33,6 +33,8 @@ class Database { virtual bool commit() = 0; virtual void close() = 0; + virtual void keys(vector* keys) = 0; + Database() { } virtual ~Database() { } @@ -65,7 +67,8 @@ class Database { iterator(const iterator& other) : parent_(other.parent_), - state_(other.state_->clone()) { } + state_(other.state_ ? other.state_->clone() + : shared_ptr()) { } iterator& operator=(iterator copy) { copy.swap(*this); @@ -86,12 +89,12 @@ class Database { } iterator& operator++() { - parent_->increment(state_); + parent_->increment(&state_); return *this; } iterator operator++(int) { iterator copy(*this); - parent_->increment(state_); + parent_->increment(&state_); return copy; } @@ -117,7 +120,7 @@ class Database { virtual bool equal(shared_ptr state1, shared_ptr state2) const = 0; - virtual void increment(shared_ptr state) const = 0; + virtual void increment(shared_ptr* state) const = 0; virtual KV& dereference( shared_ptr state) const = 0; }; diff --git a/include/caffe/leveldb_database.hpp b/include/caffe/leveldb_database.hpp index 9c7f70e..48cf11e 100644 --- a/include/caffe/leveldb_database.hpp +++ b/include/caffe/leveldb_database.hpp @@ -6,6 +6,7 @@ #include #include +#include #include "caffe/common.hpp" #include "caffe/database.hpp" @@ -20,6 +21,8 @@ class LeveldbDatabase : public Database { bool commit(); void close(); + void keys(vector* keys); + const_iterator begin() const; const_iterator cbegin() const; const_iterator end() const; @@ -62,7 +65,7 @@ class LeveldbDatabase : public Database { bool equal(shared_ptr state1, shared_ptr state2) const; - void increment(shared_ptr state) const; + void increment(shared_ptr* state) const; Database::KV& dereference(shared_ptr state) const; shared_ptr db_; diff --git a/include/caffe/lmdb_database.hpp b/include/caffe/lmdb_database.hpp index 107d936..a8ce60d 100644 --- a/include/caffe/lmdb_database.hpp +++ b/include/caffe/lmdb_database.hpp @@ -3,6 +3,7 @@ #include #include +#include #include "lmdb.h" @@ -24,6 +25,8 @@ class LmdbDatabase : public Database { bool commit(); void close(); + void keys(vector* keys); + const_iterator begin() const; const_iterator cbegin() const; const_iterator end() const; @@ -66,7 +69,7 @@ class LmdbDatabase : public Database { bool equal(shared_ptr state1, shared_ptr state2) const; - void increment(shared_ptr state) const; + void increment(shared_ptr* state) const; Database::KV& dereference(shared_ptr state) const; MDB_env* env_; diff --git a/src/caffe/leveldb_database.cpp b/src/caffe/leveldb_database.cpp index 2061b8e..ad98f9e 100644 --- a/src/caffe/leveldb_database.cpp +++ b/src/caffe/leveldb_database.cpp @@ -1,5 +1,6 @@ #include #include +#include #include "caffe/leveldb_database.hpp" @@ -115,6 +116,16 @@ void LeveldbDatabase::close() { db_.reset(); } +void LeveldbDatabase::keys(vector* keys) { + LOG(INFO) << "LevelDB: Keys"; + + keys->clear(); + for (Database::const_iterator iter = begin(); iter != end(); ++iter) { + LOG(INFO) << "KEY"; + keys->push_back(iter->key); + } +} + LeveldbDatabase::const_iterator LeveldbDatabase::begin() const { CHECK_NOTNULL(db_.get()); shared_ptr iter(db_->NewIterator(leveldb::ReadOptions())); @@ -122,13 +133,16 @@ LeveldbDatabase::const_iterator LeveldbDatabase::begin() const { if (!iter->Valid()) { iter.reset(); } - shared_ptr state(new LeveldbState(db_, iter)); + + shared_ptr state; + if (iter) { + state.reset(new LeveldbState(db_, iter)); + } return const_iterator(this, state); } LeveldbDatabase::const_iterator LeveldbDatabase::end() const { - shared_ptr iter; - shared_ptr state(new LeveldbState(db_, iter)); + shared_ptr state; return const_iterator(this, state); } @@ -143,25 +157,20 @@ bool LeveldbDatabase::equal(shared_ptr state1, shared_ptr leveldb_state1 = boost::dynamic_pointer_cast(state1); - CHECK_NOTNULL(leveldb_state1.get()); - shared_ptr leveldb_state2 = boost::dynamic_pointer_cast(state2); - CHECK_NOTNULL(leveldb_state2.get()); - - CHECK(!leveldb_state1->iter_ || leveldb_state1->iter_->Valid()); - CHECK(!leveldb_state2->iter_ || leveldb_state2->iter_->Valid()); + LOG(INFO) << leveldb_state1 << " " << leveldb_state2; // The KV store doesn't really have any sort of ordering, // so while we can do a sequential scan over the collection, // we can't really use subranges. - return !leveldb_state1->iter_ && !leveldb_state2->iter_; + return !leveldb_state1 && !leveldb_state2; } -void LeveldbDatabase::increment(shared_ptr state) const { +void LeveldbDatabase::increment(shared_ptr* state) const { shared_ptr leveldb_state = - boost::dynamic_pointer_cast(state); + boost::dynamic_pointer_cast(*state); CHECK_NOTNULL(leveldb_state.get()); @@ -172,7 +181,7 @@ void LeveldbDatabase::increment(shared_ptr state) const { iter->Next(); if (!iter->Valid()) { - iter.reset(); + state->reset(); } } diff --git a/src/caffe/lmdb_database.cpp b/src/caffe/lmdb_database.cpp index eb22d15..aabb733 100644 --- a/src/caffe/lmdb_database.cpp +++ b/src/caffe/lmdb_database.cpp @@ -2,6 +2,7 @@ #include #include +#include #include "caffe/lmdb_database.hpp" @@ -168,22 +169,43 @@ void LmdbDatabase::close() { } } +void LmdbDatabase::keys(vector* keys) { + LOG(INFO) << "LMDB: Keys"; + + keys->clear(); + for (Database::const_iterator iter = begin(); iter != end(); ++iter) { + keys->push_back(iter->key); + } +} + LmdbDatabase::const_iterator LmdbDatabase::begin() const { - MDB_cursor* cursor; int retval; - retval = mdb_cursor_open(txn_, dbi_, &cursor); + + MDB_txn* iter_txn; + + retval = mdb_txn_begin(env_, NULL, MDB_RDONLY, &iter_txn); + CHECK_EQ(MDB_SUCCESS, retval) << "mdb_txn_begin failed " + << mdb_strerror(retval); + + MDB_cursor* cursor; + retval = mdb_cursor_open(iter_txn, dbi_, &cursor); CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval); MDB_val key; MDB_val val; retval = mdb_cursor_get(cursor, &key, &val, MDB_FIRST); - CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval); - shared_ptr state(new LmdbState(cursor, txn_, &dbi_)); + CHECK(MDB_SUCCESS == retval || MDB_NOTFOUND == retval) + << mdb_strerror(retval); + + shared_ptr state; + if (MDB_SUCCESS == retval) { + state.reset(new LmdbState(cursor, iter_txn, &dbi_)); + } return const_iterator(this, state); } LmdbDatabase::const_iterator LmdbDatabase::end() const { - shared_ptr state(new LmdbState(NULL, txn_, &dbi_)); + shared_ptr state; return const_iterator(this, state); } @@ -195,22 +217,18 @@ bool LmdbDatabase::equal(shared_ptr state1, shared_ptr lmdb_state1 = boost::dynamic_pointer_cast(state1); - CHECK_NOTNULL(lmdb_state1.get()); - shared_ptr lmdb_state2 = boost::dynamic_pointer_cast(state2); - CHECK_NOTNULL(lmdb_state2.get()); - // The KV store doesn't really have any sort of ordering, // so while we can do a sequential scan over the collection, // we can't really use subranges. - return !lmdb_state1->cursor_ && !lmdb_state2->cursor_; + return !lmdb_state1 && !lmdb_state2; } -void LmdbDatabase::increment(shared_ptr state) const { +void LmdbDatabase::increment(shared_ptr* state) const { shared_ptr lmdb_state = - boost::dynamic_pointer_cast(state); + boost::dynamic_pointer_cast(*state); CHECK_NOTNULL(lmdb_state.get()); @@ -223,7 +241,7 @@ void LmdbDatabase::increment(shared_ptr state) const { int retval = mdb_cursor_get(cursor, &key, &val, MDB_NEXT); if (MDB_NOTFOUND == retval) { mdb_cursor_close(cursor); - cursor = NULL; + state->reset(); } else { CHECK_EQ(MDB_SUCCESS, retval) << mdb_strerror(retval); } diff --git a/src/caffe/test/test_database.cpp b/src/caffe/test/test_database.cpp index ad1c786..9c56910 100644 --- a/src/caffe/test/test_database.cpp +++ b/src/caffe/test/test_database.cpp @@ -112,6 +112,56 @@ TYPED_TEST(DatabaseTest, TestReadWriteDoesntExistLevelDBPasses) { database->close(); } +TYPED_TEST(DatabaseTest, TestKeysLevelDB) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("leveldb"); + EXPECT_TRUE(database->open(name, Database::New)); + + Database::key_type key1 = this->TestKey(); + Database::value_type value1 = this->TestValue(); + + EXPECT_TRUE(database->put(key1, value1)); + + Database::key_type key2 = this->TestAltKey(); + Database::value_type value2 = this->TestAltValue(); + + EXPECT_TRUE(database->put(key2, value2)); + + EXPECT_TRUE(database->commit()); + + vector keys; + database->keys(&keys); + + EXPECT_EQ(2, keys.size()); + + EXPECT_TRUE(this->BufferEq(keys.at(0), key1) || + this->BufferEq(keys.at(0), key2)); + EXPECT_TRUE(this->BufferEq(keys.at(1), key1) || + this->BufferEq(keys.at(2), key2)); + EXPECT_FALSE(this->BufferEq(keys.at(0), keys.at(1))); +} + +TYPED_TEST(DatabaseTest, TestKeysNoCommitLevelDB) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("leveldb"); + EXPECT_TRUE(database->open(name, Database::New)); + + Database::key_type key1 = this->TestKey(); + Database::value_type value1 = this->TestValue(); + + EXPECT_TRUE(database->put(key1, value1)); + + Database::key_type key2 = this->TestAltKey(); + Database::value_type value2 = this->TestAltValue(); + + EXPECT_TRUE(database->put(key2, value2)); + + vector keys; + database->keys(&keys); + + EXPECT_EQ(0, keys.size()); +} + TYPED_TEST(DatabaseTest, TestIteratorsLevelDB) { string name = this->DBName(); shared_ptr database = DatabaseFactory("leveldb"); @@ -451,6 +501,57 @@ TYPED_TEST(DatabaseTest, TestReadWriteDoesntExistLMDBPasses) { database->close(); } +TYPED_TEST(DatabaseTest, TestKeysLMDB) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("lmdb"); + EXPECT_TRUE(database->open(name, Database::New)); + + Database::key_type key1 = this->TestKey(); + Database::value_type value1 = this->TestValue(); + + EXPECT_TRUE(database->put(key1, value1)); + + Database::key_type key2 = this->TestAltKey(); + Database::value_type value2 = this->TestAltValue(); + + EXPECT_TRUE(database->put(key2, value2)); + + EXPECT_TRUE(database->commit()); + + vector keys; + database->keys(&keys); + + EXPECT_EQ(2, keys.size()); + + EXPECT_TRUE(this->BufferEq(keys.at(0), key1) || + this->BufferEq(keys.at(0), key2)); + EXPECT_TRUE(this->BufferEq(keys.at(1), key1) || + this->BufferEq(keys.at(2), key2)); + EXPECT_FALSE(this->BufferEq(keys.at(0), keys.at(1))); +} + +TYPED_TEST(DatabaseTest, TestKeysNoCommitLMDB) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("lmdb"); + EXPECT_TRUE(database->open(name, Database::New)); + + Database::key_type key1 = this->TestKey(); + Database::value_type value1 = this->TestValue(); + + EXPECT_TRUE(database->put(key1, value1)); + + Database::key_type key2 = this->TestAltKey(); + Database::value_type value2 = this->TestAltValue(); + + EXPECT_TRUE(database->put(key2, value2)); + + vector keys; + database->keys(&keys); + + EXPECT_EQ(0, keys.size()); +} + + TYPED_TEST(DatabaseTest, TestIteratorsLMDB) { string name = this->DBName(); shared_ptr database = DatabaseFactory("lmdb"); -- 2.7.4