From c31e4446884262ff3c68e5d2adf7afd5050771e4 Mon Sep 17 00:00:00 2001 From: Kevin James Matzen Date: Sun, 12 Oct 2014 16:31:17 -0400 Subject: [PATCH] Added get interface to Database. Added test cases for Database. Fixed a few bugs related to ReadOnly mode in Database in order to pass test cases. --- include/caffe/database.hpp | 1 + include/caffe/leveldb_database.hpp | 2 + include/caffe/lmdb_database.hpp | 5 +- src/caffe/leveldb_database.cpp | 22 ++ src/caffe/lmdb_database.cpp | 44 ++- src/caffe/test/test_database.cpp | 590 +++++++++++++++++++++++++++++++++++++ 6 files changed, 658 insertions(+), 6 deletions(-) create mode 100644 src/caffe/test/test_database.cpp diff --git a/include/caffe/database.hpp b/include/caffe/database.hpp index 953b58c..c61ab51 100644 --- a/include/caffe/database.hpp +++ b/include/caffe/database.hpp @@ -28,6 +28,7 @@ class Database { virtual void open(const string& filename, Mode mode) = 0; virtual void put(buffer_t* key, buffer_t* value) = 0; + virtual void get(buffer_t* key, buffer_t* value) = 0; virtual void commit() = 0; virtual void close() = 0; diff --git a/include/caffe/leveldb_database.hpp b/include/caffe/leveldb_database.hpp index 1c084cb..42f73f9 100644 --- a/include/caffe/leveldb_database.hpp +++ b/include/caffe/leveldb_database.hpp @@ -16,6 +16,7 @@ class LeveldbDatabase : public Database { public: void open(const string& filename, Mode mode); void put(buffer_t* key, buffer_t* value); + void get(buffer_t* key, buffer_t* value); void commit(); void close(); @@ -42,6 +43,7 @@ class LeveldbDatabase : public Database { shared_ptr db_; shared_ptr batch_; + bool read_only_; }; } // namespace caffe diff --git a/include/caffe/lmdb_database.hpp b/include/caffe/lmdb_database.hpp index d72be3d..4796b4d 100644 --- a/include/caffe/lmdb_database.hpp +++ b/include/caffe/lmdb_database.hpp @@ -20,6 +20,7 @@ class LmdbDatabase : public Database { void open(const string& filename, Mode mode); void put(buffer_t* key, buffer_t* value); + void get(buffer_t* key, buffer_t* value); void commit(); void close(); @@ -44,9 +45,9 @@ class LmdbDatabase : public Database { void increment(shared_ptr state) const; Database::KV& dereference(shared_ptr state) const; - MDB_env *env_; + MDB_env* env_; MDB_dbi dbi_; - MDB_txn *txn_; + MDB_txn* txn_; }; } // namespace caffe diff --git a/src/caffe/leveldb_database.cpp b/src/caffe/leveldb_database.cpp index 8084a6c..d2b37e6 100644 --- a/src/caffe/leveldb_database.cpp +++ b/src/caffe/leveldb_database.cpp @@ -14,16 +14,19 @@ void LeveldbDatabase::open(const string& filename, Mode mode) { LOG(INFO) << " mode NEW"; options.error_if_exists = true; options.create_if_missing = true; + read_only_ = false; break; case ReadWrite: LOG(INFO) << " mode RW"; options.error_if_exists = false; options.create_if_missing = true; + read_only_ = false; break; case ReadOnly: LOG(INFO) << " mode RO"; options.error_if_exists = false; options.create_if_missing = false; + read_only_ = true; break; default: LOG(FATAL) << "unknown mode " << mode; @@ -45,6 +48,8 @@ void LeveldbDatabase::open(const string& filename, Mode mode) { void LeveldbDatabase::put(buffer_t* key, buffer_t* value) { LOG(INFO) << "LevelDB: Put"; + CHECK(!read_only_); + CHECK_NOTNULL(batch_.get()); leveldb::Slice key_slice(key->data(), key->size()); @@ -53,9 +58,26 @@ void LeveldbDatabase::put(buffer_t* key, buffer_t* value) { batch_->Put(key_slice, value_slice); } +void LeveldbDatabase::get(buffer_t* key, buffer_t* value) { + LOG(INFO) << "LevelDB: Get"; + + leveldb::Slice key_slice(key->data(), key->size()); + + string value_string; + leveldb::Status status = + db_->Get(leveldb::ReadOptions(), key_slice, &value_string); + CHECK(status.ok()) << "leveldb get failed"; + + Database::buffer_t temp_value(value_string.data(), + value_string.data() + value_string.size()); + value->swap(temp_value); +} + void LeveldbDatabase::commit() { LOG(INFO) << "LevelDB: Commit"; + CHECK(!read_only_); + CHECK_NOTNULL(db_.get()); CHECK_NOTNULL(batch_.get()); diff --git a/src/caffe/lmdb_database.cpp b/src/caffe/lmdb_database.cpp index 54d67d5..952e95a 100644 --- a/src/caffe/lmdb_database.cpp +++ b/src/caffe/lmdb_database.cpp @@ -14,12 +14,24 @@ void LmdbDatabase::open(const string& filename, Mode mode) { CHECK(NULL == txn_); CHECK_EQ(0, dbi_); - if (mode == New) { - CHECK_EQ(mkdir(filename.c_str(), 0744), 0) << "mkdir " << filename - << " failed"; + int retval; + if (mode != ReadOnly) { + retval = mkdir(filename.c_str(), 0744); + switch (mode) { + case New: + CHECK_EQ(0, retval) << "mkdir " << filename << " failed"; + break; + case ReadWrite: + if (-1 == retval) { + CHECK_EQ(EEXIST, errno) << "mkdir " << filename << " failed (" + << strerror(errno) << ")"; + } + break; + default: + LOG(FATAL) << "Invalid mode " << mode; + } } - int retval; retval = mdb_env_create(&env_); CHECK_EQ(retval, MDB_SUCCESS) << "mdb_env_create failed " << mdb_strerror(retval); @@ -61,6 +73,30 @@ void LmdbDatabase::put(buffer_t* key, buffer_t* value) { << "mdb_put failed " << mdb_strerror(retval); } +void LmdbDatabase::get(buffer_t* key, buffer_t* value) { + LOG(INFO) << "LMDB: Get"; + + MDB_val mdbkey, mdbdata; + mdbkey.mv_data = key->data(); + mdbkey.mv_size = key->size(); + + int retval; + MDB_txn* get_txn; + retval = mdb_txn_begin(env_, NULL, MDB_RDONLY, &get_txn); + CHECK_EQ(MDB_SUCCESS, retval) << "mdb_txn_begin failed " + << mdb_strerror(retval); + + retval = mdb_get(get_txn, dbi_, &mdbkey, &mdbdata); + CHECK_EQ(MDB_SUCCESS, retval) << "mdb_get failed " << mdb_strerror(retval); + + mdb_txn_abort(get_txn); + + Database::buffer_t temp_value(reinterpret_cast(mdbdata.mv_data), + reinterpret_cast(mdbdata.mv_data) + mdbdata.mv_size); + + value->swap(temp_value); +} + void LmdbDatabase::commit() { LOG(INFO) << "LMDB: Commit"; diff --git a/src/caffe/test/test_database.cpp b/src/caffe/test/test_database.cpp new file mode 100644 index 0000000..dce92d8 --- /dev/null +++ b/src/caffe/test/test_database.cpp @@ -0,0 +1,590 @@ +#include +#include + +#include "caffe/util/io.hpp" + +#include "gtest/gtest.h" + +#include "caffe/database_factory.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +template +class DatabaseTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + string DBName() { + string filename; + MakeTempDir(&filename); + filename += "/db"; + return filename; + } + + Database::buffer_t TestKey() { + const char* kKey = "hello"; + Database::buffer_t key(kKey, kKey + 5); + return key; + } + + Database::buffer_t TestValue() { + const char* kValue = "world"; + Database::buffer_t value(kValue, kValue + 5); + return value; + } +}; + +TYPED_TEST_CASE(DatabaseTest, TestDtypesAndDevices); + +TYPED_TEST(DatabaseTest, TestNewDoesntExistLevelDBPasses) { + shared_ptr database = DatabaseFactory("leveldb"); + database->open(this->DBName(), Database::New); + database->close(); +} + +TYPED_TEST(DatabaseTest, TestNewExistsFailsLevelDB) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("leveldb"); + database->open(name, Database::New); + database->close(); + + EXPECT_DEATH(database->open(name, Database::New), ""); +} + +TYPED_TEST(DatabaseTest, TestReadOnlyExistsLevelDBPasses) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("leveldb"); + database->open(name, Database::New); + database->close(); + + database->open(name, Database::ReadOnly); + database->close(); +} + +TYPED_TEST(DatabaseTest, TestReadOnlyDoesntExistFailsLevelDB) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("leveldb"); + EXPECT_DEATH(database->open(name, Database::ReadOnly), ""); +} + +TYPED_TEST(DatabaseTest, TestReadWriteExistsLevelDBPasses) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("leveldb"); + database->open(name, Database::New); + database->close(); + + database->open(name, Database::ReadWrite); + database->close(); +} + +TYPED_TEST(DatabaseTest, TestReadWriteDoesntExistLevelDBPasses) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("leveldb"); + database->open(name, Database::ReadWrite); + database->close(); +} + +TYPED_TEST(DatabaseTest, TestIteratorsLevelDB) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("leveldb"); + database->open(name, Database::New); + + const int kNumExamples = 4; + for (int i = 0; i < kNumExamples; ++i) { + stringstream ss; + ss << i; + string key = ss.str(); + ss << " here be data"; + string value = ss.str(); + Database::buffer_t key_buf(key.data(), key.data() + key.size()); + Database::buffer_t val_buf(value.data(), value.data() + value.size()); + database->put(&key_buf, &val_buf); + } + database->commit(); + + int count = 0; + for (Database::const_iterator iter = database->begin(); + iter != database->end(); ++iter) { + (void)iter; + ++count; + } + + EXPECT_EQ(kNumExamples, count); +} + +TYPED_TEST(DatabaseTest, TestNewPutLevelDBPasses) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("leveldb"); + database->open(name, Database::New); + + Database::buffer_t key = this->TestKey(); + Database::buffer_t val = this->TestValue(); + + database->put(&key, &val); + + database->commit(); + + database->close(); +} + +TYPED_TEST(DatabaseTest, TestNewCommitLevelDBPasses) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("leveldb"); + database->open(name, Database::New); + + database->commit(); + + database->close(); +} + +TYPED_TEST(DatabaseTest, TestNewGetLevelDBPasses) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("leveldb"); + database->open(name, Database::New); + + Database::buffer_t key = this->TestKey(); + Database::buffer_t val = this->TestValue(); + + database->put(&key, &val); + + database->commit(); + + Database::buffer_t new_val; + + database->get(&key, &new_val); + + EXPECT_EQ(val.size(), new_val.size()); + for (size_t i = 0; i < val.size(); ++i) { + EXPECT_EQ(val.at(i), new_val.at(i)); + } + + database->close(); +} + +TYPED_TEST(DatabaseTest, TestNewGetNoCommitLevelDBFails) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("leveldb"); + database->open(name, Database::New); + + Database::buffer_t key = this->TestKey(); + Database::buffer_t val = this->TestValue(); + + database->put(&key, &val); + + Database::buffer_t new_val; + + EXPECT_DEATH(database->get(&key, &new_val), ""); +} + + +TYPED_TEST(DatabaseTest, TestReadWritePutLevelDBPasses) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("leveldb"); + database->open(name, Database::ReadWrite); + + Database::buffer_t key = this->TestKey(); + Database::buffer_t val = this->TestValue(); + + database->put(&key, &val); + + database->commit(); + + database->close(); +} + +TYPED_TEST(DatabaseTest, TestReadWriteCommitLevelDBPasses) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("leveldb"); + database->open(name, Database::ReadWrite); + + database->commit(); + + database->close(); +} + +TYPED_TEST(DatabaseTest, TestReadWriteGetLevelDBPasses) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("leveldb"); + database->open(name, Database::New); + + Database::buffer_t key = this->TestKey(); + Database::buffer_t val = this->TestValue(); + + database->put(&key, &val); + + database->commit(); + + Database::buffer_t new_val; + + database->get(&key, &new_val); + + EXPECT_EQ(val.size(), new_val.size()); + for (size_t i = 0; i < val.size(); ++i) { + EXPECT_EQ(val.at(i), new_val.at(i)); + } + + database->close(); +} + +TYPED_TEST(DatabaseTest, TestReadWriteGetNoCommitLevelDBFails) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("leveldb"); + database->open(name, Database::New); + + Database::buffer_t key = this->TestKey(); + Database::buffer_t val = this->TestValue(); + + database->put(&key, &val); + + Database::buffer_t new_val; + + EXPECT_DEATH(database->get(&key, &new_val), ""); +} + +TYPED_TEST(DatabaseTest, TestReadOnlyPutLevelDBFails) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("leveldb"); + database->open(name, Database::New); + database->close(); + + database->open(name, Database::ReadOnly); + + Database::buffer_t key = this->TestKey(); + Database::buffer_t val = this->TestValue(); + + EXPECT_DEATH(database->put(&key, &val), ""); +} + +TYPED_TEST(DatabaseTest, TestReadOnlyCommitLevelDBFails) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("leveldb"); + database->open(name, Database::New); + database->close(); + + database->open(name, Database::ReadOnly); + + EXPECT_DEATH(database->commit(), ""); +} + +TYPED_TEST(DatabaseTest, TestReadOnlyGetLevelDBPasses) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("leveldb"); + database->open(name, Database::New); + + Database::buffer_t key = this->TestKey(); + Database::buffer_t val = this->TestValue(); + + database->put(&key, &val); + + database->commit(); + + database->close(); + + database->open(name, Database::ReadOnly); + + Database::buffer_t new_val; + + database->get(&key, &new_val); + + EXPECT_EQ(val.size(), new_val.size()); + for (size_t i = 0; i < val.size(); ++i) { + EXPECT_EQ(val.at(i), new_val.at(i)); + } +} + +TYPED_TEST(DatabaseTest, TestReadOnlyGetNoCommitLevelDBFails) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("leveldb"); + database->open(name, Database::New); + + Database::buffer_t key = this->TestKey(); + Database::buffer_t val = this->TestValue(); + + database->put(&key, &val); + + database->close(); + + database->open(name, Database::ReadOnly); + + Database::buffer_t new_val; + + EXPECT_DEATH(database->get(&key, &new_val), ""); +} + +TYPED_TEST(DatabaseTest, TestNewDoesntExistLMDBPasses) { + shared_ptr database = DatabaseFactory("lmdb"); + database->open(this->DBName(), Database::New); + database->close(); +} + +TYPED_TEST(DatabaseTest, TestNewExistsFailsLMDB) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("lmdb"); + database->open(name, Database::New); + database->close(); + + EXPECT_DEATH(database->open(name, Database::New), ""); +} + +TYPED_TEST(DatabaseTest, TestReadOnlyExistsLMDBPasses) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("lmdb"); + database->open(name, Database::New); + database->close(); + + database->open(name, Database::ReadOnly); + database->close(); +} + +TYPED_TEST(DatabaseTest, TestReadOnlyDoesntExistFailsLMDB) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("lmdb"); + EXPECT_DEATH(database->open(name, Database::ReadOnly), ""); +} + +TYPED_TEST(DatabaseTest, TestReadWriteExistsLMDBPasses) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("lmdb"); + database->open(name, Database::New); + database->close(); + + database->open(name, Database::ReadWrite); + database->close(); +} + +TYPED_TEST(DatabaseTest, TestReadWriteDoesntExistLMDBPasses) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("lmdb"); + database->open(name, Database::ReadWrite); + database->close(); +} + +TYPED_TEST(DatabaseTest, TestIteratorsLMDB) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("lmdb"); + database->open(name, Database::New); + + const int kNumExamples = 4; + for (int i = 0; i < kNumExamples; ++i) { + stringstream ss; + ss << i; + string key = ss.str(); + ss << " here be data"; + string value = ss.str(); + Database::buffer_t key_buf(key.data(), key.data() + key.size()); + Database::buffer_t val_buf(value.data(), value.data() + value.size()); + database->put(&key_buf, &val_buf); + } + database->commit(); + + int count = 0; + for (Database::const_iterator iter = database->begin(); + iter != database->end(); ++iter) { + (void)iter; + ++count; + } + + EXPECT_EQ(kNumExamples, count); +} + +TYPED_TEST(DatabaseTest, TestNewPutLMDBPasses) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("lmdb"); + database->open(name, Database::New); + + Database::buffer_t key = this->TestKey(); + Database::buffer_t val = this->TestValue(); + + database->put(&key, &val); + + database->commit(); + + database->close(); +} + +TYPED_TEST(DatabaseTest, TestNewCommitLMDBPasses) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("lmdb"); + database->open(name, Database::New); + + database->commit(); + + database->close(); +} + +TYPED_TEST(DatabaseTest, TestNewGetLMDBPasses) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("lmdb"); + database->open(name, Database::New); + + Database::buffer_t key = this->TestKey(); + Database::buffer_t val = this->TestValue(); + + database->put(&key, &val); + + database->commit(); + + Database::buffer_t new_val; + + database->get(&key, &new_val); + + EXPECT_EQ(val.size(), new_val.size()); + for (size_t i = 0; i < val.size(); ++i) { + EXPECT_EQ(val.at(i), new_val.at(i)); + } + + database->close(); +} + +TYPED_TEST(DatabaseTest, TestNewGetNoCommitLMDBFails) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("lmdb"); + database->open(name, Database::New); + + Database::buffer_t key = this->TestKey(); + Database::buffer_t val = this->TestValue(); + + database->put(&key, &val); + + Database::buffer_t new_val; + + EXPECT_DEATH(database->get(&key, &new_val), ""); +} + +TYPED_TEST(DatabaseTest, TestReadWritePutLMDBPasses) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("lmdb"); + database->open(name, Database::ReadWrite); + + Database::buffer_t key = this->TestKey(); + Database::buffer_t val = this->TestValue(); + + database->put(&key, &val); + + database->commit(); + + database->close(); +} + +TYPED_TEST(DatabaseTest, TestReadWriteCommitLMDBPasses) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("lmdb"); + database->open(name, Database::ReadWrite); + + database->commit(); + + database->close(); +} + +TYPED_TEST(DatabaseTest, TestReadWriteGetLMDBPasses) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("lmdb"); + database->open(name, Database::New); + + Database::buffer_t key = this->TestKey(); + Database::buffer_t val = this->TestValue(); + + database->put(&key, &val); + + database->commit(); + + Database::buffer_t new_val; + + database->get(&key, &new_val); + + EXPECT_EQ(val.size(), new_val.size()); + for (size_t i = 0; i < val.size(); ++i) { + EXPECT_EQ(val.at(i), new_val.at(i)); + } + + database->close(); +} + +TYPED_TEST(DatabaseTest, TestReadWriteGetNoCommitLMDBFails) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("lmdb"); + database->open(name, Database::New); + + Database::buffer_t key = this->TestKey(); + Database::buffer_t val = this->TestValue(); + + database->put(&key, &val); + + Database::buffer_t new_val; + + EXPECT_DEATH(database->get(&key, &new_val), ""); +} + +TYPED_TEST(DatabaseTest, TestReadOnlyPutLMDBFails) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("lmdb"); + database->open(name, Database::New); + database->close(); + + database->open(name, Database::ReadOnly); + + Database::buffer_t key = this->TestKey(); + Database::buffer_t val = this->TestValue(); + + EXPECT_DEATH(database->put(&key, &val), ""); +} + +TYPED_TEST(DatabaseTest, TestReadOnlyCommitLMDBFails) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("lmdb"); + database->open(name, Database::New); + database->close(); + + database->open(name, Database::ReadOnly); + + EXPECT_DEATH(database->commit(), ""); +} + +TYPED_TEST(DatabaseTest, TestReadOnlyGetLMDBPasses) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("lmdb"); + database->open(name, Database::New); + + Database::buffer_t key = this->TestKey(); + Database::buffer_t val = this->TestValue(); + + database->put(&key, &val); + + database->commit(); + + database->close(); + + database->open(name, Database::ReadOnly); + + Database::buffer_t new_val; + + database->get(&key, &new_val); + + EXPECT_EQ(val.size(), new_val.size()); + for (size_t i = 0; i < val.size(); ++i) { + EXPECT_EQ(val.at(i), new_val.at(i)); + } +} + +TYPED_TEST(DatabaseTest, TestReadOnlyGetNoCommitLMDBFails) { + string name = this->DBName(); + shared_ptr database = DatabaseFactory("lmdb"); + database->open(name, Database::New); + + Database::buffer_t key = this->TestKey(); + Database::buffer_t val = this->TestValue(); + + database->put(&key, &val); + + database->close(); + + database->open(name, Database::ReadOnly); + + Database::buffer_t new_val; + + EXPECT_DEATH(database->get(&key, &new_val), ""); +} + +} // namespace caffe -- 2.7.4