Updated Database interface to take key and value by const reference for put and key...
authorKevin James Matzen <kmatzen@cs.cornell.edu>
Mon, 13 Oct 2014 00:30:29 +0000 (20:30 -0400)
committerKevin James Matzen <kmatzen@cs.cornell.edu>
Tue, 14 Oct 2014 23:31:30 +0000 (19:31 -0400)
examples/cifar10/convert_cifar_data.cpp
include/caffe/database.hpp
include/caffe/leveldb_database.hpp
include/caffe/lmdb_database.hpp
src/caffe/leveldb_database.cpp
src/caffe/lmdb_database.cpp
src/caffe/test/test_data_layer.cpp
src/caffe/test/test_database.cpp
tools/convert_imageset.cpp
tools/extract_features.cpp

index b29e412..af845ea 100644 (file)
@@ -66,7 +66,7 @@ void convert_dataset(const string& input_folder, const string& output_folder,
       int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d",
           fileid * kCIFARBatchSize + itemid);
       Database::buffer_t key(str_buffer, str_buffer + length);
-      CHECK(train_database->put(&key, &value));
+      CHECK(train_database->put(key, value));
     }
   }
   CHECK(train_database->commit());
@@ -89,7 +89,7 @@ void convert_dataset(const string& input_folder, const string& output_folder,
         reinterpret_cast<unsigned char*>(value.data()));
     int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d", itemid);
     Database::buffer_t key(str_buffer, str_buffer + length);
-    CHECK(test_database->put(&key, &value));
+    CHECK(test_database->put(key, value));
   }
   CHECK(test_database->commit());
   test_database->close();
index 148b1ed..3f3970d 100644 (file)
@@ -27,8 +27,8 @@ class Database {
   };
 
   virtual bool open(const string& filename, Mode mode) = 0;
-  virtual bool put(buffer_t* key, buffer_t* value) = 0;
-  virtual bool get(buffer_t* key, buffer_t* value) = 0;
+  virtual bool put(const buffer_t& key, const buffer_t& value) = 0;
+  virtual bool get(const buffer_t& key, buffer_t* value) = 0;
   virtual bool commit() = 0;
   virtual void close() = 0;
 
index 64bfa7c..e2558ff 100644 (file)
@@ -15,8 +15,8 @@ namespace caffe {
 class LeveldbDatabase : public Database {
  public:
   bool open(const string& filename, Mode mode);
-  bool put(buffer_t* key, buffer_t* value);
-  bool get(buffer_t* key, buffer_t* value);
+  bool put(const buffer_t& key, const buffer_t& value);
+  bool get(const buffer_t& key, buffer_t* value);
   bool commit();
   void close();
 
index 69e3ce0..4a0f318 100644 (file)
@@ -19,8 +19,8 @@ class LmdbDatabase : public Database {
         txn_(NULL) { }
 
   bool open(const string& filename, Mode mode);
-  bool put(buffer_t* key, buffer_t* value);
-  bool get(buffer_t* key, buffer_t* value);
+  bool put(const buffer_t& key, const buffer_t& value);
+  bool get(const buffer_t& key, buffer_t* value);
   bool commit();
   void close();
 
index d7506ed..c09112f 100644 (file)
@@ -51,7 +51,7 @@ bool LeveldbDatabase::open(const string& filename, Mode mode) {
   return true;
 }
 
-bool LeveldbDatabase::put(buffer_t* key, buffer_t* value) {
+bool LeveldbDatabase::put(const buffer_t& key, const buffer_t& value) {
   LOG(INFO) << "LevelDB: Put";
 
   if (read_only_) {
@@ -61,18 +61,18 @@ bool LeveldbDatabase::put(buffer_t* key, buffer_t* value) {
 
   CHECK_NOTNULL(batch_.get());
 
-  leveldb::Slice key_slice(key->data(), key->size());
-  leveldb::Slice value_slice(value->data(), value->size());
+  leveldb::Slice key_slice(key.data(), key.size());
+  leveldb::Slice value_slice(value.data(), value.size());
 
   batch_->Put(key_slice, value_slice);
 
   return true;
 }
 
-bool LeveldbDatabase::get(buffer_t* key, buffer_t* value) {
+bool LeveldbDatabase::get(const buffer_t& key, buffer_t* value) {
   LOG(INFO) << "LevelDB: Get";
 
-  leveldb::Slice key_slice(key->data(), key->size());
+  leveldb::Slice key_slice(key.data(), key.size());
 
   string value_string;
   leveldb::Status status =
index d71513a..2cb699b 100644 (file)
@@ -77,14 +77,18 @@ bool LmdbDatabase::open(const string& filename, Mode mode) {
   return true;
 }
 
-bool LmdbDatabase::put(buffer_t* key, buffer_t* value) {
+bool LmdbDatabase::put(const buffer_t& key, const buffer_t& value) {
   LOG(INFO) << "LMDB: Put";
 
+  // MDB_val::mv_size is not const, so we need to make a local copy.
+  buffer_t local_key = key;
+  buffer_t local_value = value;
+
   MDB_val mdbkey, mdbdata;
-  mdbdata.mv_size = value->size();
-  mdbdata.mv_data = value->data();
-  mdbkey.mv_size = key->size();
-  mdbkey.mv_data = key->data();
+  mdbdata.mv_size = local_value.size();
+  mdbdata.mv_data = local_value.data();
+  mdbkey.mv_size = local_key.size();
+  mdbkey.mv_data = local_key.data();
 
   CHECK_NOTNULL(txn_);
   CHECK_NE(0, dbi_);
@@ -98,12 +102,14 @@ bool LmdbDatabase::put(buffer_t* key, buffer_t* value) {
   return true;
 }
 
-bool LmdbDatabase::get(buffer_t* key, buffer_t* value) {
+bool LmdbDatabase::get(const buffer_t& key, buffer_t* value) {
   LOG(INFO) << "LMDB: Get";
 
+  buffer_t local_key = key;
+
   MDB_val mdbkey, mdbdata;
-  mdbkey.mv_data = key->data();
-  mdbkey.mv_size = key->size();
+  mdbkey.mv_data = local_key.data();
+  mdbkey.mv_size = local_key.size();
 
   int retval;
   MDB_txn* get_txn;
index cc9ad20..8ae8f2b 100644 (file)
@@ -59,7 +59,7 @@ class DataLayerTest : public MultiDeviceTest<TypeParam> {
       Database::buffer_t value(datum.ByteSize());
       datum.SerializeWithCachedSizesToArray(
           reinterpret_cast<unsigned char*>(value.data()));
-      CHECK(database->put(&key, &value));
+      CHECK(database->put(key, value));
     }
     CHECK(database->commit());
     database->close();
index f658650..70e1d96 100644 (file)
@@ -126,7 +126,7 @@ TYPED_TEST(DatabaseTest, TestIteratorsLevelDB) {
     string value = ss.str();
     Database::buffer_t key_buf(key.data(), key.data() + key.size());
     Database::buffer_t val_buf(value.data(), value.data() + value.size());
-    EXPECT_TRUE(database->put(&key_buf, &val_buf));
+    EXPECT_TRUE(database->put(key_buf, val_buf));
   }
   EXPECT_TRUE(database->commit());
 
@@ -151,8 +151,8 @@ TYPED_TEST(DatabaseTest, TestIteratorsPreIncrementLevelDB) {
   Database::buffer_t key2 = this->TestKey();
   Database::buffer_t value2 = this->TestValue();
 
-  EXPECT_TRUE(database->put(&key1, &value1));
-  EXPECT_TRUE(database->put(&key2, &value2));
+  EXPECT_TRUE(database->put(key1, value1));
+  EXPECT_TRUE(database->put(key2, value2));
   EXPECT_TRUE(database->commit());
 
   Database::const_iterator iter1 = database->begin();
@@ -190,8 +190,8 @@ TYPED_TEST(DatabaseTest, TestIteratorsPostIncrementLevelDB) {
   Database::buffer_t key2 = this->TestKey();
   Database::buffer_t value2 = this->TestValue();
 
-  EXPECT_TRUE(database->put(&key1, &value1));
-  EXPECT_TRUE(database->put(&key2, &value2));
+  EXPECT_TRUE(database->put(key1, value1));
+  EXPECT_TRUE(database->put(key2, value2));
   EXPECT_TRUE(database->commit());
 
   Database::const_iterator iter1 = database->begin();
@@ -229,7 +229,7 @@ TYPED_TEST(DatabaseTest, TestNewPutLevelDBPasses) {
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  EXPECT_TRUE(database->put(&key, &val));
+  EXPECT_TRUE(database->put(key, val));
 
   EXPECT_TRUE(database->commit());
 
@@ -254,13 +254,13 @@ TYPED_TEST(DatabaseTest, TestNewGetLevelDBPasses) {
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  EXPECT_TRUE(database->put(&key, &val));
+  EXPECT_TRUE(database->put(key, val));
 
   EXPECT_TRUE(database->commit());
 
   Database::buffer_t new_val;
 
-  EXPECT_TRUE(database->get(&key, &new_val));
+  EXPECT_TRUE(database->get(key, &new_val));
 
   EXPECT_TRUE(this->BufferEq(val, new_val));
 
@@ -275,11 +275,11 @@ TYPED_TEST(DatabaseTest, TestNewGetNoCommitLevelDBFails) {
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  EXPECT_TRUE(database->put(&key, &val));
+  EXPECT_TRUE(database->put(key, val));
 
   Database::buffer_t new_val;
 
-  EXPECT_FALSE(database->get(&key, &new_val));
+  EXPECT_FALSE(database->get(key, &new_val));
 }
 
 
@@ -291,7 +291,7 @@ TYPED_TEST(DatabaseTest, TestReadWritePutLevelDBPasses) {
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  EXPECT_TRUE(database->put(&key, &val));
+  EXPECT_TRUE(database->put(key, val));
 
   EXPECT_TRUE(database->commit());
 
@@ -316,13 +316,13 @@ TYPED_TEST(DatabaseTest, TestReadWriteGetLevelDBPasses) {
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  EXPECT_TRUE(database->put(&key, &val));
+  EXPECT_TRUE(database->put(key, val));
 
   EXPECT_TRUE(database->commit());
 
   Database::buffer_t new_val;
 
-  EXPECT_TRUE(database->get(&key, &new_val));
+  EXPECT_TRUE(database->get(key, &new_val));
 
   EXPECT_TRUE(this->BufferEq(val, new_val));
 
@@ -337,11 +337,11 @@ TYPED_TEST(DatabaseTest, TestReadWriteGetNoCommitLevelDBFails) {
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  EXPECT_TRUE(database->put(&key, &val));
+  EXPECT_TRUE(database->put(key, val));
 
   Database::buffer_t new_val;
 
-  EXPECT_FALSE(database->get(&key, &new_val));
+  EXPECT_FALSE(database->get(key, &new_val));
 }
 
 TYPED_TEST(DatabaseTest, TestReadOnlyPutLevelDBFails) {
@@ -355,7 +355,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyPutLevelDBFails) {
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  EXPECT_FALSE(database->put(&key, &val));
+  EXPECT_FALSE(database->put(key, val));
 }
 
 TYPED_TEST(DatabaseTest, TestReadOnlyCommitLevelDBFails) {
@@ -377,7 +377,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetLevelDBPasses) {
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  EXPECT_TRUE(database->put(&key, &val));
+  EXPECT_TRUE(database->put(key, val));
 
   EXPECT_TRUE(database->commit());
 
@@ -387,7 +387,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetLevelDBPasses) {
 
   Database::buffer_t new_val;
 
-  EXPECT_TRUE(database->get(&key, &new_val));
+  EXPECT_TRUE(database->get(key, &new_val));
 
   EXPECT_TRUE(this->BufferEq(val, new_val));
 }
@@ -400,7 +400,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetNoCommitLevelDBFails) {
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  EXPECT_TRUE(database->put(&key, &val));
+  EXPECT_TRUE(database->put(key, val));
 
   database->close();
 
@@ -408,7 +408,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetNoCommitLevelDBFails) {
 
   Database::buffer_t new_val;
 
-  EXPECT_FALSE(database->get(&key, &new_val));
+  EXPECT_FALSE(database->get(key, &new_val));
 }
 
 TYPED_TEST(DatabaseTest, TestNewDoesntExistLMDBPasses) {
@@ -473,7 +473,7 @@ TYPED_TEST(DatabaseTest, TestIteratorsLMDB) {
     string value = ss.str();
     Database::buffer_t key_buf(key.data(), key.data() + key.size());
     Database::buffer_t val_buf(value.data(), value.data() + value.size());
-    EXPECT_TRUE(database->put(&key_buf, &val_buf));
+    EXPECT_TRUE(database->put(key_buf, val_buf));
   }
   EXPECT_TRUE(database->commit());
 
@@ -498,8 +498,8 @@ TYPED_TEST(DatabaseTest, TestIteratorsPreIncrementLMDB) {
   Database::buffer_t key2 = this->TestKey();
   Database::buffer_t value2 = this->TestValue();
 
-  EXPECT_TRUE(database->put(&key1, &value1));
-  EXPECT_TRUE(database->put(&key2, &value2));
+  EXPECT_TRUE(database->put(key1, value1));
+  EXPECT_TRUE(database->put(key2, value2));
   EXPECT_TRUE(database->commit());
 
   Database::const_iterator iter1 = database->begin();
@@ -537,8 +537,8 @@ TYPED_TEST(DatabaseTest, TestIteratorsPostIncrementLMDB) {
   Database::buffer_t key2 = this->TestKey();
   Database::buffer_t value2 = this->TestValue();
 
-  EXPECT_TRUE(database->put(&key1, &value1));
-  EXPECT_TRUE(database->put(&key2, &value2));
+  EXPECT_TRUE(database->put(key1, value1));
+  EXPECT_TRUE(database->put(key2, value2));
   EXPECT_TRUE(database->commit());
 
   Database::const_iterator iter1 = database->begin();
@@ -576,7 +576,7 @@ TYPED_TEST(DatabaseTest, TestNewPutLMDBPasses) {
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  EXPECT_TRUE(database->put(&key, &val));
+  EXPECT_TRUE(database->put(key, val));
 
   EXPECT_TRUE(database->commit());
 
@@ -601,13 +601,13 @@ TYPED_TEST(DatabaseTest, TestNewGetLMDBPasses) {
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  EXPECT_TRUE(database->put(&key, &val));
+  EXPECT_TRUE(database->put(key, val));
 
   EXPECT_TRUE(database->commit());
 
   Database::buffer_t new_val;
 
-  EXPECT_TRUE(database->get(&key, &new_val));
+  EXPECT_TRUE(database->get(key, &new_val));
 
   EXPECT_TRUE(this->BufferEq(val, new_val));
 
@@ -622,11 +622,11 @@ TYPED_TEST(DatabaseTest, TestNewGetNoCommitLMDBFails) {
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  EXPECT_TRUE(database->put(&key, &val));
+  EXPECT_TRUE(database->put(key, val));
 
   Database::buffer_t new_val;
 
-  EXPECT_FALSE(database->get(&key, &new_val));
+  EXPECT_FALSE(database->get(key, &new_val));
 }
 
 TYPED_TEST(DatabaseTest, TestReadWritePutLMDBPasses) {
@@ -637,7 +637,7 @@ TYPED_TEST(DatabaseTest, TestReadWritePutLMDBPasses) {
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  EXPECT_TRUE(database->put(&key, &val));
+  EXPECT_TRUE(database->put(key, val));
 
   EXPECT_TRUE(database->commit());
 
@@ -662,13 +662,13 @@ TYPED_TEST(DatabaseTest, TestReadWriteGetLMDBPasses) {
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  EXPECT_TRUE(database->put(&key, &val));
+  EXPECT_TRUE(database->put(key, val));
 
   EXPECT_TRUE(database->commit());
 
   Database::buffer_t new_val;
 
-  EXPECT_TRUE(database->get(&key, &new_val));
+  EXPECT_TRUE(database->get(key, &new_val));
 
   EXPECT_TRUE(this->BufferEq(val, new_val));
 
@@ -683,11 +683,11 @@ TYPED_TEST(DatabaseTest, TestReadWriteGetNoCommitLMDBFails) {
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  EXPECT_TRUE(database->put(&key, &val));
+  EXPECT_TRUE(database->put(key, val));
 
   Database::buffer_t new_val;
 
-  EXPECT_FALSE(database->get(&key, &new_val));
+  EXPECT_FALSE(database->get(key, &new_val));
 }
 
 TYPED_TEST(DatabaseTest, TestReadOnlyPutLMDBFails) {
@@ -701,7 +701,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyPutLMDBFails) {
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  EXPECT_FALSE(database->put(&key, &val));
+  EXPECT_FALSE(database->put(key, val));
 }
 
 TYPED_TEST(DatabaseTest, TestReadOnlyCommitLMDBFails) {
@@ -723,7 +723,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetLMDBPasses) {
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  EXPECT_TRUE(database->put(&key, &val));
+  EXPECT_TRUE(database->put(key, val));
 
   EXPECT_TRUE(database->commit());
 
@@ -733,7 +733,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetLMDBPasses) {
 
   Database::buffer_t new_val;
 
-  EXPECT_TRUE(database->get(&key, &new_val));
+  EXPECT_TRUE(database->get(key, &new_val));
 
   EXPECT_TRUE(this->BufferEq(val, new_val));
 }
@@ -746,7 +746,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetNoCommitLMDBFails) {
   Database::buffer_t key = this->TestKey();
   Database::buffer_t val = this->TestValue();
 
-  EXPECT_TRUE(database->put(&key, &val));
+  EXPECT_TRUE(database->put(key, val));
 
   database->close();
 
@@ -754,7 +754,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetNoCommitLMDBFails) {
 
   Database::buffer_t new_val;
 
-  EXPECT_FALSE(database->get(&key, &new_val));
+  EXPECT_FALSE(database->get(key, &new_val));
 }
 
 }  // namespace caffe
index 3345c9c..5ad2c0b 100644 (file)
@@ -116,7 +116,7 @@ int main(int argc, char** argv) {
     Database::buffer_t keystr(key_cstr, key_cstr + length);
 
     // Put in db
-    CHECK(database->put(&keystr, &value));
+    CHECK(database->put(keystr, value));
 
     if (++count % 1000 == 0) {
       // Commit txn
index 1560ef6..0c7660d 100644 (file)
@@ -161,7 +161,7 @@ int feature_extraction_pipeline(int argc, char** argv) {
         int length = snprintf(key_str, kMaxKeyStrLength, "%d",
             image_indices[i]);
         Database::buffer_t key(key_str, key_str + length);
-        CHECK(feature_dbs.at(i)->put(&key, &value));
+        CHECK(feature_dbs.at(i)->put(key, value));
         ++image_indices[i];
         if (image_indices[i] % 1000 == 0) {
           CHECK(feature_dbs.at(i)->commit());