From 6ad4f951133535f8fd43a8b6d4ece85f33465aca Mon Sep 17 00:00:00 2001 From: Kevin James Matzen Date: Tue, 7 Oct 2014 15:37:35 -0400 Subject: [PATCH] Refactored leveldb and lmdb code. --- include/caffe/data_layers.hpp | 14 +-- include/caffe/database.hpp | 103 ++++++++++++++++++++++ include/caffe/database_factory.hpp | 16 ++++ include/caffe/leveldb_database.hpp | 52 +++++++++++ include/caffe/lmdb_database.hpp | 54 ++++++++++++ include/caffe/util/io.hpp | 7 -- src/caffe/database_factory.cpp | 33 +++++++ src/caffe/layers/data_layer.cpp | 129 ++++----------------------- src/caffe/leveldb_database.cpp | 151 ++++++++++++++++++++++++++++++++ src/caffe/lmdb_database.cpp | 154 +++++++++++++++++++++++++++++++++ src/caffe/proto/caffe.proto | 6 +- src/caffe/test/test_data_layer.cpp | 93 +++++--------------- src/caffe/test/test_hdf5data_layer.cpp | 2 - src/caffe/util/io.cpp | 10 --- tools/compute_image_mean.cpp | 132 ++++++---------------------- tools/convert_imageset.cpp | 81 ++--------------- tools/extract_features.cpp | 50 +++++------ 17 files changed, 660 insertions(+), 427 deletions(-) create mode 100644 include/caffe/database.hpp create mode 100644 include/caffe/database_factory.hpp create mode 100644 include/caffe/leveldb_database.hpp create mode 100644 include/caffe/lmdb_database.hpp create mode 100644 src/caffe/database_factory.cpp create mode 100644 src/caffe/leveldb_database.cpp create mode 100644 src/caffe/lmdb_database.cpp diff --git a/include/caffe/data_layers.hpp b/include/caffe/data_layers.hpp index 1ed26c1..810f2bb 100644 --- a/include/caffe/data_layers.hpp +++ b/include/caffe/data_layers.hpp @@ -7,12 +7,11 @@ #include "boost/scoped_ptr.hpp" #include "hdf5.h" -#include "leveldb/db.h" -#include "lmdb.h" #include "caffe/blob.hpp" #include "caffe/common.hpp" #include "caffe/data_transformer.hpp" +#include "caffe/database.hpp" #include "caffe/filler.hpp" #include "caffe/internal_thread.hpp" #include "caffe/layer.hpp" @@ -101,15 +100,8 @@ class DataLayer : public BasePrefetchingDataLayer { protected: virtual void InternalThreadEntry(); - // LEVELDB - shared_ptr db_; - shared_ptr iter_; - // LMDB - MDB_env* mdb_env_; - MDB_dbi mdb_dbi_; - MDB_txn* mdb_txn_; - MDB_cursor* mdb_cursor_; - MDB_val mdb_key_, mdb_value_; + shared_ptr database_; + Database::const_iterator iter_; }; /** diff --git a/include/caffe/database.hpp b/include/caffe/database.hpp new file mode 100644 index 0000000..4a1a25e --- /dev/null +++ b/include/caffe/database.hpp @@ -0,0 +1,103 @@ +#ifndef CAFFE_DATABASE_H_ +#define CAFFE_DATABASE_H_ + +#include +#include +#include +#include + +#include "caffe/common.hpp" + +namespace caffe { + +class Database { + public: + enum Mode { + New, + ReadWrite, + ReadOnly + }; + + virtual void open(const string& filename, Mode mode) = 0; + virtual void put(const string& key, const string& value) = 0; + virtual void commit() = 0; + virtual void close() = 0; + + Database() { } + virtual ~Database() { } + + class iterator; + typedef iterator const_iterator; + + virtual const_iterator begin() const = 0; + virtual const_iterator cbegin() const = 0; + virtual const_iterator end() const = 0; + virtual const_iterator cend() const = 0; + + protected: + class DatabaseState; + + public: + class iterator : + public std::iterator > { + public: + typedef pair T; + typedef T value_type; + typedef T& reference_type; + typedef T* pointer_type; + + iterator() + : parent_(NULL) { } + + iterator(const Database* parent, shared_ptr state) + : parent_(parent), + state_(state) { } + ~iterator() { } + + bool operator==(const iterator& other) const { + return parent_->equal(state_, other.state_); + } + + bool operator!=(const iterator& other) const { + return !(*this == other); + } + + iterator& operator++() { + parent_->increment(state_); + return *this; + } + iterator operator++(int) { + iterator copy(*this); + parent_->increment(state_); + return copy; + } + + reference_type operator*() const { + return parent_->dereference(state_); + } + + pointer_type operator->() const { + return &parent_->dereference(state_); + } + + protected: + const Database* parent_; + shared_ptr state_; + }; + + protected: + class DatabaseState { + public: + virtual ~DatabaseState() { } + }; + + virtual bool equal(shared_ptr state1, + shared_ptr state2) const = 0; + virtual void increment(shared_ptr state) const = 0; + virtual pair& dereference( + shared_ptr state) const = 0; +}; + +} // namespace caffe + +#endif // CAFFE_DATABASE_H_ diff --git a/include/caffe/database_factory.hpp b/include/caffe/database_factory.hpp new file mode 100644 index 0000000..a6e39e7 --- /dev/null +++ b/include/caffe/database_factory.hpp @@ -0,0 +1,16 @@ +#ifndef CAFFE_DATABASE_FACTORY_H_ +#define CAFFE_DATABASE_FACTORY_H_ + +#include + +#include "caffe/common.hpp" +#include "caffe/database.hpp" + +namespace caffe { + +shared_ptr DatabaseFactory(const DataParameter_DB& type); +shared_ptr DatabaseFactory(const string& type); + +} // namespace caffe + +#endif // CAFFE_DATABASE_FACTORY_H_ diff --git a/include/caffe/leveldb_database.hpp b/include/caffe/leveldb_database.hpp new file mode 100644 index 0000000..5daf0e5 --- /dev/null +++ b/include/caffe/leveldb_database.hpp @@ -0,0 +1,52 @@ +#ifndef CAFFE_LEVELDB_DATABASE_H_ +#define CAFFE_LEVELDB_DATABASE_H_ + +#include +#include + +#include +#include + +#include "caffe/common.hpp" +#include "caffe/database.hpp" + +namespace caffe { + +class LeveldbDatabase : public Database { + public: + void open(const string& filename, Mode mode); + void put(const string& key, const string& value); + void commit(); + void close(); + + ~LeveldbDatabase() { this->close(); } + + protected: + class LeveldbState : public Database::DatabaseState { + public: + explicit LeveldbState(shared_ptr iter) + : Database::DatabaseState(), + iter_(iter) { } + + shared_ptr iter_; + pair kv_pair_; + }; + + bool equal(shared_ptr state1, + shared_ptr state2) const; + void increment(shared_ptr state) const; + pair& dereference(shared_ptr state) const; + + const_iterator begin() const; + const_iterator cbegin() const; + const_iterator end() const; + const_iterator cend() const; + + protected: + shared_ptr db_; + shared_ptr batch_; +}; + +} // namespace caffe + +#endif // CAFFE_LEVELDB_DATABASE_H_ diff --git a/include/caffe/lmdb_database.hpp b/include/caffe/lmdb_database.hpp new file mode 100644 index 0000000..f275cb4 --- /dev/null +++ b/include/caffe/lmdb_database.hpp @@ -0,0 +1,54 @@ +#ifndef CAFFE_LMDB_DATABASE_H_ +#define CAFFE_LMDB_DATABASE_H_ + +#include +#include + +#include "lmdb.h" + +#include "caffe/common.hpp" +#include "caffe/database.hpp" + +namespace caffe { + +class LmdbDatabase : public Database { + public: + LmdbDatabase() + : dbi_(0) { } + ~LmdbDatabase() { this->close(); } + + void open(const string& filename, Mode mode); + void put(const string& key, const string& value); + void commit(); + void close(); + + protected: + class LmdbState : public Database::DatabaseState { + public: + explicit LmdbState(MDB_cursor* cursor) + : Database::DatabaseState(), + cursor_(cursor) { } + + MDB_cursor* cursor_; + pair kv_pair_; + }; + + bool equal(shared_ptr state1, + shared_ptr state2) const; + void increment(shared_ptr state) const; + pair& dereference(shared_ptr state) const; + + protected: + const_iterator begin() const; + const_iterator cbegin() const; + const_iterator end() const; + const_iterator cend() const; + + MDB_env *env_ = NULL; + MDB_dbi dbi_; + MDB_txn *txn_ = NULL; +}; + +} // namespace caffe + +#endif // CAFFE_LMDB_DATABASE_H_ diff --git a/include/caffe/util/io.hpp b/include/caffe/util/io.hpp index e518979..1124bef 100644 --- a/include/caffe/util/io.hpp +++ b/include/caffe/util/io.hpp @@ -17,11 +17,6 @@ #define HDF5_NUM_DIMS 4 -namespace leveldb { -// Forward declaration for leveldb::Options to be used in GetlevelDBOptions(). -struct Options; -} - namespace caffe { using ::google::protobuf::Message; @@ -132,8 +127,6 @@ inline cv::Mat ReadImageToCVMat(const string& filename) { void CVMatToDatum(const cv::Mat& cv_img, Datum* datum); #endif -leveldb::Options GetLevelDBOptions(); - template void hdf5_load_nd_dataset_helper( hid_t file_id, const char* dataset_name_, int min_dim, int max_dim, diff --git a/src/caffe/database_factory.cpp b/src/caffe/database_factory.cpp new file mode 100644 index 0000000..393635b --- /dev/null +++ b/src/caffe/database_factory.cpp @@ -0,0 +1,33 @@ +#include +#include + +#include "caffe/database_factory.hpp" +#include "caffe/leveldb_database.hpp" +#include "caffe/lmdb_database.hpp" + +namespace caffe { + +shared_ptr DatabaseFactory(const DataParameter_DB& type) { + switch (type) { + case DataParameter_DB_LEVELDB: + return shared_ptr(new LeveldbDatabase()); + case DataParameter_DB_LMDB: + return shared_ptr(new LmdbDatabase()); + default: + LOG(FATAL) << "Unknown database type " << type; + } +} + +shared_ptr DatabaseFactory(const string& type) { + if ("leveldb" == type) { + return DatabaseFactory(DataParameter_DB_LEVELDB); + } else if ("lmdb" == type) { + return DatabaseFactory(DataParameter_DB_LMDB); + } else { + LOG(FATAL) << "Unknown database type " << type; + } +} + +} // namespace caffe + + diff --git a/src/caffe/layers/data_layer.cpp b/src/caffe/layers/data_layer.cpp index daa3289..b746bc8 100644 --- a/src/caffe/layers/data_layer.cpp +++ b/src/caffe/layers/data_layer.cpp @@ -1,4 +1,3 @@ -#include #include #include @@ -6,6 +5,7 @@ #include "caffe/common.hpp" #include "caffe/data_layers.hpp" +#include "caffe/database_factory.hpp" #include "caffe/layer.hpp" #include "caffe/proto/caffe.pb.h" #include "caffe/util/io.hpp" @@ -18,60 +18,18 @@ template DataLayer::~DataLayer() { this->JoinPrefetchThread(); // clean up the database resources - switch (this->layer_param_.data_param().backend()) { - case DataParameter_DB_LEVELDB: - break; // do nothing - case DataParameter_DB_LMDB: - mdb_cursor_close(mdb_cursor_); - mdb_close(mdb_env_, mdb_dbi_); - mdb_txn_abort(mdb_txn_); - mdb_env_close(mdb_env_); - break; - default: - LOG(FATAL) << "Unknown database backend"; - } + iter_ = database_->end(); + database_->close(); } template void DataLayer::DataLayerSetUp(const vector*>& bottom, const vector*>& top) { // Initialize DB - switch (this->layer_param_.data_param().backend()) { - case DataParameter_DB_LEVELDB: - { - leveldb::DB* db_temp; - leveldb::Options options = GetLevelDBOptions(); - options.create_if_missing = false; - LOG(INFO) << "Opening leveldb " << this->layer_param_.data_param().source(); - leveldb::Status status = leveldb::DB::Open( - options, this->layer_param_.data_param().source(), &db_temp); - CHECK(status.ok()) << "Failed to open leveldb " - << this->layer_param_.data_param().source() << std::endl - << status.ToString(); - db_.reset(db_temp); - iter_.reset(db_->NewIterator(leveldb::ReadOptions())); - iter_->SeekToFirst(); - } - break; - case DataParameter_DB_LMDB: - CHECK_EQ(mdb_env_create(&mdb_env_), MDB_SUCCESS) << "mdb_env_create failed"; - CHECK_EQ(mdb_env_set_mapsize(mdb_env_, 1099511627776), MDB_SUCCESS); // 1TB - CHECK_EQ(mdb_env_open(mdb_env_, - this->layer_param_.data_param().source().c_str(), - MDB_RDONLY|MDB_NOTLS, 0664), MDB_SUCCESS) << "mdb_env_open failed"; - CHECK_EQ(mdb_txn_begin(mdb_env_, NULL, MDB_RDONLY, &mdb_txn_), MDB_SUCCESS) - << "mdb_txn_begin failed"; - CHECK_EQ(mdb_open(mdb_txn_, NULL, 0, &mdb_dbi_), MDB_SUCCESS) - << "mdb_open failed"; - CHECK_EQ(mdb_cursor_open(mdb_txn_, mdb_dbi_, &mdb_cursor_), MDB_SUCCESS) - << "mdb_cursor_open failed"; - LOG(INFO) << "Opening lmdb " << this->layer_param_.data_param().source(); - CHECK_EQ(mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_FIRST), - MDB_SUCCESS) << "mdb_cursor_get failed"; - break; - default: - LOG(FATAL) << "Unknown database backend"; - } + database_ = DatabaseFactory(this->layer_param_.data_param().backend()); + LOG(INFO) << "Opening database " << this->layer_param_.data_param().source(); + database_->open(this->layer_param_.data_param().source(), Database::ReadOnly); + iter_ = database_->begin(); // Check if we would need to randomly skip a few data points if (this->layer_param_.data_param().rand_skip()) { @@ -79,37 +37,16 @@ void DataLayer::DataLayerSetUp(const vector*>& bottom, this->layer_param_.data_param().rand_skip(); LOG(INFO) << "Skipping first " << skip << " data points."; while (skip-- > 0) { - switch (this->layer_param_.data_param().backend()) { - case DataParameter_DB_LEVELDB: - iter_->Next(); - if (!iter_->Valid()) { - iter_->SeekToFirst(); - } - break; - case DataParameter_DB_LMDB: - if (mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_NEXT) - != MDB_SUCCESS) { - CHECK_EQ(mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, - MDB_FIRST), MDB_SUCCESS); - } - break; - default: - LOG(FATAL) << "Unknown database backend"; + LOG(INFO) << iter_->first; + if (++iter_ == database_->end()) { + iter_ = database_->begin(); } } } // Read a data point, and use it to initialize the top blob. + CHECK(iter_ != database_->end()); Datum datum; - switch (this->layer_param_.data_param().backend()) { - case DataParameter_DB_LEVELDB: - datum.ParseFromString(iter_->value().ToString()); - break; - case DataParameter_DB_LMDB: - datum.ParseFromArray(mdb_value_.mv_data, mdb_value_.mv_size); - break; - default: - LOG(FATAL) << "Unknown database backend"; - } + datum.ParseFromString(iter_->second); // image int crop_size = this->layer_param_.transform_param().crop_size(); @@ -153,23 +90,10 @@ void DataLayer::InternalThreadEntry() { const int batch_size = this->layer_param_.data_param().batch_size(); for (int item_id = 0; item_id < batch_size; ++item_id) { - // get a blob Datum datum; - switch (this->layer_param_.data_param().backend()) { - case DataParameter_DB_LEVELDB: - CHECK(iter_); - CHECK(iter_->Valid()); - datum.ParseFromString(iter_->value().ToString()); - break; - case DataParameter_DB_LMDB: - CHECK_EQ(mdb_cursor_get(mdb_cursor_, &mdb_key_, - &mdb_value_, MDB_GET_CURRENT), MDB_SUCCESS); - datum.ParseFromArray(mdb_value_.mv_data, - mdb_value_.mv_size); - break; - default: - LOG(FATAL) << "Unknown database backend"; - } + CHECK(iter_ != database_->end()); + datum.ParseFromString(iter_->second); + // Apply data transformations (mirror, scale, crop...) int offset = this->prefetch_data_.offset(item_id); this->transformed_data_.set_cpu_data(top_data + offset); @@ -179,26 +103,9 @@ void DataLayer::InternalThreadEntry() { } // go to the next iter - switch (this->layer_param_.data_param().backend()) { - case DataParameter_DB_LEVELDB: - iter_->Next(); - if (!iter_->Valid()) { - // We have reached the end. Restart from the first. - DLOG(INFO) << "Restarting data prefetching from start."; - iter_->SeekToFirst(); - } - break; - case DataParameter_DB_LMDB: - if (mdb_cursor_get(mdb_cursor_, &mdb_key_, - &mdb_value_, MDB_NEXT) != MDB_SUCCESS) { - // We have reached the end. Restart from the first. - DLOG(INFO) << "Restarting data prefetching from start."; - CHECK_EQ(mdb_cursor_get(mdb_cursor_, &mdb_key_, - &mdb_value_, MDB_FIRST), MDB_SUCCESS); - } - break; - default: - LOG(FATAL) << "Unknown database backend"; + ++iter_; + if (iter_ == database_->end()) { + iter_ = database_->begin(); } } } diff --git a/src/caffe/leveldb_database.cpp b/src/caffe/leveldb_database.cpp new file mode 100644 index 0000000..be7ac7f --- /dev/null +++ b/src/caffe/leveldb_database.cpp @@ -0,0 +1,151 @@ +#include +#include + +#include "caffe/leveldb_database.hpp" + +namespace caffe { + +void LeveldbDatabase::open(const string& filename, Mode mode) { + LOG(INFO) << "LevelDB: Open " << filename; + + leveldb::Options options; + switch (mode) { + case New: + LOG(INFO) << " mode NEW"; + options.error_if_exists = true; + options.create_if_missing = true; + break; + case ReadWrite: + LOG(INFO) << " mode RW"; + options.error_if_exists = false; + options.create_if_missing = true; + break; + case ReadOnly: + LOG(INFO) << " mode RO"; + options.error_if_exists = false; + options.create_if_missing = false; + break; + default: + LOG(FATAL) << "unknown mode " << mode; + } + options.write_buffer_size = 268435456; + options.max_open_files = 100; + + leveldb::DB* db; + + LOG(INFO) << "Opening leveldb " << filename; + leveldb::Status status = leveldb::DB::Open( + options, filename, &db); + db_.reset(db); + CHECK(status.ok()) << "Failed to open leveldb " << filename + << ". Is it already existing?"; + batch_.reset(new leveldb::WriteBatch()); +} + +void LeveldbDatabase::put(const string& key, const string& value) { + LOG(INFO) << "LevelDB: Put " << key; + + CHECK_NOTNULL(batch_.get()); + batch_->Put(key, value); +} + +void LeveldbDatabase::commit() { + LOG(INFO) << "LevelDB: Commit"; + + CHECK_NOTNULL(db_.get()); + CHECK_NOTNULL(batch_.get()); + + db_->Write(leveldb::WriteOptions(), batch_.get()); + batch_.reset(new leveldb::WriteBatch()); +} + +void LeveldbDatabase::close() { + LOG(INFO) << "LevelDB: Close"; + + if (batch_ && db_) { + this->commit(); + } + batch_.reset(); + db_.reset(); +} + +LeveldbDatabase::const_iterator LeveldbDatabase::begin() const { + CHECK_NOTNULL(db_.get()); + shared_ptr iter(db_->NewIterator(leveldb::ReadOptions())); + iter->SeekToFirst(); + if (!iter->Valid()) { + iter.reset(); + } + shared_ptr state(new LeveldbState(iter)); + return const_iterator(this, state); +} + +LeveldbDatabase::const_iterator LeveldbDatabase::end() const { + shared_ptr iter; + shared_ptr state(new LeveldbState(iter)); + return const_iterator(this, state); +} + +LeveldbDatabase::const_iterator LeveldbDatabase::cbegin() const { + return begin(); +} + +LeveldbDatabase::const_iterator LeveldbDatabase::cend() const { return end(); } + +bool LeveldbDatabase::equal(shared_ptr state1, + shared_ptr state2) const { + shared_ptr leveldb_state1 = + boost::dynamic_pointer_cast(state1); + + CHECK_NOTNULL(leveldb_state1.get()); + + shared_ptr leveldb_state2 = + boost::dynamic_pointer_cast(state2); + + CHECK_NOTNULL(leveldb_state2.get()); + + CHECK(!leveldb_state1->iter_ || leveldb_state1->iter_->Valid()); + CHECK(!leveldb_state2->iter_ || leveldb_state2->iter_->Valid()); + + // The KV store doesn't really have any sort of ordering, + // so while we can do a sequential scan over the collection, + // we can't really use subranges. + return !leveldb_state1->iter_ && !leveldb_state2->iter_; +} + +void LeveldbDatabase::increment(shared_ptr state) const { + shared_ptr leveldb_state = + boost::dynamic_pointer_cast(state); + + CHECK_NOTNULL(leveldb_state.get()); + + shared_ptr& iter = leveldb_state->iter_; + + CHECK_NOTNULL(iter.get()); + CHECK(iter->Valid()); + + iter->Next(); + if (!iter->Valid()) { + iter.reset(); + } +} + +pair& LeveldbDatabase::dereference( + shared_ptr state) const { + shared_ptr leveldb_state = + boost::dynamic_pointer_cast(state); + + CHECK_NOTNULL(leveldb_state.get()); + + shared_ptr& iter = leveldb_state->iter_; + + CHECK_NOTNULL(iter.get()); + + CHECK(iter->Valid()); + + leveldb_state->kv_pair_ = make_pair(iter->key().ToString(), + iter->value().ToString()); + return leveldb_state->kv_pair_; +} + +} // namespace caffe diff --git a/src/caffe/lmdb_database.cpp b/src/caffe/lmdb_database.cpp new file mode 100644 index 0000000..796bbc9 --- /dev/null +++ b/src/caffe/lmdb_database.cpp @@ -0,0 +1,154 @@ +#include + +#include +#include + +#include "caffe/lmdb_database.hpp" + +namespace caffe { + +void LmdbDatabase::open(const string& filename, Mode mode) { + LOG(INFO) << "LMDB: Open " << filename; + + CHECK(NULL == env_); + CHECK(NULL == txn_); + CHECK_EQ(0, dbi_); + + if (mode != ReadOnly) { + CHECK_EQ(mkdir(filename.c_str(), 0744), 0) << "mkdir " << filename + << "failed"; + } + + CHECK_EQ(mdb_env_create(&env_), MDB_SUCCESS) << "mdb_env_create failed"; + CHECK_EQ(mdb_env_set_mapsize(env_, 1099511627776), MDB_SUCCESS) // 1TB + << "mdb_env_set_mapsize failed"; + + int flag1 = 0; + int flag2 = 0; + if (mode == ReadOnly) { + flag1 = MDB_RDONLY | MDB_NOTLS; + flag2 = MDB_RDONLY; + } + + CHECK_EQ(mdb_env_open(env_, filename.c_str(), flag1, 0664), MDB_SUCCESS) + << "mdb_env_open failed"; + CHECK_EQ(mdb_txn_begin(env_, NULL, flag2, &txn_), MDB_SUCCESS) + << "mdb_txn_begin failed"; + CHECK_EQ(mdb_open(txn_, NULL, 0, &dbi_), MDB_SUCCESS) << "mdb_open failed"; +} + +void LmdbDatabase::put(const string& key, const string& value) { + LOG(INFO) << "LMDB: Put " << key; + + MDB_val mdbkey, mdbdata; + mdbdata.mv_size = value.size(); + mdbdata.mv_data = const_cast(&value[0]); + mdbkey.mv_size = key.size(); + mdbkey.mv_data = const_cast(&key[0]); + + CHECK_NOTNULL(txn_); + CHECK_NE(0, dbi_); + + CHECK_EQ(mdb_put(txn_, dbi_, &mdbkey, &mdbdata, 0), MDB_SUCCESS) + << "mdb_put failed"; +} + +void LmdbDatabase::commit() { + LOG(INFO) << "LMDB: Commit"; + + CHECK_NOTNULL(txn_); + + CHECK_EQ(mdb_txn_commit(txn_), MDB_SUCCESS) << "mdb_txn_commit failed"; +} + +void LmdbDatabase::close() { + LOG(INFO) << "LMDB: Close"; + + if (env_ && dbi_ && txn_) { + this->commit(); + } + + if (env_ && dbi_) { + mdb_close(env_, dbi_); + mdb_env_close(env_); + env_ = NULL; + dbi_ = 0; + txn_ = NULL; + } +} + +LmdbDatabase::const_iterator LmdbDatabase::begin() const { + MDB_cursor* cursor; + CHECK_EQ(mdb_cursor_open(txn_, dbi_, &cursor), MDB_SUCCESS); + MDB_val key; + MDB_val val; + CHECK_EQ(mdb_cursor_get(cursor, &key, &val, MDB_FIRST), MDB_SUCCESS); + + shared_ptr state(new LmdbState(cursor)); + return const_iterator(this, state); +} + +LmdbDatabase::const_iterator LmdbDatabase::end() const { + shared_ptr state(new LmdbState(NULL)); + return const_iterator(this, state); +} + +LmdbDatabase::const_iterator LmdbDatabase::cbegin() const { return begin(); } +LmdbDatabase::const_iterator LmdbDatabase::cend() const { return end(); } + +bool LmdbDatabase::equal(shared_ptr state1, + shared_ptr state2) const { + shared_ptr lmdb_state1 = + boost::dynamic_pointer_cast(state1); + + CHECK_NOTNULL(lmdb_state1.get()); + + shared_ptr lmdb_state2 = + boost::dynamic_pointer_cast(state2); + + CHECK_NOTNULL(lmdb_state2.get()); + + // The KV store doesn't really have any sort of ordering, + // so while we can do a sequential scan over the collection, + // we can't really use subranges. + return !lmdb_state1->cursor_ && !lmdb_state2->cursor_; +} + +void LmdbDatabase::increment(shared_ptr state) const { + shared_ptr lmdb_state = + boost::dynamic_pointer_cast(state); + + CHECK_NOTNULL(lmdb_state.get()); + + MDB_cursor*& cursor = lmdb_state->cursor_; + + MDB_val key; + MDB_val val; + if (MDB_SUCCESS != mdb_cursor_get(cursor, &key, &val, MDB_NEXT)) { + mdb_cursor_close(cursor); + cursor = NULL; + } +} + +pair& LmdbDatabase::dereference( + shared_ptr state) const { + shared_ptr lmdb_state = + boost::dynamic_pointer_cast(state); + + CHECK_NOTNULL(lmdb_state.get()); + + MDB_cursor*& cursor = lmdb_state->cursor_; + + MDB_val mdb_key; + MDB_val mdb_val; + CHECK_EQ(mdb_cursor_get(cursor, &mdb_key, &mdb_val, MDB_GET_CURRENT), + MDB_SUCCESS); + + lmdb_state->kv_pair_ = make_pair( + string(reinterpret_cast(mdb_key.mv_data), mdb_key.mv_size), + string(reinterpret_cast(mdb_val.mv_data), mdb_val.mv_size)); + + return lmdb_state->kv_pair_; +} + +} // namespace caffe diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index f0404a0..f0dba09 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -423,7 +423,7 @@ message DataParameter { // The rand_skip variable is for the data layer to skip a few data points // to avoid all asynchronous sgd clients to start at the same point. The skip // point would be set as rand_skip * rand(0,1). Note that rand_skip should not - // be larger than the number of keys in the leveldb. + // be larger than the number of keys in the database. optional uint32 rand_skip = 7 [default = 0]; optional DB backend = 8 [default = LEVELDB]; // DEPRECATED. See TransformationParameter. For data pre-processing, we can do @@ -518,7 +518,7 @@ message ImageDataParameter { // The rand_skip variable is for the data layer to skip a few data points // to avoid all asynchronous sgd clients to start at the same point. The skip // point would be set as rand_skip * rand(0,1). Note that rand_skip should not - // be larger than the number of keys in the leveldb. + // be larger than the number of keys in the database. optional uint32 rand_skip = 7 [default = 0]; // Whether or not ImageLayer should shuffle the list of files at every epoch. optional bool shuffle = 8 [default = false]; @@ -767,7 +767,7 @@ message V0LayerParameter { // The rand_skip variable is for the data layer to skip a few data points // to avoid all asynchronous sgd clients to start at the same point. The skip // point would be set as rand_skip * rand(0,1). Note that rand_skip should not - // be larger than the number of keys in the leveldb. + // be larger than the number of keys in the database. optional uint32 rand_skip = 53 [default = 0]; // Fields related to detection (det_*) diff --git a/src/caffe/test/test_data_layer.cpp b/src/caffe/test/test_data_layer.cpp index 657ffde..d99b5e3 100644 --- a/src/caffe/test/test_data_layer.cpp +++ b/src/caffe/test/test_data_layer.cpp @@ -2,10 +2,10 @@ #include #include "gtest/gtest.h" -#include "leveldb/db.h" #include "caffe/blob.hpp" #include "caffe/common.hpp" +#include "caffe/database_factory.hpp" #include "caffe/filler.hpp" #include "caffe/proto/caffe.pb.h" #include "caffe/util/io.hpp" @@ -36,16 +36,11 @@ class DataLayerTest : public MultiDeviceTest { // Fill the LevelDB with data: if unique_pixels, each pixel is unique but // all images are the same; else each image is unique but all pixels within // an image are the same. - void FillLevelDB(const bool unique_pixels) { - backend_ = DataParameter_DB_LEVELDB; - LOG(INFO) << "Using temporary leveldb " << *filename_; - leveldb::DB* db; - leveldb::Options options; - options.error_if_exists = true; - options.create_if_missing = true; - leveldb::Status status = - leveldb::DB::Open(options, filename_->c_str(), &db); - CHECK(status.ok()); + void Fill(const bool unique_pixels, DataParameter_DB backend) { + backend_ = backend; + LOG(INFO) << "Using temporary database " << *filename_; + shared_ptr database = DatabaseFactory(backend_); + database->open(*filename_, Database::New); for (int i = 0; i < 5; ++i) { Datum datum; datum.set_label(i); @@ -59,57 +54,9 @@ class DataLayerTest : public MultiDeviceTest { } stringstream ss; ss << i; - db->Put(leveldb::WriteOptions(), ss.str(), datum.SerializeAsString()); + database->put(ss.str(), datum.SerializeAsString()); } - delete db; - } - - // Fill the LMDB with data: unique_pixels has same meaning as in FillLevelDB. - void FillLMDB(const bool unique_pixels) { - backend_ = DataParameter_DB_LMDB; - LOG(INFO) << "Using temporary lmdb " << *filename_; - CHECK_EQ(mkdir(filename_->c_str(), 0744), 0) << "mkdir " << filename_ - << "failed"; - MDB_env *env; - MDB_dbi dbi; - MDB_val mdbkey, mdbdata; - MDB_txn *txn; - CHECK_EQ(mdb_env_create(&env), MDB_SUCCESS) << "mdb_env_create failed"; - CHECK_EQ(mdb_env_set_mapsize(env, 1099511627776), MDB_SUCCESS) // 1TB - << "mdb_env_set_mapsize failed"; - CHECK_EQ(mdb_env_open(env, filename_->c_str(), 0, 0664), MDB_SUCCESS) - << "mdb_env_open failed"; - CHECK_EQ(mdb_txn_begin(env, NULL, 0, &txn), MDB_SUCCESS) - << "mdb_txn_begin failed"; - CHECK_EQ(mdb_open(txn, NULL, 0, &dbi), MDB_SUCCESS) << "mdb_open failed"; - - for (int i = 0; i < 5; ++i) { - Datum datum; - datum.set_label(i); - datum.set_channels(2); - datum.set_height(3); - datum.set_width(4); - std::string* data = datum.mutable_data(); - for (int j = 0; j < 24; ++j) { - int datum = unique_pixels ? j : i; - data->push_back(static_cast(datum)); - } - stringstream ss; - ss << i; - - string value; - datum.SerializeToString(&value); - mdbdata.mv_size = value.size(); - mdbdata.mv_data = reinterpret_cast(&value[0]); - string keystr = ss.str(); - mdbkey.mv_size = keystr.size(); - mdbkey.mv_data = reinterpret_cast(&keystr[0]); - CHECK_EQ(mdb_put(txn, dbi, &mdbkey, &mdbdata, 0), MDB_SUCCESS) - << "mdb_put failed"; - } - CHECK_EQ(mdb_txn_commit(txn), MDB_SUCCESS) << "mdb_txn_commit failed"; - mdb_close(env, dbi); - mdb_env_close(env); + database->close(); } void TestRead() { @@ -234,7 +181,7 @@ class DataLayerTest : public MultiDeviceTest { } crop_sequence.push_back(iter_crop_sequence); } - } // destroy 1st data layer and unlock the leveldb + } // destroy 1st data layer and unlock the database // Get crop sequence after reseeding Caffe with 1701. // Check that the sequence is the same as the original. @@ -289,7 +236,7 @@ class DataLayerTest : public MultiDeviceTest { } crop_sequence.push_back(iter_crop_sequence); } - } // destroy 1st data layer and unlock the leveldb + } // destroy 1st data layer and unlock the database // Get crop sequence continuing from previous Caffe RNG state; reseed // srand with 1701. Check that the sequence differs from the original. @@ -327,14 +274,14 @@ TYPED_TEST_CASE(DataLayerTest, TestDtypesAndDevices); TYPED_TEST(DataLayerTest, TestReadLevelDB) { const bool unique_pixels = false; // all pixels the same; images different - this->FillLevelDB(unique_pixels); + this->Fill(unique_pixels, DataParameter_DB_LEVELDB); this->TestRead(); } TYPED_TEST(DataLayerTest, TestReadCropTrainLevelDB) { Caffe::set_phase(Caffe::TRAIN); const bool unique_pixels = true; // all images the same; pixels different - this->FillLevelDB(unique_pixels); + this->Fill(unique_pixels, DataParameter_DB_LEVELDB); this->TestReadCrop(); } @@ -343,7 +290,7 @@ TYPED_TEST(DataLayerTest, TestReadCropTrainLevelDB) { TYPED_TEST(DataLayerTest, TestReadCropTrainSequenceSeededLevelDB) { Caffe::set_phase(Caffe::TRAIN); const bool unique_pixels = true; // all images the same; pixels different - this->FillLevelDB(unique_pixels); + this->Fill(unique_pixels, DataParameter_DB_LEVELDB); this->TestReadCropTrainSequenceSeeded(); } @@ -352,27 +299,27 @@ TYPED_TEST(DataLayerTest, TestReadCropTrainSequenceSeededLevelDB) { TYPED_TEST(DataLayerTest, TestReadCropTrainSequenceUnseededLevelDB) { Caffe::set_phase(Caffe::TRAIN); const bool unique_pixels = true; // all images the same; pixels different - this->FillLevelDB(unique_pixels); + this->Fill(unique_pixels, DataParameter_DB_LEVELDB); this->TestReadCropTrainSequenceUnseeded(); } TYPED_TEST(DataLayerTest, TestReadCropTestLevelDB) { Caffe::set_phase(Caffe::TEST); const bool unique_pixels = true; // all images the same; pixels different - this->FillLevelDB(unique_pixels); + this->Fill(unique_pixels, DataParameter_DB_LEVELDB); this->TestReadCrop(); } TYPED_TEST(DataLayerTest, TestReadLMDB) { const bool unique_pixels = false; // all pixels the same; images different - this->FillLMDB(unique_pixels); + this->Fill(unique_pixels, DataParameter_DB_LMDB); this->TestRead(); } TYPED_TEST(DataLayerTest, TestReadCropTrainLMDB) { Caffe::set_phase(Caffe::TRAIN); const bool unique_pixels = true; // all images the same; pixels different - this->FillLMDB(unique_pixels); + this->Fill(unique_pixels, DataParameter_DB_LMDB); this->TestReadCrop(); } @@ -381,7 +328,7 @@ TYPED_TEST(DataLayerTest, TestReadCropTrainLMDB) { TYPED_TEST(DataLayerTest, TestReadCropTrainSequenceSeededLMDB) { Caffe::set_phase(Caffe::TRAIN); const bool unique_pixels = true; // all images the same; pixels different - this->FillLMDB(unique_pixels); + this->Fill(unique_pixels, DataParameter_DB_LMDB); this->TestReadCropTrainSequenceSeeded(); } @@ -390,14 +337,14 @@ TYPED_TEST(DataLayerTest, TestReadCropTrainSequenceSeededLMDB) { TYPED_TEST(DataLayerTest, TestReadCropTrainSequenceUnseededLMDB) { Caffe::set_phase(Caffe::TRAIN); const bool unique_pixels = true; // all images the same; pixels different - this->FillLMDB(unique_pixels); + this->Fill(unique_pixels, DataParameter_DB_LMDB); this->TestReadCropTrainSequenceUnseeded(); } TYPED_TEST(DataLayerTest, TestReadCropTestLMDB) { Caffe::set_phase(Caffe::TEST); const bool unique_pixels = true; // all images the same; pixels different - this->FillLMDB(unique_pixels); + this->Fill(unique_pixels, DataParameter_DB_LMDB); this->TestReadCrop(); } diff --git a/src/caffe/test/test_hdf5data_layer.cpp b/src/caffe/test/test_hdf5data_layer.cpp index 41a3a83..8d3b3d1 100644 --- a/src/caffe/test/test_hdf5data_layer.cpp +++ b/src/caffe/test/test_hdf5data_layer.cpp @@ -1,8 +1,6 @@ #include #include -#include "leveldb/db.h" - #include "gtest/gtest.h" #include "caffe/blob.hpp" diff --git a/src/caffe/util/io.cpp b/src/caffe/util/io.cpp index 36510d6..a4a6627 100644 --- a/src/caffe/util/io.cpp +++ b/src/caffe/util/io.cpp @@ -2,7 +2,6 @@ #include #include #include -#include #include #include #include @@ -122,15 +121,6 @@ void CVMatToDatum(const cv::Mat& cv_img, Datum* datum) { datum->set_data(buffer); } - -leveldb::Options GetLevelDBOptions() { - // In default, we will return the leveldb option and set the max open files - // in order to avoid using up the operating system's limit. - leveldb::Options options; - options.max_open_files = 100; - return options; -} - // Verifies format of data stored in HDF5 file and reshapes blob accordingly. template void hdf5_load_nd_dataset_helper( diff --git a/tools/compute_image_mean.cpp b/tools/compute_image_mean.cpp index 6adde8b..e59bbf1 100644 --- a/tools/compute_image_mean.cpp +++ b/tools/compute_image_mean.cpp @@ -1,14 +1,14 @@ #include -#include -#include #include #include #include +#include "caffe/database_factory.hpp" #include "caffe/proto/caffe.pb.h" #include "caffe/util/io.hpp" +using caffe::Database; using caffe::Datum; using caffe::BlobProto; using std::max; @@ -26,57 +26,16 @@ int main(int argc, char** argv) { db_backend = std::string(argv[3]); } - // leveldb - leveldb::DB* db; - leveldb::Options options; - options.create_if_missing = false; - leveldb::Iterator* it = NULL; - // lmdb - MDB_env* mdb_env; - MDB_dbi mdb_dbi; - MDB_val mdb_key, mdb_value; - MDB_txn* mdb_txn; - MDB_cursor* mdb_cursor; + caffe::shared_ptr database = caffe::DatabaseFactory(db_backend); // Open db - if (db_backend == "leveldb") { // leveldb - LOG(INFO) << "Opening leveldb " << argv[1]; - leveldb::Status status = leveldb::DB::Open( - options, argv[1], &db); - CHECK(status.ok()) << "Failed to open leveldb " << argv[1]; - leveldb::ReadOptions read_options; - read_options.fill_cache = false; - it = db->NewIterator(read_options); - it->SeekToFirst(); - } else if (db_backend == "lmdb") { // lmdb - LOG(INFO) << "Opening lmdb " << argv[1]; - CHECK_EQ(mdb_env_create(&mdb_env), MDB_SUCCESS) << "mdb_env_create failed"; - CHECK_EQ(mdb_env_set_mapsize(mdb_env, 1099511627776), MDB_SUCCESS); // 1TB - CHECK_EQ(mdb_env_open(mdb_env, argv[1], MDB_RDONLY, 0664), - MDB_SUCCESS) << "mdb_env_open failed"; - CHECK_EQ(mdb_txn_begin(mdb_env, NULL, MDB_RDONLY, &mdb_txn), MDB_SUCCESS) - << "mdb_txn_begin failed"; - CHECK_EQ(mdb_open(mdb_txn, NULL, 0, &mdb_dbi), MDB_SUCCESS) - << "mdb_open failed"; - CHECK_EQ(mdb_cursor_open(mdb_txn, mdb_dbi, &mdb_cursor), MDB_SUCCESS) - << "mdb_cursor_open failed"; - CHECK_EQ(mdb_cursor_get(mdb_cursor, &mdb_key, &mdb_value, MDB_FIRST), - MDB_SUCCESS); - } else { - LOG(FATAL) << "Unknown db backend " << db_backend; - } + database->open(argv[1], Database::ReadOnly); Datum datum; BlobProto sum_blob; int count = 0; // load first datum - if (db_backend == "leveldb") { - datum.ParseFromString(it->value().ToString()); - } else if (db_backend == "lmdb") { - datum.ParseFromArray(mdb_value.mv_data, mdb_value.mv_size); - } else { - LOG(FATAL) << "Unknown db backend " << db_backend; - } + datum.ParseFromString(database->begin()->second); sum_blob.set_num(1); sum_blob.set_channels(datum.channels()); @@ -89,59 +48,29 @@ int main(int argc, char** argv) { sum_blob.add_data(0.); } LOG(INFO) << "Starting Iteration"; - if (db_backend == "leveldb") { // leveldb - for (it->SeekToFirst(); it->Valid(); it->Next()) { - // just a dummy operation - datum.ParseFromString(it->value().ToString()); - const std::string& data = datum.data(); - size_in_datum = std::max(datum.data().size(), - datum.float_data_size()); - CHECK_EQ(size_in_datum, data_size) << "Incorrect data field size " << - size_in_datum; - if (data.size() != 0) { - for (int i = 0; i < size_in_datum; ++i) { - sum_blob.set_data(i, sum_blob.data(i) + (uint8_t)data[i]); - } - } else { - for (int i = 0; i < size_in_datum; ++i) { - sum_blob.set_data(i, sum_blob.data(i) + - static_cast(datum.float_data(i))); - } + for (Database::const_iterator iter = database->begin(); + iter != database->end(); ++iter) { + // just a dummy operation + datum.ParseFromString(iter->second); + const std::string& data = datum.data(); + size_in_datum = std::max(datum.data().size(), + datum.float_data_size()); + CHECK_EQ(size_in_datum, data_size) << "Incorrect data field size " << + size_in_datum; + if (data.size() != 0) { + for (int i = 0; i < size_in_datum; ++i) { + sum_blob.set_data(i, sum_blob.data(i) + (uint8_t)data[i]); } - ++count; - if (count % 10000 == 0) { - LOG(ERROR) << "Processed " << count << " files."; + } else { + for (int i = 0; i < size_in_datum; ++i) { + sum_blob.set_data(i, sum_blob.data(i) + + static_cast(datum.float_data(i))); } } - } else if (db_backend == "lmdb") { // lmdb - CHECK_EQ(mdb_cursor_get(mdb_cursor, &mdb_key, &mdb_value, MDB_FIRST), - MDB_SUCCESS); - do { - // just a dummy operation - datum.ParseFromArray(mdb_value.mv_data, mdb_value.mv_size); - const std::string& data = datum.data(); - size_in_datum = std::max(datum.data().size(), - datum.float_data_size()); - CHECK_EQ(size_in_datum, data_size) << "Incorrect data field size " << - size_in_datum; - if (data.size() != 0) { - for (int i = 0; i < size_in_datum; ++i) { - sum_blob.set_data(i, sum_blob.data(i) + (uint8_t)data[i]); - } - } else { - for (int i = 0; i < size_in_datum; ++i) { - sum_blob.set_data(i, sum_blob.data(i) + - static_cast(datum.float_data(i))); - } - } - ++count; - if (count % 10000 == 0) { - LOG(ERROR) << "Processed " << count << " files."; - } - } while (mdb_cursor_get(mdb_cursor, &mdb_key, &mdb_value, MDB_NEXT) - == MDB_SUCCESS); - } else { - LOG(FATAL) << "Unknown db backend " << db_backend; + ++count; + if (count % 10000 == 0) { + LOG(ERROR) << "Processed " << count << " files."; + } } if (count % 10000 != 0) { @@ -155,15 +84,6 @@ int main(int argc, char** argv) { WriteProtoToBinaryFile(sum_blob, argv[2]); // Clean up - if (db_backend == "leveldb") { - delete db; - } else if (db_backend == "lmdb") { - mdb_cursor_close(mdb_cursor); - mdb_close(mdb_env, mdb_dbi); - mdb_txn_abort(mdb_txn); - mdb_env_close(mdb_env); - } else { - LOG(FATAL) << "Unknown db backend " << db_backend; - } + database->close(); return 0; } diff --git a/tools/convert_imageset.cpp b/tools/convert_imageset.cpp index 7c8c1da..6f03a9d 100644 --- a/tools/convert_imageset.cpp +++ b/tools/convert_imageset.cpp @@ -10,10 +10,6 @@ #include #include -#include -#include -#include -#include #include #include // NOLINT(readability/streams) @@ -21,6 +17,7 @@ #include #include +#include "caffe/database_factory.hpp" #include "caffe/proto/caffe.pb.h" #include "caffe/util/io.hpp" #include "caffe/util/rng.hpp" @@ -81,43 +78,10 @@ int main(int argc, char** argv) { int resize_width = std::max(0, FLAGS_resize_width); // Open new db - // lmdb - MDB_env *mdb_env; - MDB_dbi mdb_dbi; - MDB_val mdb_key, mdb_data; - MDB_txn *mdb_txn; - // leveldb - leveldb::DB* db; - leveldb::Options options; - options.error_if_exists = true; - options.create_if_missing = true; - options.write_buffer_size = 268435456; - leveldb::WriteBatch* batch = NULL; + shared_ptr database = DatabaseFactory(db_backend); // Open db - if (db_backend == "leveldb") { // leveldb - LOG(INFO) << "Opening leveldb " << db_path; - leveldb::Status status = leveldb::DB::Open( - options, db_path, &db); - CHECK(status.ok()) << "Failed to open leveldb " << db_path - << ". Is it already existing?"; - batch = new leveldb::WriteBatch(); - } else if (db_backend == "lmdb") { // lmdb - LOG(INFO) << "Opening lmdb " << db_path; - CHECK_EQ(mkdir(db_path, 0744), 0) - << "mkdir " << db_path << "failed"; - CHECK_EQ(mdb_env_create(&mdb_env), MDB_SUCCESS) << "mdb_env_create failed"; - CHECK_EQ(mdb_env_set_mapsize(mdb_env, 1099511627776), MDB_SUCCESS) // 1TB - << "mdb_env_set_mapsize failed"; - CHECK_EQ(mdb_env_open(mdb_env, db_path, 0, 0664), MDB_SUCCESS) - << "mdb_env_open failed"; - CHECK_EQ(mdb_txn_begin(mdb_env, NULL, 0, &mdb_txn), MDB_SUCCESS) - << "mdb_txn_begin failed"; - CHECK_EQ(mdb_open(mdb_txn, NULL, 0, &mdb_dbi), MDB_SUCCESS) - << "mdb_open failed. Does the lmdb already exist? "; - } else { - LOG(FATAL) << "Unknown db backend " << db_backend; - } + database->open(db_path, Database::New); // Storing to db std::string root_folder(argv[1]); @@ -151,50 +115,19 @@ int main(int argc, char** argv) { std::string keystr(key_cstr); // Put in db - if (db_backend == "leveldb") { // leveldb - batch->Put(keystr, value); - } else if (db_backend == "lmdb") { // lmdb - mdb_data.mv_size = value.size(); - mdb_data.mv_data = reinterpret_cast(&value[0]); - mdb_key.mv_size = keystr.size(); - mdb_key.mv_data = reinterpret_cast(&keystr[0]); - CHECK_EQ(mdb_put(mdb_txn, mdb_dbi, &mdb_key, &mdb_data, 0), MDB_SUCCESS) - << "mdb_put failed"; - } else { - LOG(FATAL) << "Unknown db backend " << db_backend; - } + database->put(keystr, value); if (++count % 1000 == 0) { // Commit txn - if (db_backend == "leveldb") { // leveldb - db->Write(leveldb::WriteOptions(), batch); - delete batch; - batch = new leveldb::WriteBatch(); - } else if (db_backend == "lmdb") { // lmdb - CHECK_EQ(mdb_txn_commit(mdb_txn), MDB_SUCCESS) - << "mdb_txn_commit failed"; - CHECK_EQ(mdb_txn_begin(mdb_env, NULL, 0, &mdb_txn), MDB_SUCCESS) - << "mdb_txn_begin failed"; - } else { - LOG(FATAL) << "Unknown db backend " << db_backend; - } + database->commit(); LOG(ERROR) << "Processed " << count << " files."; } } // write the last batch if (count % 1000 != 0) { - if (db_backend == "leveldb") { // leveldb - db->Write(leveldb::WriteOptions(), batch); - delete batch; - delete db; - } else if (db_backend == "lmdb") { // lmdb - CHECK_EQ(mdb_txn_commit(mdb_txn), MDB_SUCCESS) << "mdb_txn_commit failed"; - mdb_close(mdb_env, mdb_dbi); - mdb_env_close(mdb_env); - } else { - LOG(FATAL) << "Unknown db backend " << db_backend; - } + database->commit(); LOG(ERROR) << "Processed " << count << " files."; } + database->close(); return 0; } diff --git a/tools/extract_features.cpp b/tools/extract_features.cpp index 299b311..b3ad8e6 100644 --- a/tools/extract_features.cpp +++ b/tools/extract_features.cpp @@ -4,11 +4,10 @@ #include "boost/algorithm/string.hpp" #include "google/protobuf/text_format.h" -#include "leveldb/db.h" -#include "leveldb/write_batch.h" #include "caffe/blob.hpp" #include "caffe/common.hpp" +#include "caffe/database_factory.hpp" #include "caffe/net.hpp" #include "caffe/proto/caffe.pb.h" #include "caffe/util/io.hpp" @@ -17,6 +16,8 @@ using boost::shared_ptr; using caffe::Blob; using caffe::Caffe; +using caffe::Database; +using caffe::DatabaseFactory; using caffe::Datum; using caffe::Net; @@ -38,12 +39,12 @@ int feature_extraction_pipeline(int argc, char** argv) { " extract features of the input data produced by the net.\n" "Usage: extract_features pretrained_net_param" " feature_extraction_proto_file extract_feature_blob_name1[,name2,...]" - " save_feature_leveldb_name1[,name2,...] num_mini_batches [CPU/GPU]" + " save_feature_database_name1[,name2,...] num_mini_batches [CPU/GPU]" " [DEVICE_ID=0]\n" "Note: you can extract multiple features in one pass by specifying" - " multiple feature blob names and leveldb names seperated by ','." + " multiple feature blob names and database names seperated by ','." " The names cannot contain white space characters and the number of blobs" - " and leveldbs must be equal."; + " and databases must be equal."; return 1; } int arg_pos = num_required_args; @@ -104,12 +105,12 @@ int feature_extraction_pipeline(int argc, char** argv) { std::vector blob_names; boost::split(blob_names, extract_feature_blob_names, boost::is_any_of(",")); - std::string save_feature_leveldb_names(argv[++arg_pos]); - std::vector leveldb_names; - boost::split(leveldb_names, save_feature_leveldb_names, + std::string save_feature_database_names(argv[++arg_pos]); + std::vector database_names; + boost::split(database_names, save_feature_database_names, boost::is_any_of(",")); - CHECK_EQ(blob_names.size(), leveldb_names.size()) << - " the number of blob names and leveldb names must be equal"; + CHECK_EQ(blob_names.size(), database_names.size()) << + " the number of blob names and database names must be equal"; size_t num_features = blob_names.size(); for (size_t i = 0; i < num_features; i++) { @@ -118,19 +119,12 @@ int feature_extraction_pipeline(int argc, char** argv) { << " in the network " << feature_extraction_proto; } - leveldb::Options options; - options.error_if_exists = true; - options.create_if_missing = true; - options.write_buffer_size = 268435456; - std::vector > feature_dbs; + std::vector > feature_dbs; for (size_t i = 0; i < num_features; ++i) { - LOG(INFO)<< "Opening leveldb " << leveldb_names[i]; - leveldb::DB* db; - leveldb::Status status = leveldb::DB::Open(options, - leveldb_names[i].c_str(), - &db); - CHECK(status.ok()) << "Failed to open leveldb " << leveldb_names[i]; - feature_dbs.push_back(shared_ptr(db)); + LOG(INFO)<< "Opening database " << database_names[i]; + shared_ptr database = DatabaseFactory("leveldb"); + database->open(database_names.at(i), Database::New); + feature_dbs.push_back(database); } int num_mini_batches = atoi(argv[++arg_pos]); @@ -138,9 +132,6 @@ int feature_extraction_pipeline(int argc, char** argv) { LOG(ERROR)<< "Extacting Features"; Datum datum; - std::vector > feature_batches( - num_features, - shared_ptr(new leveldb::WriteBatch())); const int kMaxKeyStrLength = 100; char key_str[kMaxKeyStrLength]; std::vector*> input_vec; @@ -167,14 +158,12 @@ int feature_extraction_pipeline(int argc, char** argv) { std::string value; datum.SerializeToString(&value); snprintf(key_str, kMaxKeyStrLength, "%d", image_indices[i]); - feature_batches[i]->Put(std::string(key_str), value); + feature_dbs.at(i)->put(std::string(key_str), value); ++image_indices[i]; if (image_indices[i] % 1000 == 0) { - feature_dbs[i]->Write(leveldb::WriteOptions(), - feature_batches[i].get()); + feature_dbs.at(i)->commit(); LOG(ERROR)<< "Extracted features of " << image_indices[i] << " query images for feature blob " << blob_names[i]; - feature_batches[i].reset(new leveldb::WriteBatch()); } } // for (int n = 0; n < batch_size; ++n) } // for (int i = 0; i < num_features; ++i) @@ -182,10 +171,11 @@ int feature_extraction_pipeline(int argc, char** argv) { // write the last batch for (int i = 0; i < num_features; ++i) { if (image_indices[i] % 1000 != 0) { - feature_dbs[i]->Write(leveldb::WriteOptions(), feature_batches[i].get()); + feature_dbs.at(i)->commit(); } LOG(ERROR)<< "Extracted features of " << image_indices[i] << " query images for feature blob " << blob_names[i]; + feature_dbs.at(i)->close(); } LOG(ERROR)<< "Successfully extracted the features!"; -- 2.7.4