Update MNIST example to use new DB classes
authorLuke Yeager <luke.yeager@gmail.com>
Fri, 26 Feb 2016 04:14:02 +0000 (20:14 -0800)
committerLuke Yeager <luke.yeager@gmail.com>
Wed, 20 Apr 2016 22:49:57 +0000 (15:49 -0700)
examples/mnist/convert_mnist_data.cpp

index 32bee52..57ddef7 100644 (file)
 #include <fstream>  // NOLINT(readability/streams)
 #include <string>
 
+#include "boost/scoped_ptr.hpp"
 #include "caffe/proto/caffe.pb.h"
+#include "caffe/util/db.hpp"
 #include "caffe/util/format.hpp"
 
 #if defined(USE_LEVELDB) && defined(USE_LMDB)
 
 using namespace caffe;  // NOLINT(build/namespaces)
+using boost::scoped_ptr;
 using std::string;
 
 DEFINE_string(backend, "lmdb", "The backend for storing the result");
@@ -67,43 +70,10 @@ void convert_dataset(const char* image_filename, const char* label_filename,
   image_file.read(reinterpret_cast<char*>(&cols), 4);
   cols = swap_endian(cols);
 
-  // lmdb
-  MDB_env *mdb_env;
-  MDB_dbi mdb_dbi;
-  MDB_val mdb_key, mdb_data;
-  MDB_txn *mdb_txn;
-  // leveldb
-  leveldb::DB* db;
-  leveldb::Options options;
-  options.error_if_exists = true;
-  options.create_if_missing = true;
-  options.write_buffer_size = 268435456;
-  leveldb::WriteBatch* batch = NULL;
-
-  // Open db
-  if (db_backend == "leveldb") {  // leveldb
-    LOG(INFO) << "Opening leveldb " << db_path;
-    leveldb::Status status = leveldb::DB::Open(
-        options, db_path, &db);
-    CHECK(status.ok()) << "Failed to open leveldb " << db_path
-        << ". Is it already existing?";
-    batch = new leveldb::WriteBatch();
-  } else if (db_backend == "lmdb") {  // lmdb
-    LOG(INFO) << "Opening lmdb " << db_path;
-    CHECK_EQ(mkdir(db_path, 0744), 0)
-        << "mkdir " << db_path << "failed";
-    CHECK_EQ(mdb_env_create(&mdb_env), MDB_SUCCESS) << "mdb_env_create failed";
-    CHECK_EQ(mdb_env_set_mapsize(mdb_env, 1099511627776), MDB_SUCCESS)  // 1TB
-        << "mdb_env_set_mapsize failed";
-    CHECK_EQ(mdb_env_open(mdb_env, db_path, 0, 0664), MDB_SUCCESS)
-        << "mdb_env_open failed";
-    CHECK_EQ(mdb_txn_begin(mdb_env, NULL, 0, &mdb_txn), MDB_SUCCESS)
-        << "mdb_txn_begin failed";
-    CHECK_EQ(mdb_open(mdb_txn, NULL, 0, &mdb_dbi), MDB_SUCCESS)
-        << "mdb_open failed. Does the lmdb already exist? ";
-  } else {
-    LOG(FATAL) << "Unknown db backend " << db_backend;
-  }
+
+  scoped_ptr<db::DB> db(db::GetDB(db_backend));
+  db->Open(db_path, db::NEW);
+  scoped_ptr<db::Transaction> txn(db->NewTransaction());
 
   // Storing to db
   char label;
@@ -125,52 +95,19 @@ void convert_dataset(const char* image_filename, const char* label_filename,
     string key_str = caffe::format_int(item_id, 8);
     datum.SerializeToString(&value);
 
-    // Put in db
-    if (db_backend == "leveldb") {  // leveldb
-      batch->Put(key_str, value);
-    } else if (db_backend == "lmdb") {  // lmdb
-      mdb_data.mv_size = value.size();
-      mdb_data.mv_data = reinterpret_cast<void*>(&value[0]);
-      mdb_key.mv_size = key_str.size();
-      mdb_key.mv_data = reinterpret_cast<void*>(&key_str[0]);
-      CHECK_EQ(mdb_put(mdb_txn, mdb_dbi, &mdb_key, &mdb_data, 0), MDB_SUCCESS)
-          << "mdb_put failed";
-    } else {
-      LOG(FATAL) << "Unknown db backend " << db_backend;
-    }
+    txn->Put(key_str, value);
 
     if (++count % 1000 == 0) {
-      // Commit txn
-      if (db_backend == "leveldb") {  // leveldb
-        db->Write(leveldb::WriteOptions(), batch);
-        delete batch;
-        batch = new leveldb::WriteBatch();
-      } else if (db_backend == "lmdb") {  // lmdb
-        CHECK_EQ(mdb_txn_commit(mdb_txn), MDB_SUCCESS)
-            << "mdb_txn_commit failed";
-        CHECK_EQ(mdb_txn_begin(mdb_env, NULL, 0, &mdb_txn), MDB_SUCCESS)
-            << "mdb_txn_begin failed";
-      } else {
-        LOG(FATAL) << "Unknown db backend " << db_backend;
-      }
+      txn->Commit();
     }
   }
   // write the last batch
   if (count % 1000 != 0) {
-    if (db_backend == "leveldb") {  // leveldb
-      db->Write(leveldb::WriteOptions(), batch);
-      delete batch;
-      delete db;
-    } else if (db_backend == "lmdb") {  // lmdb
-      CHECK_EQ(mdb_txn_commit(mdb_txn), MDB_SUCCESS) << "mdb_txn_commit failed";
-      mdb_close(mdb_env, mdb_dbi);
-      mdb_env_close(mdb_env);
-    } else {
-      LOG(FATAL) << "Unknown db backend " << db_backend;
-    }
-    LOG(ERROR) << "Processed " << count << " files.";
+      txn->Commit();
   }
+  LOG(INFO) << "Processed " << count << " files.";
   delete[] pixels;
+  db->Close();
 }
 
 int main(int argc, char** argv) {