Updated interface to make fewer string copies.
authorKevin James Matzen <kmatzen@cs.cornell.edu>
Tue, 7 Oct 2014 22:03:02 +0000 (18:03 -0400)
committerKevin James Matzen <kmatzen@cs.cornell.edu>
Tue, 14 Oct 2014 23:29:11 +0000 (19:29 -0400)
include/caffe/database.hpp
include/caffe/leveldb_database.hpp
include/caffe/lmdb_database.hpp
include/caffe/util/io.hpp
src/caffe/layers/data_layer.cpp
src/caffe/leveldb_database.cpp
src/caffe/lmdb_database.cpp
src/caffe/test/test_data_layer.cpp
tools/compute_image_mean.cpp
tools/convert_imageset.cpp
tools/extract_features.cpp

index 4a1a25e..23036a8 100644 (file)
@@ -5,6 +5,7 @@
 #include <iterator>
 #include <string>
 #include <utility>
+#include <vector>
 
 #include "caffe/common.hpp"
 
@@ -18,8 +19,10 @@ class Database {
     ReadOnly
   };
 
+  typedef vector<char> buffer_t;
+
   virtual void open(const string& filename, Mode mode) = 0;
-  virtual void put(const string& key, const string& value) = 0;
+  virtual void put(buffer_t* key, buffer_t* value) = 0;
   virtual void commit() = 0;
   virtual void close() = 0;
 
@@ -38,10 +41,10 @@ class Database {
   class DatabaseState;
 
  public:
-  class iterator :
-      public std::iterator<std::forward_iterator_tag, pair<string, string> > {
+  class iterator : public std::iterator<
+      std::forward_iterator_tag, pair<buffer_t, buffer_t> > {
    public:
-    typedef pair<string, string> T;
+    typedef pair<buffer_t, buffer_t> T;
     typedef T value_type;
     typedef T& reference_type;
     typedef T* pointer_type;
@@ -94,7 +97,7 @@ class Database {
   virtual bool equal(shared_ptr<DatabaseState> state1,
       shared_ptr<DatabaseState> state2) const = 0;
   virtual void increment(shared_ptr<DatabaseState> state) const = 0;
-  virtual pair<string, string>& dereference(
+  virtual pair<buffer_t, buffer_t>& dereference(
       shared_ptr<DatabaseState> state) const = 0;
 };
 
index ee8c5f2..f30273d 100644 (file)
@@ -15,7 +15,7 @@ namespace caffe {
 class LeveldbDatabase : public Database {
  public:
   void open(const string& filename, Mode mode);
-  void put(const string& key, const string& value);
+  void put(buffer_t* key, buffer_t* value);
   void commit();
   void close();
 
@@ -34,13 +34,13 @@ class LeveldbDatabase : public Database {
           iter_(iter) { }
 
     shared_ptr<leveldb::Iterator> iter_;
-    pair<string, string> kv_pair_;
+    pair<buffer_t, buffer_t> kv_pair_;
   };
 
   bool equal(shared_ptr<DatabaseState> state1,
       shared_ptr<DatabaseState> state2) const;
   void increment(shared_ptr<DatabaseState> state) const;
-  pair<string, string>& dereference(shared_ptr<DatabaseState> state) const;
+  pair<buffer_t, buffer_t>& dereference(shared_ptr<DatabaseState> state) const;
 
   shared_ptr<leveldb::DB> db_;
   shared_ptr<leveldb::WriteBatch> batch_;
index 7387afd..ee3806d 100644 (file)
@@ -20,7 +20,7 @@ class LmdbDatabase : public Database {
   ~LmdbDatabase() { this->close(); }
 
   void open(const string& filename, Mode mode);
-  void put(const string& key, const string& value);
+  void put(buffer_t* key, buffer_t* value);
   void commit();
   void close();
 
@@ -37,13 +37,13 @@ class LmdbDatabase : public Database {
           cursor_(cursor) { }
 
     MDB_cursor* cursor_;
-    pair<string, string> kv_pair_;
+    pair<buffer_t, buffer_t> kv_pair_;
   };
 
   bool equal(shared_ptr<DatabaseState> state1,
       shared_ptr<DatabaseState> state2) const;
   void increment(shared_ptr<DatabaseState> state) const;
-  pair<string, string>& dereference(shared_ptr<DatabaseState> state) const;
+  pair<buffer_t, buffer_t>& dereference(shared_ptr<DatabaseState> state) const;
 
   MDB_env *env_;
   MDB_dbi dbi_;
index 1124bef..b64c821 100644 (file)
@@ -24,27 +24,27 @@ using ::google::protobuf::Message;
 inline void MakeTempFilename(string* temp_filename) {
   temp_filename->clear();
   *temp_filename = "/tmp/caffe_test.XXXXXX";
-  char* temp_filename_cstr = new char[temp_filename->size()];
+  char* temp_filename_cstr = new char[temp_filename->size() + 1];
   // NOLINT_NEXT_LINE(runtime/printf)
   strcpy(temp_filename_cstr, temp_filename->c_str());
   int fd = mkstemp(temp_filename_cstr);
   CHECK_GE(fd, 0) << "Failed to open a temporary file at: " << *temp_filename;
   close(fd);
   *temp_filename = temp_filename_cstr;
-  delete temp_filename_cstr;
+  delete[] temp_filename_cstr;
 }
 
 inline void MakeTempDir(string* temp_dirname) {
   temp_dirname->clear();
   *temp_dirname = "/tmp/caffe_test.XXXXXX";
-  char* temp_dirname_cstr = new char[temp_dirname->size()];
+  char* temp_dirname_cstr = new char[temp_dirname->size() + 1];
   // NOLINT_NEXT_LINE(runtime/printf)
   strcpy(temp_dirname_cstr, temp_dirname->c_str());
   char* mkdtemp_result = mkdtemp(temp_dirname_cstr);
   CHECK(mkdtemp_result != NULL)
       << "Failed to create a temporary directory at: " << *temp_dirname;
   *temp_dirname = temp_dirname_cstr;
-  delete temp_dirname_cstr;
+  delete[] temp_dirname_cstr;
 }
 
 bool ReadProtoFromTextFile(const char* filename, Message* proto);
index 1d37170..4d36b8e 100644 (file)
@@ -40,7 +40,6 @@ void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
                         this->layer_param_.data_param().rand_skip();
     LOG(INFO) << "Skipping first " << skip << " data points.";
     while (skip-- > 0) {
-      LOG(INFO) << iter_->first;
       if (++iter_ == database_->end()) {
         iter_ = database_->begin();
       }
@@ -49,7 +48,7 @@ void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
   // Read a data point, and use it to initialize the top blob.
   CHECK(iter_ != database_->end());
   Datum datum;
-  datum.ParseFromString(iter_->second);
+  datum.ParseFromArray(iter_->second.data(), iter_->second.size());
 
   // image
   int crop_size = this->layer_param_.transform_param().crop_size();
@@ -95,7 +94,7 @@ void DataLayer<Dtype>::InternalThreadEntry() {
   for (int item_id = 0; item_id < batch_size; ++item_id) {
     Datum datum;
     CHECK(iter_ != database_->end());
-    datum.ParseFromString(iter_->second);
+    datum.ParseFromArray(iter_->second.data(), iter_->second.size());
 
     // Apply data transformations (mirror, scale, crop...)
     int offset = this->prefetch_data_.offset(item_id);
index be7ac7f..a8fe02a 100644 (file)
@@ -42,11 +42,15 @@ void LeveldbDatabase::open(const string& filename, Mode mode) {
   batch_.reset(new leveldb::WriteBatch());
 }
 
-void LeveldbDatabase::put(const string& key, const string& value) {
-  LOG(INFO) << "LevelDB: Put " << key;
+void LeveldbDatabase::put(buffer_t* key, buffer_t* value) {
+  LOG(INFO) << "LevelDB: Put";
 
   CHECK_NOTNULL(batch_.get());
-  batch_->Put(key, value);
+
+  leveldb::Slice key_slice(key->data(), key->size());
+  leveldb::Slice value_slice(value->data(), value->size());
+
+  batch_->Put(key_slice, value_slice);
 }
 
 void LeveldbDatabase::commit() {
@@ -130,7 +134,7 @@ void LeveldbDatabase::increment(shared_ptr<DatabaseState> state) const {
   }
 }
 
-pair<string, string>& LeveldbDatabase::dereference(
+pair<Database::buffer_t, Database::buffer_t>& LeveldbDatabase::dereference(
     shared_ptr<DatabaseState> state) const {
   shared_ptr<LeveldbState> leveldb_state =
       boost::dynamic_pointer_cast<LeveldbState>(state);
@@ -143,8 +147,10 @@ pair<string, string>& LeveldbDatabase::dereference(
 
   CHECK(iter->Valid());
 
-  leveldb_state->kv_pair_ = make_pair(iter->key().ToString(),
-      iter->value().ToString());
+  leveldb_state->kv_pair_ = make_pair(
+      buffer_t(iter->key().data(), iter->key().data() + iter->key().size()),
+      buffer_t(iter->value().data(),
+          iter->value().data() + iter->value().size()));
   return leveldb_state->kv_pair_;
 }
 
index 796bbc9..7197a47 100644 (file)
@@ -19,9 +19,13 @@ void LmdbDatabase::open(const string& filename, Mode mode) {
                                                 << "failed";
   }
 
-  CHECK_EQ(mdb_env_create(&env_), MDB_SUCCESS) << "mdb_env_create failed";
-  CHECK_EQ(mdb_env_set_mapsize(env_, 1099511627776), MDB_SUCCESS)  // 1TB
-      << "mdb_env_set_mapsize failed";
+  int retval;
+  retval = mdb_env_create(&env_);
+  CHECK_EQ(retval, MDB_SUCCESS) << "mdb_env_create failed "
+      << mdb_strerror(retval);
+  retval = mdb_env_set_mapsize(env_, 1099511627776);
+  CHECK_EQ(retval, MDB_SUCCESS)  // 1TB
+      << "mdb_env_set_mapsize failed " << mdb_strerror(retval);
 
   int flag1 = 0;
   int flag2 = 0;
@@ -30,27 +34,31 @@ void LmdbDatabase::open(const string& filename, Mode mode) {
     flag2 = MDB_RDONLY;
   }
 
-  CHECK_EQ(mdb_env_open(env_, filename.c_str(), flag1, 0664), MDB_SUCCESS)
-      << "mdb_env_open failed";
-  CHECK_EQ(mdb_txn_begin(env_, NULL, flag2, &txn_), MDB_SUCCESS)
-      << "mdb_txn_begin failed";
-  CHECK_EQ(mdb_open(txn_, NULL, 0, &dbi_), MDB_SUCCESS) << "mdb_open failed";
+  retval = mdb_env_open(env_, filename.c_str(), flag1, 0664);
+  CHECK_EQ(retval, MDB_SUCCESS)
+      << "mdb_env_open failed " << mdb_strerror(retval);
+  retval = mdb_txn_begin(env_, NULL, flag2, &txn_);
+  CHECK_EQ(retval, MDB_SUCCESS)
+      << "mdb_txn_begin failed " << mdb_strerror(retval);
+  retval = mdb_open(txn_, NULL, 0, &dbi_);
+  CHECK_EQ(retval, MDB_SUCCESS) << "mdb_open failed" << mdb_strerror(retval);
 }
 
-void LmdbDatabase::put(const string& key, const string& value) {
-  LOG(INFO) << "LMDB: Put " << key;
+void LmdbDatabase::put(buffer_t* key, buffer_t* value) {
+  LOG(INFO) << "LMDB: Put";
 
   MDB_val mdbkey, mdbdata;
-  mdbdata.mv_size = value.size();
-  mdbdata.mv_data = const_cast<char*>(&value[0]);
-  mdbkey.mv_size = key.size();
-  mdbkey.mv_data = const_cast<char*>(&key[0]);
+  mdbdata.mv_size = value->size();
+  mdbdata.mv_data = value->data();
+  mdbkey.mv_size = key->size();
+  mdbkey.mv_data = key->data();
 
   CHECK_NOTNULL(txn_);
   CHECK_NE(0, dbi_);
 
-  CHECK_EQ(mdb_put(txn_, dbi_, &mdbkey, &mdbdata, 0), MDB_SUCCESS)
-      << "mdb_put failed";
+  int retval = mdb_put(txn_, dbi_, &mdbkey, &mdbdata, 0);
+  CHECK_EQ(retval, MDB_SUCCESS)
+      << "mdb_put failed " << mdb_strerror(retval);
 }
 
 void LmdbDatabase::commit() {
@@ -58,7 +66,9 @@ void LmdbDatabase::commit() {
 
   CHECK_NOTNULL(txn_);
 
-  CHECK_EQ(mdb_txn_commit(txn_), MDB_SUCCESS) << "mdb_txn_commit failed";
+  int retval = mdb_txn_commit(txn_);
+  CHECK_EQ(retval, MDB_SUCCESS) << "mdb_txn_commit failed "
+      << mdb_strerror(retval);
 }
 
 void LmdbDatabase::close() {
@@ -79,10 +89,13 @@ void LmdbDatabase::close() {
 
 LmdbDatabase::const_iterator LmdbDatabase::begin() const {
   MDB_cursor* cursor;
-  CHECK_EQ(mdb_cursor_open(txn_, dbi_, &cursor), MDB_SUCCESS);
+  int retval;
+  retval = mdb_cursor_open(txn_, dbi_, &cursor);
+  CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval);
   MDB_val key;
   MDB_val val;
-  CHECK_EQ(mdb_cursor_get(cursor, &key, &val, MDB_FIRST), MDB_SUCCESS);
+  retval = mdb_cursor_get(cursor, &key, &val, MDB_FIRST);
+  CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval);
 
   shared_ptr<DatabaseState> state(new LmdbState(cursor));
   return const_iterator(this, state);
@@ -122,15 +135,20 @@ void LmdbDatabase::increment(shared_ptr<DatabaseState> state) const {
 
   MDB_cursor*& cursor = lmdb_state->cursor_;
 
+  CHECK_NOTNULL(cursor);
+
   MDB_val key;
   MDB_val val;
-  if (MDB_SUCCESS != mdb_cursor_get(cursor, &key, &val, MDB_NEXT)) {
+  int retval = mdb_cursor_get(cursor, &key, &val, MDB_NEXT);
+  if (MDB_NOTFOUND == retval) {
     mdb_cursor_close(cursor);
     cursor = NULL;
+  } else {
+    CHECK_EQ(MDB_SUCCESS, retval) << mdb_strerror(retval);
   }
 }
 
-pair<string, string>& LmdbDatabase::dereference(
+pair<Database::buffer_t, Database::buffer_t>& LmdbDatabase::dereference(
     shared_ptr<DatabaseState> state) const {
   shared_ptr<LmdbState> lmdb_state =
       boost::dynamic_pointer_cast<LmdbState>(state);
@@ -139,14 +157,19 @@ pair<string, string>& LmdbDatabase::dereference(
 
   MDB_cursor*& cursor = lmdb_state->cursor_;
 
+  CHECK_NOTNULL(cursor);
+
   MDB_val mdb_key;
   MDB_val mdb_val;
-  CHECK_EQ(mdb_cursor_get(cursor, &mdb_key, &mdb_val, MDB_GET_CURRENT),
-      MDB_SUCCESS);
+  int retval = mdb_cursor_get(cursor, &mdb_key, &mdb_val, MDB_GET_CURRENT);
+  CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval);
+
+  char* key_data = reinterpret_cast<char*>(mdb_key.mv_data);
+  char* value_data = reinterpret_cast<char*>(mdb_val.mv_data);
 
   lmdb_state->kv_pair_ = make_pair(
-    string(reinterpret_cast<char*>(mdb_key.mv_data), mdb_key.mv_size),
-    string(reinterpret_cast<char*>(mdb_val.mv_data), mdb_val.mv_size));
+    buffer_t(key_data, key_data + mdb_key.mv_size),
+    buffer_t(value_data, value_data + mdb_val.mv_size));
 
   return lmdb_state->kv_pair_;
 }
index d99b5e3..c17f729 100644 (file)
@@ -54,7 +54,12 @@ class DataLayerTest : public MultiDeviceTest<TypeParam> {
       }
       stringstream ss;
       ss << i;
-      database->put(ss.str(), datum.SerializeAsString());
+      string key_str = ss.str();
+      Database::buffer_t key(key_str.c_str(), key_str.c_str() + key_str.size());
+      Database::buffer_t value(datum.ByteSize());
+      datum.SerializeWithCachedSizesToArray(
+          reinterpret_cast<unsigned char*>(value.data()));
+      database->put(&key, &value);
     }
     database->close();
   }
index e59bbf1..f981af4 100644 (file)
@@ -35,7 +35,8 @@ int main(int argc, char** argv) {
   BlobProto sum_blob;
   int count = 0;
   // load first datum
-  datum.ParseFromString(database->begin()->second);
+  const Database::buffer_t& first_blob = database->begin()->second;
+  datum.ParseFromArray(first_blob.data(), first_blob.size());
 
   sum_blob.set_num(1);
   sum_blob.set_channels(datum.channels());
@@ -51,7 +52,8 @@ int main(int argc, char** argv) {
   for (Database::const_iterator iter = database->begin();
       iter != database->end(); ++iter) {
     // just a dummy operation
-    datum.ParseFromString(iter->second);
+    const Database::buffer_t& blob = iter->second;
+    datum.ParseFromArray(blob.data(), blob.size());
     const std::string& data = datum.data();
     size_in_datum = std::max<int>(datum.data().size(),
         datum.float_data_size());
index 6f03a9d..19c87e5 100644 (file)
@@ -108,14 +108,15 @@ int main(int argc, char** argv) {
       }
     }
     // sequential
-    snprintf(key_cstr, kMaxKeyLength, "%08d_%s", line_id,
+    int length = snprintf(key_cstr, kMaxKeyLength, "%08d_%s", line_id,
         lines[line_id].first.c_str());
-    std::string value;
-    datum.SerializeToString(&value);
-    std::string keystr(key_cstr);
+    Database::buffer_t value(datum.ByteSize());
+    datum.SerializeWithCachedSizesToArray(
+        reinterpret_cast<unsigned char*>(value.data()));
+    Database::buffer_t keystr(key_cstr, key_cstr + length);
 
     // Put in db
-    database->put(keystr, value);
+    database->put(&keystr, &value);
 
     if (++count % 1000 == 0) {
       // Commit txn
index b3ad8e6..1065d44 100644 (file)
@@ -155,10 +155,13 @@ int feature_extraction_pipeline(int argc, char** argv) {
         for (int d = 0; d < dim_features; ++d) {
           datum.add_float_data(feature_blob_data[d]);
         }
-        std::string value;
-        datum.SerializeToString(&value);
-        snprintf(key_str, kMaxKeyStrLength, "%d", image_indices[i]);
-        feature_dbs.at(i)->put(std::string(key_str), value);
+        Database::buffer_t value(datum.ByteSize());
+        datum.SerializeWithCachedSizesToArray(
+            reinterpret_cast<unsigned char*>(value.data()));
+        int length = snprintf(key_str, kMaxKeyStrLength, "%d",
+            image_indices[i]);
+        Database::buffer_t key(key_str, key_str + length);
+        feature_dbs.at(i)->put(&key, &value);
         ++image_indices[i];
         if (image_indices[i] % 1000 == 0) {
           feature_dbs.at(i)->commit();