From c86ed418b9e9a89a98de9c5d339a595ac7ebc3c2 Mon Sep 17 00:00:00 2001 From: Kevin James Matzen Date: Sun, 12 Oct 2014 18:43:44 -0400 Subject: [PATCH] Added some tests for the Database iterator interface. Updated the post-increment operator so that it forks off a copy of the LevelDB or LMDB iterator/cursor when necessary. Neither of these APIs allow you to directly copy an iterator or cursor, so I create a new iterator and seek to the key that the previous one was currently on. This means the pre-increment operator can be much cheaper than the post-increment operator. --- include/caffe/database.hpp | 15 +++ include/caffe/leveldb_database.hpp | 18 ++- include/caffe/lmdb_database.hpp | 28 ++++- src/caffe/leveldb_database.cpp | 4 +- src/caffe/lmdb_database.cpp | 4 +- src/caffe/test/test_database.cpp | 218 +++++++++++++++++++++++++++++++++---- 6 files changed, 256 insertions(+), 31 deletions(-) diff --git a/include/caffe/database.hpp b/include/caffe/database.hpp index c61ab51..08baf2f 100644 --- a/include/caffe/database.hpp +++ b/include/caffe/database.hpp @@ -62,6 +62,20 @@ class Database { state_(state) { } ~iterator() { } + iterator(const iterator& other) + : parent_(other.parent_), + state_(other.state_->clone()) { } + + iterator& operator=(iterator copy) { + copy.swap(*this); + return *this; + } + + void swap(iterator& other) throw() { + std::swap(this->parent_, other.parent_); + std::swap(this->state_, other.state_); + } + bool operator==(const iterator& other) const { return parent_->equal(state_, other.state_); } @@ -97,6 +111,7 @@ class Database { class DatabaseState { public: virtual ~DatabaseState() { } + virtual shared_ptr clone() = 0; }; virtual bool equal(shared_ptr state1, diff --git a/include/caffe/leveldb_database.hpp b/include/caffe/leveldb_database.hpp index 42f73f9..03bfd38 100644 --- a/include/caffe/leveldb_database.hpp +++ b/include/caffe/leveldb_database.hpp @@ -28,10 +28,26 @@ class LeveldbDatabase : public Database { protected: class LeveldbState : public Database::DatabaseState { public: - explicit LeveldbState(shared_ptr iter) + explicit LeveldbState(shared_ptr db, + shared_ptr iter) : Database::DatabaseState(), + db_(db), iter_(iter) { } + shared_ptr clone() { + shared_ptr new_iter; + + if (iter_.get()) { + new_iter.reset(db_->NewIterator(leveldb::ReadOptions())); + CHECK(iter_->Valid()); + new_iter->Seek(iter_->key()); + CHECK(new_iter->Valid()); + } + + return shared_ptr(new LeveldbState(db_, new_iter)); + } + + shared_ptr db_; shared_ptr iter_; KV kv_pair_; }; diff --git a/include/caffe/lmdb_database.hpp b/include/caffe/lmdb_database.hpp index 4796b4d..7b532f8 100644 --- a/include/caffe/lmdb_database.hpp +++ b/include/caffe/lmdb_database.hpp @@ -32,11 +32,35 @@ class LmdbDatabase : public Database { protected: class LmdbState : public Database::DatabaseState { public: - explicit LmdbState(MDB_cursor* cursor) + explicit LmdbState(MDB_cursor* cursor, MDB_txn* txn, const MDB_dbi* dbi) : Database::DatabaseState(), - cursor_(cursor) { } + cursor_(cursor), + txn_(txn), + dbi_(dbi) { } + + shared_ptr clone() { + MDB_cursor* new_cursor; + + if (cursor_) { + int retval; + retval = mdb_cursor_open(txn_, *dbi_, &new_cursor); + CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval); + MDB_val key; + MDB_val val; + retval = mdb_cursor_get(cursor_, &key, &val, MDB_GET_CURRENT); + CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval); + retval = mdb_cursor_get(new_cursor, &key, &val, MDB_SET); + CHECK_EQ(MDB_SUCCESS, retval) << mdb_strerror(retval); + } else { + new_cursor = cursor_; + } + + return shared_ptr(new LmdbState(new_cursor, txn_, dbi_)); + } MDB_cursor* cursor_; + MDB_txn* txn_; + const MDB_dbi* dbi_; KV kv_pair_; }; diff --git a/src/caffe/leveldb_database.cpp b/src/caffe/leveldb_database.cpp index d2b37e6..51d50cc 100644 --- a/src/caffe/leveldb_database.cpp +++ b/src/caffe/leveldb_database.cpp @@ -99,13 +99,13 @@ LeveldbDatabase::const_iterator LeveldbDatabase::begin() const { if (!iter->Valid()) { iter.reset(); } - shared_ptr state(new LeveldbState(iter)); + shared_ptr state(new LeveldbState(db_, iter)); return const_iterator(this, state); } LeveldbDatabase::const_iterator LeveldbDatabase::end() const { shared_ptr iter; - shared_ptr state(new LeveldbState(iter)); + shared_ptr state(new LeveldbState(db_, iter)); return const_iterator(this, state); } diff --git a/src/caffe/lmdb_database.cpp b/src/caffe/lmdb_database.cpp index 952e95a..a546c8c 100644 --- a/src/caffe/lmdb_database.cpp +++ b/src/caffe/lmdb_database.cpp @@ -134,12 +134,12 @@ LmdbDatabase::const_iterator LmdbDatabase::begin() const { retval = mdb_cursor_get(cursor, &key, &val, MDB_FIRST); CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval); - shared_ptr state(new LmdbState(cursor)); + shared_ptr state(new LmdbState(cursor, txn_, &dbi_)); return const_iterator(this, state); } LmdbDatabase::const_iterator LmdbDatabase::end() const { - shared_ptr state(new LmdbState(NULL)); + shared_ptr state(new LmdbState(NULL, txn_, &dbi_)); return const_iterator(this, state); } diff --git a/src/caffe/test/test_database.cpp b/src/caffe/test/test_database.cpp index dce92d8..5d5f4ea 100644 --- a/src/caffe/test/test_database.cpp +++ b/src/caffe/test/test_database.cpp @@ -34,6 +34,32 @@ class DatabaseTest : public MultiDeviceTest { Database::buffer_t value(kValue, kValue + 5); return value; } + + Database::buffer_t TestAltKey() { + const char* kKey = "foo"; + Database::buffer_t key(kKey, kKey + 3); + return key; + } + + Database::buffer_t TestAltValue() { + const char* kValue = "bar"; + Database::buffer_t value(kValue, kValue + 3); + return value; + } + + bool BufferEq(const Database::buffer_t& buf1, + const Database::buffer_t& buf2) { + if (buf1.size() != buf2.size()) { + return false; + } + for (size_t i = 0; i < buf1.size(); ++i) { + if (buf1.at(i) != buf2.at(i)) { + return false; + } + } + + return true; + } }; TYPED_TEST_CASE(DatabaseTest, TestDtypesAndDevices); @@ -114,6 +140,87 @@ TYPED_TEST(DatabaseTest, TestIteratorsLevelDB) { EXPECT_EQ(kNumExamples, count); } +TYPED_TEST(DatabaseTest, TestIteratorsPreIncrementLevelDB) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("leveldb"); + database->open(name, Database::New); + + Database::buffer_t key1 = this->TestAltKey(); + Database::buffer_t value1 = this->TestAltValue(); + + Database::buffer_t key2 = this->TestKey(); + Database::buffer_t value2 = this->TestValue(); + + database->put(&key1, &value1); + database->put(&key2, &value2); + database->commit(); + + Database::const_iterator iter1 = database->begin(); + + EXPECT_FALSE(database->end() == iter1); + + EXPECT_TRUE(this->BufferEq(iter1->key, key1)); + + Database::const_iterator iter2 = ++iter1; + + EXPECT_FALSE(database->end() == iter1); + EXPECT_FALSE(database->end() == iter2); + + EXPECT_TRUE(this->BufferEq(iter2->key, key2)); + + Database::const_iterator iter3 = ++iter2; + + EXPECT_TRUE(database->end() == iter3); + + iter1 = database->end(); + iter2 = database->end(); + iter3 = database->end(); + + database->close(); +} + +TYPED_TEST(DatabaseTest, TestIteratorsPostIncrementLevelDB) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("leveldb"); + database->open(name, Database::New); + + Database::buffer_t key1 = this->TestAltKey(); + Database::buffer_t value1 = this->TestAltValue(); + + Database::buffer_t key2 = this->TestKey(); + Database::buffer_t value2 = this->TestValue(); + + database->put(&key1, &value1); + database->put(&key2, &value2); + database->commit(); + + Database::const_iterator iter1 = database->begin(); + + EXPECT_FALSE(database->end() == iter1); + + EXPECT_TRUE(this->BufferEq(iter1->key, key1)); + + Database::const_iterator iter2 = iter1++; + + EXPECT_FALSE(database->end() == iter1); + EXPECT_FALSE(database->end() == iter2); + + EXPECT_TRUE(this->BufferEq(iter2->key, key1)); + EXPECT_TRUE(this->BufferEq(iter1->key, key2)); + + Database::const_iterator iter3 = iter1++; + + EXPECT_FALSE(database->end() == iter3); + EXPECT_TRUE(this->BufferEq(iter3->key, key2)); + EXPECT_TRUE(database->end() == iter1); + + iter1 = database->end(); + iter2 = database->end(); + iter3 = database->end(); + + database->close(); +} + TYPED_TEST(DatabaseTest, TestNewPutLevelDBPasses) { string name = this->DBName(); shared_ptr database = DatabaseFactory("leveldb"); @@ -155,10 +262,7 @@ TYPED_TEST(DatabaseTest, TestNewGetLevelDBPasses) { database->get(&key, &new_val); - EXPECT_EQ(val.size(), new_val.size()); - for (size_t i = 0; i < val.size(); ++i) { - EXPECT_EQ(val.at(i), new_val.at(i)); - } + EXPECT_TRUE(this->BufferEq(val, new_val)); database->close(); } @@ -220,10 +324,7 @@ TYPED_TEST(DatabaseTest, TestReadWriteGetLevelDBPasses) { database->get(&key, &new_val); - EXPECT_EQ(val.size(), new_val.size()); - for (size_t i = 0; i < val.size(); ++i) { - EXPECT_EQ(val.at(i), new_val.at(i)); - } + EXPECT_TRUE(this->BufferEq(val, new_val)); database->close(); } @@ -288,10 +389,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetLevelDBPasses) { database->get(&key, &new_val); - EXPECT_EQ(val.size(), new_val.size()); - for (size_t i = 0; i < val.size(); ++i) { - EXPECT_EQ(val.at(i), new_val.at(i)); - } + EXPECT_TRUE(this->BufferEq(val, new_val)); } TYPED_TEST(DatabaseTest, TestReadOnlyGetNoCommitLevelDBFails) { @@ -389,6 +487,87 @@ TYPED_TEST(DatabaseTest, TestIteratorsLMDB) { EXPECT_EQ(kNumExamples, count); } +TYPED_TEST(DatabaseTest, TestIteratorsPreIncrementLMDB) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("lmdb"); + database->open(name, Database::New); + + Database::buffer_t key1 = this->TestAltKey(); + Database::buffer_t value1 = this->TestAltValue(); + + Database::buffer_t key2 = this->TestKey(); + Database::buffer_t value2 = this->TestValue(); + + database->put(&key1, &value1); + database->put(&key2, &value2); + database->commit(); + + Database::const_iterator iter1 = database->begin(); + + EXPECT_FALSE(database->end() == iter1); + + EXPECT_TRUE(this->BufferEq(iter1->key, key1)); + + Database::const_iterator iter2 = ++iter1; + + EXPECT_FALSE(database->end() == iter1); + EXPECT_FALSE(database->end() == iter2); + + EXPECT_TRUE(this->BufferEq(iter2->key, key2)); + + Database::const_iterator iter3 = ++iter2; + + EXPECT_TRUE(database->end() == iter3); + + iter1 = database->end(); + iter2 = database->end(); + iter3 = database->end(); + + database->close(); +} + +TYPED_TEST(DatabaseTest, TestIteratorsPostIncrementLMDB) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("lmdb"); + database->open(name, Database::New); + + Database::buffer_t key1 = this->TestAltKey(); + Database::buffer_t value1 = this->TestAltValue(); + + Database::buffer_t key2 = this->TestKey(); + Database::buffer_t value2 = this->TestValue(); + + database->put(&key1, &value1); + database->put(&key2, &value2); + database->commit(); + + Database::const_iterator iter1 = database->begin(); + + EXPECT_FALSE(database->end() == iter1); + + EXPECT_TRUE(this->BufferEq(iter1->key, key1)); + + Database::const_iterator iter2 = iter1++; + + EXPECT_FALSE(database->end() == iter1); + EXPECT_FALSE(database->end() == iter2); + + EXPECT_TRUE(this->BufferEq(iter2->key, key1)); + EXPECT_TRUE(this->BufferEq(iter1->key, key2)); + + Database::const_iterator iter3 = iter1++; + + EXPECT_FALSE(database->end() == iter3); + EXPECT_TRUE(this->BufferEq(iter3->key, key2)); + EXPECT_TRUE(database->end() == iter1); + + iter1 = database->end(); + iter2 = database->end(); + iter3 = database->end(); + + database->close(); +} + TYPED_TEST(DatabaseTest, TestNewPutLMDBPasses) { string name = this->DBName(); shared_ptr database = DatabaseFactory("lmdb"); @@ -430,10 +609,7 @@ TYPED_TEST(DatabaseTest, TestNewGetLMDBPasses) { database->get(&key, &new_val); - EXPECT_EQ(val.size(), new_val.size()); - for (size_t i = 0; i < val.size(); ++i) { - EXPECT_EQ(val.at(i), new_val.at(i)); - } + EXPECT_TRUE(this->BufferEq(val, new_val)); database->close(); } @@ -494,10 +670,7 @@ TYPED_TEST(DatabaseTest, TestReadWriteGetLMDBPasses) { database->get(&key, &new_val); - EXPECT_EQ(val.size(), new_val.size()); - for (size_t i = 0; i < val.size(); ++i) { - EXPECT_EQ(val.at(i), new_val.at(i)); - } + EXPECT_TRUE(this->BufferEq(val, new_val)); database->close(); } @@ -562,10 +735,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetLMDBPasses) { database->get(&key, &new_val); - EXPECT_EQ(val.size(), new_val.size()); - for (size_t i = 0; i < val.size(); ++i) { - EXPECT_EQ(val.at(i), new_val.at(i)); - } + EXPECT_TRUE(this->BufferEq(val, new_val)); } TYPED_TEST(DatabaseTest, TestReadOnlyGetNoCommitLMDBFails) { -- 2.7.4