From 8fef285e536a67a6b2c72b54734335aeb3069d59 Mon Sep 17 00:00:00 2001 From: Kevin James Matzen Date: Sun, 12 Oct 2014 14:39:31 -0400 Subject: [PATCH] Updated Database interface to use custom KV type rather than std::pair. Removed two buffer copies in dereference operation for DB iterators. --- include/caffe/database.hpp | 12 ++++++++---- include/caffe/leveldb_database.hpp | 4 ++-- include/caffe/lmdb_database.hpp | 4 ++-- src/caffe/layers/data_layer.cpp | 4 ++-- src/caffe/leveldb_database.cpp | 14 +++++++++----- src/caffe/lmdb_database.cpp | 13 ++++++++----- tools/compute_image_mean.cpp | 4 ++-- 7 files changed, 33 insertions(+), 22 deletions(-) diff --git a/include/caffe/database.hpp b/include/caffe/database.hpp index 23036a8..953b58c 100644 --- a/include/caffe/database.hpp +++ b/include/caffe/database.hpp @@ -21,6 +21,11 @@ class Database { typedef vector buffer_t; + struct KV { + buffer_t key; + buffer_t value; + }; + virtual void open(const string& filename, Mode mode) = 0; virtual void put(buffer_t* key, buffer_t* value) = 0; virtual void commit() = 0; @@ -41,10 +46,9 @@ class Database { class DatabaseState; public: - class iterator : public std::iterator< - std::forward_iterator_tag, pair > { + class iterator : public std::iterator { public: - typedef pair T; + typedef KV T; typedef T value_type; typedef T& reference_type; typedef T* pointer_type; @@ -97,7 +101,7 @@ class Database { virtual bool equal(shared_ptr state1, shared_ptr state2) const = 0; virtual void increment(shared_ptr state) const = 0; - virtual pair& dereference( + virtual KV& dereference( shared_ptr state) const = 0; }; diff --git a/include/caffe/leveldb_database.hpp b/include/caffe/leveldb_database.hpp index dda6697..1c084cb 100644 --- a/include/caffe/leveldb_database.hpp +++ b/include/caffe/leveldb_database.hpp @@ -32,13 +32,13 @@ class LeveldbDatabase : public Database { iter_(iter) { } shared_ptr iter_; - pair kv_pair_; + KV kv_pair_; }; bool equal(shared_ptr state1, shared_ptr state2) const; void increment(shared_ptr state) const; - pair& dereference(shared_ptr state) const; + Database::KV& dereference(shared_ptr state) const; shared_ptr db_; shared_ptr batch_; diff --git a/include/caffe/lmdb_database.hpp b/include/caffe/lmdb_database.hpp index 9654222..d72be3d 100644 --- a/include/caffe/lmdb_database.hpp +++ b/include/caffe/lmdb_database.hpp @@ -36,13 +36,13 @@ class LmdbDatabase : public Database { cursor_(cursor) { } MDB_cursor* cursor_; - pair kv_pair_; + KV kv_pair_; }; bool equal(shared_ptr state1, shared_ptr state2) const; void increment(shared_ptr state) const; - pair& dereference(shared_ptr state) const; + Database::KV& dereference(shared_ptr state) const; MDB_env *env_; MDB_dbi dbi_; diff --git a/src/caffe/layers/data_layer.cpp b/src/caffe/layers/data_layer.cpp index 4d36b8e..998c00c 100644 --- a/src/caffe/layers/data_layer.cpp +++ b/src/caffe/layers/data_layer.cpp @@ -48,7 +48,7 @@ void DataLayer::DataLayerSetUp(const vector*>& bottom, // Read a data point, and use it to initialize the top blob. CHECK(iter_ != database_->end()); Datum datum; - datum.ParseFromArray(iter_->second.data(), iter_->second.size()); + datum.ParseFromArray(iter_->value.data(), iter_->value.size()); // image int crop_size = this->layer_param_.transform_param().crop_size(); @@ -94,7 +94,7 @@ void DataLayer::InternalThreadEntry() { for (int item_id = 0; item_id < batch_size; ++item_id) { Datum datum; CHECK(iter_ != database_->end()); - datum.ParseFromArray(iter_->second.data(), iter_->second.size()); + datum.ParseFromArray(iter_->value.data(), iter_->value.size()); // Apply data transformations (mirror, scale, crop...) int offset = this->prefetch_data_.offset(item_id); diff --git a/src/caffe/leveldb_database.cpp b/src/caffe/leveldb_database.cpp index a5cdaa3..8084a6c 100644 --- a/src/caffe/leveldb_database.cpp +++ b/src/caffe/leveldb_database.cpp @@ -131,7 +131,7 @@ void LeveldbDatabase::increment(shared_ptr state) const { } } -pair& LeveldbDatabase::dereference( +Database::KV& LeveldbDatabase::dereference( shared_ptr state) const { shared_ptr leveldb_state = boost::dynamic_pointer_cast(state); @@ -144,10 +144,14 @@ pair& LeveldbDatabase::dereference( CHECK(iter->Valid()); - leveldb_state->kv_pair_ = make_pair( - buffer_t(iter->key().data(), iter->key().data() + iter->key().size()), - buffer_t(iter->value().data(), - iter->value().data() + iter->value().size())); + Database::buffer_t temp_key(buffer_t(iter->key().data(), + iter->key().data() + iter->key().size())); + + Database::buffer_t temp_value(buffer_t(iter->value().data(), + iter->value().data() + iter->value().size())); + + leveldb_state->kv_pair_.key.swap(temp_key); + leveldb_state->kv_pair_.value.swap(temp_value); return leveldb_state->kv_pair_; } diff --git a/src/caffe/lmdb_database.cpp b/src/caffe/lmdb_database.cpp index 0860778..54d67d5 100644 --- a/src/caffe/lmdb_database.cpp +++ b/src/caffe/lmdb_database.cpp @@ -149,8 +149,7 @@ void LmdbDatabase::increment(shared_ptr state) const { } } -pair& LmdbDatabase::dereference( - shared_ptr state) const { +Database::KV& LmdbDatabase::dereference(shared_ptr state) const { shared_ptr lmdb_state = boost::dynamic_pointer_cast(state); @@ -168,9 +167,13 @@ pair& LmdbDatabase::dereference( char* key_data = reinterpret_cast(mdb_key.mv_data); char* value_data = reinterpret_cast(mdb_val.mv_data); - lmdb_state->kv_pair_ = make_pair( - buffer_t(key_data, key_data + mdb_key.mv_size), - buffer_t(value_data, value_data + mdb_val.mv_size)); + Database::buffer_t temp_key(key_data, key_data + mdb_key.mv_size); + + Database::buffer_t temp_value(value_data, + value_data + mdb_val.mv_size); + + lmdb_state->kv_pair_.key.swap(temp_key); + lmdb_state->kv_pair_.value.swap(temp_value); return lmdb_state->kv_pair_; } diff --git a/tools/compute_image_mean.cpp b/tools/compute_image_mean.cpp index 11f6fb8..01e16c1 100644 --- a/tools/compute_image_mean.cpp +++ b/tools/compute_image_mean.cpp @@ -36,7 +36,7 @@ int main(int argc, char** argv) { int count = 0; // load first datum Database::const_iterator iter = database->begin(); - const Database::buffer_t& first_blob = iter->second; + const Database::buffer_t& first_blob = iter->value; datum.ParseFromArray(first_blob.data(), first_blob.size()); iter = database->end(); @@ -54,7 +54,7 @@ int main(int argc, char** argv) { for (Database::const_iterator iter = database->begin(); iter != database->end(); ++iter) { // just a dummy operation - const Database::buffer_t& blob = iter->second; + const Database::buffer_t& blob = iter->value; datum.ParseFromArray(blob.data(), blob.size()); const std::string& data = datum.data(); size_in_datum = std::max(datum.data().size(), -- 2.7.4