add option for lmdb
authorlinmin <mavenlin@gmail.com>
Wed, 21 May 2014 16:15:32 +0000 (00:15 +0800)
committerJeff Donahue <jeff.donahue@gmail.com>
Fri, 13 Jun 2014 05:23:49 +0000 (22:23 -0700)
Makefile
include/caffe/data_layers.hpp
src/caffe/layers/data_layer.cpp
src/caffe/proto/caffe.proto

index c43b30e..943165a 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -129,6 +129,7 @@ LIBRARY_DIRS += $(CUDA_LIB_DIR)
 LIBRARIES := cudart cublas curand \
        pthread \
        glog protobuf leveldb snappy \
+       lmdb \
        boost_system \
        hdf5_hl hdf5 \
        opencv_core opencv_highgui opencv_imgproc
index 3d3ce94..2b4c278 100644 (file)
@@ -8,6 +8,7 @@
 #include <vector>
 
 #include "leveldb/db.h"
+#include "lmdb.h"
 #include "pthread.h"
 #include "hdf5.h"
 #include "boost/scoped_ptr.hpp"
@@ -133,8 +134,17 @@ class DataLayer : public Layer<Dtype> {
   virtual unsigned int PrefetchRand();
 
   shared_ptr<Caffe::RNG> prefetch_rng_;
+
+  // LEVELDB
   shared_ptr<leveldb::DB> db_;
   shared_ptr<leveldb::Iterator> iter_;
+  // LMDB
+  MDB_env* mdb_env_;
+  MDB_dbi mdb_dbi_;
+  MDB_txn* mdb_txn_;
+  MDB_cursor* mdb_cursor_;
+  MDB_val mdb_key_, mdb_value_;
+
   int datum_channels_;
   int datum_height_;
   int datum_width_;
index 348753e..3b307ec 100644 (file)
@@ -12,6 +12,7 @@
 #include "caffe/util/math_functions.hpp"
 #include "caffe/util/rng.hpp"
 #include "caffe/vision_layers.hpp"
+#include "caffe/proto/caffe.pb.h"
 
 using std::string;
 
@@ -46,9 +47,22 @@ void* DataLayerPrefetch(void* layer_pointer) {
   const Dtype* mean = layer->data_mean_.cpu_data();
   for (int item_id = 0; item_id < batch_size; ++item_id) {
     // get a blob
-    CHECK(layer->iter_);
-    CHECK(layer->iter_->Valid());
-    datum.ParseFromString(layer->iter_->value().ToString());
+    switch (layer->layer_param_.data_param().backend()) {
+    case DataParameter_DB_LEVELDB:
+      CHECK(layer->iter_);
+      CHECK(layer->iter_->Valid());
+      datum.ParseFromString(layer->iter_->value().ToString());
+      break;
+    case DataParameter_DB_LMDB:
+      CHECK_EQ(mdb_cursor_get(layer->mdb_cursor_, &layer->mdb_key_,
+              &layer->mdb_value_, MDB_GET_CURRENT), MDB_SUCCESS);
+      datum.ParseFromArray(layer->mdb_value_.mv_data,
+          layer->mdb_value_.mv_size);
+      break;
+    default:
+      LOG(FATAL) << "Unknown database backend";
+    }
+
     const string& data = datum.data();
     if (crop_size) {
       CHECK(data.size()) << "Image cropping only support uint8 data";
@@ -110,11 +124,26 @@ void* DataLayerPrefetch(void* layer_pointer) {
       top_label[item_id] = datum.label();
     }
     // go to the next iter
-    layer->iter_->Next();
-    if (!layer->iter_->Valid()) {
-      // We have reached the end. Restart from the first.
-      DLOG(INFO) << "Restarting data prefetching from start.";
-      layer->iter_->SeekToFirst();
+    switch (layer->layer_param_.data_param().backend()) {
+    case DataParameter_DB_LEVELDB:
+      layer->iter_->Next();
+      if (!layer->iter_->Valid()) {
+        // We have reached the end. Restart from the first.
+        DLOG(INFO) << "Restarting data prefetching from start.";
+        layer->iter_->SeekToFirst();
+      }
+      break;
+    case DataParameter_DB_LMDB:
+      if (mdb_cursor_get(layer->mdb_cursor_, &layer->mdb_key_,
+              &layer->mdb_value_, MDB_NEXT) != MDB_SUCCESS) {
+        // We have reached the end. Restart from the first.
+        DLOG(INFO) << "Restarting data prefetching from start.";
+        CHECK_EQ(mdb_cursor_get(layer->mdb_cursor_, &layer->mdb_key_,
+                &layer->mdb_value_, MDB_FIRST), MDB_SUCCESS);
+      }
+      break;
+    default:
+      LOG(FATAL) << "Unknown database backend";
     }
   }
 
@@ -124,6 +153,19 @@ void* DataLayerPrefetch(void* layer_pointer) {
 template <typename Dtype>
 DataLayer<Dtype>::~DataLayer<Dtype>() {
   JoinPrefetchThread();
+  // clean up the database resources
+  switch (this->layer_param_.data_param().backend()) {
+  case DataParameter_DB_LEVELDB:
+    break; // do nothing
+  case DataParameter_DB_LMDB:
+    mdb_cursor_close(mdb_cursor_);
+    mdb_close(mdb_env_, mdb_dbi_);
+    mdb_txn_abort(mdb_txn_);
+    mdb_env_close(mdb_env_);
+    break;
+  default:
+    LOG(FATAL) << "Unknown database backend";
+  }
 }
 
 template <typename Dtype>
@@ -135,35 +177,83 @@ void DataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
   } else {
     output_labels_ = true;
   }
-  // Initialize the leveldb
-  leveldb::DB* db_temp;
-  leveldb::Options options;
-  options.create_if_missing = false;
-  options.max_open_files = 100;
-  LOG(INFO) << "Opening leveldb " << this->layer_param_.data_param().source();
-  leveldb::Status status = leveldb::DB::Open(
-      options, this->layer_param_.data_param().source(), &db_temp);
-  CHECK(status.ok()) << "Failed to open leveldb "
-      << this->layer_param_.data_param().source() << std::endl
-      << status.ToString();
-  db_.reset(db_temp);
-  iter_.reset(db_->NewIterator(leveldb::ReadOptions()));
-  iter_->SeekToFirst();
+  // Initialize DB
+  switch (this->layer_param_.data_param().backend()) {
+  case DataParameter_DB_LEVELDB:
+    {
+    leveldb::DB* db_temp;
+    leveldb::Options options;
+    options.create_if_missing = false;
+    options.max_open_files = 100;
+    LOG(INFO) << "Opening leveldb " << this->layer_param_.data_param().source();
+    leveldb::Status status = leveldb::DB::Open(
+        options, this->layer_param_.data_param().source(), &db_temp);
+    CHECK(status.ok()) << "Failed to open leveldb "
+                       << this->layer_param_.data_param().source() << std::endl
+                       << status.ToString();
+    db_.reset(db_temp);
+    iter_.reset(db_->NewIterator(leveldb::ReadOptions()));
+    iter_->SeekToFirst();
+    }
+    break;
+  case DataParameter_DB_LMDB:
+    CHECK_EQ(mdb_env_create(&mdb_env_), MDB_SUCCESS) << "mdb_env_create failed";
+    CHECK_EQ(mdb_env_set_mapsize(mdb_env_, 1099511627776), MDB_SUCCESS); // 1TB
+    CHECK_EQ(mdb_env_open(mdb_env_,
+            this->layer_param_.data_param().source().c_str(),
+            MDB_RDONLY|MDB_NOTLS, 0664), MDB_SUCCESS) << "mdb_env_open failed";
+    CHECK_EQ(mdb_txn_begin(mdb_env_, NULL, MDB_RDONLY, &mdb_txn_), MDB_SUCCESS)
+        << "mdb_txn_begin failed";
+    CHECK_EQ(mdb_open(mdb_txn_, NULL, 0, &mdb_dbi_), MDB_SUCCESS)
+        << "mdb_open failed";
+    CHECK_EQ(mdb_cursor_open(mdb_txn_, mdb_dbi_, &mdb_cursor_), MDB_SUCCESS)
+        << "mdb_cursor_open failed";
+    LOG(INFO) << "Opening lmdb " << this->layer_param_.data_param().source();
+    CHECK_EQ(mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_FIRST),
+        MDB_SUCCESS) << "mdb_cursor_get failed";
+    break;
+  default:
+    LOG(FATAL) << "Unknown database backend";
+  }
+
   // Check if we would need to randomly skip a few data points
   if (this->layer_param_.data_param().rand_skip()) {
     unsigned int skip = caffe_rng_rand() %
                         this->layer_param_.data_param().rand_skip();
     LOG(INFO) << "Skipping first " << skip << " data points.";
     while (skip-- > 0) {
-      iter_->Next();
-      if (!iter_->Valid()) {
-        iter_->SeekToFirst();
+      switch (this->layer_param_.data_param().backend()) {
+      case DataParameter_DB_LEVELDB:
+        iter_->Next();
+        if (!iter_->Valid()) {
+          iter_->SeekToFirst();
+        }
+        break;
+      case DataParameter_DB_LMDB:
+        if(mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_NEXT)
+            != MDB_SUCCESS) {
+          CHECK_EQ(mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_,
+                  MDB_FIRST), MDB_SUCCESS);
+        }
+        break;
+      default:
+        LOG(FATAL) << "Unknown database backend";
       }
     }
   }
   // Read a data point, and use it to initialize the top blob.
   Datum datum;
-  datum.ParseFromString(iter_->value().ToString());
+  switch (this->layer_param_.data_param().backend()) {
+  case DataParameter_DB_LEVELDB:
+    datum.ParseFromString(iter_->value().ToString());
+    break;
+  case DataParameter_DB_LMDB:
+    datum.ParseFromArray(mdb_value_.mv_data, mdb_value_.mv_size);
+    break;
+  default:
+    LOG(FATAL) << "Unknown database backend";
+  }
+
   // image
   int crop_size = this->layer_param_.data_param().crop_size();
   if (crop_size > 0) {
index 60c7daa..b85b49c 100644 (file)
@@ -229,6 +229,10 @@ message ConvolutionParameter {
 
 // Message that stores parameters used by DataLayer
 message DataParameter {
+  enum DB {
+    LEVELDB = 0;
+    LMDB = 1;
+  }
   // Specify the data source.
   optional string source = 1;
   // For data pre-processing, we can do simple scaling and subtracting the
@@ -247,6 +251,7 @@ message DataParameter {
   // point would be set as rand_skip * rand(0,1). Note that rand_skip should not
   // be larger than the number of keys in the leveldb.
   optional uint32 rand_skip = 7 [default = 0];
+  optional DB backend = 8 [default = LEVELDB];
 }
 
 // Message that stores parameters used by DropoutLayer