From: Kevin James Matzen Date: Tue, 7 Oct 2014 22:03:02 +0000 (-0400) Subject: Updated interface to make fewer string copies. X-Git-Tag: submit/tizen/20180823.020014~572^2~109^2~19 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=7e504c0612b23449aa83fd7aac8d49c56b03fd62;p=platform%2Fupstream%2Fcaffeonacl.git Updated interface to make fewer string copies. --- diff --git a/include/caffe/database.hpp b/include/caffe/database.hpp index 4a1a25e..23036a8 100644 --- a/include/caffe/database.hpp +++ b/include/caffe/database.hpp @@ -5,6 +5,7 @@ #include #include #include +#include #include "caffe/common.hpp" @@ -18,8 +19,10 @@ class Database { ReadOnly }; + typedef vector buffer_t; + virtual void open(const string& filename, Mode mode) = 0; - virtual void put(const string& key, const string& value) = 0; + virtual void put(buffer_t* key, buffer_t* value) = 0; virtual void commit() = 0; virtual void close() = 0; @@ -38,10 +41,10 @@ class Database { class DatabaseState; public: - class iterator : - public std::iterator > { + class iterator : public std::iterator< + std::forward_iterator_tag, pair > { public: - typedef pair T; + typedef pair T; typedef T value_type; typedef T& reference_type; typedef T* pointer_type; @@ -94,7 +97,7 @@ class Database { virtual bool equal(shared_ptr state1, shared_ptr state2) const = 0; virtual void increment(shared_ptr state) const = 0; - virtual pair& dereference( + virtual pair& dereference( shared_ptr state) const = 0; }; diff --git a/include/caffe/leveldb_database.hpp b/include/caffe/leveldb_database.hpp index ee8c5f2..f30273d 100644 --- a/include/caffe/leveldb_database.hpp +++ b/include/caffe/leveldb_database.hpp @@ -15,7 +15,7 @@ namespace caffe { class LeveldbDatabase : public Database { public: void open(const string& filename, Mode mode); - void put(const string& key, const string& value); + void put(buffer_t* key, buffer_t* value); void commit(); void close(); @@ -34,13 +34,13 @@ class LeveldbDatabase : public Database { iter_(iter) { } shared_ptr iter_; - pair kv_pair_; + pair kv_pair_; }; bool equal(shared_ptr state1, shared_ptr state2) const; void increment(shared_ptr state) const; - pair& dereference(shared_ptr state) const; + pair& dereference(shared_ptr state) const; shared_ptr db_; shared_ptr batch_; diff --git a/include/caffe/lmdb_database.hpp b/include/caffe/lmdb_database.hpp index 7387afd..ee3806d 100644 --- a/include/caffe/lmdb_database.hpp +++ b/include/caffe/lmdb_database.hpp @@ -20,7 +20,7 @@ class LmdbDatabase : public Database { ~LmdbDatabase() { this->close(); } void open(const string& filename, Mode mode); - void put(const string& key, const string& value); + void put(buffer_t* key, buffer_t* value); void commit(); void close(); @@ -37,13 +37,13 @@ class LmdbDatabase : public Database { cursor_(cursor) { } MDB_cursor* cursor_; - pair kv_pair_; + pair kv_pair_; }; bool equal(shared_ptr state1, shared_ptr state2) const; void increment(shared_ptr state) const; - pair& dereference(shared_ptr state) const; + pair& dereference(shared_ptr state) const; MDB_env *env_; MDB_dbi dbi_; diff --git a/include/caffe/util/io.hpp b/include/caffe/util/io.hpp index 1124bef..b64c821 100644 --- a/include/caffe/util/io.hpp +++ b/include/caffe/util/io.hpp @@ -24,27 +24,27 @@ using ::google::protobuf::Message; inline void MakeTempFilename(string* temp_filename) { temp_filename->clear(); *temp_filename = "/tmp/caffe_test.XXXXXX"; - char* temp_filename_cstr = new char[temp_filename->size()]; + char* temp_filename_cstr = new char[temp_filename->size() + 1]; // NOLINT_NEXT_LINE(runtime/printf) strcpy(temp_filename_cstr, temp_filename->c_str()); int fd = mkstemp(temp_filename_cstr); CHECK_GE(fd, 0) << "Failed to open a temporary file at: " << *temp_filename; close(fd); *temp_filename = temp_filename_cstr; - delete temp_filename_cstr; + delete[] temp_filename_cstr; } inline void MakeTempDir(string* temp_dirname) { temp_dirname->clear(); *temp_dirname = "/tmp/caffe_test.XXXXXX"; - char* temp_dirname_cstr = new char[temp_dirname->size()]; + char* temp_dirname_cstr = new char[temp_dirname->size() + 1]; // NOLINT_NEXT_LINE(runtime/printf) strcpy(temp_dirname_cstr, temp_dirname->c_str()); char* mkdtemp_result = mkdtemp(temp_dirname_cstr); CHECK(mkdtemp_result != NULL) << "Failed to create a temporary directory at: " << *temp_dirname; *temp_dirname = temp_dirname_cstr; - delete temp_dirname_cstr; + delete[] temp_dirname_cstr; } bool ReadProtoFromTextFile(const char* filename, Message* proto); diff --git a/src/caffe/layers/data_layer.cpp b/src/caffe/layers/data_layer.cpp index 1d37170..4d36b8e 100644 --- a/src/caffe/layers/data_layer.cpp +++ b/src/caffe/layers/data_layer.cpp @@ -40,7 +40,6 @@ void DataLayer::DataLayerSetUp(const vector*>& bottom, this->layer_param_.data_param().rand_skip(); LOG(INFO) << "Skipping first " << skip << " data points."; while (skip-- > 0) { - LOG(INFO) << iter_->first; if (++iter_ == database_->end()) { iter_ = database_->begin(); } @@ -49,7 +48,7 @@ void DataLayer::DataLayerSetUp(const vector*>& bottom, // Read a data point, and use it to initialize the top blob. CHECK(iter_ != database_->end()); Datum datum; - datum.ParseFromString(iter_->second); + datum.ParseFromArray(iter_->second.data(), iter_->second.size()); // image int crop_size = this->layer_param_.transform_param().crop_size(); @@ -95,7 +94,7 @@ void DataLayer::InternalThreadEntry() { for (int item_id = 0; item_id < batch_size; ++item_id) { Datum datum; CHECK(iter_ != database_->end()); - datum.ParseFromString(iter_->second); + datum.ParseFromArray(iter_->second.data(), iter_->second.size()); // Apply data transformations (mirror, scale, crop...) int offset = this->prefetch_data_.offset(item_id); diff --git a/src/caffe/leveldb_database.cpp b/src/caffe/leveldb_database.cpp index be7ac7f..a8fe02a 100644 --- a/src/caffe/leveldb_database.cpp +++ b/src/caffe/leveldb_database.cpp @@ -42,11 +42,15 @@ void LeveldbDatabase::open(const string& filename, Mode mode) { batch_.reset(new leveldb::WriteBatch()); } -void LeveldbDatabase::put(const string& key, const string& value) { - LOG(INFO) << "LevelDB: Put " << key; +void LeveldbDatabase::put(buffer_t* key, buffer_t* value) { + LOG(INFO) << "LevelDB: Put"; CHECK_NOTNULL(batch_.get()); - batch_->Put(key, value); + + leveldb::Slice key_slice(key->data(), key->size()); + leveldb::Slice value_slice(value->data(), value->size()); + + batch_->Put(key_slice, value_slice); } void LeveldbDatabase::commit() { @@ -130,7 +134,7 @@ void LeveldbDatabase::increment(shared_ptr state) const { } } -pair& LeveldbDatabase::dereference( +pair& LeveldbDatabase::dereference( shared_ptr state) const { shared_ptr leveldb_state = boost::dynamic_pointer_cast(state); @@ -143,8 +147,10 @@ pair& LeveldbDatabase::dereference( CHECK(iter->Valid()); - leveldb_state->kv_pair_ = make_pair(iter->key().ToString(), - iter->value().ToString()); + leveldb_state->kv_pair_ = make_pair( + buffer_t(iter->key().data(), iter->key().data() + iter->key().size()), + buffer_t(iter->value().data(), + iter->value().data() + iter->value().size())); return leveldb_state->kv_pair_; } diff --git a/src/caffe/lmdb_database.cpp b/src/caffe/lmdb_database.cpp index 796bbc9..7197a47 100644 --- a/src/caffe/lmdb_database.cpp +++ b/src/caffe/lmdb_database.cpp @@ -19,9 +19,13 @@ void LmdbDatabase::open(const string& filename, Mode mode) { << "failed"; } - CHECK_EQ(mdb_env_create(&env_), MDB_SUCCESS) << "mdb_env_create failed"; - CHECK_EQ(mdb_env_set_mapsize(env_, 1099511627776), MDB_SUCCESS) // 1TB - << "mdb_env_set_mapsize failed"; + int retval; + retval = mdb_env_create(&env_); + CHECK_EQ(retval, MDB_SUCCESS) << "mdb_env_create failed " + << mdb_strerror(retval); + retval = mdb_env_set_mapsize(env_, 1099511627776); + CHECK_EQ(retval, MDB_SUCCESS) // 1TB + << "mdb_env_set_mapsize failed " << mdb_strerror(retval); int flag1 = 0; int flag2 = 0; @@ -30,27 +34,31 @@ void LmdbDatabase::open(const string& filename, Mode mode) { flag2 = MDB_RDONLY; } - CHECK_EQ(mdb_env_open(env_, filename.c_str(), flag1, 0664), MDB_SUCCESS) - << "mdb_env_open failed"; - CHECK_EQ(mdb_txn_begin(env_, NULL, flag2, &txn_), MDB_SUCCESS) - << "mdb_txn_begin failed"; - CHECK_EQ(mdb_open(txn_, NULL, 0, &dbi_), MDB_SUCCESS) << "mdb_open failed"; + retval = mdb_env_open(env_, filename.c_str(), flag1, 0664); + CHECK_EQ(retval, MDB_SUCCESS) + << "mdb_env_open failed " << mdb_strerror(retval); + retval = mdb_txn_begin(env_, NULL, flag2, &txn_); + CHECK_EQ(retval, MDB_SUCCESS) + << "mdb_txn_begin failed " << mdb_strerror(retval); + retval = mdb_open(txn_, NULL, 0, &dbi_); + CHECK_EQ(retval, MDB_SUCCESS) << "mdb_open failed" << mdb_strerror(retval); } -void LmdbDatabase::put(const string& key, const string& value) { - LOG(INFO) << "LMDB: Put " << key; +void LmdbDatabase::put(buffer_t* key, buffer_t* value) { + LOG(INFO) << "LMDB: Put"; MDB_val mdbkey, mdbdata; - mdbdata.mv_size = value.size(); - mdbdata.mv_data = const_cast(&value[0]); - mdbkey.mv_size = key.size(); - mdbkey.mv_data = const_cast(&key[0]); + mdbdata.mv_size = value->size(); + mdbdata.mv_data = value->data(); + mdbkey.mv_size = key->size(); + mdbkey.mv_data = key->data(); CHECK_NOTNULL(txn_); CHECK_NE(0, dbi_); - CHECK_EQ(mdb_put(txn_, dbi_, &mdbkey, &mdbdata, 0), MDB_SUCCESS) - << "mdb_put failed"; + int retval = mdb_put(txn_, dbi_, &mdbkey, &mdbdata, 0); + CHECK_EQ(retval, MDB_SUCCESS) + << "mdb_put failed " << mdb_strerror(retval); } void LmdbDatabase::commit() { @@ -58,7 +66,9 @@ void LmdbDatabase::commit() { CHECK_NOTNULL(txn_); - CHECK_EQ(mdb_txn_commit(txn_), MDB_SUCCESS) << "mdb_txn_commit failed"; + int retval = mdb_txn_commit(txn_); + CHECK_EQ(retval, MDB_SUCCESS) << "mdb_txn_commit failed " + << mdb_strerror(retval); } void LmdbDatabase::close() { @@ -79,10 +89,13 @@ void LmdbDatabase::close() { LmdbDatabase::const_iterator LmdbDatabase::begin() const { MDB_cursor* cursor; - CHECK_EQ(mdb_cursor_open(txn_, dbi_, &cursor), MDB_SUCCESS); + int retval; + retval = mdb_cursor_open(txn_, dbi_, &cursor); + CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval); MDB_val key; MDB_val val; - CHECK_EQ(mdb_cursor_get(cursor, &key, &val, MDB_FIRST), MDB_SUCCESS); + retval = mdb_cursor_get(cursor, &key, &val, MDB_FIRST); + CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval); shared_ptr state(new LmdbState(cursor)); return const_iterator(this, state); @@ -122,15 +135,20 @@ void LmdbDatabase::increment(shared_ptr state) const { MDB_cursor*& cursor = lmdb_state->cursor_; + CHECK_NOTNULL(cursor); + MDB_val key; MDB_val val; - if (MDB_SUCCESS != mdb_cursor_get(cursor, &key, &val, MDB_NEXT)) { + int retval = mdb_cursor_get(cursor, &key, &val, MDB_NEXT); + if (MDB_NOTFOUND == retval) { mdb_cursor_close(cursor); cursor = NULL; + } else { + CHECK_EQ(MDB_SUCCESS, retval) << mdb_strerror(retval); } } -pair& LmdbDatabase::dereference( +pair& LmdbDatabase::dereference( shared_ptr state) const { shared_ptr lmdb_state = boost::dynamic_pointer_cast(state); @@ -139,14 +157,19 @@ pair& LmdbDatabase::dereference( MDB_cursor*& cursor = lmdb_state->cursor_; + CHECK_NOTNULL(cursor); + MDB_val mdb_key; MDB_val mdb_val; - CHECK_EQ(mdb_cursor_get(cursor, &mdb_key, &mdb_val, MDB_GET_CURRENT), - MDB_SUCCESS); + int retval = mdb_cursor_get(cursor, &mdb_key, &mdb_val, MDB_GET_CURRENT); + CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval); + + char* key_data = reinterpret_cast(mdb_key.mv_data); + char* value_data = reinterpret_cast(mdb_val.mv_data); lmdb_state->kv_pair_ = make_pair( - string(reinterpret_cast(mdb_key.mv_data), mdb_key.mv_size), - string(reinterpret_cast(mdb_val.mv_data), mdb_val.mv_size)); + buffer_t(key_data, key_data + mdb_key.mv_size), + buffer_t(value_data, value_data + mdb_val.mv_size)); return lmdb_state->kv_pair_; } diff --git a/src/caffe/test/test_data_layer.cpp b/src/caffe/test/test_data_layer.cpp index d99b5e3..c17f729 100644 --- a/src/caffe/test/test_data_layer.cpp +++ b/src/caffe/test/test_data_layer.cpp @@ -54,7 +54,12 @@ class DataLayerTest : public MultiDeviceTest { } stringstream ss; ss << i; - database->put(ss.str(), datum.SerializeAsString()); + string key_str = ss.str(); + Database::buffer_t key(key_str.c_str(), key_str.c_str() + key_str.size()); + Database::buffer_t value(datum.ByteSize()); + datum.SerializeWithCachedSizesToArray( + reinterpret_cast(value.data())); + database->put(&key, &value); } database->close(); } diff --git a/tools/compute_image_mean.cpp b/tools/compute_image_mean.cpp index e59bbf1..f981af4 100644 --- a/tools/compute_image_mean.cpp +++ b/tools/compute_image_mean.cpp @@ -35,7 +35,8 @@ int main(int argc, char** argv) { BlobProto sum_blob; int count = 0; // load first datum - datum.ParseFromString(database->begin()->second); + const Database::buffer_t& first_blob = database->begin()->second; + datum.ParseFromArray(first_blob.data(), first_blob.size()); sum_blob.set_num(1); sum_blob.set_channels(datum.channels()); @@ -51,7 +52,8 @@ int main(int argc, char** argv) { for (Database::const_iterator iter = database->begin(); iter != database->end(); ++iter) { // just a dummy operation - datum.ParseFromString(iter->second); + const Database::buffer_t& blob = iter->second; + datum.ParseFromArray(blob.data(), blob.size()); const std::string& data = datum.data(); size_in_datum = std::max(datum.data().size(), datum.float_data_size()); diff --git a/tools/convert_imageset.cpp b/tools/convert_imageset.cpp index 6f03a9d..19c87e5 100644 --- a/tools/convert_imageset.cpp +++ b/tools/convert_imageset.cpp @@ -108,14 +108,15 @@ int main(int argc, char** argv) { } } // sequential - snprintf(key_cstr, kMaxKeyLength, "%08d_%s", line_id, + int length = snprintf(key_cstr, kMaxKeyLength, "%08d_%s", line_id, lines[line_id].first.c_str()); - std::string value; - datum.SerializeToString(&value); - std::string keystr(key_cstr); + Database::buffer_t value(datum.ByteSize()); + datum.SerializeWithCachedSizesToArray( + reinterpret_cast(value.data())); + Database::buffer_t keystr(key_cstr, key_cstr + length); // Put in db - database->put(keystr, value); + database->put(&keystr, &value); if (++count % 1000 == 0) { // Commit txn diff --git a/tools/extract_features.cpp b/tools/extract_features.cpp index b3ad8e6..1065d44 100644 --- a/tools/extract_features.cpp +++ b/tools/extract_features.cpp @@ -155,10 +155,13 @@ int feature_extraction_pipeline(int argc, char** argv) { for (int d = 0; d < dim_features; ++d) { datum.add_float_data(feature_blob_data[d]); } - std::string value; - datum.SerializeToString(&value); - snprintf(key_str, kMaxKeyStrLength, "%d", image_indices[i]); - feature_dbs.at(i)->put(std::string(key_str), value); + Database::buffer_t value(datum.ByteSize()); + datum.SerializeWithCachedSizesToArray( + reinterpret_cast(value.data())); + int length = snprintf(key_str, kMaxKeyStrLength, "%d", + image_indices[i]); + Database::buffer_t key(key_str, key_str + length); + feature_dbs.at(i)->put(&key, &value); ++image_indices[i]; if (image_indices[i] % 1000 == 0) { feature_dbs.at(i)->commit();