Don't autocommit on close for the databases. If they were read-only, then they might...
authorKevin James Matzen <kmatzen@cs.cornell.edu>
Wed, 8 Oct 2014 01:46:15 +0000 (21:46 -0400)
committerKevin James Matzen <kmatzen@cs.cornell.edu>
Tue, 14 Oct 2014 23:29:11 +0000 (19:29 -0400)
examples/cifar10/convert_cifar_data.cpp
include/caffe/leveldb_database.hpp
include/caffe/lmdb_database.hpp
src/caffe/leveldb_database.cpp
src/caffe/lmdb_database.cpp
tools/compute_image_mean.cpp

index 90ecb6d..c493087 100644 (file)
 
 #include "glog/logging.h"
 #include "google/protobuf/text_format.h"
-#include "leveldb/db.h"
 #include "stdint.h"
 
+#include "caffe/database_factory.hpp"
 #include "caffe/proto/caffe.pb.h"
 
 using std::string;
 
+using caffe::Database;
+using caffe::DatabaseFactory;
+using caffe::shared_ptr;
+
 const int kCIFARSize = 32;
 const int kCIFARImageNBytes = 3072;
 const int kCIFARBatchSize = 10000;
@@ -31,26 +35,20 @@ void read_image(std::ifstream* file, int* label, char* buffer) {
   return;
 }
 
-void convert_dataset(const string& input_folder, const string& output_folder) {
-  // Leveldb options
-  leveldb::Options options;
-  options.create_if_missing = true;
-  options.error_if_exists = true;
+void convert_dataset(const string& input_folder, const string& output_folder,
+    const string& db_type) {
+  shared_ptr<Database> train_database = DatabaseFactory(db_type);
+  train_database->open(output_folder + "/cifar10_train_" + db_type,
+      Database::New);
   // Data buffer
   int label;
   char str_buffer[kCIFARImageNBytes];
-  string value;
   caffe::Datum datum;
   datum.set_channels(3);
   datum.set_height(kCIFARSize);
   datum.set_width(kCIFARSize);
 
   LOG(INFO) << "Writing Training data";
-  leveldb::DB* train_db;
-  leveldb::Status status;
-  status = leveldb::DB::Open(options, output_folder + "/cifar10_train_leveldb",
-      &train_db);
-  CHECK(status.ok()) << "Failed to open leveldb.";
   for (int fileid = 0; fileid < kCIFARTrainBatches; ++fileid) {
     // Open files
     LOG(INFO) << "Training Batch " << fileid + 1;
@@ -62,17 +60,22 @@ void convert_dataset(const string& input_folder, const string& output_folder) {
       read_image(&data_file, &label, str_buffer);
       datum.set_label(label);
       datum.set_data(str_buffer, kCIFARImageNBytes);
-      datum.SerializeToString(&value);
-      snprintf(str_buffer, kCIFARImageNBytes, "%05d",
+      Database::buffer_t value(datum.ByteSize());
+      datum.SerializeWithCachedSizesToArray(
+          reinterpret_cast<unsigned char*>(value.data()));
+      int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d",
           fileid * kCIFARBatchSize + itemid);
-      train_db->Put(leveldb::WriteOptions(), string(str_buffer), value);
+      Database::buffer_t key(str_buffer, str_buffer + length);
+      train_database->put(&key, &value);
     }
   }
+  train_database->commit();
+  train_database->close();
 
   LOG(INFO) << "Writing Testing data";
-  leveldb::DB* test_db;
-  CHECK(leveldb::DB::Open(options, output_folder + "/cifar10_test_leveldb",
-      &test_db).ok()) << "Failed to open leveldb.";
+  shared_ptr<Database> test_database = DatabaseFactory(db_type);
+  test_database->open(output_folder + "/cifar10_test_" + db_type,
+      Database::New);
   // Open files
   std::ifstream data_file((input_folder + "/test_batch.bin").c_str(),
       std::ios::in | std::ios::binary);
@@ -81,28 +84,30 @@ void convert_dataset(const string& input_folder, const string& output_folder) {
     read_image(&data_file, &label, str_buffer);
     datum.set_label(label);
     datum.set_data(str_buffer, kCIFARImageNBytes);
-    datum.SerializeToString(&value);
-    snprintf(str_buffer, kCIFARImageNBytes, "%05d", itemid);
-    test_db->Put(leveldb::WriteOptions(), string(str_buffer), value);
+    Database::buffer_t value(datum.ByteSize());
+    datum.SerializeWithCachedSizesToArray(
+        reinterpret_cast<unsigned char*>(value.data()));
+    int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d", itemid);
+    Database::buffer_t key(str_buffer, str_buffer + length);
+    test_database->put(&key, &value);
   }
-
-  delete train_db;
-  delete test_db;
+  test_database->commit();
+  test_database->close();
 }
 
 int main(int argc, char** argv) {
-  if (argc != 3) {
+  if (argc != 4) {
     printf("This script converts the CIFAR dataset to the leveldb format used\n"
            "by caffe to perform classification.\n"
            "Usage:\n"
-           "    convert_cifar_data input_folder output_folder\n"
+           "    convert_cifar_data input_folder output_folder db_type\n"
            "Where the input folder should contain the binary batch files.\n"
            "The CIFAR dataset could be downloaded at\n"
            "    http://www.cs.toronto.edu/~kriz/cifar.html\n"
            "You should gunzip them after downloading.\n");
   } else {
     google::InitGoogleLogging(argv[0]);
-    convert_dataset(string(argv[1]), string(argv[2]));
+    convert_dataset(string(argv[1]), string(argv[2]), string(argv[3]));
   }
   return 0;
 }
index f30273d..dda6697 100644 (file)
@@ -24,8 +24,6 @@ class LeveldbDatabase : public Database {
   const_iterator end() const;
   const_iterator cend() const;
 
-  ~LeveldbDatabase() { this->close(); }
-
  protected:
   class LeveldbState : public Database::DatabaseState {
    public:
index ee3806d..9654222 100644 (file)
@@ -17,7 +17,6 @@ class LmdbDatabase : public Database {
       : env_(NULL),
         dbi_(0),
         txn_(NULL) { }
-  ~LmdbDatabase() { this->close(); }
 
   void open(const string& filename, Mode mode);
   void put(buffer_t* key, buffer_t* value);
index a8fe02a..a5cdaa3 100644 (file)
@@ -66,9 +66,6 @@ void LeveldbDatabase::commit() {
 void LeveldbDatabase::close() {
   LOG(INFO) << "LevelDB: Close";
 
-  if (batch_ && db_) {
-    this->commit();
-  }
   batch_.reset();
   db_.reset();
 }
index 7197a47..0860778 100644 (file)
@@ -14,9 +14,9 @@ void LmdbDatabase::open(const string& filename, Mode mode) {
   CHECK(NULL == txn_);
   CHECK_EQ(0, dbi_);
 
-  if (mode != ReadOnly) {
+  if (mode == New) {
     CHECK_EQ(mkdir(filename.c_str(), 0744), 0) << "mkdir " << filename
-                                                << "failed";
+                                                << " failed";
   }
 
   int retval;
@@ -66,18 +66,19 @@ void LmdbDatabase::commit() {
 
   CHECK_NOTNULL(txn_);
 
-  int retval = mdb_txn_commit(txn_);
+  int retval;
+  retval = mdb_txn_commit(txn_);
   CHECK_EQ(retval, MDB_SUCCESS) << "mdb_txn_commit failed "
       << mdb_strerror(retval);
+
+  retval = mdb_txn_begin(env_, NULL, 0, &txn_);
+  CHECK_EQ(retval, MDB_SUCCESS)
+      << "mdb_txn_begin failed " << mdb_strerror(retval);
 }
 
 void LmdbDatabase::close() {
   LOG(INFO) << "LMDB: Close";
 
-  if (env_ && dbi_ && txn_) {
-    this->commit();
-  }
-
   if (env_ && dbi_) {
     mdb_close(env_, dbi_);
     mdb_env_close(env_);
index f981af4..11f6fb8 100644 (file)
@@ -35,8 +35,10 @@ int main(int argc, char** argv) {
   BlobProto sum_blob;
   int count = 0;
   // load first datum
-  const Database::buffer_t& first_blob = database->begin()->second;
+  Database::const_iterator iter = database->begin();
+  const Database::buffer_t& first_blob = iter->second;
   datum.ParseFromArray(first_blob.data(), first_blob.size());
+  iter = database->end();
 
   sum_blob.set_num(1);
   sum_blob.set_channels(datum.channels());