From 08b971feae8551ca5c7ce31a938a1f232ee56af2 Mon Sep 17 00:00:00 2001 From: Kevin James Matzen Date: Mon, 13 Oct 2014 13:16:04 -0400 Subject: [PATCH] Templated the key and value types for the Database interface. The Database is now responsible for serialization. Refactored the tests so that they reuse the same code for each value type and backend configuration. --- examples/cifar10/convert_cifar_data.cpp | 25 +- include/caffe/data_layers.hpp | 4 +- include/caffe/database.hpp | 116 +++- include/caffe/database_factory.hpp | 7 +- include/caffe/leveldb_database.hpp | 23 +- include/caffe/lmdb_database.hpp | 23 +- src/caffe/database_factory.cpp | 31 +- src/caffe/layers/data_layer.cpp | 11 +- src/caffe/leveldb_database.cpp | 95 ++-- src/caffe/lmdb_database.cpp | 113 ++-- src/caffe/test/test_data_layer.cpp | 12 +- src/caffe/test/test_database.cpp | 901 +++++++++++++------------------- tools/compute_image_mean.cpp | 16 +- tools/convert_imageset.cpp | 11 +- tools/extract_features.cpp | 14 +- 15 files changed, 689 insertions(+), 713 deletions(-) diff --git a/examples/cifar10/convert_cifar_data.cpp b/examples/cifar10/convert_cifar_data.cpp index 46ab155..f86f393 100644 --- a/examples/cifar10/convert_cifar_data.cpp +++ b/examples/cifar10/convert_cifar_data.cpp @@ -20,6 +20,7 @@ using std::string; using caffe::Database; using caffe::DatabaseFactory; +using caffe::Datum; using caffe::shared_ptr; const int kCIFARSize = 32; @@ -37,13 +38,14 @@ void read_image(std::ifstream* file, int* label, char* buffer) { void convert_dataset(const string& input_folder, const string& output_folder, const string& db_type) { - shared_ptr train_database = DatabaseFactory(db_type); + shared_ptr > train_database = + DatabaseFactory(db_type); CHECK(train_database->open(output_folder + "/cifar10_train_" + db_type, - Database::New)); + Database::New)); // Data buffer int label; char str_buffer[kCIFARImageNBytes]; - caffe::Datum datum; + Datum datum; datum.set_channels(3); datum.set_height(kCIFARSize); datum.set_width(kCIFARSize); @@ -60,22 +62,19 @@ void convert_dataset(const string& input_folder, const string& output_folder, read_image(&data_file, &label, str_buffer); datum.set_label(label); datum.set_data(str_buffer, kCIFARImageNBytes); - Database::value_type value(datum.ByteSize()); - datum.SerializeWithCachedSizesToArray( - reinterpret_cast(value.data())); int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d", fileid * kCIFARBatchSize + itemid); - Database::key_type key(str_buffer, str_buffer + length); - CHECK(train_database->put(key, value)); + CHECK(train_database->put(string(str_buffer, length), datum)); } } CHECK(train_database->commit()); train_database->close(); LOG(INFO) << "Writing Testing data"; - shared_ptr test_database = DatabaseFactory(db_type); + shared_ptr > test_database = + DatabaseFactory(db_type); CHECK(test_database->open(output_folder + "/cifar10_test_" + db_type, - Database::New)); + Database::New)); // Open files std::ifstream data_file((input_folder + "/test_batch.bin").c_str(), std::ios::in | std::ios::binary); @@ -84,12 +83,8 @@ void convert_dataset(const string& input_folder, const string& output_folder, read_image(&data_file, &label, str_buffer); datum.set_label(label); datum.set_data(str_buffer, kCIFARImageNBytes); - Database::value_type value(datum.ByteSize()); - datum.SerializeWithCachedSizesToArray( - reinterpret_cast(value.data())); int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d", itemid); - Database::key_type key(str_buffer, str_buffer + length); - CHECK(test_database->put(key, value)); + CHECK(test_database->put(string(str_buffer, length), datum)); } CHECK(test_database->commit()); test_database->close(); diff --git a/include/caffe/data_layers.hpp b/include/caffe/data_layers.hpp index 810f2bb..13ba4e6 100644 --- a/include/caffe/data_layers.hpp +++ b/include/caffe/data_layers.hpp @@ -100,8 +100,8 @@ class DataLayer : public BasePrefetchingDataLayer { protected: virtual void InternalThreadEntry(); - shared_ptr database_; - Database::const_iterator iter_; + shared_ptr > database_; + Database::const_iterator iter_; }; /** diff --git a/include/caffe/database.hpp b/include/caffe/database.hpp index 8469e84..711ebdd 100644 --- a/include/caffe/database.hpp +++ b/include/caffe/database.hpp @@ -11,6 +11,83 @@ namespace caffe { +namespace database_internal { + +template +struct Coder { + static bool serialize(const T& obj, string* serialized) { + return obj.SerializeToString(serialized); + } + + static bool serialize(const T& obj, vector* serialized) { + serialized->resize(obj.ByteSize()); + return obj.SerializeWithCachedSizesToArray( + reinterpret_cast(serialized->data())); + } + + static bool deserialize(const string& serialized, T* obj) { + return obj->ParseFromString(serialized); + } + + static bool deserialize(const char* data, size_t size, T* obj) { + return obj->ParseFromArray(data, size); + } +}; + +template <> +struct Coder { + static bool serialize(string obj, string* serialized) { + *serialized = obj; + return true; + } + + static bool serialize(const string& obj, vector* serialized) { + vector temp(obj.data(), obj.data() + obj.size()); + serialized->swap(temp); + return true; + } + + static bool deserialize(const string& serialized, string* obj) { + *obj = serialized; + return true; + } + + static bool deserialize(const char* data, size_t size, string* obj) { + string temp_string(data, size); + obj->swap(temp_string); + return true; + } +}; + +template <> +struct Coder > { + static bool serialize(vector obj, string* serialized) { + string tmp(obj.data(), obj.size()); + serialized->swap(tmp); + return true; + } + + static bool serialize(const vector& obj, vector* serialized) { + *serialized = obj; + return true; + } + + static bool deserialize(const string& serialized, vector* obj) { + vector tmp(serialized.data(), serialized.data() + serialized.size()); + obj->swap(tmp); + return true; + } + + static bool deserialize(const char* data, size_t size, vector* obj) { + vector tmp(data, data + size); + obj->swap(tmp); + return true; + } +}; + +} // namespace database_internal + +template class Database { public: enum Mode { @@ -19,21 +96,21 @@ class Database { ReadOnly }; - typedef vector key_type; - typedef vector value_type; + typedef K key_type; + typedef V value_type; struct KV { - key_type key; - value_type value; + K key; + V value; }; virtual bool open(const string& filename, Mode mode) = 0; - virtual bool put(const key_type& key, const value_type& value) = 0; - virtual bool get(const key_type& key, value_type* value) = 0; + virtual bool put(const K& key, const V& value) = 0; + virtual bool get(const K& key, V* value) = 0; virtual bool commit() = 0; virtual void close() = 0; - virtual void keys(vector* keys) = 0; + virtual void keys(vector* keys) = 0; Database() { } virtual ~Database() { } @@ -123,8 +200,33 @@ class Database { virtual void increment(shared_ptr* state) const = 0; virtual KV& dereference( shared_ptr state) const = 0; + + template + static bool serialize(const T& obj, string* serialized) { + return database_internal::Coder::serialize(obj, serialized); + } + + template + static bool serialize(const T& obj, vector* serialized) { + return database_internal::Coder::serialize(obj, serialized); + } + + template + static bool deserialize(const string& serialized, T* obj) { + return database_internal::Coder::deserialize(serialized, obj); + } + + template + static bool deserialize(const char* data, size_t size, T* obj) { + return database_internal::Coder::deserialize(data, size, obj); + } }; } // namespace caffe +#define INSTANTIATE_DATABASE(type) \ + template class type; \ + template class type >; \ + template class type; + #endif // CAFFE_DATABASE_H_ diff --git a/include/caffe/database_factory.hpp b/include/caffe/database_factory.hpp index 91185e9..30f191e 100644 --- a/include/caffe/database_factory.hpp +++ b/include/caffe/database_factory.hpp @@ -9,8 +9,11 @@ namespace caffe { -shared_ptr DatabaseFactory(const DataParameter_DB& type); -shared_ptr DatabaseFactory(const string& type); +template +shared_ptr > DatabaseFactory(const DataParameter_DB& type); + +template +shared_ptr > DatabaseFactory(const string& type); } // namespace caffe diff --git a/include/caffe/leveldb_database.hpp b/include/caffe/leveldb_database.hpp index 48cf11e..cd966c4 100644 --- a/include/caffe/leveldb_database.hpp +++ b/include/caffe/leveldb_database.hpp @@ -13,15 +13,24 @@ namespace caffe { -class LeveldbDatabase : public Database { +template +class LeveldbDatabase : public Database { public: + typedef Database Base; + typedef typename Base::key_type key_type; + typedef typename Base::value_type value_type; + typedef typename Base::DatabaseState DatabaseState; + typedef typename Base::Mode Mode; + typedef typename Base::const_iterator const_iterator; + typedef typename Base::KV KV; + bool open(const string& filename, Mode mode); - bool put(const key_type& key, const value_type& value); - bool get(const key_type& key, value_type* value); + bool put(const K& key, const V& value); + bool get(const K& key, V* value); bool commit(); void close(); - void keys(vector* keys); + void keys(vector* keys); const_iterator begin() const; const_iterator cbegin() const; @@ -29,11 +38,11 @@ class LeveldbDatabase : public Database { const_iterator cend() const; protected: - class LeveldbState : public Database::DatabaseState { + class LeveldbState : public DatabaseState { public: explicit LeveldbState(shared_ptr db, shared_ptr iter) - : Database::DatabaseState(), + : DatabaseState(), db_(db), iter_(iter) { } @@ -66,7 +75,7 @@ class LeveldbDatabase : public Database { bool equal(shared_ptr state1, shared_ptr state2) const; void increment(shared_ptr* state) const; - Database::KV& dereference(shared_ptr state) const; + KV& dereference(shared_ptr state) const; shared_ptr db_; shared_ptr batch_; diff --git a/include/caffe/lmdb_database.hpp b/include/caffe/lmdb_database.hpp index a8ce60d..b5a02ca 100644 --- a/include/caffe/lmdb_database.hpp +++ b/include/caffe/lmdb_database.hpp @@ -12,20 +12,29 @@ namespace caffe { -class LmdbDatabase : public Database { +template +class LmdbDatabase : public Database { public: + typedef Database Base; + typedef typename Base::key_type key_type; + typedef typename Base::value_type value_type; + typedef typename Base::DatabaseState DatabaseState; + typedef typename Base::Mode Mode; + typedef typename Base::const_iterator const_iterator; + typedef typename Base::KV KV; + LmdbDatabase() : env_(NULL), dbi_(0), txn_(NULL) { } bool open(const string& filename, Mode mode); - bool put(const key_type& key, const value_type& value); - bool get(const key_type& key, value_type* value); + bool put(const K& key, const V& value); + bool get(const K& key, V* value); bool commit(); void close(); - void keys(vector* keys); + void keys(vector* keys); const_iterator begin() const; const_iterator cbegin() const; @@ -33,10 +42,10 @@ class LmdbDatabase : public Database { const_iterator cend() const; protected: - class LmdbState : public Database::DatabaseState { + class LmdbState : public DatabaseState { public: explicit LmdbState(MDB_cursor* cursor, MDB_txn* txn, const MDB_dbi* dbi) - : Database::DatabaseState(), + : DatabaseState(), cursor_(cursor), txn_(txn), dbi_(dbi) { } @@ -70,7 +79,7 @@ class LmdbDatabase : public Database { bool equal(shared_ptr state1, shared_ptr state2) const; void increment(shared_ptr* state) const; - Database::KV& dereference(shared_ptr state) const; + KV& dereference(shared_ptr state) const; MDB_env* env_; MDB_dbi dbi_; diff --git a/src/caffe/database_factory.cpp b/src/caffe/database_factory.cpp index 062de8c..4ccd429 100644 --- a/src/caffe/database_factory.cpp +++ b/src/caffe/database_factory.cpp @@ -1,5 +1,6 @@ #include #include +#include #include "caffe/database_factory.hpp" #include "caffe/leveldb_database.hpp" @@ -7,29 +8,43 @@ namespace caffe { -shared_ptr DatabaseFactory(const DataParameter_DB& type) { +template +shared_ptr > DatabaseFactory(const DataParameter_DB& type) { switch (type) { case DataParameter_DB_LEVELDB: - return shared_ptr(new LeveldbDatabase()); + return shared_ptr >(new LeveldbDatabase()); case DataParameter_DB_LMDB: - return shared_ptr(new LmdbDatabase()); + return shared_ptr >(new LmdbDatabase()); default: LOG(FATAL) << "Unknown database type " << type; - return shared_ptr(); + return shared_ptr >(); } } -shared_ptr DatabaseFactory(const string& type) { +template +shared_ptr > DatabaseFactory(const string& type) { if ("leveldb" == type) { - return DatabaseFactory(DataParameter_DB_LEVELDB); + return DatabaseFactory(DataParameter_DB_LEVELDB); } else if ("lmdb" == type) { - return DatabaseFactory(DataParameter_DB_LMDB); + return DatabaseFactory(DataParameter_DB_LMDB); } else { LOG(FATAL) << "Unknown database type " << type; - return shared_ptr(); + return shared_ptr >(); } } +#define REGISTER_DATABASE(key_type, value_type) \ + template shared_ptr > \ + DatabaseFactory(const string& type); \ + template shared_ptr > \ + DatabaseFactory(const DataParameter_DB& type); \ + +REGISTER_DATABASE(string, string); +REGISTER_DATABASE(string, vector); +REGISTER_DATABASE(string, Datum); + +#undef REGISTER_DATABASE + } // namespace caffe diff --git a/src/caffe/layers/data_layer.cpp b/src/caffe/layers/data_layer.cpp index c379d78..6296335 100644 --- a/src/caffe/layers/data_layer.cpp +++ b/src/caffe/layers/data_layer.cpp @@ -25,10 +25,11 @@ template void DataLayer::DataLayerSetUp(const vector*>& bottom, const vector*>& top) { // Initialize DB - database_ = DatabaseFactory(this->layer_param_.data_param().backend()); + database_ = DatabaseFactory( + this->layer_param_.data_param().backend()); const string& source = this->layer_param_.data_param().source(); LOG(INFO) << "Opening database " << source; - CHECK(database_->open(source, Database::ReadOnly)); + CHECK(database_->open(source, Database::ReadOnly)); iter_ = database_->begin(); // Check if we would need to randomly skip a few data points @@ -44,8 +45,7 @@ void DataLayer::DataLayerSetUp(const vector*>& bottom, } // Read a data point, and use it to initialize the top blob. CHECK(iter_ != database_->end()); - Datum datum; - datum.ParseFromArray(iter_->value.data(), iter_->value.size()); + const Datum& datum = iter_->value; // image int crop_size = this->layer_param_.transform_param().crop_size(); @@ -89,9 +89,8 @@ void DataLayer::InternalThreadEntry() { const int batch_size = this->layer_param_.data_param().batch_size(); for (int item_id = 0; item_id < batch_size; ++item_id) { - Datum datum; CHECK(iter_ != database_->end()); - datum.ParseFromArray(iter_->value.data(), iter_->value.size()); + const Datum& datum = iter_->value; // Apply data transformations (mirror, scale, crop...) int offset = this->prefetch_data_.offset(item_id); diff --git a/src/caffe/leveldb_database.cpp b/src/caffe/leveldb_database.cpp index a163abc..267ddf0 100644 --- a/src/caffe/leveldb_database.cpp +++ b/src/caffe/leveldb_database.cpp @@ -2,28 +2,30 @@ #include #include +#include "caffe/caffe.hpp" #include "caffe/leveldb_database.hpp" namespace caffe { -bool LeveldbDatabase::open(const string& filename, Mode mode) { +template +bool LeveldbDatabase::open(const string& filename, Mode mode) { DLOG(INFO) << "LevelDB: Open " << filename; leveldb::Options options; switch (mode) { - case New: + case Base::New: DLOG(INFO) << " mode NEW"; options.error_if_exists = true; options.create_if_missing = true; read_only_ = false; break; - case ReadWrite: + case Base::ReadWrite: DLOG(INFO) << " mode RW"; options.error_if_exists = false; options.create_if_missing = true; read_only_ = false; break; - case ReadOnly: + case Base::ReadOnly: DLOG(INFO) << " mode RO"; options.error_if_exists = false; options.create_if_missing = false; @@ -52,7 +54,8 @@ bool LeveldbDatabase::open(const string& filename, Mode mode) { return true; } -bool LeveldbDatabase::put(const key_type& key, const value_type& value) { +template +bool LeveldbDatabase::put(const K& key, const V& value) { DLOG(INFO) << "LevelDB: Put"; if (read_only_) { @@ -62,36 +65,48 @@ bool LeveldbDatabase::put(const key_type& key, const value_type& value) { CHECK_NOTNULL(batch_.get()); - leveldb::Slice key_slice(key.data(), key.size()); - leveldb::Slice value_slice(value.data(), value.size()); + string serialized_key; + if (!Base::serialize(key, &serialized_key)) { + return false; + } + + string serialized_value; + if (!Base::serialize(value, &serialized_value)) { + return false; + } - batch_->Put(key_slice, value_slice); + batch_->Put(serialized_key, serialized_value); return true; } -bool LeveldbDatabase::get(const key_type& key, value_type* value) { +template +bool LeveldbDatabase::get(const K& key, V* value) { DLOG(INFO) << "LevelDB: Get"; - leveldb::Slice key_slice(key.data(), key.size()); + string serialized_key; + if (!Base::serialize(key, &serialized_key)) { + return false; + } - string value_string; + string serialized_value; leveldb::Status status = - db_->Get(leveldb::ReadOptions(), key_slice, &value_string); + db_->Get(leveldb::ReadOptions(), serialized_key, &serialized_value); if (!status.ok()) { LOG(ERROR) << "leveldb get failed"; return false; } - Database::value_type temp_value(value_string.data(), - value_string.data() + value_string.size()); - value->swap(temp_value); + if (!Base::deserialize(serialized_value, value)) { + return false; + } return true; } -bool LeveldbDatabase::commit() { +template +bool LeveldbDatabase::commit() { DLOG(INFO) << "LevelDB: Commit"; if (read_only_) { @@ -109,23 +124,27 @@ bool LeveldbDatabase::commit() { return status.ok(); } -void LeveldbDatabase::close() { +template +void LeveldbDatabase::close() { DLOG(INFO) << "LevelDB: Close"; batch_.reset(); db_.reset(); } -void LeveldbDatabase::keys(vector* keys) { +template +void LeveldbDatabase::keys(vector* keys) { DLOG(INFO) << "LevelDB: Keys"; keys->clear(); - for (Database::const_iterator iter = begin(); iter != end(); ++iter) { + for (const_iterator iter = begin(); iter != end(); ++iter) { keys->push_back(iter->key); } } -LeveldbDatabase::const_iterator LeveldbDatabase::begin() const { +template +typename LeveldbDatabase::const_iterator + LeveldbDatabase::begin() const { CHECK_NOTNULL(db_.get()); shared_ptr iter(db_->NewIterator(leveldb::ReadOptions())); iter->SeekToFirst(); @@ -140,18 +159,25 @@ LeveldbDatabase::const_iterator LeveldbDatabase::begin() const { return const_iterator(this, state); } -LeveldbDatabase::const_iterator LeveldbDatabase::end() const { +template +typename LeveldbDatabase::const_iterator + LeveldbDatabase::end() const { shared_ptr state; return const_iterator(this, state); } -LeveldbDatabase::const_iterator LeveldbDatabase::cbegin() const { +template +typename LeveldbDatabase::const_iterator + LeveldbDatabase::cbegin() const { return begin(); } -LeveldbDatabase::const_iterator LeveldbDatabase::cend() const { return end(); } +template +typename LeveldbDatabase::const_iterator + LeveldbDatabase::cend() const { return end(); } -bool LeveldbDatabase::equal(shared_ptr state1, +template +bool LeveldbDatabase::equal(shared_ptr state1, shared_ptr state2) const { shared_ptr leveldb_state1 = boost::dynamic_pointer_cast(state1); @@ -165,7 +191,8 @@ bool LeveldbDatabase::equal(shared_ptr state1, return !leveldb_state1 && !leveldb_state2; } -void LeveldbDatabase::increment(shared_ptr* state) const { +template +void LeveldbDatabase::increment(shared_ptr* state) const { shared_ptr leveldb_state = boost::dynamic_pointer_cast(*state); @@ -182,7 +209,8 @@ void LeveldbDatabase::increment(shared_ptr* state) const { } } -Database::KV& LeveldbDatabase::dereference( +template +typename Database::KV& LeveldbDatabase::dereference( shared_ptr state) const { shared_ptr leveldb_state = boost::dynamic_pointer_cast(state); @@ -195,15 +223,16 @@ Database::KV& LeveldbDatabase::dereference( CHECK(iter->Valid()); - Database::key_type temp_key(key_type(iter->key().data(), - iter->key().data() + iter->key().size())); + const leveldb::Slice& key = iter->key(); + const leveldb::Slice& value = iter->value(); + CHECK(Base::deserialize(key.data(), key.size(), + &leveldb_state->kv_pair_.key)); + CHECK(Base::deserialize(value.data(), value.size(), + &leveldb_state->kv_pair_.value)); - Database::value_type temp_value(value_type(iter->value().data(), - iter->value().data() + iter->value().size())); - - leveldb_state->kv_pair_.key.swap(temp_key); - leveldb_state->kv_pair_.value.swap(temp_value); return leveldb_state->kv_pair_; } +INSTANTIATE_DATABASE(LeveldbDatabase); + } // namespace caffe diff --git a/src/caffe/lmdb_database.cpp b/src/caffe/lmdb_database.cpp index de94391..22a08b2 100644 --- a/src/caffe/lmdb_database.cpp +++ b/src/caffe/lmdb_database.cpp @@ -4,11 +4,13 @@ #include #include +#include "caffe/caffe.hpp" #include "caffe/lmdb_database.hpp" namespace caffe { -bool LmdbDatabase::open(const string& filename, Mode mode) { +template +bool LmdbDatabase::open(const string& filename, Mode mode) { DLOG(INFO) << "LMDB: Open " << filename; CHECK(NULL == env_); @@ -16,16 +18,16 @@ bool LmdbDatabase::open(const string& filename, Mode mode) { CHECK_EQ(0, dbi_); int retval; - if (mode != ReadOnly) { + if (mode != Base::ReadOnly) { retval = mkdir(filename.c_str(), 0744); switch (mode) { - case New: + case Base::New: if (0 != retval) { LOG(ERROR) << "mkdir " << filename << " failed"; return false; } break; - case ReadWrite: + case Base::ReadWrite: if (-1 == retval && EEXIST != errno) { LOG(ERROR) << "mkdir " << filename << " failed (" << strerror(errno) << ")"; @@ -52,7 +54,7 @@ bool LmdbDatabase::open(const string& filename, Mode mode) { int flag1 = 0; int flag2 = 0; - if (mode == ReadOnly) { + if (mode == Base::ReadOnly) { flag1 = MDB_RDONLY | MDB_NOTLS; flag2 = MDB_RDONLY; } @@ -78,18 +80,27 @@ bool LmdbDatabase::open(const string& filename, Mode mode) { return true; } -bool LmdbDatabase::put(const key_type& key, const value_type& value) { +template +bool LmdbDatabase::put(const K& key, const V& value) { DLOG(INFO) << "LMDB: Put"; - // MDB_val::mv_size is not const, so we need to make a local copy. - key_type local_key = key; - value_type local_value = value; + vector serialized_key; + if (!Base::serialize(key, &serialized_key)) { + LOG(ERROR) << "failed to serialize key"; + return false; + } + + vector serialized_value; + if (!Base::serialize(value, &serialized_value)) { + LOG(ERROR) << "failed to serialized value"; + return false; + } MDB_val mdbkey, mdbdata; - mdbdata.mv_size = local_value.size(); - mdbdata.mv_data = local_value.data(); - mdbkey.mv_size = local_key.size(); - mdbkey.mv_data = local_key.data(); + mdbdata.mv_size = serialized_value.size(); + mdbdata.mv_data = serialized_value.data(); + mdbkey.mv_size = serialized_key.size(); + mdbkey.mv_data = serialized_key.data(); CHECK_NOTNULL(txn_); CHECK_NE(0, dbi_); @@ -103,14 +114,19 @@ bool LmdbDatabase::put(const key_type& key, const value_type& value) { return true; } -bool LmdbDatabase::get(const key_type& key, value_type* value) { +template +bool LmdbDatabase::get(const K& key, V* value) { DLOG(INFO) << "LMDB: Get"; - key_type local_key = key; + vector serialized_key; + if (!Base::serialize(key, &serialized_key)) { + LOG(ERROR) << "failed to serialized key"; + return false; + } MDB_val mdbkey, mdbdata; - mdbkey.mv_data = local_key.data(); - mdbkey.mv_size = local_key.size(); + mdbkey.mv_data = serialized_key.data(); + mdbkey.mv_size = serialized_key.size(); int retval; MDB_txn* get_txn; @@ -128,15 +144,17 @@ bool LmdbDatabase::get(const key_type& key, value_type* value) { mdb_txn_abort(get_txn); - Database::value_type temp_value(reinterpret_cast(mdbdata.mv_data), - reinterpret_cast(mdbdata.mv_data) + mdbdata.mv_size); - - value->swap(temp_value); + if (!Base::deserialize(reinterpret_cast(mdbdata.mv_data), + mdbdata.mv_size, value)) { + LOG(ERROR) << "failed to deserialize value"; + return false; + } return true; } -bool LmdbDatabase::commit() { +template +bool LmdbDatabase::commit() { DLOG(INFO) << "LMDB: Commit"; CHECK_NOTNULL(txn_); @@ -157,7 +175,8 @@ bool LmdbDatabase::commit() { return true; } -void LmdbDatabase::close() { +template +void LmdbDatabase::close() { DLOG(INFO) << "LMDB: Close"; if (env_ && dbi_) { @@ -169,16 +188,19 @@ void LmdbDatabase::close() { } } -void LmdbDatabase::keys(vector* keys) { +template +void LmdbDatabase::keys(vector* keys) { DLOG(INFO) << "LMDB: Keys"; keys->clear(); - for (Database::const_iterator iter = begin(); iter != end(); ++iter) { + for (const_iterator iter = begin(); iter != end(); ++iter) { keys->push_back(iter->key); } } -LmdbDatabase::const_iterator LmdbDatabase::begin() const { +template +typename LmdbDatabase::const_iterator + LmdbDatabase::begin() const { int retval; MDB_txn* iter_txn; @@ -204,15 +226,23 @@ LmdbDatabase::const_iterator LmdbDatabase::begin() const { return const_iterator(this, state); } -LmdbDatabase::const_iterator LmdbDatabase::end() const { +template +typename LmdbDatabase::const_iterator + LmdbDatabase::end() const { shared_ptr state; return const_iterator(this, state); } -LmdbDatabase::const_iterator LmdbDatabase::cbegin() const { return begin(); } -LmdbDatabase::const_iterator LmdbDatabase::cend() const { return end(); } +template +typename LmdbDatabase::const_iterator + LmdbDatabase::cbegin() const { return begin(); } -bool LmdbDatabase::equal(shared_ptr state1, +template +typename LmdbDatabase::const_iterator + LmdbDatabase::cend() const { return end(); } + +template +bool LmdbDatabase::equal(shared_ptr state1, shared_ptr state2) const { shared_ptr lmdb_state1 = boost::dynamic_pointer_cast(state1); @@ -226,7 +256,8 @@ bool LmdbDatabase::equal(shared_ptr state1, return !lmdb_state1 && !lmdb_state2; } -void LmdbDatabase::increment(shared_ptr* state) const { +template +void LmdbDatabase::increment(shared_ptr* state) const { shared_ptr lmdb_state = boost::dynamic_pointer_cast(*state); @@ -247,7 +278,9 @@ void LmdbDatabase::increment(shared_ptr* state) const { } } -Database::KV& LmdbDatabase::dereference(shared_ptr state) const { +template +typename Database::KV& LmdbDatabase::dereference( + shared_ptr state) const { shared_ptr lmdb_state = boost::dynamic_pointer_cast(state); @@ -262,18 +295,14 @@ Database::KV& LmdbDatabase::dereference(shared_ptr state) const { int retval = mdb_cursor_get(cursor, &mdb_key, &mdb_val, MDB_GET_CURRENT); CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval); - char* key_data = reinterpret_cast(mdb_key.mv_data); - char* value_data = reinterpret_cast(mdb_val.mv_data); - - Database::key_type temp_key(key_data, key_data + mdb_key.mv_size); - - Database::value_type temp_value(value_data, - value_data + mdb_val.mv_size); - - lmdb_state->kv_pair_.key.swap(temp_key); - lmdb_state->kv_pair_.value.swap(temp_value); + CHECK(Base::deserialize(reinterpret_cast(mdb_key.mv_data), + mdb_key.mv_size, &lmdb_state->kv_pair_.key)); + CHECK(Base::deserialize(reinterpret_cast(mdb_val.mv_data), + mdb_val.mv_size, &lmdb_state->kv_pair_.value)); return lmdb_state->kv_pair_; } +INSTANTIATE_DATABASE(LmdbDatabase); + } // namespace caffe diff --git a/src/caffe/test/test_data_layer.cpp b/src/caffe/test/test_data_layer.cpp index 5f98206..1c3ec1f 100644 --- a/src/caffe/test/test_data_layer.cpp +++ b/src/caffe/test/test_data_layer.cpp @@ -39,8 +39,9 @@ class DataLayerTest : public MultiDeviceTest { void Fill(const bool unique_pixels, DataParameter_DB backend) { backend_ = backend; LOG(INFO) << "Using temporary database " << *filename_; - shared_ptr database = DatabaseFactory(backend_); - CHECK(database->open(*filename_, Database::New)); + shared_ptr > database = + DatabaseFactory(backend_); + CHECK(database->open(*filename_, Database::New)); for (int i = 0; i < 5; ++i) { Datum datum; datum.set_label(i); @@ -54,12 +55,7 @@ class DataLayerTest : public MultiDeviceTest { } stringstream ss; ss << i; - string key_str = ss.str(); - Database::key_type key(key_str.c_str(), key_str.c_str() + key_str.size()); - Database::value_type value(datum.ByteSize()); - datum.SerializeWithCachedSizesToArray( - reinterpret_cast(value.data())); - CHECK(database->put(key, value)); + CHECK(database->put(ss.str(), datum)); } CHECK(database->commit()); database->close(); diff --git a/src/caffe/test/test_database.cpp b/src/caffe/test/test_database.cpp index 9c56910..9a7d6de 100644 --- a/src/caffe/test/test_database.cpp +++ b/src/caffe/test/test_database.cpp @@ -11,551 +11,303 @@ namespace caffe { -template -class DatabaseTest : public MultiDeviceTest { - typedef typename TypeParam::Dtype Dtype; - - protected: - string DBName() { - string filename; - MakeTempDir(&filename); - filename += "/db"; - return filename; - } - - Database::key_type TestKey() { - const char* kKey = "hello"; - Database::key_type key(kKey, kKey + 5); - return key; - } - - Database::value_type TestValue() { - const char* kValue = "world"; - Database::value_type value(kValue, kValue + 5); - return value; - } - - Database::key_type TestAltKey() { - const char* kKey = "foo"; - Database::key_type key(kKey, kKey + 3); - return key; - } - - Database::value_type TestAltValue() { - const char* kValue = "bar"; - Database::value_type value(kValue, kValue + 3); - return value; - } - - template - bool BufferEq(const T& buf1, const T& buf2) { - if (buf1.size() != buf2.size()) { - return false; - } - for (size_t i = 0; i < buf1.size(); ++i) { - if (buf1.at(i) != buf2.at(i)) { - return false; - } - } +namespace DatabaseTest_internal { - return true; - } +template +struct TestData { + static T TestValue(); + static T TestAltValue(); + static bool equals(const T& a, const T& b); }; -TYPED_TEST_CASE(DatabaseTest, TestDtypesAndDevices); - -TYPED_TEST(DatabaseTest, TestNewDoesntExistLevelDBPasses) { - shared_ptr database = DatabaseFactory("leveldb"); - EXPECT_TRUE(database->open(this->DBName(), Database::New)); - database->close(); +template <> +string TestData::TestValue() { + return "world"; } -TYPED_TEST(DatabaseTest, TestNewExistsFailsLevelDB) { - string name = this->DBName(); - shared_ptr database = DatabaseFactory("leveldb"); - EXPECT_TRUE(database->open(name, Database::New)); - database->close(); - - EXPECT_FALSE(database->open(name, Database::New)); +template <> +string TestData::TestAltValue() { + return "bar"; } -TYPED_TEST(DatabaseTest, TestReadOnlyExistsLevelDBPasses) { - string name = this->DBName(); - shared_ptr database = DatabaseFactory("leveldb"); - EXPECT_TRUE(database->open(name, Database::New)); - database->close(); - - EXPECT_TRUE(database->open(name, Database::ReadOnly)); - database->close(); +template <> +bool TestData::equals(const string& a, const string& b) { + return a == b; } -TYPED_TEST(DatabaseTest, TestReadOnlyDoesntExistFailsLevelDB) { - string name = this->DBName(); - shared_ptr database = DatabaseFactory("leveldb"); - EXPECT_FALSE(database->open(name, Database::ReadOnly)); -} - -TYPED_TEST(DatabaseTest, TestReadWriteExistsLevelDBPasses) { - string name = this->DBName(); - shared_ptr database = DatabaseFactory("leveldb"); - EXPECT_TRUE(database->open(name, Database::New)); - database->close(); - - EXPECT_TRUE(database->open(name, Database::ReadWrite)); - database->close(); +template <> +vector TestData >::TestValue() { + string str = "world"; + vector val(str.data(), str.data() + str.size()); + return val; } -TYPED_TEST(DatabaseTest, TestReadWriteDoesntExistLevelDBPasses) { - string name = this->DBName(); - shared_ptr database = DatabaseFactory("leveldb"); - EXPECT_TRUE(database->open(name, Database::ReadWrite)); - database->close(); +template <> +vector TestData >::TestAltValue() { + string str = "bar"; + vector val(str.data(), str.data() + str.size()); + return val; } -TYPED_TEST(DatabaseTest, TestKeysLevelDB) { - string name = this->DBName(); - shared_ptr database = DatabaseFactory("leveldb"); - EXPECT_TRUE(database->open(name, Database::New)); - - Database::key_type key1 = this->TestKey(); - Database::value_type value1 = this->TestValue(); - - EXPECT_TRUE(database->put(key1, value1)); - - Database::key_type key2 = this->TestAltKey(); - Database::value_type value2 = this->TestAltValue(); - - EXPECT_TRUE(database->put(key2, value2)); - - EXPECT_TRUE(database->commit()); - - vector keys; - database->keys(&keys); - - EXPECT_EQ(2, keys.size()); - - EXPECT_TRUE(this->BufferEq(keys.at(0), key1) || - this->BufferEq(keys.at(0), key2)); - EXPECT_TRUE(this->BufferEq(keys.at(1), key1) || - this->BufferEq(keys.at(2), key2)); - EXPECT_FALSE(this->BufferEq(keys.at(0), keys.at(1))); -} - -TYPED_TEST(DatabaseTest, TestKeysNoCommitLevelDB) { - string name = this->DBName(); - shared_ptr database = DatabaseFactory("leveldb"); - EXPECT_TRUE(database->open(name, Database::New)); - - Database::key_type key1 = this->TestKey(); - Database::value_type value1 = this->TestValue(); - - EXPECT_TRUE(database->put(key1, value1)); - - Database::key_type key2 = this->TestAltKey(); - Database::value_type value2 = this->TestAltValue(); - - EXPECT_TRUE(database->put(key2, value2)); - - vector keys; - database->keys(&keys); - - EXPECT_EQ(0, keys.size()); -} - -TYPED_TEST(DatabaseTest, TestIteratorsLevelDB) { - string name = this->DBName(); - shared_ptr database = DatabaseFactory("leveldb"); - EXPECT_TRUE(database->open(name, Database::New)); - - const int kNumExamples = 4; - for (int i = 0; i < kNumExamples; ++i) { - stringstream ss; - ss << i; - string key = ss.str(); - ss << " here be data"; - string value = ss.str(); - Database::key_type key_buf(key.data(), key.data() + key.size()); - Database::value_type val_buf(value.data(), value.data() + value.size()); - EXPECT_TRUE(database->put(key_buf, val_buf)); +template <> +bool TestData >::equals(const vector& a, + const vector& b) { + if (a.size() != b.size()) { + return false; } - EXPECT_TRUE(database->commit()); - - int count = 0; - for (Database::const_iterator iter = database->begin(); - iter != database->end(); ++iter) { - (void)iter; - ++count; + for (size_t i = 0; i < a.size(); ++i) { + if (a.at(i) != b.at(i)) { + return false; + } } - EXPECT_EQ(kNumExamples, count); + return true; } -TYPED_TEST(DatabaseTest, TestIteratorsPreIncrementLevelDB) { - string name = this->DBName(); - shared_ptr database = DatabaseFactory("leveldb"); - EXPECT_TRUE(database->open(name, Database::New)); - - Database::key_type key1 = this->TestAltKey(); - Database::value_type value1 = this->TestAltValue(); - - Database::key_type key2 = this->TestKey(); - Database::value_type value2 = this->TestValue(); - - EXPECT_TRUE(database->put(key1, value1)); - EXPECT_TRUE(database->put(key2, value2)); - EXPECT_TRUE(database->commit()); - - Database::const_iterator iter1 = database->begin(); - - EXPECT_FALSE(database->end() == iter1); - - EXPECT_TRUE(this->BufferEq(iter1->key, key1)); - - Database::const_iterator iter2 = ++iter1; - - EXPECT_FALSE(database->end() == iter1); - EXPECT_FALSE(database->end() == iter2); - - EXPECT_TRUE(this->BufferEq(iter2->key, key2)); - - Database::const_iterator iter3 = ++iter2; - - EXPECT_TRUE(database->end() == iter3); - - database->close(); +template <> +Datum TestData::TestValue() { + Datum datum; + datum.set_channels(3); + datum.set_height(32); + datum.set_width(32); + datum.set_data(string(32 * 32 * 3 * 4, ' ')); + datum.set_label(0); + return datum; } -TYPED_TEST(DatabaseTest, TestIteratorsPostIncrementLevelDB) { - string name = this->DBName(); - shared_ptr database = DatabaseFactory("leveldb"); - EXPECT_TRUE(database->open(name, Database::New)); - - Database::key_type key1 = this->TestAltKey(); - Database::value_type value1 = this->TestAltValue(); - - Database::key_type key2 = this->TestKey(); - Database::value_type value2 = this->TestValue(); - - EXPECT_TRUE(database->put(key1, value1)); - EXPECT_TRUE(database->put(key2, value2)); - EXPECT_TRUE(database->commit()); - - Database::const_iterator iter1 = database->begin(); - - EXPECT_FALSE(database->end() == iter1); - - EXPECT_TRUE(this->BufferEq(iter1->key, key1)); - - Database::const_iterator iter2 = iter1++; - - EXPECT_FALSE(database->end() == iter1); - EXPECT_FALSE(database->end() == iter2); - - EXPECT_TRUE(this->BufferEq(iter2->key, key1)); - EXPECT_TRUE(this->BufferEq(iter1->key, key2)); - - Database::const_iterator iter3 = iter1++; - - EXPECT_FALSE(database->end() == iter3); - EXPECT_TRUE(this->BufferEq(iter3->key, key2)); - EXPECT_TRUE(database->end() == iter1); - - database->close(); -} - -TYPED_TEST(DatabaseTest, TestNewPutLevelDBPasses) { - string name = this->DBName(); - shared_ptr database = DatabaseFactory("leveldb"); - EXPECT_TRUE(database->open(name, Database::New)); - - Database::key_type key = this->TestKey(); - Database::value_type val = this->TestValue(); - - EXPECT_TRUE(database->put(key, val)); - - EXPECT_TRUE(database->commit()); - - database->close(); +template <> +Datum TestData::TestAltValue() { + Datum datum; + datum.set_channels(1); + datum.set_height(64); + datum.set_width(64); + datum.set_data(string(64 * 64 * 1 * 4, ' ')); + datum.set_label(1); + return datum; } -TYPED_TEST(DatabaseTest, TestNewCommitLevelDBPasses) { - string name = this->DBName(); - shared_ptr database = DatabaseFactory("leveldb"); - EXPECT_TRUE(database->open(name, Database::New)); +template <> +bool TestData::equals(const Datum& a, const Datum& b) { + string serialized_a; + a.SerializeToString(&serialized_a); - EXPECT_TRUE(database->commit()); + string serialized_b; + b.SerializeToString(&serialized_b); - database->close(); + return serialized_a == serialized_b; } -TYPED_TEST(DatabaseTest, TestNewGetLevelDBPasses) { - string name = this->DBName(); - shared_ptr database = DatabaseFactory("leveldb"); - EXPECT_TRUE(database->open(name, Database::New)); - - Database::key_type key = this->TestKey(); - Database::value_type val = this->TestValue(); - - EXPECT_TRUE(database->put(key, val)); - - EXPECT_TRUE(database->commit()); - - Database::value_type new_val; - - EXPECT_TRUE(database->get(key, &new_val)); - - EXPECT_TRUE(this->BufferEq(val, new_val)); - - database->close(); -} - -TYPED_TEST(DatabaseTest, TestNewGetNoCommitLevelDBFails) { - string name = this->DBName(); - shared_ptr database = DatabaseFactory("leveldb"); - EXPECT_TRUE(database->open(name, Database::New)); - - Database::key_type key = this->TestKey(); - Database::value_type val = this->TestValue(); - - EXPECT_TRUE(database->put(key, val)); +} // namespace DatabaseTest_internal - Database::value_type new_val; - - EXPECT_FALSE(database->get(key, &new_val)); -} +#define UNPACK_TYPES \ + typedef typename TypeParam::value_type value_type; \ + const DataParameter_DB backend = TypeParam::backend; +template +class DatabaseTest : public ::testing::Test { + protected: + typedef typename TypeParam::value_type value_type; -TYPED_TEST(DatabaseTest, TestReadWritePutLevelDBPasses) { - string name = this->DBName(); - shared_ptr database = DatabaseFactory("leveldb"); - EXPECT_TRUE(database->open(name, Database::ReadWrite)); + string DBName() { + string filename; + MakeTempDir(&filename); + filename += "/db"; + return filename; + } - Database::key_type key = this->TestKey(); - Database::value_type val = this->TestValue(); + string TestKey() { + return "hello"; + } - EXPECT_TRUE(database->put(key, val)); + value_type TestValue() { + return DatabaseTest_internal::TestData::TestValue(); + } - EXPECT_TRUE(database->commit()); + string TestAltKey() { + return "foo"; + } - database->close(); -} + value_type TestAltValue() { + return DatabaseTest_internal::TestData::TestAltValue(); + } -TYPED_TEST(DatabaseTest, TestReadWriteCommitLevelDBPasses) { - string name = this->DBName(); - shared_ptr database = DatabaseFactory("leveldb"); - EXPECT_TRUE(database->open(name, Database::ReadWrite)); + template + bool equals(const T& a, const T& b) { + return DatabaseTest_internal::TestData::equals(a, b); + } +}; - EXPECT_TRUE(database->commit()); +struct StringLeveldb { + typedef string value_type; + static const DataParameter_DB backend; +}; +const DataParameter_DB StringLeveldb::backend = DataParameter_DB_LEVELDB; - database->close(); -} +struct StringLmdb { + typedef string value_type; + static const DataParameter_DB backend; +}; +const DataParameter_DB StringLmdb::backend = DataParameter_DB_LEVELDB; -TYPED_TEST(DatabaseTest, TestReadWriteGetLevelDBPasses) { - string name = this->DBName(); - shared_ptr database = DatabaseFactory("leveldb"); - EXPECT_TRUE(database->open(name, Database::New)); +struct VectorLeveldb { + typedef vector value_type; + static const DataParameter_DB backend; +}; +const DataParameter_DB VectorLeveldb::backend = DataParameter_DB_LEVELDB; - Database::key_type key = this->TestKey(); - Database::value_type val = this->TestValue(); +struct VectorLmdb { + typedef vector value_type; + static const DataParameter_DB backend; +}; +const DataParameter_DB VectorLmdb::backend = DataParameter_DB_LEVELDB; - EXPECT_TRUE(database->put(key, val)); +struct DatumLeveldb { + typedef Datum value_type; + static const DataParameter_DB backend; +}; +const DataParameter_DB DatumLeveldb::backend = DataParameter_DB_LEVELDB; - EXPECT_TRUE(database->commit()); +struct DatumLmdb { + typedef Datum value_type; + static const DataParameter_DB backend; +}; +const DataParameter_DB DatumLmdb::backend = DataParameter_DB_LEVELDB; - Database::value_type new_val; +typedef ::testing::Types TestTypes; - EXPECT_TRUE(database->get(key, &new_val)); +TYPED_TEST_CASE(DatabaseTest, TestTypes); - EXPECT_TRUE(this->BufferEq(val, new_val)); +TYPED_TEST(DatabaseTest, TestNewDoesntExistPasses) { + UNPACK_TYPES; + shared_ptr > database = + DatabaseFactory(backend); + EXPECT_TRUE(database->open(this->DBName(), + Database::New)); database->close(); } -TYPED_TEST(DatabaseTest, TestReadWriteGetNoCommitLevelDBFails) { - string name = this->DBName(); - shared_ptr database = DatabaseFactory("leveldb"); - EXPECT_TRUE(database->open(name, Database::New)); - - Database::key_type key = this->TestKey(); - Database::value_type val = this->TestValue(); - - EXPECT_TRUE(database->put(key, val)); - - Database::value_type new_val; - - EXPECT_FALSE(database->get(key, &new_val)); -} +TYPED_TEST(DatabaseTest, TestNewExistsFails) { + UNPACK_TYPES; -TYPED_TEST(DatabaseTest, TestReadOnlyPutLevelDBFails) { string name = this->DBName(); - shared_ptr database = DatabaseFactory("leveldb"); - EXPECT_TRUE(database->open(name, Database::New)); + shared_ptr > database = + DatabaseFactory(backend); + EXPECT_TRUE(database->open(name, Database::New)); database->close(); - EXPECT_TRUE(database->open(name, Database::ReadOnly)); - - Database::key_type key = this->TestKey(); - Database::value_type val = this->TestValue(); - - EXPECT_FALSE(database->put(key, val)); + EXPECT_FALSE(database->open(name, Database::New)); } -TYPED_TEST(DatabaseTest, TestReadOnlyCommitLevelDBFails) { - string name = this->DBName(); - shared_ptr database = DatabaseFactory("leveldb"); - EXPECT_TRUE(database->open(name, Database::New)); - database->close(); - - EXPECT_TRUE(database->open(name, Database::ReadOnly)); - - EXPECT_FALSE(database->commit()); -} +TYPED_TEST(DatabaseTest, TestReadOnlyExistsPasses) { + UNPACK_TYPES; -TYPED_TEST(DatabaseTest, TestReadOnlyGetLevelDBPasses) { string name = this->DBName(); - shared_ptr database = DatabaseFactory("leveldb"); - EXPECT_TRUE(database->open(name, Database::New)); - - Database::key_type key = this->TestKey(); - Database::value_type val = this->TestValue(); - - EXPECT_TRUE(database->put(key, val)); - - EXPECT_TRUE(database->commit()); - + shared_ptr > database = + DatabaseFactory(backend); + EXPECT_TRUE(database->open(name, Database::New)); database->close(); - EXPECT_TRUE(database->open(name, Database::ReadOnly)); - - Database::value_type new_val; - - EXPECT_TRUE(database->get(key, &new_val)); - - EXPECT_TRUE(this->BufferEq(val, new_val)); -} - -TYPED_TEST(DatabaseTest, TestReadOnlyGetNoCommitLevelDBFails) { - string name = this->DBName(); - shared_ptr database = DatabaseFactory("leveldb"); - EXPECT_TRUE(database->open(name, Database::New)); - - Database::key_type key = this->TestKey(); - Database::value_type val = this->TestValue(); - - EXPECT_TRUE(database->put(key, val)); - + EXPECT_TRUE(database->open(name, Database::ReadOnly)); database->close(); - - EXPECT_TRUE(database->open(name, Database::ReadOnly)); - - Database::value_type new_val; - - EXPECT_FALSE(database->get(key, &new_val)); } -TYPED_TEST(DatabaseTest, TestNewDoesntExistLMDBPasses) { - shared_ptr database = DatabaseFactory("lmdb"); - EXPECT_TRUE(database->open(this->DBName(), Database::New)); - database->close(); -} +TYPED_TEST(DatabaseTest, TestReadOnlyDoesntExistFails) { + UNPACK_TYPES; -TYPED_TEST(DatabaseTest, TestNewExistsFailsLMDB) { string name = this->DBName(); - shared_ptr database = DatabaseFactory("lmdb"); - EXPECT_TRUE(database->open(name, Database::New)); - database->close(); - - EXPECT_FALSE(database->open(name, Database::New)); + shared_ptr > database = + DatabaseFactory(backend); + EXPECT_FALSE(database->open(name, Database::ReadOnly)); } -TYPED_TEST(DatabaseTest, TestReadOnlyExistsLMDBPasses) { +TYPED_TEST(DatabaseTest, TestReadWriteExistsPasses) { + UNPACK_TYPES; + string name = this->DBName(); - shared_ptr database = DatabaseFactory("lmdb"); - EXPECT_TRUE(database->open(name, Database::New)); + shared_ptr > database = + DatabaseFactory(backend); + EXPECT_TRUE(database->open(name, Database::New)); database->close(); - EXPECT_TRUE(database->open(name, Database::ReadOnly)); + EXPECT_TRUE(database->open(name, Database::ReadWrite)); database->close(); } -TYPED_TEST(DatabaseTest, TestReadOnlyDoesntExistFailsLMDB) { - string name = this->DBName(); - shared_ptr database = DatabaseFactory("lmdb"); - EXPECT_FALSE(database->open(name, Database::ReadOnly)); -} +TYPED_TEST(DatabaseTest, TestReadWriteDoesntExistPasses) { + UNPACK_TYPES; -TYPED_TEST(DatabaseTest, TestReadWriteExistsLMDBPasses) { string name = this->DBName(); - shared_ptr database = DatabaseFactory("lmdb"); - EXPECT_TRUE(database->open(name, Database::New)); - database->close(); - - EXPECT_TRUE(database->open(name, Database::ReadWrite)); + shared_ptr > database = + DatabaseFactory(backend); + EXPECT_TRUE(database->open(name, Database::ReadWrite)); database->close(); } -TYPED_TEST(DatabaseTest, TestReadWriteDoesntExistLMDBPasses) { - string name = this->DBName(); - shared_ptr database = DatabaseFactory("lmdb"); - EXPECT_TRUE(database->open(name, Database::ReadWrite)); - database->close(); -} +TYPED_TEST(DatabaseTest, TestKeys) { + UNPACK_TYPES; -TYPED_TEST(DatabaseTest, TestKeysLMDB) { string name = this->DBName(); - shared_ptr database = DatabaseFactory("lmdb"); - EXPECT_TRUE(database->open(name, Database::New)); + shared_ptr > database = + DatabaseFactory(backend); + EXPECT_TRUE(database->open(name, Database::New)); - Database::key_type key1 = this->TestKey(); - Database::value_type value1 = this->TestValue(); + string key1 = this->TestKey(); + value_type value1 = this->TestValue(); EXPECT_TRUE(database->put(key1, value1)); - Database::key_type key2 = this->TestAltKey(); - Database::value_type value2 = this->TestAltValue(); + string key2 = this->TestAltKey(); + value_type value2 = this->TestAltValue(); EXPECT_TRUE(database->put(key2, value2)); EXPECT_TRUE(database->commit()); - vector keys; + vector keys; database->keys(&keys); EXPECT_EQ(2, keys.size()); - EXPECT_TRUE(this->BufferEq(keys.at(0), key1) || - this->BufferEq(keys.at(0), key2)); - EXPECT_TRUE(this->BufferEq(keys.at(1), key1) || - this->BufferEq(keys.at(2), key2)); - EXPECT_FALSE(this->BufferEq(keys.at(0), keys.at(1))); + EXPECT_TRUE(this->equals(keys.at(0), key1) || + this->equals(keys.at(0), key2)); + EXPECT_TRUE(this->equals(keys.at(1), key1) || + this->equals(keys.at(2), key2)); + EXPECT_FALSE(this->equals(keys.at(0), keys.at(1))); } -TYPED_TEST(DatabaseTest, TestKeysNoCommitLMDB) { +TYPED_TEST(DatabaseTest, TestKeysNoCommit) { + UNPACK_TYPES; + string name = this->DBName(); - shared_ptr database = DatabaseFactory("lmdb"); - EXPECT_TRUE(database->open(name, Database::New)); + shared_ptr > database = + DatabaseFactory(backend); + EXPECT_TRUE(database->open(name, Database::New)); - Database::key_type key1 = this->TestKey(); - Database::value_type value1 = this->TestValue(); + string key1 = this->TestKey(); + value_type value1 = this->TestValue(); EXPECT_TRUE(database->put(key1, value1)); - Database::key_type key2 = this->TestAltKey(); - Database::value_type value2 = this->TestAltValue(); + string key2 = this->TestAltKey(); + value_type value2 = this->TestAltValue(); EXPECT_TRUE(database->put(key2, value2)); - vector keys; + vector keys; database->keys(&keys); EXPECT_EQ(0, keys.size()); } +TYPED_TEST(DatabaseTest, TestIterators) { + UNPACK_TYPES; -TYPED_TEST(DatabaseTest, TestIteratorsLMDB) { string name = this->DBName(); - shared_ptr database = DatabaseFactory("lmdb"); - EXPECT_TRUE(database->open(name, Database::New)); + shared_ptr > database = + DatabaseFactory(backend); + EXPECT_TRUE(database->open(name, Database::New)); const int kNumExamples = 4; for (int i = 0; i < kNumExamples; ++i) { @@ -563,16 +315,14 @@ TYPED_TEST(DatabaseTest, TestIteratorsLMDB) { ss << i; string key = ss.str(); ss << " here be data"; - string value = ss.str(); - Database::key_type key_buf(key.data(), key.data() + key.size()); - Database::value_type val_buf(value.data(), value.data() + value.size()); - EXPECT_TRUE(database->put(key_buf, val_buf)); + value_type value = this->TestValue(); + EXPECT_TRUE(database->put(key, value)); } EXPECT_TRUE(database->commit()); int count = 0; - for (Database::const_iterator iter = database->begin(); - iter != database->end(); ++iter) { + typedef typename Database::const_iterator Iter; + for (Iter iter = database->begin(); iter != database->end(); ++iter) { (void)iter; ++count; } @@ -580,266 +330,313 @@ TYPED_TEST(DatabaseTest, TestIteratorsLMDB) { EXPECT_EQ(kNumExamples, count); } -TYPED_TEST(DatabaseTest, TestIteratorsPreIncrementLMDB) { +TYPED_TEST(DatabaseTest, TestIteratorsPreIncrement) { + UNPACK_TYPES; + string name = this->DBName(); - shared_ptr database = DatabaseFactory("lmdb"); - EXPECT_TRUE(database->open(name, Database::New)); + shared_ptr > database = + DatabaseFactory(backend); + EXPECT_TRUE(database->open(name, Database::New)); - Database::key_type key1 = this->TestAltKey(); - Database::value_type value1 = this->TestAltValue(); + string key1 = this->TestAltKey(); + value_type value1 = this->TestAltValue(); - Database::key_type key2 = this->TestKey(); - Database::value_type value2 = this->TestValue(); + string key2 = this->TestKey(); + value_type value2 = this->TestValue(); EXPECT_TRUE(database->put(key1, value1)); EXPECT_TRUE(database->put(key2, value2)); EXPECT_TRUE(database->commit()); - Database::const_iterator iter1 = database->begin(); + typename Database::const_iterator iter1 = + database->begin(); EXPECT_FALSE(database->end() == iter1); - EXPECT_TRUE(this->BufferEq(iter1->key, key1)); + EXPECT_TRUE(this->equals(iter1->key, key1)); - Database::const_iterator iter2 = ++iter1; + typename Database::const_iterator iter2 = ++iter1; EXPECT_FALSE(database->end() == iter1); EXPECT_FALSE(database->end() == iter2); - EXPECT_TRUE(this->BufferEq(iter2->key, key2)); + EXPECT_TRUE(this->equals(iter2->key, key2)); - Database::const_iterator iter3 = ++iter2; + typename Database::const_iterator iter3 = ++iter2; EXPECT_TRUE(database->end() == iter3); database->close(); } -TYPED_TEST(DatabaseTest, TestIteratorsPostIncrementLMDB) { +TYPED_TEST(DatabaseTest, TestIteratorsPostIncrement) { + UNPACK_TYPES; + string name = this->DBName(); - shared_ptr database = DatabaseFactory("lmdb"); - EXPECT_TRUE(database->open(name, Database::New)); + shared_ptr > database = + DatabaseFactory(backend); + EXPECT_TRUE(database->open(name, Database::New)); - Database::key_type key1 = this->TestAltKey(); - Database::value_type value1 = this->TestAltValue(); + string key1 = this->TestAltKey(); + value_type value1 = this->TestAltValue(); - Database::key_type key2 = this->TestKey(); - Database::value_type value2 = this->TestValue(); + string key2 = this->TestKey(); + value_type value2 = this->TestValue(); EXPECT_TRUE(database->put(key1, value1)); EXPECT_TRUE(database->put(key2, value2)); EXPECT_TRUE(database->commit()); - Database::const_iterator iter1 = database->begin(); + typename Database::const_iterator iter1 = + database->begin(); EXPECT_FALSE(database->end() == iter1); - EXPECT_TRUE(this->BufferEq(iter1->key, key1)); + EXPECT_TRUE(this->equals(iter1->key, key1)); - Database::const_iterator iter2 = iter1++; + typename Database::const_iterator iter2 = iter1++; EXPECT_FALSE(database->end() == iter1); EXPECT_FALSE(database->end() == iter2); - EXPECT_TRUE(this->BufferEq(iter2->key, key1)); - EXPECT_TRUE(this->BufferEq(iter1->key, key2)); + EXPECT_TRUE(this->equals(iter2->key, key1)); + EXPECT_TRUE(this->equals(iter1->key, key2)); - Database::const_iterator iter3 = iter1++; + typename Database::const_iterator iter3 = iter1++; EXPECT_FALSE(database->end() == iter3); - EXPECT_TRUE(this->BufferEq(iter3->key, key2)); + EXPECT_TRUE(this->equals(iter3->key, key2)); EXPECT_TRUE(database->end() == iter1); database->close(); } -TYPED_TEST(DatabaseTest, TestNewPutLMDBPasses) { +TYPED_TEST(DatabaseTest, TestNewPutPasses) { + UNPACK_TYPES; + string name = this->DBName(); - shared_ptr database = DatabaseFactory("lmdb"); - EXPECT_TRUE(database->open(name, Database::New)); + shared_ptr > database = + DatabaseFactory(backend); + EXPECT_TRUE(database->open(name, Database::New)); - Database::key_type key = this->TestKey(); - Database::value_type val = this->TestValue(); + string key = this->TestKey(); + value_type value = this->TestValue(); - EXPECT_TRUE(database->put(key, val)); + EXPECT_TRUE(database->put(key, value)); EXPECT_TRUE(database->commit()); database->close(); } -TYPED_TEST(DatabaseTest, TestNewCommitLMDBPasses) { +TYPED_TEST(DatabaseTest, TestNewCommitPasses) { + UNPACK_TYPES; + string name = this->DBName(); - shared_ptr database = DatabaseFactory("lmdb"); - EXPECT_TRUE(database->open(name, Database::New)); + shared_ptr > database = + DatabaseFactory(backend); + EXPECT_TRUE(database->open(name, Database::New)); EXPECT_TRUE(database->commit()); database->close(); } -TYPED_TEST(DatabaseTest, TestNewGetLMDBPasses) { +TYPED_TEST(DatabaseTest, TestNewGetPasses) { + UNPACK_TYPES; + string name = this->DBName(); - shared_ptr database = DatabaseFactory("lmdb"); - EXPECT_TRUE(database->open(name, Database::New)); + shared_ptr > database = + DatabaseFactory(backend); + EXPECT_TRUE(database->open(name, Database::New)); - Database::key_type key = this->TestKey(); - Database::value_type val = this->TestValue(); + string key = this->TestKey(); + value_type value = this->TestValue(); - EXPECT_TRUE(database->put(key, val)); + EXPECT_TRUE(database->put(key, value)); EXPECT_TRUE(database->commit()); - Database::value_type new_val; + value_type new_value; - EXPECT_TRUE(database->get(key, &new_val)); + EXPECT_TRUE(database->get(key, &new_value)); - EXPECT_TRUE(this->BufferEq(val, new_val)); + EXPECT_TRUE(this->equals(value, new_value)); database->close(); } -TYPED_TEST(DatabaseTest, TestNewGetNoCommitLMDBFails) { +TYPED_TEST(DatabaseTest, TestNewGetNoCommitFails) { + UNPACK_TYPES; + string name = this->DBName(); - shared_ptr database = DatabaseFactory("lmdb"); - EXPECT_TRUE(database->open(name, Database::New)); + shared_ptr > database = + DatabaseFactory(backend); + EXPECT_TRUE(database->open(name, Database::New)); - Database::key_type key = this->TestKey(); - Database::value_type val = this->TestValue(); + string key = this->TestKey(); + value_type value = this->TestValue(); - EXPECT_TRUE(database->put(key, val)); + EXPECT_TRUE(database->put(key, value)); - Database::value_type new_val; + value_type new_value; - EXPECT_FALSE(database->get(key, &new_val)); + EXPECT_FALSE(database->get(key, &new_value)); } -TYPED_TEST(DatabaseTest, TestReadWritePutLMDBPasses) { + +TYPED_TEST(DatabaseTest, TestReadWritePutPasses) { + UNPACK_TYPES; + string name = this->DBName(); - shared_ptr database = DatabaseFactory("lmdb"); - EXPECT_TRUE(database->open(name, Database::ReadWrite)); + shared_ptr > database = + DatabaseFactory(backend); + EXPECT_TRUE(database->open(name, Database::ReadWrite)); - Database::key_type key = this->TestKey(); - Database::value_type val = this->TestValue(); + string key = this->TestKey(); + value_type value = this->TestValue(); - EXPECT_TRUE(database->put(key, val)); + EXPECT_TRUE(database->put(key, value)); EXPECT_TRUE(database->commit()); database->close(); } -TYPED_TEST(DatabaseTest, TestReadWriteCommitLMDBPasses) { +TYPED_TEST(DatabaseTest, TestReadWriteCommitPasses) { + UNPACK_TYPES; + string name = this->DBName(); - shared_ptr database = DatabaseFactory("lmdb"); - EXPECT_TRUE(database->open(name, Database::ReadWrite)); + shared_ptr > database = + DatabaseFactory(backend); + EXPECT_TRUE(database->open(name, Database::ReadWrite)); EXPECT_TRUE(database->commit()); database->close(); } -TYPED_TEST(DatabaseTest, TestReadWriteGetLMDBPasses) { +TYPED_TEST(DatabaseTest, TestReadWriteGetPasses) { + UNPACK_TYPES; + string name = this->DBName(); - shared_ptr database = DatabaseFactory("lmdb"); - EXPECT_TRUE(database->open(name, Database::New)); + shared_ptr > database = + DatabaseFactory(backend); + EXPECT_TRUE(database->open(name, Database::New)); - Database::key_type key = this->TestKey(); - Database::value_type val = this->TestValue(); + string key = this->TestKey(); + value_type value = this->TestValue(); - EXPECT_TRUE(database->put(key, val)); + EXPECT_TRUE(database->put(key, value)); EXPECT_TRUE(database->commit()); - Database::value_type new_val; + value_type new_value; - EXPECT_TRUE(database->get(key, &new_val)); + EXPECT_TRUE(database->get(key, &new_value)); - EXPECT_TRUE(this->BufferEq(val, new_val)); + EXPECT_TRUE(this->equals(value, new_value)); database->close(); } -TYPED_TEST(DatabaseTest, TestReadWriteGetNoCommitLMDBFails) { +TYPED_TEST(DatabaseTest, TestReadWriteGetNoCommitFails) { + UNPACK_TYPES; + string name = this->DBName(); - shared_ptr database = DatabaseFactory("lmdb"); - EXPECT_TRUE(database->open(name, Database::New)); + shared_ptr > database = + DatabaseFactory(backend); + EXPECT_TRUE(database->open(name, Database::New)); - Database::key_type key = this->TestKey(); - Database::value_type val = this->TestValue(); + string key = this->TestKey(); + value_type value = this->TestValue(); - EXPECT_TRUE(database->put(key, val)); + EXPECT_TRUE(database->put(key, value)); - Database::value_type new_val; + value_type new_value; - EXPECT_FALSE(database->get(key, &new_val)); + EXPECT_FALSE(database->get(key, &new_value)); } -TYPED_TEST(DatabaseTest, TestReadOnlyPutLMDBFails) { +TYPED_TEST(DatabaseTest, TestReadOnlyPutFails) { + UNPACK_TYPES; + string name = this->DBName(); - shared_ptr database = DatabaseFactory("lmdb"); - EXPECT_TRUE(database->open(name, Database::New)); + shared_ptr > database = + DatabaseFactory(backend); + EXPECT_TRUE(database->open(name, Database::New)); database->close(); - EXPECT_TRUE(database->open(name, Database::ReadOnly)); + EXPECT_TRUE(database->open(name, Database::ReadOnly)); - Database::key_type key = this->TestKey(); - Database::value_type val = this->TestValue(); + string key = this->TestKey(); + value_type value = this->TestValue(); - EXPECT_FALSE(database->put(key, val)); + EXPECT_FALSE(database->put(key, value)); } -TYPED_TEST(DatabaseTest, TestReadOnlyCommitLMDBFails) { +TYPED_TEST(DatabaseTest, TestReadOnlyCommitFails) { + UNPACK_TYPES; + string name = this->DBName(); - shared_ptr database = DatabaseFactory("lmdb"); - EXPECT_TRUE(database->open(name, Database::New)); + shared_ptr > database = + DatabaseFactory(backend); + EXPECT_TRUE(database->open(name, Database::New)); database->close(); - EXPECT_TRUE(database->open(name, Database::ReadOnly)); + EXPECT_TRUE(database->open(name, Database::ReadOnly)); EXPECT_FALSE(database->commit()); } -TYPED_TEST(DatabaseTest, TestReadOnlyGetLMDBPasses) { +TYPED_TEST(DatabaseTest, TestReadOnlyGetPasses) { + UNPACK_TYPES; + string name = this->DBName(); - shared_ptr database = DatabaseFactory("lmdb"); - EXPECT_TRUE(database->open(name, Database::New)); + shared_ptr > database = + DatabaseFactory(backend); + EXPECT_TRUE(database->open(name, Database::New)); - Database::key_type key = this->TestKey(); - Database::value_type val = this->TestValue(); + string key = this->TestKey(); + value_type value = this->TestValue(); - EXPECT_TRUE(database->put(key, val)); + EXPECT_TRUE(database->put(key, value)); EXPECT_TRUE(database->commit()); database->close(); - EXPECT_TRUE(database->open(name, Database::ReadOnly)); + EXPECT_TRUE(database->open(name, Database::ReadOnly)); - Database::value_type new_val; + value_type new_value; - EXPECT_TRUE(database->get(key, &new_val)); + EXPECT_TRUE(database->get(key, &new_value)); - EXPECT_TRUE(this->BufferEq(val, new_val)); + EXPECT_TRUE(this->equals(value, new_value)); } -TYPED_TEST(DatabaseTest, TestReadOnlyGetNoCommitLMDBFails) { +TYPED_TEST(DatabaseTest, TestReadOnlyGetNoCommitFails) { + UNPACK_TYPES; + string name = this->DBName(); - shared_ptr database = DatabaseFactory("lmdb"); - EXPECT_TRUE(database->open(name, Database::New)); + shared_ptr > database = + DatabaseFactory(backend); + EXPECT_TRUE(database->open(name, Database::New)); - Database::key_type key = this->TestKey(); - Database::value_type val = this->TestValue(); + string key = this->TestKey(); + value_type value = this->TestValue(); - EXPECT_TRUE(database->put(key, val)); + EXPECT_TRUE(database->put(key, value)); database->close(); - EXPECT_TRUE(database->open(name, Database::ReadOnly)); + EXPECT_TRUE(database->open(name, Database::ReadOnly)); - Database::value_type new_val; + value_type new_value; - EXPECT_FALSE(database->get(key, &new_val)); + EXPECT_FALSE(database->get(key, &new_value)); } +#undef UNPACK_TYPES + } // namespace caffe diff --git a/tools/compute_image_mean.cpp b/tools/compute_image_mean.cpp index d13c4a0..f1a7967 100644 --- a/tools/compute_image_mean.cpp +++ b/tools/compute_image_mean.cpp @@ -26,18 +26,17 @@ int main(int argc, char** argv) { db_backend = std::string(argv[3]); } - caffe::shared_ptr database = caffe::DatabaseFactory(db_backend); + caffe::shared_ptr > database = + caffe::DatabaseFactory(db_backend); // Open db - CHECK(database->open(argv[1], Database::ReadOnly)); + CHECK(database->open(argv[1], Database::ReadOnly)); - Datum datum; BlobProto sum_blob; int count = 0; // load first datum - Database::const_iterator iter = database->begin(); - const Database::value_type& first_blob = iter->value; - datum.ParseFromArray(first_blob.data(), first_blob.size()); + Database::const_iterator iter = database->begin(); + const Datum& datum = iter->value; sum_blob.set_num(1); sum_blob.set_channels(datum.channels()); @@ -50,11 +49,10 @@ int main(int argc, char** argv) { sum_blob.add_data(0.); } LOG(INFO) << "Starting Iteration"; - for (Database::const_iterator iter = database->begin(); + for (Database::const_iterator iter = database->begin(); iter != database->end(); ++iter) { // just a dummy operation - const Database::value_type& blob = iter->value; - datum.ParseFromArray(blob.data(), blob.size()); + const Datum& datum = iter->value; const std::string& data = datum.data(); size_in_datum = std::max(datum.data().size(), datum.float_data_size()); diff --git a/tools/convert_imageset.cpp b/tools/convert_imageset.cpp index 1cdca7e..2ba3e3c 100644 --- a/tools/convert_imageset.cpp +++ b/tools/convert_imageset.cpp @@ -78,10 +78,11 @@ int main(int argc, char** argv) { int resize_width = std::max(0, FLAGS_resize_width); // Open new db - shared_ptr database = DatabaseFactory(db_backend); + shared_ptr > database = + DatabaseFactory(db_backend); // Open db - CHECK(database->open(db_path, Database::New)); + CHECK(database->open(db_path, Database::New)); // Storing to db std::string root_folder(argv[1]); @@ -110,13 +111,9 @@ int main(int argc, char** argv) { // sequential int length = snprintf(key_cstr, kMaxKeyLength, "%08d_%s", line_id, lines[line_id].first.c_str()); - Database::value_type value(datum.ByteSize()); - datum.SerializeWithCachedSizesToArray( - reinterpret_cast(value.data())); - Database::key_type keystr(key_cstr, key_cstr + length); // Put in db - CHECK(database->put(keystr, value)); + CHECK(database->put(string(key_cstr, length), datum)); if (++count % 1000 == 0) { // Commit txn diff --git a/tools/extract_features.cpp b/tools/extract_features.cpp index 1340192..47565a8 100644 --- a/tools/extract_features.cpp +++ b/tools/extract_features.cpp @@ -121,11 +121,13 @@ int feature_extraction_pipeline(int argc, char** argv) { int num_mini_batches = atoi(argv[++arg_pos]); - std::vector > feature_dbs; + std::vector > > feature_dbs; for (size_t i = 0; i < num_features; ++i) { LOG(INFO)<< "Opening database " << database_names[i]; - shared_ptr database = DatabaseFactory(argv[++arg_pos]); - CHECK(database->open(database_names.at(i), Database::New)); + shared_ptr > database = + DatabaseFactory(argv[++arg_pos]); + CHECK(database->open(database_names.at(i), + Database::New)); feature_dbs.push_back(database); } @@ -155,13 +157,9 @@ int feature_extraction_pipeline(int argc, char** argv) { for (int d = 0; d < dim_features; ++d) { datum.add_float_data(feature_blob_data[d]); } - Database::value_type value(datum.ByteSize()); - datum.SerializeWithCachedSizesToArray( - reinterpret_cast(value.data())); int length = snprintf(key_str, kMaxKeyStrLength, "%d", image_indices[i]); - Database::key_type key(key_str, key_str + length); - CHECK(feature_dbs.at(i)->put(key, value)); + CHECK(feature_dbs.at(i)->put(std::string(key_str, length), datum)); ++image_indices[i]; if (image_indices[i] % 1000 == 0) { CHECK(feature_dbs.at(i)->commit()); -- 2.7.4