From e0b572dd1a63d5dda000ee4cc803be04e51597a1 Mon Sep 17 00:00:00 2001 From: Kevin James Matzen Date: Sun, 12 Oct 2014 20:30:29 -0400 Subject: [PATCH] Updated Database interface to take key and value by const reference for put and key by const reference for put. Additional copies are made for get and put in the LMDB implementation. --- examples/cifar10/convert_cifar_data.cpp | 4 +- include/caffe/database.hpp | 4 +- include/caffe/leveldb_database.hpp | 4 +- include/caffe/lmdb_database.hpp | 4 +- src/caffe/leveldb_database.cpp | 10 ++--- src/caffe/lmdb_database.cpp | 22 +++++---- src/caffe/test/test_data_layer.cpp | 2 +- src/caffe/test/test_database.cpp | 80 ++++++++++++++++----------------- tools/convert_imageset.cpp | 2 +- tools/extract_features.cpp | 2 +- 10 files changed, 70 insertions(+), 64 deletions(-) diff --git a/examples/cifar10/convert_cifar_data.cpp b/examples/cifar10/convert_cifar_data.cpp index b29e412..af845ea 100644 --- a/examples/cifar10/convert_cifar_data.cpp +++ b/examples/cifar10/convert_cifar_data.cpp @@ -66,7 +66,7 @@ void convert_dataset(const string& input_folder, const string& output_folder, int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d", fileid * kCIFARBatchSize + itemid); Database::buffer_t key(str_buffer, str_buffer + length); - CHECK(train_database->put(&key, &value)); + CHECK(train_database->put(key, value)); } } CHECK(train_database->commit()); @@ -89,7 +89,7 @@ void convert_dataset(const string& input_folder, const string& output_folder, reinterpret_cast(value.data())); int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d", itemid); Database::buffer_t key(str_buffer, str_buffer + length); - CHECK(test_database->put(&key, &value)); + CHECK(test_database->put(key, value)); } CHECK(test_database->commit()); test_database->close(); diff --git a/include/caffe/database.hpp b/include/caffe/database.hpp index 148b1ed..3f3970d 100644 --- a/include/caffe/database.hpp +++ b/include/caffe/database.hpp @@ -27,8 +27,8 @@ class Database { }; virtual bool open(const string& filename, Mode mode) = 0; - virtual bool put(buffer_t* key, buffer_t* value) = 0; - virtual bool get(buffer_t* key, buffer_t* value) = 0; + virtual bool put(const buffer_t& key, const buffer_t& value) = 0; + virtual bool get(const buffer_t& key, buffer_t* value) = 0; virtual bool commit() = 0; virtual void close() = 0; diff --git a/include/caffe/leveldb_database.hpp b/include/caffe/leveldb_database.hpp index 64bfa7c..e2558ff 100644 --- a/include/caffe/leveldb_database.hpp +++ b/include/caffe/leveldb_database.hpp @@ -15,8 +15,8 @@ namespace caffe { class LeveldbDatabase : public Database { public: bool open(const string& filename, Mode mode); - bool put(buffer_t* key, buffer_t* value); - bool get(buffer_t* key, buffer_t* value); + bool put(const buffer_t& key, const buffer_t& value); + bool get(const buffer_t& key, buffer_t* value); bool commit(); void close(); diff --git a/include/caffe/lmdb_database.hpp b/include/caffe/lmdb_database.hpp index 69e3ce0..4a0f318 100644 --- a/include/caffe/lmdb_database.hpp +++ b/include/caffe/lmdb_database.hpp @@ -19,8 +19,8 @@ class LmdbDatabase : public Database { txn_(NULL) { } bool open(const string& filename, Mode mode); - bool put(buffer_t* key, buffer_t* value); - bool get(buffer_t* key, buffer_t* value); + bool put(const buffer_t& key, const buffer_t& value); + bool get(const buffer_t& key, buffer_t* value); bool commit(); void close(); diff --git a/src/caffe/leveldb_database.cpp b/src/caffe/leveldb_database.cpp index d7506ed..c09112f 100644 --- a/src/caffe/leveldb_database.cpp +++ b/src/caffe/leveldb_database.cpp @@ -51,7 +51,7 @@ bool LeveldbDatabase::open(const string& filename, Mode mode) { return true; } -bool LeveldbDatabase::put(buffer_t* key, buffer_t* value) { +bool LeveldbDatabase::put(const buffer_t& key, const buffer_t& value) { LOG(INFO) << "LevelDB: Put"; if (read_only_) { @@ -61,18 +61,18 @@ bool LeveldbDatabase::put(buffer_t* key, buffer_t* value) { CHECK_NOTNULL(batch_.get()); - leveldb::Slice key_slice(key->data(), key->size()); - leveldb::Slice value_slice(value->data(), value->size()); + leveldb::Slice key_slice(key.data(), key.size()); + leveldb::Slice value_slice(value.data(), value.size()); batch_->Put(key_slice, value_slice); return true; } -bool LeveldbDatabase::get(buffer_t* key, buffer_t* value) { +bool LeveldbDatabase::get(const buffer_t& key, buffer_t* value) { LOG(INFO) << "LevelDB: Get"; - leveldb::Slice key_slice(key->data(), key->size()); + leveldb::Slice key_slice(key.data(), key.size()); string value_string; leveldb::Status status = diff --git a/src/caffe/lmdb_database.cpp b/src/caffe/lmdb_database.cpp index d71513a..2cb699b 100644 --- a/src/caffe/lmdb_database.cpp +++ b/src/caffe/lmdb_database.cpp @@ -77,14 +77,18 @@ bool LmdbDatabase::open(const string& filename, Mode mode) { return true; } -bool LmdbDatabase::put(buffer_t* key, buffer_t* value) { +bool LmdbDatabase::put(const buffer_t& key, const buffer_t& value) { LOG(INFO) << "LMDB: Put"; + // MDB_val::mv_size is not const, so we need to make a local copy. + buffer_t local_key = key; + buffer_t local_value = value; + MDB_val mdbkey, mdbdata; - mdbdata.mv_size = value->size(); - mdbdata.mv_data = value->data(); - mdbkey.mv_size = key->size(); - mdbkey.mv_data = key->data(); + mdbdata.mv_size = local_value.size(); + mdbdata.mv_data = local_value.data(); + mdbkey.mv_size = local_key.size(); + mdbkey.mv_data = local_key.data(); CHECK_NOTNULL(txn_); CHECK_NE(0, dbi_); @@ -98,12 +102,14 @@ bool LmdbDatabase::put(buffer_t* key, buffer_t* value) { return true; } -bool LmdbDatabase::get(buffer_t* key, buffer_t* value) { +bool LmdbDatabase::get(const buffer_t& key, buffer_t* value) { LOG(INFO) << "LMDB: Get"; + buffer_t local_key = key; + MDB_val mdbkey, mdbdata; - mdbkey.mv_data = key->data(); - mdbkey.mv_size = key->size(); + mdbkey.mv_data = local_key.data(); + mdbkey.mv_size = local_key.size(); int retval; MDB_txn* get_txn; diff --git a/src/caffe/test/test_data_layer.cpp b/src/caffe/test/test_data_layer.cpp index cc9ad20..8ae8f2b 100644 --- a/src/caffe/test/test_data_layer.cpp +++ b/src/caffe/test/test_data_layer.cpp @@ -59,7 +59,7 @@ class DataLayerTest : public MultiDeviceTest { Database::buffer_t value(datum.ByteSize()); datum.SerializeWithCachedSizesToArray( reinterpret_cast(value.data())); - CHECK(database->put(&key, &value)); + CHECK(database->put(key, value)); } CHECK(database->commit()); database->close(); diff --git a/src/caffe/test/test_database.cpp b/src/caffe/test/test_database.cpp index f658650..70e1d96 100644 --- a/src/caffe/test/test_database.cpp +++ b/src/caffe/test/test_database.cpp @@ -126,7 +126,7 @@ TYPED_TEST(DatabaseTest, TestIteratorsLevelDB) { string value = ss.str(); Database::buffer_t key_buf(key.data(), key.data() + key.size()); Database::buffer_t val_buf(value.data(), value.data() + value.size()); - EXPECT_TRUE(database->put(&key_buf, &val_buf)); + EXPECT_TRUE(database->put(key_buf, val_buf)); } EXPECT_TRUE(database->commit()); @@ -151,8 +151,8 @@ TYPED_TEST(DatabaseTest, TestIteratorsPreIncrementLevelDB) { Database::buffer_t key2 = this->TestKey(); Database::buffer_t value2 = this->TestValue(); - EXPECT_TRUE(database->put(&key1, &value1)); - EXPECT_TRUE(database->put(&key2, &value2)); + EXPECT_TRUE(database->put(key1, value1)); + EXPECT_TRUE(database->put(key2, value2)); EXPECT_TRUE(database->commit()); Database::const_iterator iter1 = database->begin(); @@ -190,8 +190,8 @@ TYPED_TEST(DatabaseTest, TestIteratorsPostIncrementLevelDB) { Database::buffer_t key2 = this->TestKey(); Database::buffer_t value2 = this->TestValue(); - EXPECT_TRUE(database->put(&key1, &value1)); - EXPECT_TRUE(database->put(&key2, &value2)); + EXPECT_TRUE(database->put(key1, value1)); + EXPECT_TRUE(database->put(key2, value2)); EXPECT_TRUE(database->commit()); Database::const_iterator iter1 = database->begin(); @@ -229,7 +229,7 @@ TYPED_TEST(DatabaseTest, TestNewPutLevelDBPasses) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); EXPECT_TRUE(database->commit()); @@ -254,13 +254,13 @@ TYPED_TEST(DatabaseTest, TestNewGetLevelDBPasses) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); EXPECT_TRUE(database->commit()); Database::buffer_t new_val; - EXPECT_TRUE(database->get(&key, &new_val)); + EXPECT_TRUE(database->get(key, &new_val)); EXPECT_TRUE(this->BufferEq(val, new_val)); @@ -275,11 +275,11 @@ TYPED_TEST(DatabaseTest, TestNewGetNoCommitLevelDBFails) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); Database::buffer_t new_val; - EXPECT_FALSE(database->get(&key, &new_val)); + EXPECT_FALSE(database->get(key, &new_val)); } @@ -291,7 +291,7 @@ TYPED_TEST(DatabaseTest, TestReadWritePutLevelDBPasses) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); EXPECT_TRUE(database->commit()); @@ -316,13 +316,13 @@ TYPED_TEST(DatabaseTest, TestReadWriteGetLevelDBPasses) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); EXPECT_TRUE(database->commit()); Database::buffer_t new_val; - EXPECT_TRUE(database->get(&key, &new_val)); + EXPECT_TRUE(database->get(key, &new_val)); EXPECT_TRUE(this->BufferEq(val, new_val)); @@ -337,11 +337,11 @@ TYPED_TEST(DatabaseTest, TestReadWriteGetNoCommitLevelDBFails) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); Database::buffer_t new_val; - EXPECT_FALSE(database->get(&key, &new_val)); + EXPECT_FALSE(database->get(key, &new_val)); } TYPED_TEST(DatabaseTest, TestReadOnlyPutLevelDBFails) { @@ -355,7 +355,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyPutLevelDBFails) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_FALSE(database->put(&key, &val)); + EXPECT_FALSE(database->put(key, val)); } TYPED_TEST(DatabaseTest, TestReadOnlyCommitLevelDBFails) { @@ -377,7 +377,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetLevelDBPasses) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); EXPECT_TRUE(database->commit()); @@ -387,7 +387,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetLevelDBPasses) { Database::buffer_t new_val; - EXPECT_TRUE(database->get(&key, &new_val)); + EXPECT_TRUE(database->get(key, &new_val)); EXPECT_TRUE(this->BufferEq(val, new_val)); } @@ -400,7 +400,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetNoCommitLevelDBFails) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); database->close(); @@ -408,7 +408,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetNoCommitLevelDBFails) { Database::buffer_t new_val; - EXPECT_FALSE(database->get(&key, &new_val)); + EXPECT_FALSE(database->get(key, &new_val)); } TYPED_TEST(DatabaseTest, TestNewDoesntExistLMDBPasses) { @@ -473,7 +473,7 @@ TYPED_TEST(DatabaseTest, TestIteratorsLMDB) { string value = ss.str(); Database::buffer_t key_buf(key.data(), key.data() + key.size()); Database::buffer_t val_buf(value.data(), value.data() + value.size()); - EXPECT_TRUE(database->put(&key_buf, &val_buf)); + EXPECT_TRUE(database->put(key_buf, val_buf)); } EXPECT_TRUE(database->commit()); @@ -498,8 +498,8 @@ TYPED_TEST(DatabaseTest, TestIteratorsPreIncrementLMDB) { Database::buffer_t key2 = this->TestKey(); Database::buffer_t value2 = this->TestValue(); - EXPECT_TRUE(database->put(&key1, &value1)); - EXPECT_TRUE(database->put(&key2, &value2)); + EXPECT_TRUE(database->put(key1, value1)); + EXPECT_TRUE(database->put(key2, value2)); EXPECT_TRUE(database->commit()); Database::const_iterator iter1 = database->begin(); @@ -537,8 +537,8 @@ TYPED_TEST(DatabaseTest, TestIteratorsPostIncrementLMDB) { Database::buffer_t key2 = this->TestKey(); Database::buffer_t value2 = this->TestValue(); - EXPECT_TRUE(database->put(&key1, &value1)); - EXPECT_TRUE(database->put(&key2, &value2)); + EXPECT_TRUE(database->put(key1, value1)); + EXPECT_TRUE(database->put(key2, value2)); EXPECT_TRUE(database->commit()); Database::const_iterator iter1 = database->begin(); @@ -576,7 +576,7 @@ TYPED_TEST(DatabaseTest, TestNewPutLMDBPasses) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); EXPECT_TRUE(database->commit()); @@ -601,13 +601,13 @@ TYPED_TEST(DatabaseTest, TestNewGetLMDBPasses) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); EXPECT_TRUE(database->commit()); Database::buffer_t new_val; - EXPECT_TRUE(database->get(&key, &new_val)); + EXPECT_TRUE(database->get(key, &new_val)); EXPECT_TRUE(this->BufferEq(val, new_val)); @@ -622,11 +622,11 @@ TYPED_TEST(DatabaseTest, TestNewGetNoCommitLMDBFails) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); Database::buffer_t new_val; - EXPECT_FALSE(database->get(&key, &new_val)); + EXPECT_FALSE(database->get(key, &new_val)); } TYPED_TEST(DatabaseTest, TestReadWritePutLMDBPasses) { @@ -637,7 +637,7 @@ TYPED_TEST(DatabaseTest, TestReadWritePutLMDBPasses) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); EXPECT_TRUE(database->commit()); @@ -662,13 +662,13 @@ TYPED_TEST(DatabaseTest, TestReadWriteGetLMDBPasses) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); EXPECT_TRUE(database->commit()); Database::buffer_t new_val; - EXPECT_TRUE(database->get(&key, &new_val)); + EXPECT_TRUE(database->get(key, &new_val)); EXPECT_TRUE(this->BufferEq(val, new_val)); @@ -683,11 +683,11 @@ TYPED_TEST(DatabaseTest, TestReadWriteGetNoCommitLMDBFails) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); Database::buffer_t new_val; - EXPECT_FALSE(database->get(&key, &new_val)); + EXPECT_FALSE(database->get(key, &new_val)); } TYPED_TEST(DatabaseTest, TestReadOnlyPutLMDBFails) { @@ -701,7 +701,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyPutLMDBFails) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_FALSE(database->put(&key, &val)); + EXPECT_FALSE(database->put(key, val)); } TYPED_TEST(DatabaseTest, TestReadOnlyCommitLMDBFails) { @@ -723,7 +723,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetLMDBPasses) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); EXPECT_TRUE(database->commit()); @@ -733,7 +733,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetLMDBPasses) { Database::buffer_t new_val; - EXPECT_TRUE(database->get(&key, &new_val)); + EXPECT_TRUE(database->get(key, &new_val)); EXPECT_TRUE(this->BufferEq(val, new_val)); } @@ -746,7 +746,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetNoCommitLMDBFails) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); database->close(); @@ -754,7 +754,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetNoCommitLMDBFails) { Database::buffer_t new_val; - EXPECT_FALSE(database->get(&key, &new_val)); + EXPECT_FALSE(database->get(key, &new_val)); } } // namespace caffe diff --git a/tools/convert_imageset.cpp b/tools/convert_imageset.cpp index 3345c9c..5ad2c0b 100644 --- a/tools/convert_imageset.cpp +++ b/tools/convert_imageset.cpp @@ -116,7 +116,7 @@ int main(int argc, char** argv) { Database::buffer_t keystr(key_cstr, key_cstr + length); // Put in db - CHECK(database->put(&keystr, &value)); + CHECK(database->put(keystr, value)); if (++count % 1000 == 0) { // Commit txn diff --git a/tools/extract_features.cpp b/tools/extract_features.cpp index 1560ef6..0c7660d 100644 --- a/tools/extract_features.cpp +++ b/tools/extract_features.cpp @@ -161,7 +161,7 @@ int feature_extraction_pipeline(int argc, char** argv) { int length = snprintf(key_str, kMaxKeyStrLength, "%d", image_indices[i]); Database::buffer_t key(key_str, key_str + length); - CHECK(feature_dbs.at(i)->put(&key, &value)); + CHECK(feature_dbs.at(i)->put(key, value)); ++image_indices[i]; if (image_indices[i] % 1000 == 0) { CHECK(feature_dbs.at(i)->commit()); -- 2.7.4