Updated Database interface so that rather than CHECKing for certain conditions inside...
authorKevin James Matzen <kmatzen@cs.cornell.edu>
Mon, 13 Oct 2014 00:20:31 +0000 (20:20 -0400)
committerKevin James Matzen <kmatzen@cs.cornell.edu>
Tue, 14 Oct 2014 23:31:30 +0000 (19:31 -0400)
12 files changed:
examples/cifar10/convert_cifar_data.cpp
include/caffe/database.hpp
include/caffe/leveldb_database.hpp
include/caffe/lmdb_database.hpp
src/caffe/layers/data_layer.cpp
src/caffe/leveldb_database.cpp
src/caffe/lmdb_database.cpp
src/caffe/test/test_data_layer.cpp
src/caffe/test/test_database.cpp
tools/compute_image_mean.cpp
tools/convert_imageset.cpp
tools/extract_features.cpp

index c493087..b29e412 100644 (file)
@@ -38,8 +38,8 @@ void read_image(std::ifstream* file, int* label, char* buffer) {
 void convert_dataset(const string& input_folder, const string& output_folder,
     const string& db_type) {
   shared_ptr<Database> train_database = DatabaseFactory(db_type);
-  train_database->open(output_folder + "/cifar10_train_" + db_type,
-      Database::New);
+  CHECK(train_database->open(output_folder + "/cifar10_train_" + db_type,
+      Database::New));
   // Data buffer
   int label;
   char str_buffer[kCIFARImageNBytes];
@@ -66,16 +66,16 @@ void convert_dataset(const string& input_folder, const string& output_folder,
       int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d",
           fileid * kCIFARBatchSize + itemid);
       Database::buffer_t key(str_buffer, str_buffer + length);
-      train_database->put(&key, &value);
+      CHECK(train_database->put(&key, &value));
     }
   }
-  train_database->commit();
+  CHECK(train_database->commit());
   train_database->close();
 
   LOG(INFO) << "Writing Testing data";
   shared_ptr<Database> test_database = DatabaseFactory(db_type);
-  test_database->open(output_folder + "/cifar10_test_" + db_type,
-      Database::New);
+  CHECK(test_database->open(output_folder + "/cifar10_test_" + db_type,
+      Database::New));
   // Open files
   std::ifstream data_file((input_folder + "/test_batch.bin").c_str(),
       std::ios::in | std::ios::binary);
@@ -89,9 +89,9 @@ void convert_dataset(const string& input_folder, const string& output_folder,
         reinterpret_cast<unsigned char*>(value.data()));
     int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d", itemid);
     Database::buffer_t key(str_buffer, str_buffer + length);
-    test_database->put(&key, &value);
+    CHECK(test_database->put(&key, &value));
   }
-  test_database->commit();
+  CHECK(test_database->commit());
   test_database->close();
 }
 
index 08baf2f..148b1ed 100644 (file)
@@ -26,10 +26,10 @@ class Database {
     buffer_t value;
   };
 
-  virtual void open(const string& filename, Mode mode) = 0;
-  virtual void put(buffer_t* key, buffer_t* value) = 0;
-  virtual void get(buffer_t* key, buffer_t* value) = 0;
-  virtual void commit() = 0;
+  virtual bool open(const string& filename, Mode mode) = 0;
+  virtual bool put(buffer_t* key, buffer_t* value) = 0;
+  virtual bool get(buffer_t* key, buffer_t* value) = 0;
+  virtual bool commit() = 0;
   virtual void close() = 0;
 
   Database() { }
index 03bfd38..64bfa7c 100644 (file)
@@ -14,10 +14,10 @@ namespace caffe {
 
 class LeveldbDatabase : public Database {
  public:
-  void open(const string& filename, Mode mode);
-  void put(buffer_t* key, buffer_t* value);
-  void get(buffer_t* key, buffer_t* value);
-  void commit();
+  bool open(const string& filename, Mode mode);
+  bool put(buffer_t* key, buffer_t* value);
+  bool get(buffer_t* key, buffer_t* value);
+  bool commit();
   void close();
 
   const_iterator begin() const;
index 7b532f8..69e3ce0 100644 (file)
@@ -18,10 +18,10 @@ class LmdbDatabase : public Database {
         dbi_(0),
         txn_(NULL) { }
 
-  void open(const string& filename, Mode mode);
-  void put(buffer_t* key, buffer_t* value);
-  void get(buffer_t* key, buffer_t* value);
-  void commit();
+  bool open(const string& filename, Mode mode);
+  bool put(buffer_t* key, buffer_t* value);
+  bool get(buffer_t* key, buffer_t* value);
+  bool commit();
   void close();
 
   const_iterator begin() const;
index 998c00c..dcba10f 100644 (file)
@@ -30,8 +30,9 @@ void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top) {
   // Initialize DB
   database_ = DatabaseFactory(this->layer_param_.data_param().backend());
-  LOG(INFO) << "Opening database " << this->layer_param_.data_param().source();
-  database_->open(this->layer_param_.data_param().source(), Database::ReadOnly);
+  const string& source = this->layer_param_.data_param().source();
+  LOG(INFO) << "Opening database " << source;
+  CHECK(database_->open(source, Database::ReadOnly));
   iter_ = database_->begin();
 
   // Check if we would need to randomly skip a few data points
index 51d50cc..d7506ed 100644 (file)
@@ -5,7 +5,7 @@
 
 namespace caffe {
 
-void LeveldbDatabase::open(const string& filename, Mode mode) {
+bool LeveldbDatabase::open(const string& filename, Mode mode) {
   LOG(INFO) << "LevelDB: Open " << filename;
 
   leveldb::Options options;
@@ -40,15 +40,24 @@ void LeveldbDatabase::open(const string& filename, Mode mode) {
   leveldb::Status status = leveldb::DB::Open(
       options, filename, &db);
   db_.reset(db);
-  CHECK(status.ok()) << "Failed to open leveldb " << filename
-      << ". Is it already existing?";
+
+  if (!status.ok()) {
+    LOG(ERROR) << "Failed to open leveldb " << filename
+        << ". Is it already existing?";
+    return false;
+  }
+
   batch_.reset(new leveldb::WriteBatch());
+  return true;
 }
 
-void LeveldbDatabase::put(buffer_t* key, buffer_t* value) {
+bool LeveldbDatabase::put(buffer_t* key, buffer_t* value) {
   LOG(INFO) << "LevelDB: Put";
 
-  CHECK(!read_only_);
+  if (read_only_) {
+    LOG(ERROR) << "put can not be used on a database in ReadOnly mode";
+    return false;
+  }
 
   CHECK_NOTNULL(batch_.get());
 
@@ -56,9 +65,11 @@ void LeveldbDatabase::put(buffer_t* key, buffer_t* value) {
   leveldb::Slice value_slice(value->data(), value->size());
 
   batch_->Put(key_slice, value_slice);
+
+  return true;
 }
 
-void LeveldbDatabase::get(buffer_t* key, buffer_t* value) {
+bool LeveldbDatabase::get(buffer_t* key, buffer_t* value) {
   LOG(INFO) << "LevelDB: Get";
 
   leveldb::Slice key_slice(key->data(), key->size());
@@ -66,23 +77,35 @@ void LeveldbDatabase::get(buffer_t* key, buffer_t* value) {
   string value_string;
   leveldb::Status status =
       db_->Get(leveldb::ReadOptions(), key_slice, &value_string);
-  CHECK(status.ok()) << "leveldb get failed";
+
+  if (!status.ok()) {
+    LOG(ERROR) << "leveldb get failed";
+    return false;
+  }
 
   Database::buffer_t temp_value(value_string.data(),
       value_string.data() + value_string.size());
   value->swap(temp_value);
+
+  return true;
 }
 
-void LeveldbDatabase::commit() {
+bool LeveldbDatabase::commit() {
   LOG(INFO) << "LevelDB: Commit";
 
-  CHECK(!read_only_);
+  if (read_only_) {
+    LOG(ERROR) << "commit can not be used on a database in ReadOnly mode";
+    return false;
+  }
 
   CHECK_NOTNULL(db_.get());
   CHECK_NOTNULL(batch_.get());
 
-  db_->Write(leveldb::WriteOptions(), batch_.get());
+  leveldb::Status status = db_->Write(leveldb::WriteOptions(), batch_.get());
+
   batch_.reset(new leveldb::WriteBatch());
+
+  return status.ok();
 }
 
 void LeveldbDatabase::close() {
index a546c8c..d71513a 100644 (file)
@@ -7,7 +7,7 @@
 
 namespace caffe {
 
-void LmdbDatabase::open(const string& filename, Mode mode) {
+bool LmdbDatabase::open(const string& filename, Mode mode) {
   LOG(INFO) << "LMDB: Open " << filename;
 
   CHECK(NULL == env_);
@@ -19,12 +19,16 @@ void LmdbDatabase::open(const string& filename, Mode mode) {
     retval = mkdir(filename.c_str(), 0744);
     switch (mode) {
     case New:
-      CHECK_EQ(0, retval) << "mkdir " << filename << " failed";
+      if (0 != retval) {
+        LOG(ERROR) << "mkdir " << filename << " failed";
+        return false;
+      }
       break;
     case ReadWrite:
-      if (-1 == retval) {
-        CHECK_EQ(EEXIST, errno) << "mkdir " << filename << " failed ("
+      if (-1 == retval && EEXIST != errno) {
+        LOG(ERROR) << "mkdir " << filename << " failed ("
             << strerror(errno) << ")";
+        return false;
       }
       break;
     default:
@@ -33,11 +37,17 @@ void LmdbDatabase::open(const string& filename, Mode mode) {
   }
 
   retval = mdb_env_create(&env_);
-  CHECK_EQ(retval, MDB_SUCCESS) << "mdb_env_create failed "
-      << mdb_strerror(retval);
+  if (MDB_SUCCESS != retval) {
+    LOG(ERROR) << "mdb_env_create failed "
+        << mdb_strerror(retval);
+    return false;
+  }
+
   retval = mdb_env_set_mapsize(env_, 1099511627776);
-  CHECK_EQ(retval, MDB_SUCCESS)  // 1TB
-      << "mdb_env_set_mapsize failed " << mdb_strerror(retval);
+  if (MDB_SUCCESS != retval) {
+    LOG(ERROR) << "mdb_env_set_mapsize failed " << mdb_strerror(retval);
+    return false;
+  }
 
   int flag1 = 0;
   int flag2 = 0;
@@ -47,16 +57,27 @@ void LmdbDatabase::open(const string& filename, Mode mode) {
   }
 
   retval = mdb_env_open(env_, filename.c_str(), flag1, 0664);
-  CHECK_EQ(retval, MDB_SUCCESS)
-      << "mdb_env_open failed " << mdb_strerror(retval);
+  if (MDB_SUCCESS != retval) {
+    LOG(ERROR) << "mdb_env_open failed " << mdb_strerror(retval);
+    return false;
+  }
+
   retval = mdb_txn_begin(env_, NULL, flag2, &txn_);
-  CHECK_EQ(retval, MDB_SUCCESS)
-      << "mdb_txn_begin failed " << mdb_strerror(retval);
+  if (MDB_SUCCESS != retval) {
+    LOG(ERROR) << "mdb_txn_begin failed " << mdb_strerror(retval);
+    return false;
+  }
+
   retval = mdb_open(txn_, NULL, 0, &dbi_);
-  CHECK_EQ(retval, MDB_SUCCESS) << "mdb_open failed" << mdb_strerror(retval);
+  if (MDB_SUCCESS != retval) {
+    LOG(ERROR) << "mdb_open failed" << mdb_strerror(retval);
+    return false;
+  }
+
+  return true;
 }
 
-void LmdbDatabase::put(buffer_t* key, buffer_t* value) {
+bool LmdbDatabase::put(buffer_t* key, buffer_t* value) {
   LOG(INFO) << "LMDB: Put";
 
   MDB_val mdbkey, mdbdata;
@@ -69,11 +90,15 @@ void LmdbDatabase::put(buffer_t* key, buffer_t* value) {
   CHECK_NE(0, dbi_);
 
   int retval = mdb_put(txn_, dbi_, &mdbkey, &mdbdata, 0);
-  CHECK_EQ(retval, MDB_SUCCESS)
-      << "mdb_put failed " << mdb_strerror(retval);
+  if (MDB_SUCCESS != retval) {
+    LOG(ERROR) << "mdb_put failed " << mdb_strerror(retval);
+    return false;
+  }
+
+  return true;
 }
 
-void LmdbDatabase::get(buffer_t* key, buffer_t* value) {
+bool LmdbDatabase::get(buffer_t* key, buffer_t* value) {
   LOG(INFO) << "LMDB: Get";
 
   MDB_val mdbkey, mdbdata;
@@ -83,11 +108,16 @@ void LmdbDatabase::get(buffer_t* key, buffer_t* value) {
   int retval;
   MDB_txn* get_txn;
   retval = mdb_txn_begin(env_, NULL, MDB_RDONLY, &get_txn);
-  CHECK_EQ(MDB_SUCCESS, retval) << "mdb_txn_begin failed "
-      << mdb_strerror(retval);
+  if (MDB_SUCCESS != retval) {
+    LOG(ERROR) << "mdb_txn_begin failed " << mdb_strerror(retval);
+    return false;
+  }
 
   retval = mdb_get(get_txn, dbi_, &mdbkey, &mdbdata);
-  CHECK_EQ(MDB_SUCCESS, retval) << "mdb_get failed " << mdb_strerror(retval);
+  if (MDB_SUCCESS != retval) {
+    LOG(ERROR) << "mdb_get failed " << mdb_strerror(retval);
+    return false;
+  }
 
   mdb_txn_abort(get_txn);
 
@@ -95,21 +125,29 @@ void LmdbDatabase::get(buffer_t* key, buffer_t* value) {
       reinterpret_cast<char*>(mdbdata.mv_data) + mdbdata.mv_size);
 
   value->swap(temp_value);
+
+  return true;
 }
 
-void LmdbDatabase::commit() {
+bool LmdbDatabase::commit() {
   LOG(INFO) << "LMDB: Commit";
 
   CHECK_NOTNULL(txn_);
 
   int retval;
   retval = mdb_txn_commit(txn_);
-  CHECK_EQ(retval, MDB_SUCCESS) << "mdb_txn_commit failed "
-      << mdb_strerror(retval);
+  if (MDB_SUCCESS != retval) {
+    LOG(ERROR) << "mdb_txn_commit failed " << mdb_strerror(retval);
+    return false;
+  }
 
   retval = mdb_txn_begin(env_, NULL, 0, &txn_);
-  CHECK_EQ(retval, MDB_SUCCESS)
-      << "mdb_txn_begin failed " << mdb_strerror(retval);
+  if (MDB_SUCCESS != retval) {
+    LOG(ERROR) << "mdb_txn_begin failed " << mdb_strerror(retval);
+    return false;
+  }
+
+  return true;
 }
 
 void LmdbDatabase::close() {
index 98ef1b9..cc9ad20 100644 (file)
@@ -40,7 +40,7 @@ class DataLayerTest : public MultiDeviceTest<TypeParam> {
     backend_ = backend;
     LOG(INFO) << "Using temporary database " << *filename_;
     shared_ptr<Database> database = DatabaseFactory(backend_);
-    database->open(*filename_, Database::New);
+    CHECK(database->open(*filename_, Database::New));
     for (int i = 0; i < 5; ++i) {
       Datum datum;
       datum.set_label(i);
@@ -59,9 +59,9 @@ class DataLayerTest : public MultiDeviceTest<TypeParam> {
       Database::buffer_t value(datum.ByteSize());
       datum.SerializeWithCachedSizesToArray(
           reinterpret_cast<unsigned char*>(value.data()));
-      database->put(&key, &value);
+      CHECK(database->put(&key, &value));
     }
-    database->commit();
+    CHECK(database->commit());
     database->close();
   }
 
index 5d5f4ea..f658650 100644 (file)
@@ -66,56 +66,56 @@ TYPED_TEST_CASE(DatabaseTest, TestDtypesAndDevices);
 
 TYPED_TEST(DatabaseTest, TestNewDoesntExistLevelDBPasses) {
   shared_ptr<Database> database = DatabaseFactory("leveldb");
-  database->open(this->DBName(), Database::New);
+  EXPECT_TRUE(database->open(this->DBName(), Database::New));
   database->close();
 }
 
 TYPED_TEST(DatabaseTest, TestNewExistsFailsLevelDB) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("leveldb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
   database->close();
 
-  EXPECT_DEATH(database->open(name, Database::New), "");
+  EXPECT_FALSE(database->open(name, Database::New));
 }
 
 TYPED_TEST(DatabaseTest, TestReadOnlyExistsLevelDBPasses) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("leveldb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
   database->close();
 
-  database->open(name, Database::ReadOnly);
+  EXPECT_TRUE(database->open(name, Database::ReadOnly));
   database->close();
 }
 
 TYPED_TEST(DatabaseTest, TestReadOnlyDoesntExistFailsLevelDB) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("leveldb");
-  EXPECT_DEATH(database->open(name, Database::ReadOnly), "");
+  EXPECT_FALSE(database->open(name, Database::ReadOnly));
 }
 
 TYPED_TEST(DatabaseTest, TestReadWriteExistsLevelDBPasses) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("leveldb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
   database->close();
 
-  database->open(name, Database::ReadWrite);
+  EXPECT_TRUE(database->open(name, Database::ReadWrite));
   database->close();
 }
 
 TYPED_TEST(DatabaseTest, TestReadWriteDoesntExistLevelDBPasses) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("leveldb");
-  database->open(name, Database::ReadWrite);
+  EXPECT_TRUE(database->open(name, Database::ReadWrite));
   database->close();
 }
 
 TYPED_TEST(DatabaseTest, TestIteratorsLevelDB) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("leveldb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
 
   const int kNumExamples = 4;
   for (int i = 0; i < kNumExamples; ++i) {
@@ -126,9 +126,9 @@ TYPED_TEST(DatabaseTest, TestIteratorsLevelDB) {
     string value = ss.str();
     Database::buffer_t key_buf(key.data(), key.data() + key.size());
     Database::buffer_t val_buf(value.data(), value.data() + value.size());
-    database->put(&key_buf, &val_buf);
+    EXPECT_TRUE(database->put(&key_buf, &val_buf));
   }
-  database->commit();
+  EXPECT_TRUE(database->commit());
 
   int count = 0;
   for (Database::const_iterator iter = database->begin();
@@ -143,7 +143,7 @@ TYPED_TEST(DatabaseTest, TestIteratorsLevelDB) {
 TYPED_TEST(DatabaseTest, TestIteratorsPreIncrementLevelDB) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("leveldb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
 
   Database::buffer_t key1 = this->TestAltKey();
   Database::buffer_t value1 = this->TestAltValue();
@@ -151,9 +151,9 @@ TYPED_TEST(DatabaseTest, TestIteratorsPreIncrementLevelDB) {
   Database::buffer_t key2 = this->TestKey();
   Database::buffer_t value2 = this->TestValue();
 
-  database->put(&key1, &value1);
-  database->put(&key2, &value2);
-  database->commit();
+  EXPECT_TRUE(database->put(&key1, &value1));
+  EXPECT_TRUE(database->put(&key2, &value2));
+  EXPECT_TRUE(database->commit());
 
   Database::const_iterator iter1 = database->begin();
 
@@ -182,7 +182,7 @@ TYPED_TEST(DatabaseTest, TestIteratorsPreIncrementLevelDB) {
 TYPED_TEST(DatabaseTest, TestIteratorsPostIncrementLevelDB) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("leveldb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
 
   Database::buffer_t key1 = this->TestAltKey();
   Database::buffer_t value1 = this->TestAltValue();
@@ -190,9 +190,9 @@ TYPED_TEST(DatabaseTest, TestIteratorsPostIncrementLevelDB) {
   Database::buffer_t key2 = this->TestKey();
   Database::buffer_t value2 = this->TestValue();
 
-  database->put(&key1, &value1);
-  database->put(&key2, &value2);
-  database->commit();
+  EXPECT_TRUE(database->put(&key1, &value1));
+  EXPECT_TRUE(database->put(&key2, &value2));
+  EXPECT_TRUE(database->commit());
 
   Database::const_iterator iter1 = database->begin();
 
@@ -224,14 +224,14 @@ TYPED_TEST(DatabaseTest, TestIteratorsPostIncrementLevelDB) {
 TYPED_TEST(DatabaseTest, TestNewPutLevelDBPasses) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("leveldb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
 
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  database->put(&key, &val);
+  EXPECT_TRUE(database->put(&key, &val));
 
-  database->commit();
+  EXPECT_TRUE(database->commit());
 
   database->close();
 }
@@ -239,9 +239,9 @@ TYPED_TEST(DatabaseTest, TestNewPutLevelDBPasses) {
 TYPED_TEST(DatabaseTest, TestNewCommitLevelDBPasses) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("leveldb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
 
-  database->commit();
+  EXPECT_TRUE(database->commit());
 
   database->close();
 }
@@ -249,18 +249,18 @@ TYPED_TEST(DatabaseTest, TestNewCommitLevelDBPasses) {
 TYPED_TEST(DatabaseTest, TestNewGetLevelDBPasses) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("leveldb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
 
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  database->put(&key, &val);
+  EXPECT_TRUE(database->put(&key, &val));
 
-  database->commit();
+  EXPECT_TRUE(database->commit());
 
   Database::buffer_t new_val;
 
-  database->get(&key, &new_val);
+  EXPECT_TRUE(database->get(&key, &new_val));
 
   EXPECT_TRUE(this->BufferEq(val, new_val));
 
@@ -270,30 +270,30 @@ TYPED_TEST(DatabaseTest, TestNewGetLevelDBPasses) {
 TYPED_TEST(DatabaseTest, TestNewGetNoCommitLevelDBFails) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("leveldb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
 
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  database->put(&key, &val);
+  EXPECT_TRUE(database->put(&key, &val));
 
   Database::buffer_t new_val;
 
-  EXPECT_DEATH(database->get(&key, &new_val), "");
+  EXPECT_FALSE(database->get(&key, &new_val));
 }
 
 
 TYPED_TEST(DatabaseTest, TestReadWritePutLevelDBPasses) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("leveldb");
-  database->open(name, Database::ReadWrite);
+  EXPECT_TRUE(database->open(name, Database::ReadWrite));
 
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  database->put(&key, &val);
+  EXPECT_TRUE(database->put(&key, &val));
 
-  database->commit();
+  EXPECT_TRUE(database->commit());
 
   database->close();
 }
@@ -301,9 +301,9 @@ TYPED_TEST(DatabaseTest, TestReadWritePutLevelDBPasses) {
 TYPED_TEST(DatabaseTest, TestReadWriteCommitLevelDBPasses) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("leveldb");
-  database->open(name, Database::ReadWrite);
+  EXPECT_TRUE(database->open(name, Database::ReadWrite));
 
-  database->commit();
+  EXPECT_TRUE(database->commit());
 
   database->close();
 }
@@ -311,18 +311,18 @@ TYPED_TEST(DatabaseTest, TestReadWriteCommitLevelDBPasses) {
 TYPED_TEST(DatabaseTest, TestReadWriteGetLevelDBPasses) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("leveldb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
 
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  database->put(&key, &val);
+  EXPECT_TRUE(database->put(&key, &val));
 
-  database->commit();
+  EXPECT_TRUE(database->commit());
 
   Database::buffer_t new_val;
 
-  database->get(&key, &new_val);
+  EXPECT_TRUE(database->get(&key, &new_val));
 
   EXPECT_TRUE(this->BufferEq(val, new_val));
 
@@ -332,62 +332,62 @@ TYPED_TEST(DatabaseTest, TestReadWriteGetLevelDBPasses) {
 TYPED_TEST(DatabaseTest, TestReadWriteGetNoCommitLevelDBFails) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("leveldb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
 
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  database->put(&key, &val);
+  EXPECT_TRUE(database->put(&key, &val));
 
   Database::buffer_t new_val;
 
-  EXPECT_DEATH(database->get(&key, &new_val), "");
+  EXPECT_FALSE(database->get(&key, &new_val));
 }
 
 TYPED_TEST(DatabaseTest, TestReadOnlyPutLevelDBFails) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("leveldb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
   database->close();
 
-  database->open(name, Database::ReadOnly);
+  EXPECT_TRUE(database->open(name, Database::ReadOnly));
 
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  EXPECT_DEATH(database->put(&key, &val), "");
+  EXPECT_FALSE(database->put(&key, &val));
 }
 
 TYPED_TEST(DatabaseTest, TestReadOnlyCommitLevelDBFails) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("leveldb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
   database->close();
 
-  database->open(name, Database::ReadOnly);
+  EXPECT_TRUE(database->open(name, Database::ReadOnly));
 
-  EXPECT_DEATH(database->commit(), "");
+  EXPECT_FALSE(database->commit());
 }
 
 TYPED_TEST(DatabaseTest, TestReadOnlyGetLevelDBPasses) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("leveldb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
 
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  database->put(&key, &val);
+  EXPECT_TRUE(database->put(&key, &val));
 
-  database->commit();
+  EXPECT_TRUE(database->commit());
 
   database->close();
 
-  database->open(name, Database::ReadOnly);
+  EXPECT_TRUE(database->open(name, Database::ReadOnly));
 
   Database::buffer_t new_val;
 
-  database->get(&key, &new_val);
+  EXPECT_TRUE(database->get(&key, &new_val));
 
   EXPECT_TRUE(this->BufferEq(val, new_val));
 }
@@ -395,74 +395,74 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetLevelDBPasses) {
 TYPED_TEST(DatabaseTest, TestReadOnlyGetNoCommitLevelDBFails) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("leveldb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
 
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  database->put(&key, &val);
+  EXPECT_TRUE(database->put(&key, &val));
 
   database->close();
 
-  database->open(name, Database::ReadOnly);
+  EXPECT_TRUE(database->open(name, Database::ReadOnly));
 
   Database::buffer_t new_val;
 
-  EXPECT_DEATH(database->get(&key, &new_val), "");
+  EXPECT_FALSE(database->get(&key, &new_val));
 }
 
 TYPED_TEST(DatabaseTest, TestNewDoesntExistLMDBPasses) {
   shared_ptr<Database> database = DatabaseFactory("lmdb");
-  database->open(this->DBName(), Database::New);
+  EXPECT_TRUE(database->open(this->DBName(), Database::New));
   database->close();
 }
 
 TYPED_TEST(DatabaseTest, TestNewExistsFailsLMDB) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("lmdb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
   database->close();
 
-  EXPECT_DEATH(database->open(name, Database::New), "");
+  EXPECT_FALSE(database->open(name, Database::New));
 }
 
 TYPED_TEST(DatabaseTest, TestReadOnlyExistsLMDBPasses) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("lmdb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
   database->close();
 
-  database->open(name, Database::ReadOnly);
+  EXPECT_TRUE(database->open(name, Database::ReadOnly));
   database->close();
 }
 
 TYPED_TEST(DatabaseTest, TestReadOnlyDoesntExistFailsLMDB) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("lmdb");
-  EXPECT_DEATH(database->open(name, Database::ReadOnly), "");
+  EXPECT_FALSE(database->open(name, Database::ReadOnly));
 }
 
 TYPED_TEST(DatabaseTest, TestReadWriteExistsLMDBPasses) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("lmdb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
   database->close();
 
-  database->open(name, Database::ReadWrite);
+  EXPECT_TRUE(database->open(name, Database::ReadWrite));
   database->close();
 }
 
 TYPED_TEST(DatabaseTest, TestReadWriteDoesntExistLMDBPasses) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("lmdb");
-  database->open(name, Database::ReadWrite);
+  EXPECT_TRUE(database->open(name, Database::ReadWrite));
   database->close();
 }
 
 TYPED_TEST(DatabaseTest, TestIteratorsLMDB) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("lmdb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
 
   const int kNumExamples = 4;
   for (int i = 0; i < kNumExamples; ++i) {
@@ -473,9 +473,9 @@ TYPED_TEST(DatabaseTest, TestIteratorsLMDB) {
     string value = ss.str();
     Database::buffer_t key_buf(key.data(), key.data() + key.size());
     Database::buffer_t val_buf(value.data(), value.data() + value.size());
-    database->put(&key_buf, &val_buf);
+    EXPECT_TRUE(database->put(&key_buf, &val_buf));
   }
-  database->commit();
+  EXPECT_TRUE(database->commit());
 
   int count = 0;
   for (Database::const_iterator iter = database->begin();
@@ -490,7 +490,7 @@ TYPED_TEST(DatabaseTest, TestIteratorsLMDB) {
 TYPED_TEST(DatabaseTest, TestIteratorsPreIncrementLMDB) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("lmdb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
 
   Database::buffer_t key1 = this->TestAltKey();
   Database::buffer_t value1 = this->TestAltValue();
@@ -498,9 +498,9 @@ TYPED_TEST(DatabaseTest, TestIteratorsPreIncrementLMDB) {
   Database::buffer_t key2 = this->TestKey();
   Database::buffer_t value2 = this->TestValue();
 
-  database->put(&key1, &value1);
-  database->put(&key2, &value2);
-  database->commit();
+  EXPECT_TRUE(database->put(&key1, &value1));
+  EXPECT_TRUE(database->put(&key2, &value2));
+  EXPECT_TRUE(database->commit());
 
   Database::const_iterator iter1 = database->begin();
 
@@ -529,7 +529,7 @@ TYPED_TEST(DatabaseTest, TestIteratorsPreIncrementLMDB) {
 TYPED_TEST(DatabaseTest, TestIteratorsPostIncrementLMDB) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("lmdb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
 
   Database::buffer_t key1 = this->TestAltKey();
   Database::buffer_t value1 = this->TestAltValue();
@@ -537,9 +537,9 @@ TYPED_TEST(DatabaseTest, TestIteratorsPostIncrementLMDB) {
   Database::buffer_t key2 = this->TestKey();
   Database::buffer_t value2 = this->TestValue();
 
-  database->put(&key1, &value1);
-  database->put(&key2, &value2);
-  database->commit();
+  EXPECT_TRUE(database->put(&key1, &value1));
+  EXPECT_TRUE(database->put(&key2, &value2));
+  EXPECT_TRUE(database->commit());
 
   Database::const_iterator iter1 = database->begin();
 
@@ -571,14 +571,14 @@ TYPED_TEST(DatabaseTest, TestIteratorsPostIncrementLMDB) {
 TYPED_TEST(DatabaseTest, TestNewPutLMDBPasses) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("lmdb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
 
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  database->put(&key, &val);
+  EXPECT_TRUE(database->put(&key, &val));
 
-  database->commit();
+  EXPECT_TRUE(database->commit());
 
   database->close();
 }
@@ -586,9 +586,9 @@ TYPED_TEST(DatabaseTest, TestNewPutLMDBPasses) {
 TYPED_TEST(DatabaseTest, TestNewCommitLMDBPasses) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("lmdb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
 
-  database->commit();
+  EXPECT_TRUE(database->commit());
 
   database->close();
 }
@@ -596,18 +596,18 @@ TYPED_TEST(DatabaseTest, TestNewCommitLMDBPasses) {
 TYPED_TEST(DatabaseTest, TestNewGetLMDBPasses) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("lmdb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
 
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  database->put(&key, &val);
+  EXPECT_TRUE(database->put(&key, &val));
 
-  database->commit();
+  EXPECT_TRUE(database->commit());
 
   Database::buffer_t new_val;
 
-  database->get(&key, &new_val);
+  EXPECT_TRUE(database->get(&key, &new_val));
 
   EXPECT_TRUE(this->BufferEq(val, new_val));
 
@@ -617,29 +617,29 @@ TYPED_TEST(DatabaseTest, TestNewGetLMDBPasses) {
 TYPED_TEST(DatabaseTest, TestNewGetNoCommitLMDBFails) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("lmdb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
 
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  database->put(&key, &val);
+  EXPECT_TRUE(database->put(&key, &val));
 
   Database::buffer_t new_val;
 
-  EXPECT_DEATH(database->get(&key, &new_val), "");
+  EXPECT_FALSE(database->get(&key, &new_val));
 }
 
 TYPED_TEST(DatabaseTest, TestReadWritePutLMDBPasses) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("lmdb");
-  database->open(name, Database::ReadWrite);
+  EXPECT_TRUE(database->open(name, Database::ReadWrite));
 
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  database->put(&key, &val);
+  EXPECT_TRUE(database->put(&key, &val));
 
-  database->commit();
+  EXPECT_TRUE(database->commit());
 
   database->close();
 }
@@ -647,9 +647,9 @@ TYPED_TEST(DatabaseTest, TestReadWritePutLMDBPasses) {
 TYPED_TEST(DatabaseTest, TestReadWriteCommitLMDBPasses) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("lmdb");
-  database->open(name, Database::ReadWrite);
+  EXPECT_TRUE(database->open(name, Database::ReadWrite));
 
-  database->commit();
+  EXPECT_TRUE(database->commit());
 
   database->close();
 }
@@ -657,18 +657,18 @@ TYPED_TEST(DatabaseTest, TestReadWriteCommitLMDBPasses) {
 TYPED_TEST(DatabaseTest, TestReadWriteGetLMDBPasses) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("lmdb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
 
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  database->put(&key, &val);
+  EXPECT_TRUE(database->put(&key, &val));
 
-  database->commit();
+  EXPECT_TRUE(database->commit());
 
   Database::buffer_t new_val;
 
-  database->get(&key, &new_val);
+  EXPECT_TRUE(database->get(&key, &new_val));
 
   EXPECT_TRUE(this->BufferEq(val, new_val));
 
@@ -678,62 +678,62 @@ TYPED_TEST(DatabaseTest, TestReadWriteGetLMDBPasses) {
 TYPED_TEST(DatabaseTest, TestReadWriteGetNoCommitLMDBFails) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("lmdb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
 
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  database->put(&key, &val);
+  EXPECT_TRUE(database->put(&key, &val));
 
   Database::buffer_t new_val;
 
-  EXPECT_DEATH(database->get(&key, &new_val), "");
+  EXPECT_FALSE(database->get(&key, &new_val));
 }
 
 TYPED_TEST(DatabaseTest, TestReadOnlyPutLMDBFails) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("lmdb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
   database->close();
 
-  database->open(name, Database::ReadOnly);
+  EXPECT_TRUE(database->open(name, Database::ReadOnly));
 
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  EXPECT_DEATH(database->put(&key, &val), "");
+  EXPECT_FALSE(database->put(&key, &val));
 }
 
 TYPED_TEST(DatabaseTest, TestReadOnlyCommitLMDBFails) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("lmdb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
   database->close();
 
-  database->open(name, Database::ReadOnly);
+  EXPECT_TRUE(database->open(name, Database::ReadOnly));
 
-  EXPECT_DEATH(database->commit(), "");
+  EXPECT_FALSE(database->commit());
 }
 
 TYPED_TEST(DatabaseTest, TestReadOnlyGetLMDBPasses) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("lmdb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
 
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  database->put(&key, &val);
+  EXPECT_TRUE(database->put(&key, &val));
 
-  database->commit();
+  EXPECT_TRUE(database->commit());
 
   database->close();
 
-  database->open(name, Database::ReadOnly);
+  EXPECT_TRUE(database->open(name, Database::ReadOnly));
 
   Database::buffer_t new_val;
 
-  database->get(&key, &new_val);
+  EXPECT_TRUE(database->get(&key, &new_val));
 
   EXPECT_TRUE(this->BufferEq(val, new_val));
 }
@@ -741,20 +741,20 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetLMDBPasses) {
 TYPED_TEST(DatabaseTest, TestReadOnlyGetNoCommitLMDBFails) {
   string name = this->DBName();
   shared_ptr<Database> database = DatabaseFactory("lmdb");
-  database->open(name, Database::New);
+  EXPECT_TRUE(database->open(name, Database::New));
 
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  database->put(&key, &val);
+  EXPECT_TRUE(database->put(&key, &val));
 
   database->close();
 
-  database->open(name, Database::ReadOnly);
+  EXPECT_TRUE(database->open(name, Database::ReadOnly));
 
   Database::buffer_t new_val;
 
-  EXPECT_DEATH(database->get(&key, &new_val), "");
+  EXPECT_FALSE(database->get(&key, &new_val));
 }
 
 }  // namespace caffe
index aaa324a..b973c52 100644 (file)
@@ -29,7 +29,7 @@ int main(int argc, char** argv) {
   caffe::shared_ptr<Database> database = caffe::DatabaseFactory(db_backend);
 
   // Open db
-  database->open(argv[1], Database::ReadOnly);
+  CHECK(database->open(argv[1], Database::ReadOnly));
 
   Datum datum;
   BlobProto sum_blob;
index 19c87e5..3345c9c 100644 (file)
@@ -81,7 +81,7 @@ int main(int argc, char** argv) {
   shared_ptr<Database> database = DatabaseFactory(db_backend);
 
   // Open db
-  database->open(db_path, Database::New);
+  CHECK(database->open(db_path, Database::New));
 
   // Storing to db
   std::string root_folder(argv[1]);
@@ -116,17 +116,17 @@ int main(int argc, char** argv) {
     Database::buffer_t keystr(key_cstr, key_cstr + length);
 
     // Put in db
-    database->put(&keystr, &value);
+    CHECK(database->put(&keystr, &value));
 
     if (++count % 1000 == 0) {
       // Commit txn
-      database->commit();
+      CHECK(database->commit());
       LOG(ERROR) << "Processed " << count << " files.";
     }
   }
   // write the last batch
   if (count % 1000 != 0) {
-    database->commit();
+    CHECK(database->commit());
     LOG(ERROR) << "Processed " << count << " files.";
   }
   database->close();
index c4d1a39..1560ef6 100644 (file)
@@ -125,7 +125,7 @@ int feature_extraction_pipeline(int argc, char** argv) {
   for (size_t i = 0; i < num_features; ++i) {
     LOG(INFO)<< "Opening database " << database_names[i];
     shared_ptr<Database> database = DatabaseFactory(argv[++arg_pos]);
-    database->open(database_names.at(i), Database::New);
+    CHECK(database->open(database_names.at(i), Database::New));
     feature_dbs.push_back(database);
   }
 
@@ -161,10 +161,10 @@ int feature_extraction_pipeline(int argc, char** argv) {
         int length = snprintf(key_str, kMaxKeyStrLength, "%d",
             image_indices[i]);
         Database::buffer_t key(key_str, key_str + length);
-        feature_dbs.at(i)->put(&key, &value);
+        CHECK(feature_dbs.at(i)->put(&key, &value));
         ++image_indices[i];
         if (image_indices[i] % 1000 == 0) {
-          feature_dbs.at(i)->commit();
+          CHECK(feature_dbs.at(i)->commit());
           LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
               " query images for feature blob " << blob_names[i];
         }
@@ -174,7 +174,7 @@ int feature_extraction_pipeline(int argc, char** argv) {
   // write the last batch
   for (int i = 0; i < num_features; ++i) {
     if (image_indices[i] % 1000 != 0) {
-      feature_dbs.at(i)->commit();
+      CHECK(feature_dbs.at(i)->commit());
     }
     LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
         " query images for feature blob " << blob_names[i];