Updated Database interface to use custom KV type rather than std::pair. Removed...
authorKevin James Matzen <kmatzen@cs.cornell.edu>
Sun, 12 Oct 2014 18:39:31 +0000 (14:39 -0400)
committerKevin James Matzen <kmatzen@cs.cornell.edu>
Tue, 14 Oct 2014 23:31:30 +0000 (19:31 -0400)
include/caffe/database.hpp
include/caffe/leveldb_database.hpp
include/caffe/lmdb_database.hpp
src/caffe/layers/data_layer.cpp
src/caffe/leveldb_database.cpp
src/caffe/lmdb_database.cpp
tools/compute_image_mean.cpp

index 23036a8..953b58c 100644 (file)
@@ -21,6 +21,11 @@ class Database {
 
   typedef vector<char> buffer_t;
 
+  struct KV {
+    buffer_t key;
+    buffer_t value;
+  };
+
   virtual void open(const string& filename, Mode mode) = 0;
   virtual void put(buffer_t* key, buffer_t* value) = 0;
   virtual void commit() = 0;
@@ -41,10 +46,9 @@ class Database {
   class DatabaseState;
 
  public:
-  class iterator : public std::iterator<
-      std::forward_iterator_tag, pair<buffer_t, buffer_t> > {
+  class iterator : public std::iterator<std::forward_iterator_tag, KV> {
    public:
-    typedef pair<buffer_t, buffer_t> T;
+    typedef KV T;
     typedef T value_type;
     typedef T& reference_type;
     typedef T* pointer_type;
@@ -97,7 +101,7 @@ class Database {
   virtual bool equal(shared_ptr<DatabaseState> state1,
       shared_ptr<DatabaseState> state2) const = 0;
   virtual void increment(shared_ptr<DatabaseState> state) const = 0;
-  virtual pair<buffer_t, buffer_t>& dereference(
+  virtual KV& dereference(
       shared_ptr<DatabaseState> state) const = 0;
 };
 
index dda6697..1c084cb 100644 (file)
@@ -32,13 +32,13 @@ class LeveldbDatabase : public Database {
           iter_(iter) { }
 
     shared_ptr<leveldb::Iterator> iter_;
-    pair<buffer_t, buffer_t> kv_pair_;
+    KV kv_pair_;
   };
 
   bool equal(shared_ptr<DatabaseState> state1,
       shared_ptr<DatabaseState> state2) const;
   void increment(shared_ptr<DatabaseState> state) const;
-  pair<buffer_t, buffer_t>& dereference(shared_ptr<DatabaseState> state) const;
+  Database::KV& dereference(shared_ptr<DatabaseState> state) const;
 
   shared_ptr<leveldb::DB> db_;
   shared_ptr<leveldb::WriteBatch> batch_;
index 9654222..d72be3d 100644 (file)
@@ -36,13 +36,13 @@ class LmdbDatabase : public Database {
           cursor_(cursor) { }
 
     MDB_cursor* cursor_;
-    pair<buffer_t, buffer_t> kv_pair_;
+    KV kv_pair_;
   };
 
   bool equal(shared_ptr<DatabaseState> state1,
       shared_ptr<DatabaseState> state2) const;
   void increment(shared_ptr<DatabaseState> state) const;
-  pair<buffer_t, buffer_t>& dereference(shared_ptr<DatabaseState> state) const;
+  Database::KV& dereference(shared_ptr<DatabaseState> state) const;
 
   MDB_env *env_;
   MDB_dbi dbi_;
index 4d36b8e..998c00c 100644 (file)
@@ -48,7 +48,7 @@ void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
   // Read a data point, and use it to initialize the top blob.
   CHECK(iter_ != database_->end());
   Datum datum;
-  datum.ParseFromArray(iter_->second.data(), iter_->second.size());
+  datum.ParseFromArray(iter_->value.data(), iter_->value.size());
 
   // image
   int crop_size = this->layer_param_.transform_param().crop_size();
@@ -94,7 +94,7 @@ void DataLayer<Dtype>::InternalThreadEntry() {
   for (int item_id = 0; item_id < batch_size; ++item_id) {
     Datum datum;
     CHECK(iter_ != database_->end());
-    datum.ParseFromArray(iter_->second.data(), iter_->second.size());
+    datum.ParseFromArray(iter_->value.data(), iter_->value.size());
 
     // Apply data transformations (mirror, scale, crop...)
     int offset = this->prefetch_data_.offset(item_id);
index a5cdaa3..8084a6c 100644 (file)
@@ -131,7 +131,7 @@ void LeveldbDatabase::increment(shared_ptr<DatabaseState> state) const {
   }
 }
 
-pair<Database::buffer_t, Database::buffer_t>& LeveldbDatabase::dereference(
+Database::KV& LeveldbDatabase::dereference(
     shared_ptr<DatabaseState> state) const {
   shared_ptr<LeveldbState> leveldb_state =
       boost::dynamic_pointer_cast<LeveldbState>(state);
@@ -144,10 +144,14 @@ pair<Database::buffer_t, Database::buffer_t>& LeveldbDatabase::dereference(
 
   CHECK(iter->Valid());
 
-  leveldb_state->kv_pair_ = make_pair(
-      buffer_t(iter->key().data(), iter->key().data() + iter->key().size()),
-      buffer_t(iter->value().data(),
-          iter->value().data() + iter->value().size()));
+  Database::buffer_t temp_key(buffer_t(iter->key().data(),
+      iter->key().data() + iter->key().size()));
+
+  Database::buffer_t temp_value(buffer_t(iter->value().data(),
+      iter->value().data() + iter->value().size()));
+
+  leveldb_state->kv_pair_.key.swap(temp_key);
+  leveldb_state->kv_pair_.value.swap(temp_value);
   return leveldb_state->kv_pair_;
 }
 
index 0860778..54d67d5 100644 (file)
@@ -149,8 +149,7 @@ void LmdbDatabase::increment(shared_ptr<DatabaseState> state) const {
   }
 }
 
-pair<Database::buffer_t, Database::buffer_t>& LmdbDatabase::dereference(
-    shared_ptr<DatabaseState> state) const {
+Database::KV& LmdbDatabase::dereference(shared_ptr<DatabaseState> state) const {
   shared_ptr<LmdbState> lmdb_state =
       boost::dynamic_pointer_cast<LmdbState>(state);
 
@@ -168,9 +167,13 @@ pair<Database::buffer_t, Database::buffer_t>& LmdbDatabase::dereference(
   char* key_data = reinterpret_cast<char*>(mdb_key.mv_data);
   char* value_data = reinterpret_cast<char*>(mdb_val.mv_data);
 
-  lmdb_state->kv_pair_ = make_pair(
-    buffer_t(key_data, key_data + mdb_key.mv_size),
-    buffer_t(value_data, value_data + mdb_val.mv_size));
+  Database::buffer_t temp_key(key_data, key_data + mdb_key.mv_size);
+
+  Database::buffer_t temp_value(value_data,
+      value_data + mdb_val.mv_size);
+
+  lmdb_state->kv_pair_.key.swap(temp_key);
+  lmdb_state->kv_pair_.value.swap(temp_value);
 
   return lmdb_state->kv_pair_;
 }
index 11f6fb8..01e16c1 100644 (file)
@@ -36,7 +36,7 @@ int main(int argc, char** argv) {
   int count = 0;
   // load first datum
   Database::const_iterator iter = database->begin();
-  const Database::buffer_t& first_blob = iter->second;
+  const Database::buffer_t& first_blob = iter->value;
   datum.ParseFromArray(first_blob.data(), first_blob.size());
   iter = database->end();
 
@@ -54,7 +54,7 @@ int main(int argc, char** argv) {
   for (Database::const_iterator iter = database->begin();
       iter != database->end(); ++iter) {
     // just a dummy operation
-    const Database::buffer_t& blob = iter->second;
+    const Database::buffer_t& blob = iter->value;
     datum.ParseFromArray(blob.data(), blob.size());
     const std::string& data = datum.data();
     size_in_datum = std::max<int>(datum.data().size(),