Refactored leveldb and lmdb code.
authorKevin James Matzen <kmatzen@cs.cornell.edu>
Tue, 7 Oct 2014 19:37:35 +0000 (15:37 -0400)
committerKevin James Matzen <kmatzen@cs.cornell.edu>
Tue, 14 Oct 2014 23:25:01 +0000 (19:25 -0400)
17 files changed:
include/caffe/data_layers.hpp
include/caffe/database.hpp [new file with mode: 0644]
include/caffe/database_factory.hpp [new file with mode: 0644]
include/caffe/leveldb_database.hpp [new file with mode: 0644]
include/caffe/lmdb_database.hpp [new file with mode: 0644]
include/caffe/util/io.hpp
src/caffe/database_factory.cpp [new file with mode: 0644]
src/caffe/layers/data_layer.cpp
src/caffe/leveldb_database.cpp [new file with mode: 0644]
src/caffe/lmdb_database.cpp [new file with mode: 0644]
src/caffe/proto/caffe.proto
src/caffe/test/test_data_layer.cpp
src/caffe/test/test_hdf5data_layer.cpp
src/caffe/util/io.cpp
tools/compute_image_mean.cpp
tools/convert_imageset.cpp
tools/extract_features.cpp

index 1ed26c1..810f2bb 100644 (file)
@@ -7,12 +7,11 @@
 
 #include "boost/scoped_ptr.hpp"
 #include "hdf5.h"
-#include "leveldb/db.h"
-#include "lmdb.h"
 
 #include "caffe/blob.hpp"
 #include "caffe/common.hpp"
 #include "caffe/data_transformer.hpp"
+#include "caffe/database.hpp"
 #include "caffe/filler.hpp"
 #include "caffe/internal_thread.hpp"
 #include "caffe/layer.hpp"
@@ -101,15 +100,8 @@ class DataLayer : public BasePrefetchingDataLayer<Dtype> {
  protected:
   virtual void InternalThreadEntry();
 
-  // LEVELDB
-  shared_ptr<leveldb::DB> db_;
-  shared_ptr<leveldb::Iterator> iter_;
-  // LMDB
-  MDB_env* mdb_env_;
-  MDB_dbi mdb_dbi_;
-  MDB_txn* mdb_txn_;
-  MDB_cursor* mdb_cursor_;
-  MDB_val mdb_key_, mdb_value_;
+  shared_ptr<Database> database_;
+  Database::const_iterator iter_;
 };
 
 /**
diff --git a/include/caffe/database.hpp b/include/caffe/database.hpp
new file mode 100644 (file)
index 0000000..4a1a25e
--- /dev/null
@@ -0,0 +1,103 @@
+#ifndef CAFFE_DATABASE_H_
+#define CAFFE_DATABASE_H_
+
+#include <algorithm>
+#include <iterator>
+#include <string>
+#include <utility>
+
+#include "caffe/common.hpp"
+
+namespace caffe {
+
+class Database {
+ public:
+  enum Mode {
+    New,
+    ReadWrite,
+    ReadOnly
+  };
+
+  virtual void open(const string& filename, Mode mode) = 0;
+  virtual void put(const string& key, const string& value) = 0;
+  virtual void commit() = 0;
+  virtual void close() = 0;
+
+  Database() { }
+  virtual ~Database() { }
+
+  class iterator;
+  typedef iterator const_iterator;
+
+  virtual const_iterator begin() const = 0;
+  virtual const_iterator cbegin() const = 0;
+  virtual const_iterator end() const = 0;
+  virtual const_iterator cend() const = 0;
+
+ protected:
+  class DatabaseState;
+
+ public:
+  class iterator :
+      public std::iterator<std::forward_iterator_tag, pair<string, string> > {
+   public:
+    typedef pair<string, string> T;
+    typedef T value_type;
+    typedef T& reference_type;
+    typedef T* pointer_type;
+
+    iterator()
+        : parent_(NULL) { }
+
+    iterator(const Database* parent, shared_ptr<DatabaseState> state)
+        : parent_(parent),
+          state_(state) { }
+    ~iterator() { }
+
+    bool operator==(const iterator& other) const {
+      return parent_->equal(state_, other.state_);
+    }
+
+    bool operator!=(const iterator& other) const {
+      return !(*this == other);
+    }
+
+    iterator& operator++() {
+      parent_->increment(state_);
+      return *this;
+    }
+    iterator operator++(int) {
+      iterator copy(*this);
+      parent_->increment(state_);
+      return copy;
+    }
+
+    reference_type operator*() const {
+      return parent_->dereference(state_);
+    }
+
+    pointer_type operator->() const {
+      return &parent_->dereference(state_);
+    }
+
+   protected:
+    const Database* parent_;
+    shared_ptr<DatabaseState> state_;
+  };
+
+ protected:
+  class DatabaseState {
+   public:
+    virtual ~DatabaseState() { }
+  };
+
+  virtual bool equal(shared_ptr<DatabaseState> state1,
+      shared_ptr<DatabaseState> state2) const = 0;
+  virtual void increment(shared_ptr<DatabaseState> state) const = 0;
+  virtual pair<string, string>& dereference(
+      shared_ptr<DatabaseState> state) const = 0;
+};
+
+}  // namespace caffe
+
+#endif  // CAFFE_DATABASE_H_
diff --git a/include/caffe/database_factory.hpp b/include/caffe/database_factory.hpp
new file mode 100644 (file)
index 0000000..a6e39e7
--- /dev/null
@@ -0,0 +1,16 @@
+#ifndef CAFFE_DATABASE_FACTORY_H_
+#define CAFFE_DATABASE_FACTORY_H_
+
+#include <string>
+
+#include "caffe/common.hpp"
+#include "caffe/database.hpp"
+
+namespace caffe {
+
+shared_ptr<Database> DatabaseFactory(const DataParameter_DB& type);
+shared_ptr<Database> DatabaseFactory(const string& type);
+
+}  // namespace caffe
+
+#endif  // CAFFE_DATABASE_FACTORY_H_
diff --git a/include/caffe/leveldb_database.hpp b/include/caffe/leveldb_database.hpp
new file mode 100644 (file)
index 0000000..5daf0e5
--- /dev/null
@@ -0,0 +1,52 @@
+#ifndef CAFFE_LEVELDB_DATABASE_H_
+#define CAFFE_LEVELDB_DATABASE_H_
+
+#include <leveldb/db.h>
+#include <leveldb/write_batch.h>
+
+#include <string>
+#include <utility>
+
+#include "caffe/common.hpp"
+#include "caffe/database.hpp"
+
+namespace caffe {
+
+class LeveldbDatabase : public Database {
+ public:
+  void open(const string& filename, Mode mode);
+  void put(const string& key, const string& value);
+  void commit();
+  void close();
+
+  ~LeveldbDatabase() { this->close(); }
+
+ protected:
+  class LeveldbState : public Database::DatabaseState {
+   public:
+    explicit LeveldbState(shared_ptr<leveldb::Iterator> iter)
+        : Database::DatabaseState(),
+          iter_(iter) { }
+
+    shared_ptr<leveldb::Iterator> iter_;
+    pair<string, string> kv_pair_;
+  };
+
+  bool equal(shared_ptr<DatabaseState> state1,
+      shared_ptr<DatabaseState> state2) const;
+  void increment(shared_ptr<DatabaseState> state) const;
+  pair<string, string>& dereference(shared_ptr<DatabaseState> state) const;
+
+  const_iterator begin() const;
+  const_iterator cbegin() const;
+  const_iterator end() const;
+  const_iterator cend() const;
+
+ protected:
+  shared_ptr<leveldb::DB> db_;
+  shared_ptr<leveldb::WriteBatch> batch_;
+};
+
+}  // namespace caffe
+
+#endif  // CAFFE_LEVELDB_DATABASE_H_
diff --git a/include/caffe/lmdb_database.hpp b/include/caffe/lmdb_database.hpp
new file mode 100644 (file)
index 0000000..f275cb4
--- /dev/null
@@ -0,0 +1,54 @@
+#ifndef CAFFE_LMDB_DATABASE_H_
+#define CAFFE_LMDB_DATABASE_H_
+
+#include <string>
+#include <utility>
+
+#include "lmdb.h"
+
+#include "caffe/common.hpp"
+#include "caffe/database.hpp"
+
+namespace caffe {
+
+class LmdbDatabase : public Database {
+ public:
+  LmdbDatabase()
+      : dbi_(0) { }
+  ~LmdbDatabase() { this->close(); }
+
+  void open(const string& filename, Mode mode);
+  void put(const string& key, const string& value);
+  void commit();
+  void close();
+
+ protected:
+  class LmdbState : public Database::DatabaseState {
+   public:
+    explicit LmdbState(MDB_cursor* cursor)
+        : Database::DatabaseState(),
+          cursor_(cursor) { }
+
+    MDB_cursor* cursor_;
+    pair<string, string> kv_pair_;
+  };
+
+  bool equal(shared_ptr<DatabaseState> state1,
+      shared_ptr<DatabaseState> state2) const;
+  void increment(shared_ptr<DatabaseState> state) const;
+  pair<string, string>& dereference(shared_ptr<DatabaseState> state) const;
+
+ protected:
+  const_iterator begin() const;
+  const_iterator cbegin() const;
+  const_iterator end() const;
+  const_iterator cend() const;
+
+  MDB_env *env_ = NULL;
+  MDB_dbi dbi_;
+  MDB_txn *txn_ = NULL;
+};
+
+}  // namespace caffe
+
+#endif  // CAFFE_LMDB_DATABASE_H_
index e518979..1124bef 100644 (file)
 
 #define HDF5_NUM_DIMS 4
 
-namespace leveldb {
-// Forward declaration for leveldb::Options to be used in GetlevelDBOptions().
-struct Options;
-}
-
 namespace caffe {
 
 using ::google::protobuf::Message;
@@ -132,8 +127,6 @@ inline cv::Mat ReadImageToCVMat(const string& filename) {
 void CVMatToDatum(const cv::Mat& cv_img, Datum* datum);
 #endif
 
-leveldb::Options GetLevelDBOptions();
-
 template <typename Dtype>
 void hdf5_load_nd_dataset_helper(
   hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
diff --git a/src/caffe/database_factory.cpp b/src/caffe/database_factory.cpp
new file mode 100644 (file)
index 0000000..393635b
--- /dev/null
@@ -0,0 +1,33 @@
+#include <string>
+#include <utility>
+
+#include "caffe/database_factory.hpp"
+#include "caffe/leveldb_database.hpp"
+#include "caffe/lmdb_database.hpp"
+
+namespace caffe {
+
+shared_ptr<Database> DatabaseFactory(const DataParameter_DB& type) {
+  switch (type) {
+  case DataParameter_DB_LEVELDB:
+    return shared_ptr<Database>(new LeveldbDatabase());
+  case DataParameter_DB_LMDB:
+    return shared_ptr<Database>(new LmdbDatabase());
+  default:
+    LOG(FATAL) << "Unknown database type " << type;
+  }
+}
+
+shared_ptr<Database> DatabaseFactory(const string& type) {
+  if ("leveldb" == type) {
+    return DatabaseFactory(DataParameter_DB_LEVELDB);
+  } else if ("lmdb" == type) {
+    return DatabaseFactory(DataParameter_DB_LMDB);
+  } else {
+    LOG(FATAL) << "Unknown database type " << type;
+  }
+}
+
+}  // namespace caffe
+
+
index daa3289..b746bc8 100644 (file)
@@ -1,4 +1,3 @@
-#include <leveldb/db.h>
 #include <stdint.h>
 
 #include <string>
@@ -6,6 +5,7 @@
 
 #include "caffe/common.hpp"
 #include "caffe/data_layers.hpp"
+#include "caffe/database_factory.hpp"
 #include "caffe/layer.hpp"
 #include "caffe/proto/caffe.pb.h"
 #include "caffe/util/io.hpp"
@@ -18,60 +18,18 @@ template <typename Dtype>
 DataLayer<Dtype>::~DataLayer<Dtype>() {
   this->JoinPrefetchThread();
   // clean up the database resources
-  switch (this->layer_param_.data_param().backend()) {
-  case DataParameter_DB_LEVELDB:
-    break;  // do nothing
-  case DataParameter_DB_LMDB:
-    mdb_cursor_close(mdb_cursor_);
-    mdb_close(mdb_env_, mdb_dbi_);
-    mdb_txn_abort(mdb_txn_);
-    mdb_env_close(mdb_env_);
-    break;
-  default:
-    LOG(FATAL) << "Unknown database backend";
-  }
+  iter_ = database_->end();
+  database_->close();
 }
 
 template <typename Dtype>
 void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top) {
   // Initialize DB
-  switch (this->layer_param_.data_param().backend()) {
-  case DataParameter_DB_LEVELDB:
-    {
-    leveldb::DB* db_temp;
-    leveldb::Options options = GetLevelDBOptions();
-    options.create_if_missing = false;
-    LOG(INFO) << "Opening leveldb " << this->layer_param_.data_param().source();
-    leveldb::Status status = leveldb::DB::Open(
-        options, this->layer_param_.data_param().source(), &db_temp);
-    CHECK(status.ok()) << "Failed to open leveldb "
-                       << this->layer_param_.data_param().source() << std::endl
-                       << status.ToString();
-    db_.reset(db_temp);
-    iter_.reset(db_->NewIterator(leveldb::ReadOptions()));
-    iter_->SeekToFirst();
-    }
-    break;
-  case DataParameter_DB_LMDB:
-    CHECK_EQ(mdb_env_create(&mdb_env_), MDB_SUCCESS) << "mdb_env_create failed";
-    CHECK_EQ(mdb_env_set_mapsize(mdb_env_, 1099511627776), MDB_SUCCESS);  // 1TB
-    CHECK_EQ(mdb_env_open(mdb_env_,
-             this->layer_param_.data_param().source().c_str(),
-             MDB_RDONLY|MDB_NOTLS, 0664), MDB_SUCCESS) << "mdb_env_open failed";
-    CHECK_EQ(mdb_txn_begin(mdb_env_, NULL, MDB_RDONLY, &mdb_txn_), MDB_SUCCESS)
-        << "mdb_txn_begin failed";
-    CHECK_EQ(mdb_open(mdb_txn_, NULL, 0, &mdb_dbi_), MDB_SUCCESS)
-        << "mdb_open failed";
-    CHECK_EQ(mdb_cursor_open(mdb_txn_, mdb_dbi_, &mdb_cursor_), MDB_SUCCESS)
-        << "mdb_cursor_open failed";
-    LOG(INFO) << "Opening lmdb " << this->layer_param_.data_param().source();
-    CHECK_EQ(mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_FIRST),
-        MDB_SUCCESS) << "mdb_cursor_get failed";
-    break;
-  default:
-    LOG(FATAL) << "Unknown database backend";
-  }
+  database_ = DatabaseFactory(this->layer_param_.data_param().backend());
+  LOG(INFO) << "Opening database " << this->layer_param_.data_param().source();
+  database_->open(this->layer_param_.data_param().source(), Database::ReadOnly);
+  iter_ = database_->begin();
 
   // Check if we would need to randomly skip a few data points
   if (this->layer_param_.data_param().rand_skip()) {
@@ -79,37 +37,16 @@ void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
                         this->layer_param_.data_param().rand_skip();
     LOG(INFO) << "Skipping first " << skip << " data points.";
     while (skip-- > 0) {
-      switch (this->layer_param_.data_param().backend()) {
-      case DataParameter_DB_LEVELDB:
-        iter_->Next();
-        if (!iter_->Valid()) {
-          iter_->SeekToFirst();
-        }
-        break;
-      case DataParameter_DB_LMDB:
-        if (mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_NEXT)
-            != MDB_SUCCESS) {
-          CHECK_EQ(mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_,
-                   MDB_FIRST), MDB_SUCCESS);
-        }
-        break;
-      default:
-        LOG(FATAL) << "Unknown database backend";
+      LOG(INFO) << iter_->first;
+      if (++iter_ == database_->end()) {
+        iter_ = database_->begin();
       }
     }
   }
   // Read a data point, and use it to initialize the top blob.
+  CHECK(iter_ != database_->end());
   Datum datum;
-  switch (this->layer_param_.data_param().backend()) {
-  case DataParameter_DB_LEVELDB:
-    datum.ParseFromString(iter_->value().ToString());
-    break;
-  case DataParameter_DB_LMDB:
-    datum.ParseFromArray(mdb_value_.mv_data, mdb_value_.mv_size);
-    break;
-  default:
-    LOG(FATAL) << "Unknown database backend";
-  }
+  datum.ParseFromString(iter_->second);
 
   // image
   int crop_size = this->layer_param_.transform_param().crop_size();
@@ -153,23 +90,10 @@ void DataLayer<Dtype>::InternalThreadEntry() {
   const int batch_size = this->layer_param_.data_param().batch_size();
 
   for (int item_id = 0; item_id < batch_size; ++item_id) {
-    // get a blob
     Datum datum;
-    switch (this->layer_param_.data_param().backend()) {
-    case DataParameter_DB_LEVELDB:
-      CHECK(iter_);
-      CHECK(iter_->Valid());
-      datum.ParseFromString(iter_->value().ToString());
-      break;
-    case DataParameter_DB_LMDB:
-      CHECK_EQ(mdb_cursor_get(mdb_cursor_, &mdb_key_,
-              &mdb_value_, MDB_GET_CURRENT), MDB_SUCCESS);
-      datum.ParseFromArray(mdb_value_.mv_data,
-          mdb_value_.mv_size);
-      break;
-    default:
-      LOG(FATAL) << "Unknown database backend";
-    }
+    CHECK(iter_ != database_->end());
+    datum.ParseFromString(iter_->second);
+
     // Apply data transformations (mirror, scale, crop...)
     int offset = this->prefetch_data_.offset(item_id);
     this->transformed_data_.set_cpu_data(top_data + offset);
@@ -179,26 +103,9 @@ void DataLayer<Dtype>::InternalThreadEntry() {
     }
 
     // go to the next iter
-    switch (this->layer_param_.data_param().backend()) {
-    case DataParameter_DB_LEVELDB:
-      iter_->Next();
-      if (!iter_->Valid()) {
-        // We have reached the end. Restart from the first.
-        DLOG(INFO) << "Restarting data prefetching from start.";
-        iter_->SeekToFirst();
-      }
-      break;
-    case DataParameter_DB_LMDB:
-      if (mdb_cursor_get(mdb_cursor_, &mdb_key_,
-              &mdb_value_, MDB_NEXT) != MDB_SUCCESS) {
-        // We have reached the end. Restart from the first.
-        DLOG(INFO) << "Restarting data prefetching from start.";
-        CHECK_EQ(mdb_cursor_get(mdb_cursor_, &mdb_key_,
-                &mdb_value_, MDB_FIRST), MDB_SUCCESS);
-      }
-      break;
-    default:
-      LOG(FATAL) << "Unknown database backend";
+    ++iter_;
+    if (iter_ == database_->end()) {
+      iter_ = database_->begin();
     }
   }
 }
diff --git a/src/caffe/leveldb_database.cpp b/src/caffe/leveldb_database.cpp
new file mode 100644 (file)
index 0000000..be7ac7f
--- /dev/null
@@ -0,0 +1,151 @@
+#include <string>
+#include <utility>
+
+#include "caffe/leveldb_database.hpp"
+
+namespace caffe {
+
+void LeveldbDatabase::open(const string& filename, Mode mode) {
+  LOG(INFO) << "LevelDB: Open " << filename;
+
+  leveldb::Options options;
+  switch (mode) {
+  case New:
+    LOG(INFO) << " mode NEW";
+    options.error_if_exists = true;
+    options.create_if_missing = true;
+    break;
+  case ReadWrite:
+    LOG(INFO) << " mode RW";
+    options.error_if_exists = false;
+    options.create_if_missing = true;
+    break;
+  case ReadOnly:
+    LOG(INFO) << " mode RO";
+    options.error_if_exists = false;
+    options.create_if_missing = false;
+    break;
+  default:
+    LOG(FATAL) << "unknown mode " << mode;
+  }
+  options.write_buffer_size = 268435456;
+  options.max_open_files = 100;
+
+  leveldb::DB* db;
+
+  LOG(INFO) << "Opening leveldb " << filename;
+  leveldb::Status status = leveldb::DB::Open(
+      options, filename, &db);
+  db_.reset(db);
+  CHECK(status.ok()) << "Failed to open leveldb " << filename
+      << ". Is it already existing?";
+  batch_.reset(new leveldb::WriteBatch());
+}
+
+void LeveldbDatabase::put(const string& key, const string& value) {
+  LOG(INFO) << "LevelDB: Put " << key;
+
+  CHECK_NOTNULL(batch_.get());
+  batch_->Put(key, value);
+}
+
+void LeveldbDatabase::commit() {
+  LOG(INFO) << "LevelDB: Commit";
+
+  CHECK_NOTNULL(db_.get());
+  CHECK_NOTNULL(batch_.get());
+
+  db_->Write(leveldb::WriteOptions(), batch_.get());
+  batch_.reset(new leveldb::WriteBatch());
+}
+
+void LeveldbDatabase::close() {
+  LOG(INFO) << "LevelDB: Close";
+
+  if (batch_ && db_) {
+    this->commit();
+  }
+  batch_.reset();
+  db_.reset();
+}
+
+LeveldbDatabase::const_iterator LeveldbDatabase::begin() const {
+  CHECK_NOTNULL(db_.get());
+  shared_ptr<leveldb::Iterator> iter(db_->NewIterator(leveldb::ReadOptions()));
+  iter->SeekToFirst();
+  if (!iter->Valid()) {
+    iter.reset();
+  }
+  shared_ptr<DatabaseState> state(new LeveldbState(iter));
+  return const_iterator(this, state);
+}
+
+LeveldbDatabase::const_iterator LeveldbDatabase::end() const {
+  shared_ptr<leveldb::Iterator> iter;
+  shared_ptr<DatabaseState> state(new LeveldbState(iter));
+  return const_iterator(this, state);
+}
+
+LeveldbDatabase::const_iterator LeveldbDatabase::cbegin() const {
+  return begin();
+}
+
+LeveldbDatabase::const_iterator LeveldbDatabase::cend() const { return end(); }
+
+bool LeveldbDatabase::equal(shared_ptr<DatabaseState> state1,
+    shared_ptr<DatabaseState> state2) const {
+  shared_ptr<LeveldbState> leveldb_state1 =
+      boost::dynamic_pointer_cast<LeveldbState>(state1);
+
+  CHECK_NOTNULL(leveldb_state1.get());
+
+  shared_ptr<LeveldbState> leveldb_state2 =
+      boost::dynamic_pointer_cast<LeveldbState>(state2);
+
+  CHECK_NOTNULL(leveldb_state2.get());
+
+  CHECK(!leveldb_state1->iter_ || leveldb_state1->iter_->Valid());
+  CHECK(!leveldb_state2->iter_ || leveldb_state2->iter_->Valid());
+
+  // The KV store doesn't really have any sort of ordering,
+  // so while we can do a sequential scan over the collection,
+  // we can't really use subranges.
+  return !leveldb_state1->iter_ && !leveldb_state2->iter_;
+}
+
+void LeveldbDatabase::increment(shared_ptr<DatabaseState> state) const {
+  shared_ptr<LeveldbState> leveldb_state =
+      boost::dynamic_pointer_cast<LeveldbState>(state);
+
+  CHECK_NOTNULL(leveldb_state.get());
+
+  shared_ptr<leveldb::Iterator>& iter = leveldb_state->iter_;
+
+  CHECK_NOTNULL(iter.get());
+  CHECK(iter->Valid());
+
+  iter->Next();
+  if (!iter->Valid()) {
+    iter.reset();
+  }
+}
+
+pair<string, string>& LeveldbDatabase::dereference(
+    shared_ptr<DatabaseState> state) const {
+  shared_ptr<LeveldbState> leveldb_state =
+      boost::dynamic_pointer_cast<LeveldbState>(state);
+
+  CHECK_NOTNULL(leveldb_state.get());
+
+  shared_ptr<leveldb::Iterator>& iter = leveldb_state->iter_;
+
+  CHECK_NOTNULL(iter.get());
+
+  CHECK(iter->Valid());
+
+  leveldb_state->kv_pair_ = make_pair(iter->key().ToString(),
+      iter->value().ToString());
+  return leveldb_state->kv_pair_;
+}
+
+}  // namespace caffe
diff --git a/src/caffe/lmdb_database.cpp b/src/caffe/lmdb_database.cpp
new file mode 100644 (file)
index 0000000..796bbc9
--- /dev/null
@@ -0,0 +1,154 @@
+#include <sys/stat.h>
+
+#include <string>
+#include <utility>
+
+#include "caffe/lmdb_database.hpp"
+
+namespace caffe {
+
+void LmdbDatabase::open(const string& filename, Mode mode) {
+  LOG(INFO) << "LMDB: Open " << filename;
+
+  CHECK(NULL == env_);
+  CHECK(NULL == txn_);
+  CHECK_EQ(0, dbi_);
+
+  if (mode != ReadOnly) {
+    CHECK_EQ(mkdir(filename.c_str(), 0744), 0) << "mkdir " << filename
+                                                << "failed";
+  }
+
+  CHECK_EQ(mdb_env_create(&env_), MDB_SUCCESS) << "mdb_env_create failed";
+  CHECK_EQ(mdb_env_set_mapsize(env_, 1099511627776), MDB_SUCCESS)  // 1TB
+      << "mdb_env_set_mapsize failed";
+
+  int flag1 = 0;
+  int flag2 = 0;
+  if (mode == ReadOnly) {
+    flag1 = MDB_RDONLY | MDB_NOTLS;
+    flag2 = MDB_RDONLY;
+  }
+
+  CHECK_EQ(mdb_env_open(env_, filename.c_str(), flag1, 0664), MDB_SUCCESS)
+      << "mdb_env_open failed";
+  CHECK_EQ(mdb_txn_begin(env_, NULL, flag2, &txn_), MDB_SUCCESS)
+      << "mdb_txn_begin failed";
+  CHECK_EQ(mdb_open(txn_, NULL, 0, &dbi_), MDB_SUCCESS) << "mdb_open failed";
+}
+
+void LmdbDatabase::put(const string& key, const string& value) {
+  LOG(INFO) << "LMDB: Put " << key;
+
+  MDB_val mdbkey, mdbdata;
+  mdbdata.mv_size = value.size();
+  mdbdata.mv_data = const_cast<char*>(&value[0]);
+  mdbkey.mv_size = key.size();
+  mdbkey.mv_data = const_cast<char*>(&key[0]);
+
+  CHECK_NOTNULL(txn_);
+  CHECK_NE(0, dbi_);
+
+  CHECK_EQ(mdb_put(txn_, dbi_, &mdbkey, &mdbdata, 0), MDB_SUCCESS)
+      << "mdb_put failed";
+}
+
+void LmdbDatabase::commit() {
+  LOG(INFO) << "LMDB: Commit";
+
+  CHECK_NOTNULL(txn_);
+
+  CHECK_EQ(mdb_txn_commit(txn_), MDB_SUCCESS) << "mdb_txn_commit failed";
+}
+
+void LmdbDatabase::close() {
+  LOG(INFO) << "LMDB: Close";
+
+  if (env_ && dbi_ && txn_) {
+    this->commit();
+  }
+
+  if (env_ && dbi_) {
+    mdb_close(env_, dbi_);
+    mdb_env_close(env_);
+    env_ = NULL;
+    dbi_ = 0;
+    txn_ = NULL;
+  }
+}
+
+LmdbDatabase::const_iterator LmdbDatabase::begin() const {
+  MDB_cursor* cursor;
+  CHECK_EQ(mdb_cursor_open(txn_, dbi_, &cursor), MDB_SUCCESS);
+  MDB_val key;
+  MDB_val val;
+  CHECK_EQ(mdb_cursor_get(cursor, &key, &val, MDB_FIRST), MDB_SUCCESS);
+
+  shared_ptr<DatabaseState> state(new LmdbState(cursor));
+  return const_iterator(this, state);
+}
+
+LmdbDatabase::const_iterator LmdbDatabase::end() const {
+  shared_ptr<DatabaseState> state(new LmdbState(NULL));
+  return const_iterator(this, state);
+}
+
+LmdbDatabase::const_iterator LmdbDatabase::cbegin() const { return begin(); }
+LmdbDatabase::const_iterator LmdbDatabase::cend() const { return end(); }
+
+bool LmdbDatabase::equal(shared_ptr<DatabaseState> state1,
+    shared_ptr<DatabaseState> state2) const {
+  shared_ptr<LmdbState> lmdb_state1 =
+      boost::dynamic_pointer_cast<LmdbState>(state1);
+
+  CHECK_NOTNULL(lmdb_state1.get());
+
+  shared_ptr<LmdbState> lmdb_state2 =
+      boost::dynamic_pointer_cast<LmdbState>(state2);
+
+  CHECK_NOTNULL(lmdb_state2.get());
+
+  // The KV store doesn't really have any sort of ordering,
+  // so while we can do a sequential scan over the collection,
+  // we can't really use subranges.
+  return !lmdb_state1->cursor_ && !lmdb_state2->cursor_;
+}
+
+void LmdbDatabase::increment(shared_ptr<DatabaseState> state) const {
+  shared_ptr<LmdbState> lmdb_state =
+      boost::dynamic_pointer_cast<LmdbState>(state);
+
+  CHECK_NOTNULL(lmdb_state.get());
+
+  MDB_cursor*& cursor = lmdb_state->cursor_;
+
+  MDB_val key;
+  MDB_val val;
+  if (MDB_SUCCESS != mdb_cursor_get(cursor, &key, &val, MDB_NEXT)) {
+    mdb_cursor_close(cursor);
+    cursor = NULL;
+  }
+}
+
+pair<string, string>& LmdbDatabase::dereference(
+    shared_ptr<DatabaseState> state) const {
+  shared_ptr<LmdbState> lmdb_state =
+      boost::dynamic_pointer_cast<LmdbState>(state);
+
+  CHECK_NOTNULL(lmdb_state.get());
+
+  MDB_cursor*& cursor = lmdb_state->cursor_;
+
+  MDB_val mdb_key;
+  MDB_val mdb_val;
+  CHECK_EQ(mdb_cursor_get(cursor, &mdb_key, &mdb_val, MDB_GET_CURRENT),
+      MDB_SUCCESS);
+
+  lmdb_state->kv_pair_ = make_pair(
+    string(reinterpret_cast<char*>(mdb_key.mv_data), mdb_key.mv_size),
+    string(reinterpret_cast<char*>(mdb_val.mv_data), mdb_val.mv_size));
+
+  return lmdb_state->kv_pair_;
+}
+
+}  // namespace caffe
index f0404a0..f0dba09 100644 (file)
@@ -423,7 +423,7 @@ message DataParameter {
   // The rand_skip variable is for the data layer to skip a few data points
   // to avoid all asynchronous sgd clients to start at the same point. The skip
   // point would be set as rand_skip * rand(0,1). Note that rand_skip should not
-  // be larger than the number of keys in the leveldb.
+  // be larger than the number of keys in the database.
   optional uint32 rand_skip = 7 [default = 0];
   optional DB backend = 8 [default = LEVELDB];
   // DEPRECATED. See TransformationParameter. For data pre-processing, we can do
@@ -518,7 +518,7 @@ message ImageDataParameter {
   // The rand_skip variable is for the data layer to skip a few data points
   // to avoid all asynchronous sgd clients to start at the same point. The skip
   // point would be set as rand_skip * rand(0,1). Note that rand_skip should not
-  // be larger than the number of keys in the leveldb.
+  // be larger than the number of keys in the database.
   optional uint32 rand_skip = 7 [default = 0];
   // Whether or not ImageLayer should shuffle the list of files at every epoch.
   optional bool shuffle = 8 [default = false];
@@ -767,7 +767,7 @@ message V0LayerParameter {
   // The rand_skip variable is for the data layer to skip a few data points
   // to avoid all asynchronous sgd clients to start at the same point. The skip
   // point would be set as rand_skip * rand(0,1). Note that rand_skip should not
-  // be larger than the number of keys in the leveldb.
+  // be larger than the number of keys in the database.
   optional uint32 rand_skip = 53 [default = 0];
 
   // Fields related to detection (det_*)
index 657ffde..d99b5e3 100644 (file)
@@ -2,10 +2,10 @@
 #include <vector>
 
 #include "gtest/gtest.h"
-#include "leveldb/db.h"
 
 #include "caffe/blob.hpp"
 #include "caffe/common.hpp"
+#include "caffe/database_factory.hpp"
 #include "caffe/filler.hpp"
 #include "caffe/proto/caffe.pb.h"
 #include "caffe/util/io.hpp"
@@ -36,16 +36,11 @@ class DataLayerTest : public MultiDeviceTest<TypeParam> {
   // Fill the LevelDB with data: if unique_pixels, each pixel is unique but
   // all images are the same; else each image is unique but all pixels within
   // an image are the same.
-  void FillLevelDB(const bool unique_pixels) {
-    backend_ = DataParameter_DB_LEVELDB;
-    LOG(INFO) << "Using temporary leveldb " << *filename_;
-    leveldb::DB* db;
-    leveldb::Options options;
-    options.error_if_exists = true;
-    options.create_if_missing = true;
-    leveldb::Status status =
-        leveldb::DB::Open(options, filename_->c_str(), &db);
-    CHECK(status.ok());
+  void Fill(const bool unique_pixels, DataParameter_DB backend) {
+    backend_ = backend;
+    LOG(INFO) << "Using temporary database " << *filename_;
+    shared_ptr<Database> database = DatabaseFactory(backend_);
+    database->open(*filename_, Database::New);
     for (int i = 0; i < 5; ++i) {
       Datum datum;
       datum.set_label(i);
@@ -59,57 +54,9 @@ class DataLayerTest : public MultiDeviceTest<TypeParam> {
       }
       stringstream ss;
       ss << i;
-      db->Put(leveldb::WriteOptions(), ss.str(), datum.SerializeAsString());
+      database->put(ss.str(), datum.SerializeAsString());
     }
-    delete db;
-  }
-
-  // Fill the LMDB with data: unique_pixels has same meaning as in FillLevelDB.
-  void FillLMDB(const bool unique_pixels) {
-    backend_ = DataParameter_DB_LMDB;
-    LOG(INFO) << "Using temporary lmdb " << *filename_;
-    CHECK_EQ(mkdir(filename_->c_str(), 0744), 0) << "mkdir " << filename_
-                                                 << "failed";
-    MDB_env *env;
-    MDB_dbi dbi;
-    MDB_val mdbkey, mdbdata;
-    MDB_txn *txn;
-    CHECK_EQ(mdb_env_create(&env), MDB_SUCCESS) << "mdb_env_create failed";
-    CHECK_EQ(mdb_env_set_mapsize(env, 1099511627776), MDB_SUCCESS)  // 1TB
-        << "mdb_env_set_mapsize failed";
-    CHECK_EQ(mdb_env_open(env, filename_->c_str(), 0, 0664), MDB_SUCCESS)
-        << "mdb_env_open failed";
-    CHECK_EQ(mdb_txn_begin(env, NULL, 0, &txn), MDB_SUCCESS)
-        << "mdb_txn_begin failed";
-    CHECK_EQ(mdb_open(txn, NULL, 0, &dbi), MDB_SUCCESS) << "mdb_open failed";
-
-    for (int i = 0; i < 5; ++i) {
-      Datum datum;
-      datum.set_label(i);
-      datum.set_channels(2);
-      datum.set_height(3);
-      datum.set_width(4);
-      std::string* data = datum.mutable_data();
-      for (int j = 0; j < 24; ++j) {
-        int datum = unique_pixels ? j : i;
-        data->push_back(static_cast<uint8_t>(datum));
-      }
-      stringstream ss;
-      ss << i;
-
-      string value;
-      datum.SerializeToString(&value);
-      mdbdata.mv_size = value.size();
-      mdbdata.mv_data = reinterpret_cast<void*>(&value[0]);
-      string keystr = ss.str();
-      mdbkey.mv_size = keystr.size();
-      mdbkey.mv_data = reinterpret_cast<void*>(&keystr[0]);
-      CHECK_EQ(mdb_put(txn, dbi, &mdbkey, &mdbdata, 0), MDB_SUCCESS)
-          << "mdb_put failed";
-    }
-    CHECK_EQ(mdb_txn_commit(txn), MDB_SUCCESS) << "mdb_txn_commit failed";
-    mdb_close(env, dbi);
-    mdb_env_close(env);
+    database->close();
   }
 
   void TestRead() {
@@ -234,7 +181,7 @@ class DataLayerTest : public MultiDeviceTest<TypeParam> {
         }
         crop_sequence.push_back(iter_crop_sequence);
       }
-    }  // destroy 1st data layer and unlock the leveldb
+    }  // destroy 1st data layer and unlock the database
 
     // Get crop sequence after reseeding Caffe with 1701.
     // Check that the sequence is the same as the original.
@@ -289,7 +236,7 @@ class DataLayerTest : public MultiDeviceTest<TypeParam> {
         }
         crop_sequence.push_back(iter_crop_sequence);
       }
-    }  // destroy 1st data layer and unlock the leveldb
+    }  // destroy 1st data layer and unlock the database
 
     // Get crop sequence continuing from previous Caffe RNG state; reseed
     // srand with 1701. Check that the sequence differs from the original.
@@ -327,14 +274,14 @@ TYPED_TEST_CASE(DataLayerTest, TestDtypesAndDevices);
 
 TYPED_TEST(DataLayerTest, TestReadLevelDB) {
   const bool unique_pixels = false;  // all pixels the same; images different
-  this->FillLevelDB(unique_pixels);
+  this->Fill(unique_pixels, DataParameter_DB_LEVELDB);
   this->TestRead();
 }
 
 TYPED_TEST(DataLayerTest, TestReadCropTrainLevelDB) {
   Caffe::set_phase(Caffe::TRAIN);
   const bool unique_pixels = true;  // all images the same; pixels different
-  this->FillLevelDB(unique_pixels);
+  this->Fill(unique_pixels, DataParameter_DB_LEVELDB);
   this->TestReadCrop();
 }
 
@@ -343,7 +290,7 @@ TYPED_TEST(DataLayerTest, TestReadCropTrainLevelDB) {
 TYPED_TEST(DataLayerTest, TestReadCropTrainSequenceSeededLevelDB) {
   Caffe::set_phase(Caffe::TRAIN);
   const bool unique_pixels = true;  // all images the same; pixels different
-  this->FillLevelDB(unique_pixels);
+  this->Fill(unique_pixels, DataParameter_DB_LEVELDB);
   this->TestReadCropTrainSequenceSeeded();
 }
 
@@ -352,27 +299,27 @@ TYPED_TEST(DataLayerTest, TestReadCropTrainSequenceSeededLevelDB) {
 TYPED_TEST(DataLayerTest, TestReadCropTrainSequenceUnseededLevelDB) {
   Caffe::set_phase(Caffe::TRAIN);
   const bool unique_pixels = true;  // all images the same; pixels different
-  this->FillLevelDB(unique_pixels);
+  this->Fill(unique_pixels, DataParameter_DB_LEVELDB);
   this->TestReadCropTrainSequenceUnseeded();
 }
 
 TYPED_TEST(DataLayerTest, TestReadCropTestLevelDB) {
   Caffe::set_phase(Caffe::TEST);
   const bool unique_pixels = true;  // all images the same; pixels different
-  this->FillLevelDB(unique_pixels);
+  this->Fill(unique_pixels, DataParameter_DB_LEVELDB);
   this->TestReadCrop();
 }
 
 TYPED_TEST(DataLayerTest, TestReadLMDB) {
   const bool unique_pixels = false;  // all pixels the same; images different
-  this->FillLMDB(unique_pixels);
+  this->Fill(unique_pixels, DataParameter_DB_LMDB);
   this->TestRead();
 }
 
 TYPED_TEST(DataLayerTest, TestReadCropTrainLMDB) {
   Caffe::set_phase(Caffe::TRAIN);
   const bool unique_pixels = true;  // all images the same; pixels different
-  this->FillLMDB(unique_pixels);
+  this->Fill(unique_pixels, DataParameter_DB_LMDB);
   this->TestReadCrop();
 }
 
@@ -381,7 +328,7 @@ TYPED_TEST(DataLayerTest, TestReadCropTrainLMDB) {
 TYPED_TEST(DataLayerTest, TestReadCropTrainSequenceSeededLMDB) {
   Caffe::set_phase(Caffe::TRAIN);
   const bool unique_pixels = true;  // all images the same; pixels different
-  this->FillLMDB(unique_pixels);
+  this->Fill(unique_pixels, DataParameter_DB_LMDB);
   this->TestReadCropTrainSequenceSeeded();
 }
 
@@ -390,14 +337,14 @@ TYPED_TEST(DataLayerTest, TestReadCropTrainSequenceSeededLMDB) {
 TYPED_TEST(DataLayerTest, TestReadCropTrainSequenceUnseededLMDB) {
   Caffe::set_phase(Caffe::TRAIN);
   const bool unique_pixels = true;  // all images the same; pixels different
-  this->FillLMDB(unique_pixels);
+  this->Fill(unique_pixels, DataParameter_DB_LMDB);
   this->TestReadCropTrainSequenceUnseeded();
 }
 
 TYPED_TEST(DataLayerTest, TestReadCropTestLMDB) {
   Caffe::set_phase(Caffe::TEST);
   const bool unique_pixels = true;  // all images the same; pixels different
-  this->FillLMDB(unique_pixels);
+  this->Fill(unique_pixels, DataParameter_DB_LMDB);
   this->TestReadCrop();
 }
 
index 41a3a83..8d3b3d1 100644 (file)
@@ -1,8 +1,6 @@
 #include <string>
 #include <vector>
 
-#include "leveldb/db.h"
-
 #include "gtest/gtest.h"
 
 #include "caffe/blob.hpp"
index 36510d6..a4a6627 100644 (file)
@@ -2,7 +2,6 @@
 #include <google/protobuf/io/coded_stream.h>
 #include <google/protobuf/io/zero_copy_stream_impl.h>
 #include <google/protobuf/text_format.h>
-#include <leveldb/db.h>
 #include <opencv2/core/core.hpp>
 #include <opencv2/highgui/highgui.hpp>
 #include <opencv2/highgui/highgui_c.h>
@@ -122,15 +121,6 @@ void CVMatToDatum(const cv::Mat& cv_img, Datum* datum) {
   datum->set_data(buffer);
 }
 
-
-leveldb::Options GetLevelDBOptions() {
-  // In default, we will return the leveldb option and set the max open files
-  // in order to avoid using up the operating system's limit.
-  leveldb::Options options;
-  options.max_open_files = 100;
-  return options;
-}
-
 // Verifies format of data stored in HDF5 file and reshapes blob accordingly.
 template <typename Dtype>
 void hdf5_load_nd_dataset_helper(
index 6adde8b..e59bbf1 100644 (file)
@@ -1,14 +1,14 @@
 #include <glog/logging.h>
-#include <leveldb/db.h>
-#include <lmdb.h>
 #include <stdint.h>
 
 #include <algorithm>
 #include <string>
 
+#include "caffe/database_factory.hpp"
 #include "caffe/proto/caffe.pb.h"
 #include "caffe/util/io.hpp"
 
+using caffe::Database;
 using caffe::Datum;
 using caffe::BlobProto;
 using std::max;
@@ -26,57 +26,16 @@ int main(int argc, char** argv) {
     db_backend = std::string(argv[3]);
   }
 
-  // leveldb
-  leveldb::DB* db;
-  leveldb::Options options;
-  options.create_if_missing = false;
-  leveldb::Iterator* it = NULL;
-  // lmdb
-  MDB_env* mdb_env;
-  MDB_dbi mdb_dbi;
-  MDB_val mdb_key, mdb_value;
-  MDB_txn* mdb_txn;
-  MDB_cursor* mdb_cursor;
+  caffe::shared_ptr<Database> database = caffe::DatabaseFactory(db_backend);
 
   // Open db
-  if (db_backend == "leveldb") {  // leveldb
-    LOG(INFO) << "Opening leveldb " << argv[1];
-    leveldb::Status status = leveldb::DB::Open(
-        options, argv[1], &db);
-    CHECK(status.ok()) << "Failed to open leveldb " << argv[1];
-    leveldb::ReadOptions read_options;
-    read_options.fill_cache = false;
-    it = db->NewIterator(read_options);
-    it->SeekToFirst();
-  } else if (db_backend == "lmdb") {  // lmdb
-    LOG(INFO) << "Opening lmdb " << argv[1];
-    CHECK_EQ(mdb_env_create(&mdb_env), MDB_SUCCESS) << "mdb_env_create failed";
-    CHECK_EQ(mdb_env_set_mapsize(mdb_env, 1099511627776), MDB_SUCCESS);  // 1TB
-    CHECK_EQ(mdb_env_open(mdb_env, argv[1], MDB_RDONLY, 0664),
-        MDB_SUCCESS) << "mdb_env_open failed";
-    CHECK_EQ(mdb_txn_begin(mdb_env, NULL, MDB_RDONLY, &mdb_txn), MDB_SUCCESS)
-        << "mdb_txn_begin failed";
-    CHECK_EQ(mdb_open(mdb_txn, NULL, 0, &mdb_dbi), MDB_SUCCESS)
-        << "mdb_open failed";
-    CHECK_EQ(mdb_cursor_open(mdb_txn, mdb_dbi, &mdb_cursor), MDB_SUCCESS)
-        << "mdb_cursor_open failed";
-    CHECK_EQ(mdb_cursor_get(mdb_cursor, &mdb_key, &mdb_value, MDB_FIRST),
-        MDB_SUCCESS);
-  } else {
-    LOG(FATAL) << "Unknown db backend " << db_backend;
-  }
+  database->open(argv[1], Database::ReadOnly);
 
   Datum datum;
   BlobProto sum_blob;
   int count = 0;
   // load first datum
-  if (db_backend == "leveldb") {
-    datum.ParseFromString(it->value().ToString());
-  } else if (db_backend == "lmdb") {
-    datum.ParseFromArray(mdb_value.mv_data, mdb_value.mv_size);
-  } else {
-    LOG(FATAL) << "Unknown db backend " << db_backend;
-  }
+  datum.ParseFromString(database->begin()->second);
 
   sum_blob.set_num(1);
   sum_blob.set_channels(datum.channels());
@@ -89,59 +48,29 @@ int main(int argc, char** argv) {
     sum_blob.add_data(0.);
   }
   LOG(INFO) << "Starting Iteration";
-  if (db_backend == "leveldb") {  // leveldb
-    for (it->SeekToFirst(); it->Valid(); it->Next()) {
-      // just a dummy operation
-      datum.ParseFromString(it->value().ToString());
-      const std::string& data = datum.data();
-      size_in_datum = std::max<int>(datum.data().size(),
-          datum.float_data_size());
-      CHECK_EQ(size_in_datum, data_size) << "Incorrect data field size " <<
-          size_in_datum;
-      if (data.size() != 0) {
-        for (int i = 0; i < size_in_datum; ++i) {
-          sum_blob.set_data(i, sum_blob.data(i) + (uint8_t)data[i]);
-        }
-      } else {
-        for (int i = 0; i < size_in_datum; ++i) {
-          sum_blob.set_data(i, sum_blob.data(i) +
-              static_cast<float>(datum.float_data(i)));
-        }
+  for (Database::const_iterator iter = database->begin();
+      iter != database->end(); ++iter) {
+    // just a dummy operation
+    datum.ParseFromString(iter->second);
+    const std::string& data = datum.data();
+    size_in_datum = std::max<int>(datum.data().size(),
+        datum.float_data_size());
+    CHECK_EQ(size_in_datum, data_size) << "Incorrect data field size " <<
+        size_in_datum;
+    if (data.size() != 0) {
+      for (int i = 0; i < size_in_datum; ++i) {
+        sum_blob.set_data(i, sum_blob.data(i) + (uint8_t)data[i]);
       }
-      ++count;
-      if (count % 10000 == 0) {
-        LOG(ERROR) << "Processed " << count << " files.";
+    } else {
+      for (int i = 0; i < size_in_datum; ++i) {
+        sum_blob.set_data(i, sum_blob.data(i) +
+            static_cast<float>(datum.float_data(i)));
       }
     }
-  } else if (db_backend == "lmdb") {  // lmdb
-    CHECK_EQ(mdb_cursor_get(mdb_cursor, &mdb_key, &mdb_value, MDB_FIRST),
-        MDB_SUCCESS);
-    do {
-      // just a dummy operation
-      datum.ParseFromArray(mdb_value.mv_data, mdb_value.mv_size);
-      const std::string& data = datum.data();
-      size_in_datum = std::max<int>(datum.data().size(),
-          datum.float_data_size());
-      CHECK_EQ(size_in_datum, data_size) << "Incorrect data field size " <<
-          size_in_datum;
-      if (data.size() != 0) {
-        for (int i = 0; i < size_in_datum; ++i) {
-          sum_blob.set_data(i, sum_blob.data(i) + (uint8_t)data[i]);
-        }
-      } else {
-        for (int i = 0; i < size_in_datum; ++i) {
-          sum_blob.set_data(i, sum_blob.data(i) +
-              static_cast<float>(datum.float_data(i)));
-        }
-      }
-      ++count;
-      if (count % 10000 == 0) {
-        LOG(ERROR) << "Processed " << count << " files.";
-      }
-    } while (mdb_cursor_get(mdb_cursor, &mdb_key, &mdb_value, MDB_NEXT)
-        == MDB_SUCCESS);
-  } else {
-    LOG(FATAL) << "Unknown db backend " << db_backend;
+    ++count;
+    if (count % 10000 == 0) {
+      LOG(ERROR) << "Processed " << count << " files.";
+    }
   }
 
   if (count % 10000 != 0) {
@@ -155,15 +84,6 @@ int main(int argc, char** argv) {
   WriteProtoToBinaryFile(sum_blob, argv[2]);
 
   // Clean up
-  if (db_backend == "leveldb") {
-    delete db;
-  } else if (db_backend == "lmdb") {
-    mdb_cursor_close(mdb_cursor);
-    mdb_close(mdb_env, mdb_dbi);
-    mdb_txn_abort(mdb_txn);
-    mdb_env_close(mdb_env);
-  } else {
-    LOG(FATAL) << "Unknown db backend " << db_backend;
-  }
+  database->close();
   return 0;
 }
index 7c8c1da..6f03a9d 100644 (file)
 
 #include <gflags/gflags.h>
 #include <glog/logging.h>
-#include <leveldb/db.h>
-#include <leveldb/write_batch.h>
-#include <lmdb.h>
-#include <sys/stat.h>
 
 #include <algorithm>
 #include <fstream>  // NOLINT(readability/streams)
@@ -21,6 +17,7 @@
 #include <utility>
 #include <vector>
 
+#include "caffe/database_factory.hpp"
 #include "caffe/proto/caffe.pb.h"
 #include "caffe/util/io.hpp"
 #include "caffe/util/rng.hpp"
@@ -81,43 +78,10 @@ int main(int argc, char** argv) {
   int resize_width = std::max<int>(0, FLAGS_resize_width);
 
   // Open new db
-  // lmdb
-  MDB_env *mdb_env;
-  MDB_dbi mdb_dbi;
-  MDB_val mdb_key, mdb_data;
-  MDB_txn *mdb_txn;
-  // leveldb
-  leveldb::DB* db;
-  leveldb::Options options;
-  options.error_if_exists = true;
-  options.create_if_missing = true;
-  options.write_buffer_size = 268435456;
-  leveldb::WriteBatch* batch = NULL;
+  shared_ptr<Database> database = DatabaseFactory(db_backend);
 
   // Open db
-  if (db_backend == "leveldb") {  // leveldb
-    LOG(INFO) << "Opening leveldb " << db_path;
-    leveldb::Status status = leveldb::DB::Open(
-        options, db_path, &db);
-    CHECK(status.ok()) << "Failed to open leveldb " << db_path
-        << ". Is it already existing?";
-    batch = new leveldb::WriteBatch();
-  } else if (db_backend == "lmdb") {  // lmdb
-    LOG(INFO) << "Opening lmdb " << db_path;
-    CHECK_EQ(mkdir(db_path, 0744), 0)
-        << "mkdir " << db_path << "failed";
-    CHECK_EQ(mdb_env_create(&mdb_env), MDB_SUCCESS) << "mdb_env_create failed";
-    CHECK_EQ(mdb_env_set_mapsize(mdb_env, 1099511627776), MDB_SUCCESS)  // 1TB
-        << "mdb_env_set_mapsize failed";
-    CHECK_EQ(mdb_env_open(mdb_env, db_path, 0, 0664), MDB_SUCCESS)
-        << "mdb_env_open failed";
-    CHECK_EQ(mdb_txn_begin(mdb_env, NULL, 0, &mdb_txn), MDB_SUCCESS)
-        << "mdb_txn_begin failed";
-    CHECK_EQ(mdb_open(mdb_txn, NULL, 0, &mdb_dbi), MDB_SUCCESS)
-        << "mdb_open failed. Does the lmdb already exist? ";
-  } else {
-    LOG(FATAL) << "Unknown db backend " << db_backend;
-  }
+  database->open(db_path, Database::New);
 
   // Storing to db
   std::string root_folder(argv[1]);
@@ -151,50 +115,19 @@ int main(int argc, char** argv) {
     std::string keystr(key_cstr);
 
     // Put in db
-    if (db_backend == "leveldb") {  // leveldb
-      batch->Put(keystr, value);
-    } else if (db_backend == "lmdb") {  // lmdb
-      mdb_data.mv_size = value.size();
-      mdb_data.mv_data = reinterpret_cast<void*>(&value[0]);
-      mdb_key.mv_size = keystr.size();
-      mdb_key.mv_data = reinterpret_cast<void*>(&keystr[0]);
-      CHECK_EQ(mdb_put(mdb_txn, mdb_dbi, &mdb_key, &mdb_data, 0), MDB_SUCCESS)
-          << "mdb_put failed";
-    } else {
-      LOG(FATAL) << "Unknown db backend " << db_backend;
-    }
+    database->put(keystr, value);
 
     if (++count % 1000 == 0) {
       // Commit txn
-      if (db_backend == "leveldb") {  // leveldb
-        db->Write(leveldb::WriteOptions(), batch);
-        delete batch;
-        batch = new leveldb::WriteBatch();
-      } else if (db_backend == "lmdb") {  // lmdb
-        CHECK_EQ(mdb_txn_commit(mdb_txn), MDB_SUCCESS)
-            << "mdb_txn_commit failed";
-        CHECK_EQ(mdb_txn_begin(mdb_env, NULL, 0, &mdb_txn), MDB_SUCCESS)
-            << "mdb_txn_begin failed";
-      } else {
-        LOG(FATAL) << "Unknown db backend " << db_backend;
-      }
+      database->commit();
       LOG(ERROR) << "Processed " << count << " files.";
     }
   }
   // write the last batch
   if (count % 1000 != 0) {
-    if (db_backend == "leveldb") {  // leveldb
-      db->Write(leveldb::WriteOptions(), batch);
-      delete batch;
-      delete db;
-    } else if (db_backend == "lmdb") {  // lmdb
-      CHECK_EQ(mdb_txn_commit(mdb_txn), MDB_SUCCESS) << "mdb_txn_commit failed";
-      mdb_close(mdb_env, mdb_dbi);
-      mdb_env_close(mdb_env);
-    } else {
-      LOG(FATAL) << "Unknown db backend " << db_backend;
-    }
+    database->commit();
     LOG(ERROR) << "Processed " << count << " files.";
   }
+  database->close();
   return 0;
 }
index 299b311..b3ad8e6 100644 (file)
@@ -4,11 +4,10 @@
 
 #include "boost/algorithm/string.hpp"
 #include "google/protobuf/text_format.h"
-#include "leveldb/db.h"
-#include "leveldb/write_batch.h"
 
 #include "caffe/blob.hpp"
 #include "caffe/common.hpp"
+#include "caffe/database_factory.hpp"
 #include "caffe/net.hpp"
 #include "caffe/proto/caffe.pb.h"
 #include "caffe/util/io.hpp"
@@ -17,6 +16,8 @@
 using boost::shared_ptr;
 using caffe::Blob;
 using caffe::Caffe;
+using caffe::Database;
+using caffe::DatabaseFactory;
 using caffe::Datum;
 using caffe::Net;
 
@@ -38,12 +39,12 @@ int feature_extraction_pipeline(int argc, char** argv) {
     " extract features of the input data produced by the net.\n"
     "Usage: extract_features  pretrained_net_param"
     "  feature_extraction_proto_file  extract_feature_blob_name1[,name2,...]"
-    "  save_feature_leveldb_name1[,name2,...]  num_mini_batches  [CPU/GPU]"
+    "  save_feature_database_name1[,name2,...]  num_mini_batches  [CPU/GPU]"
     "  [DEVICE_ID=0]\n"
     "Note: you can extract multiple features in one pass by specifying"
-    " multiple feature blob names and leveldb names seperated by ','."
+    " multiple feature blob names and database names seperated by ','."
     " The names cannot contain white space characters and the number of blobs"
-    " and leveldbs must be equal.";
+    " and databases must be equal.";
     return 1;
   }
   int arg_pos = num_required_args;
@@ -104,12 +105,12 @@ int feature_extraction_pipeline(int argc, char** argv) {
   std::vector<std::string> blob_names;
   boost::split(blob_names, extract_feature_blob_names, boost::is_any_of(","));
 
-  std::string save_feature_leveldb_names(argv[++arg_pos]);
-  std::vector<std::string> leveldb_names;
-  boost::split(leveldb_names, save_feature_leveldb_names,
+  std::string save_feature_database_names(argv[++arg_pos]);
+  std::vector<std::string> database_names;
+  boost::split(database_names, save_feature_database_names,
                boost::is_any_of(","));
-  CHECK_EQ(blob_names.size(), leveldb_names.size()) <<
-      " the number of blob names and leveldb names must be equal";
+  CHECK_EQ(blob_names.size(), database_names.size()) <<
+      " the number of blob names and database names must be equal";
   size_t num_features = blob_names.size();
 
   for (size_t i = 0; i < num_features; i++) {
@@ -118,19 +119,12 @@ int feature_extraction_pipeline(int argc, char** argv) {
         << " in the network " << feature_extraction_proto;
   }
 
-  leveldb::Options options;
-  options.error_if_exists = true;
-  options.create_if_missing = true;
-  options.write_buffer_size = 268435456;
-  std::vector<shared_ptr<leveldb::DB> > feature_dbs;
+  std::vector<shared_ptr<Database> > feature_dbs;
   for (size_t i = 0; i < num_features; ++i) {
-    LOG(INFO)<< "Opening leveldb " << leveldb_names[i];
-    leveldb::DB* db;
-    leveldb::Status status = leveldb::DB::Open(options,
-                                               leveldb_names[i].c_str(),
-                                               &db);
-    CHECK(status.ok()) << "Failed to open leveldb " << leveldb_names[i];
-    feature_dbs.push_back(shared_ptr<leveldb::DB>(db));
+    LOG(INFO)<< "Opening database " << database_names[i];
+    shared_ptr<Database> database = DatabaseFactory("leveldb");
+    database->open(database_names.at(i), Database::New);
+    feature_dbs.push_back(database);
   }
 
   int num_mini_batches = atoi(argv[++arg_pos]);
@@ -138,9 +132,6 @@ int feature_extraction_pipeline(int argc, char** argv) {
   LOG(ERROR)<< "Extacting Features";
 
   Datum datum;
-  std::vector<shared_ptr<leveldb::WriteBatch> > feature_batches(
-      num_features,
-      shared_ptr<leveldb::WriteBatch>(new leveldb::WriteBatch()));
   const int kMaxKeyStrLength = 100;
   char key_str[kMaxKeyStrLength];
   std::vector<Blob<float>*> input_vec;
@@ -167,14 +158,12 @@ int feature_extraction_pipeline(int argc, char** argv) {
         std::string value;
         datum.SerializeToString(&value);
         snprintf(key_str, kMaxKeyStrLength, "%d", image_indices[i]);
-        feature_batches[i]->Put(std::string(key_str), value);
+        feature_dbs.at(i)->put(std::string(key_str), value);
         ++image_indices[i];
         if (image_indices[i] % 1000 == 0) {
-          feature_dbs[i]->Write(leveldb::WriteOptions(),
-                                feature_batches[i].get());
+          feature_dbs.at(i)->commit();
           LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
               " query images for feature blob " << blob_names[i];
-          feature_batches[i].reset(new leveldb::WriteBatch());
         }
       }  // for (int n = 0; n < batch_size; ++n)
     }  // for (int i = 0; i < num_features; ++i)
@@ -182,10 +171,11 @@ int feature_extraction_pipeline(int argc, char** argv) {
   // write the last batch
   for (int i = 0; i < num_features; ++i) {
     if (image_indices[i] % 1000 != 0) {
-      feature_dbs[i]->Write(leveldb::WriteOptions(), feature_batches[i].get());
+      feature_dbs.at(i)->commit();
     }
     LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
         " query images for feature blob " << blob_names[i];
+    feature_dbs.at(i)->close();
   }
 
   LOG(ERROR)<< "Successfully extracted the features!";