Templated the key and value types for the Database interface. The Database is now...
authorKevin James Matzen <kmatzen@cs.cornell.edu>
Mon, 13 Oct 2014 17:16:04 +0000 (13:16 -0400)
committerKevin James Matzen <kmatzen@cs.cornell.edu>
Tue, 14 Oct 2014 23:35:23 +0000 (19:35 -0400)
15 files changed:
examples/cifar10/convert_cifar_data.cpp
include/caffe/data_layers.hpp
include/caffe/database.hpp
include/caffe/database_factory.hpp
include/caffe/leveldb_database.hpp
include/caffe/lmdb_database.hpp
src/caffe/database_factory.cpp
src/caffe/layers/data_layer.cpp
src/caffe/leveldb_database.cpp
src/caffe/lmdb_database.cpp
src/caffe/test/test_data_layer.cpp
src/caffe/test/test_database.cpp
tools/compute_image_mean.cpp
tools/convert_imageset.cpp
tools/extract_features.cpp

index 46ab155..f86f393 100644 (file)
@@ -20,6 +20,7 @@ using std::string;
 
 using caffe::Database;
 using caffe::DatabaseFactory;
+using caffe::Datum;
 using caffe::shared_ptr;
 
 const int kCIFARSize = 32;
@@ -37,13 +38,14 @@ void read_image(std::ifstream* file, int* label, char* buffer) {
 
 void convert_dataset(const string& input_folder, const string& output_folder,
     const string& db_type) {
-  shared_ptr<Database> train_database = DatabaseFactory(db_type);
+  shared_ptr<Database<string, Datum> > train_database =
+      DatabaseFactory<string, Datum>(db_type);
   CHECK(train_database->open(output_folder + "/cifar10_train_" + db_type,
-      Database::New));
+      Database<string, Datum>::New));
   // Data buffer
   int label;
   char str_buffer[kCIFARImageNBytes];
-  caffe::Datum datum;
+  Datum datum;
   datum.set_channels(3);
   datum.set_height(kCIFARSize);
   datum.set_width(kCIFARSize);
@@ -60,22 +62,19 @@ void convert_dataset(const string& input_folder, const string& output_folder,
       read_image(&data_file, &label, str_buffer);
       datum.set_label(label);
       datum.set_data(str_buffer, kCIFARImageNBytes);
-      Database::value_type value(datum.ByteSize());
-      datum.SerializeWithCachedSizesToArray(
-          reinterpret_cast<unsigned char*>(value.data()));
       int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d",
           fileid * kCIFARBatchSize + itemid);
-      Database::key_type key(str_buffer, str_buffer + length);
-      CHECK(train_database->put(key, value));
+      CHECK(train_database->put(string(str_buffer, length), datum));
     }
   }
   CHECK(train_database->commit());
   train_database->close();
 
   LOG(INFO) << "Writing Testing data";
-  shared_ptr<Database> test_database = DatabaseFactory(db_type);
+  shared_ptr<Database<string, Datum> > test_database =
+      DatabaseFactory<string, Datum>(db_type);
   CHECK(test_database->open(output_folder + "/cifar10_test_" + db_type,
-      Database::New));
+      Database<string, Datum>::New));
   // Open files
   std::ifstream data_file((input_folder + "/test_batch.bin").c_str(),
       std::ios::in | std::ios::binary);
@@ -84,12 +83,8 @@ void convert_dataset(const string& input_folder, const string& output_folder,
     read_image(&data_file, &label, str_buffer);
     datum.set_label(label);
     datum.set_data(str_buffer, kCIFARImageNBytes);
-    Database::value_type value(datum.ByteSize());
-    datum.SerializeWithCachedSizesToArray(
-        reinterpret_cast<unsigned char*>(value.data()));
     int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d", itemid);
-    Database::key_type key(str_buffer, str_buffer + length);
-    CHECK(test_database->put(key, value));
+    CHECK(test_database->put(string(str_buffer, length), datum));
   }
   CHECK(test_database->commit());
   test_database->close();
index 810f2bb..13ba4e6 100644 (file)
@@ -100,8 +100,8 @@ class DataLayer : public BasePrefetchingDataLayer<Dtype> {
  protected:
   virtual void InternalThreadEntry();
 
-  shared_ptr<Database> database_;
-  Database::const_iterator iter_;
+  shared_ptr<Database<string, Datum> > database_;
+  Database<string, Datum>::const_iterator iter_;
 };
 
 /**
index 8469e84..711ebdd 100644 (file)
 
 namespace caffe {
 
+namespace database_internal {
+
+template <typename T>
+struct Coder {
+  static bool serialize(const T& obj, string* serialized) {
+    return obj.SerializeToString(serialized);
+  }
+
+  static bool serialize(const T& obj, vector<char>* serialized) {
+    serialized->resize(obj.ByteSize());
+    return obj.SerializeWithCachedSizesToArray(
+        reinterpret_cast<unsigned char*>(serialized->data()));
+  }
+
+  static bool deserialize(const string& serialized, T* obj) {
+    return obj->ParseFromString(serialized);
+  }
+
+  static bool deserialize(const char* data, size_t size, T* obj) {
+    return obj->ParseFromArray(data, size);
+  }
+};
+
+template <>
+struct Coder<string> {
+  static bool serialize(string obj, string* serialized) {
+    *serialized = obj;
+    return true;
+  }
+
+  static bool serialize(const string& obj, vector<char>* serialized) {
+    vector<char> temp(obj.data(), obj.data() + obj.size());
+    serialized->swap(temp);
+    return true;
+  }
+
+  static bool deserialize(const string& serialized, string* obj) {
+    *obj = serialized;
+    return true;
+  }
+
+  static bool deserialize(const char* data, size_t size, string* obj) {
+    string temp_string(data, size);
+    obj->swap(temp_string);
+    return true;
+  }
+};
+
+template <>
+struct Coder<vector<char> > {
+  static bool serialize(vector<char> obj, string* serialized) {
+    string tmp(obj.data(), obj.size());
+    serialized->swap(tmp);
+    return true;
+  }
+
+  static bool serialize(const vector<char>& obj, vector<char>* serialized) {
+    *serialized = obj;
+    return true;
+  }
+
+  static bool deserialize(const string& serialized, vector<char>* obj) {
+    vector<char> tmp(serialized.data(), serialized.data() + serialized.size());
+    obj->swap(tmp);
+    return true;
+  }
+
+  static bool deserialize(const char* data, size_t size, vector<char>* obj) {
+    vector<char> tmp(data, data + size);
+    obj->swap(tmp);
+    return true;
+  }
+};
+
+}  // namespace database_internal
+
+template <typename K, typename V>
 class Database {
  public:
   enum Mode {
@@ -19,21 +96,21 @@ class Database {
     ReadOnly
   };
 
-  typedef vector<char> key_type;
-  typedef vector<char> value_type;
+  typedef K key_type;
+  typedef V value_type;
 
   struct KV {
-    key_type key;
-    value_type value;
+    K key;
+    V value;
   };
 
   virtual bool open(const string& filename, Mode mode) = 0;
-  virtual bool put(const key_type& key, const value_type& value) = 0;
-  virtual bool get(const key_type& key, value_type* value) = 0;
+  virtual bool put(const K& key, const V& value) = 0;
+  virtual bool get(const K& key, V* value) = 0;
   virtual bool commit() = 0;
   virtual void close() = 0;
 
-  virtual void keys(vector<key_type>* keys) = 0;
+  virtual void keys(vector<K>* keys) = 0;
 
   Database() { }
   virtual ~Database() { }
@@ -123,8 +200,33 @@ class Database {
   virtual void increment(shared_ptr<DatabaseState>* state) const = 0;
   virtual KV& dereference(
       shared_ptr<DatabaseState> state) const = 0;
+
+  template <typename T>
+  static bool serialize(const T& obj, string* serialized) {
+    return database_internal::Coder<T>::serialize(obj, serialized);
+  }
+
+  template <typename T>
+  static bool serialize(const T& obj, vector<char>* serialized) {
+    return database_internal::Coder<T>::serialize(obj, serialized);
+  }
+
+  template <typename T>
+  static bool deserialize(const string& serialized, T* obj) {
+    return database_internal::Coder<T>::deserialize(serialized, obj);
+  }
+
+  template <typename T>
+  static bool deserialize(const char* data, size_t size, T* obj) {
+    return database_internal::Coder<T>::deserialize(data, size, obj);
+  }
 };
 
 }  // namespace caffe
 
+#define INSTANTIATE_DATABASE(type) \
+  template class type<string, string>; \
+  template class type<string, vector<char> >; \
+  template class type<string, Datum>;
+
 #endif  // CAFFE_DATABASE_H_
index 91185e9..30f191e 100644 (file)
@@ -9,8 +9,11 @@
 
 namespace caffe {
 
-shared_ptr<Database> DatabaseFactory(const DataParameter_DB& type);
-shared_ptr<Database> DatabaseFactory(const string& type);
+template <typename K, typename V>
+shared_ptr<Database<K, V> > DatabaseFactory(const DataParameter_DB& type);
+
+template <typename K, typename V>
+shared_ptr<Database<K, V> > DatabaseFactory(const string& type);
 
 }  // namespace caffe
 
index 48cf11e..cd966c4 100644 (file)
 
 namespace caffe {
 
-class LeveldbDatabase : public Database {
+template <typename K, typename V>
+class LeveldbDatabase : public Database<K, V> {
  public:
+  typedef Database<K, V> Base;
+  typedef typename Base::key_type key_type;
+  typedef typename Base::value_type value_type;
+  typedef typename Base::DatabaseState DatabaseState;
+  typedef typename Base::Mode Mode;
+  typedef typename Base::const_iterator const_iterator;
+  typedef typename Base::KV KV;
+
   bool open(const string& filename, Mode mode);
-  bool put(const key_type& key, const value_type& value);
-  bool get(const key_type& key, value_type* value);
+  bool put(const K& key, const V& value);
+  bool get(const K& key, V* value);
   bool commit();
   void close();
 
-  void keys(vector<key_type>* keys);
+  void keys(vector<K>* keys);
 
   const_iterator begin() const;
   const_iterator cbegin() const;
@@ -29,11 +38,11 @@ class LeveldbDatabase : public Database {
   const_iterator cend() const;
 
  protected:
-  class LeveldbState : public Database::DatabaseState {
+  class LeveldbState : public DatabaseState {
    public:
     explicit LeveldbState(shared_ptr<leveldb::DB> db,
         shared_ptr<leveldb::Iterator> iter)
-        : Database::DatabaseState(),
+        : DatabaseState(),
           db_(db),
           iter_(iter) { }
 
@@ -66,7 +75,7 @@ class LeveldbDatabase : public Database {
   bool equal(shared_ptr<DatabaseState> state1,
       shared_ptr<DatabaseState> state2) const;
   void increment(shared_ptr<DatabaseState>* state) const;
-  Database::KV& dereference(shared_ptr<DatabaseState> state) const;
+  KV& dereference(shared_ptr<DatabaseState> state) const;
 
   shared_ptr<leveldb::DB> db_;
   shared_ptr<leveldb::WriteBatch> batch_;
index a8ce60d..b5a02ca 100644 (file)
 
 namespace caffe {
 
-class LmdbDatabase : public Database {
+template <typename K, typename V>
+class LmdbDatabase : public Database<K, V> {
  public:
+  typedef Database<K, V> Base;
+  typedef typename Base::key_type key_type;
+  typedef typename Base::value_type value_type;
+  typedef typename Base::DatabaseState DatabaseState;
+  typedef typename Base::Mode Mode;
+  typedef typename Base::const_iterator const_iterator;
+  typedef typename Base::KV KV;
+
   LmdbDatabase()
       : env_(NULL),
         dbi_(0),
         txn_(NULL) { }
 
   bool open(const string& filename, Mode mode);
-  bool put(const key_type& key, const value_type& value);
-  bool get(const key_type& key, value_type* value);
+  bool put(const K& key, const V& value);
+  bool get(const K& key, V* value);
   bool commit();
   void close();
 
-  void keys(vector<key_type>* keys);
+  void keys(vector<K>* keys);
 
   const_iterator begin() const;
   const_iterator cbegin() const;
@@ -33,10 +42,10 @@ class LmdbDatabase : public Database {
   const_iterator cend() const;
 
  protected:
-  class LmdbState : public Database::DatabaseState {
+  class LmdbState : public DatabaseState {
    public:
     explicit LmdbState(MDB_cursor* cursor, MDB_txn* txn, const MDB_dbi* dbi)
-        : Database::DatabaseState(),
+        : DatabaseState(),
           cursor_(cursor),
           txn_(txn),
           dbi_(dbi) { }
@@ -70,7 +79,7 @@ class LmdbDatabase : public Database {
   bool equal(shared_ptr<DatabaseState> state1,
       shared_ptr<DatabaseState> state2) const;
   void increment(shared_ptr<DatabaseState>* state) const;
-  Database::KV& dereference(shared_ptr<DatabaseState> state) const;
+  KV& dereference(shared_ptr<DatabaseState> state) const;
 
   MDB_env* env_;
   MDB_dbi dbi_;
index 062de8c..4ccd429 100644 (file)
@@ -1,5 +1,6 @@
 #include <string>
 #include <utility>
+#include <vector>
 
 #include "caffe/database_factory.hpp"
 #include "caffe/leveldb_database.hpp"
@@ -7,29 +8,43 @@
 
 namespace caffe {
 
-shared_ptr<Database> DatabaseFactory(const DataParameter_DB& type) {
+template <typename K, typename V>
+shared_ptr<Database<K, V> > DatabaseFactory(const DataParameter_DB& type) {
   switch (type) {
   case DataParameter_DB_LEVELDB:
-    return shared_ptr<Database>(new LeveldbDatabase());
+    return shared_ptr<Database<K, V> >(new LeveldbDatabase<K, V>());
   case DataParameter_DB_LMDB:
-    return shared_ptr<Database>(new LmdbDatabase());
+    return shared_ptr<Database<K, V> >(new LmdbDatabase<K, V>());
   default:
     LOG(FATAL) << "Unknown database type " << type;
-    return shared_ptr<Database>();
+    return shared_ptr<Database<K, V> >();
   }
 }
 
-shared_ptr<Database> DatabaseFactory(const string& type) {
+template <typename K, typename V>
+shared_ptr<Database<K, V> > DatabaseFactory(const string& type) {
   if ("leveldb" == type) {
-    return DatabaseFactory(DataParameter_DB_LEVELDB);
+    return DatabaseFactory<K, V>(DataParameter_DB_LEVELDB);
   } else if ("lmdb" == type) {
-    return DatabaseFactory(DataParameter_DB_LMDB);
+    return DatabaseFactory<K, V>(DataParameter_DB_LMDB);
   } else {
     LOG(FATAL) << "Unknown database type " << type;
-    return shared_ptr<Database>();
+    return shared_ptr<Database<K, V> >();
   }
 }
 
+#define REGISTER_DATABASE(key_type, value_type) \
+  template shared_ptr<Database<key_type, value_type> > \
+      DatabaseFactory(const string& type); \
+  template shared_ptr<Database<key_type, value_type> > \
+      DatabaseFactory(const DataParameter_DB& type); \
+
+REGISTER_DATABASE(string, string);
+REGISTER_DATABASE(string, vector<char>);
+REGISTER_DATABASE(string, Datum);
+
+#undef REGISTER_DATABASE
+
 }  // namespace caffe
 
 
index c379d78..6296335 100644 (file)
@@ -25,10 +25,11 @@ template <typename Dtype>
 void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top) {
   // Initialize DB
-  database_ = DatabaseFactory(this->layer_param_.data_param().backend());
+  database_ = DatabaseFactory<string, Datum>(
+      this->layer_param_.data_param().backend());
   const string& source = this->layer_param_.data_param().source();
   LOG(INFO) << "Opening database " << source;
-  CHECK(database_->open(source, Database::ReadOnly));
+  CHECK(database_->open(source, Database<string, Datum>::ReadOnly));
   iter_ = database_->begin();
 
   // Check if we would need to randomly skip a few data points
@@ -44,8 +45,7 @@ void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
   }
   // Read a data point, and use it to initialize the top blob.
   CHECK(iter_ != database_->end());
-  Datum datum;
-  datum.ParseFromArray(iter_->value.data(), iter_->value.size());
+  const Datum& datum = iter_->value;
 
   // image
   int crop_size = this->layer_param_.transform_param().crop_size();
@@ -89,9 +89,8 @@ void DataLayer<Dtype>::InternalThreadEntry() {
   const int batch_size = this->layer_param_.data_param().batch_size();
 
   for (int item_id = 0; item_id < batch_size; ++item_id) {
-    Datum datum;
     CHECK(iter_ != database_->end());
-    datum.ParseFromArray(iter_->value.data(), iter_->value.size());
+    const Datum& datum = iter_->value;
 
     // Apply data transformations (mirror, scale, crop...)
     int offset = this->prefetch_data_.offset(item_id);
index a163abc..267ddf0 100644 (file)
@@ -2,28 +2,30 @@
 #include <utility>
 #include <vector>
 
+#include "caffe/caffe.hpp"
 #include "caffe/leveldb_database.hpp"
 
 namespace caffe {
 
-bool LeveldbDatabase::open(const string& filename, Mode mode) {
+template <typename K, typename V>
+bool LeveldbDatabase<K, V>::open(const string& filename, Mode mode) {
   DLOG(INFO) << "LevelDB: Open " << filename;
 
   leveldb::Options options;
   switch (mode) {
-  case New:
+  case Base::New:
     DLOG(INFO) << " mode NEW";
     options.error_if_exists = true;
     options.create_if_missing = true;
     read_only_ = false;
     break;
-  case ReadWrite:
+  case Base::ReadWrite:
     DLOG(INFO) << " mode RW";
     options.error_if_exists = false;
     options.create_if_missing = true;
     read_only_ = false;
     break;
-  case ReadOnly:
+  case Base::ReadOnly:
     DLOG(INFO) << " mode RO";
     options.error_if_exists = false;
     options.create_if_missing = false;
@@ -52,7 +54,8 @@ bool LeveldbDatabase::open(const string& filename, Mode mode) {
   return true;
 }
 
-bool LeveldbDatabase::put(const key_type& key, const value_type& value) {
+template <typename K, typename V>
+bool LeveldbDatabase<K, V>::put(const K& key, const V& value) {
   DLOG(INFO) << "LevelDB: Put";
 
   if (read_only_) {
@@ -62,36 +65,48 @@ bool LeveldbDatabase::put(const key_type& key, const value_type& value) {
 
   CHECK_NOTNULL(batch_.get());
 
-  leveldb::Slice key_slice(key.data(), key.size());
-  leveldb::Slice value_slice(value.data(), value.size());
+  string serialized_key;
+  if (!Base::serialize(key, &serialized_key)) {
+    return false;
+  }
+
+  string serialized_value;
+  if (!Base::serialize(value, &serialized_value)) {
+    return false;
+  }
 
-  batch_->Put(key_slice, value_slice);
+  batch_->Put(serialized_key, serialized_value);
 
   return true;
 }
 
-bool LeveldbDatabase::get(const key_type& key, value_type* value) {
+template <typename K, typename V>
+bool LeveldbDatabase<K, V>::get(const K& key, V* value) {
   DLOG(INFO) << "LevelDB: Get";
 
-  leveldb::Slice key_slice(key.data(), key.size());
+  string serialized_key;
+  if (!Base::serialize(key, &serialized_key)) {
+    return false;
+  }
 
-  string value_string;
+  string serialized_value;
   leveldb::Status status =
-      db_->Get(leveldb::ReadOptions(), key_slice, &value_string);
+      db_->Get(leveldb::ReadOptions(), serialized_key, &serialized_value);
 
   if (!status.ok()) {
     LOG(ERROR) << "leveldb get failed";
     return false;
   }
 
-  Database::value_type temp_value(value_string.data(),
-      value_string.data() + value_string.size());
-  value->swap(temp_value);
+  if (!Base::deserialize(serialized_value, value)) {
+    return false;
+  }
 
   return true;
 }
 
-bool LeveldbDatabase::commit() {
+template <typename K, typename V>
+bool LeveldbDatabase<K, V>::commit() {
   DLOG(INFO) << "LevelDB: Commit";
 
   if (read_only_) {
@@ -109,23 +124,27 @@ bool LeveldbDatabase::commit() {
   return status.ok();
 }
 
-void LeveldbDatabase::close() {
+template <typename K, typename V>
+void LeveldbDatabase<K, V>::close() {
   DLOG(INFO) << "LevelDB: Close";
 
   batch_.reset();
   db_.reset();
 }
 
-void LeveldbDatabase::keys(vector<key_type>* keys) {
+template <typename K, typename V>
+void LeveldbDatabase<K, V>::keys(vector<K>* keys) {
   DLOG(INFO) << "LevelDB: Keys";
 
   keys->clear();
-  for (Database::const_iterator iter = begin(); iter != end(); ++iter) {
+  for (const_iterator iter = begin(); iter != end(); ++iter) {
     keys->push_back(iter->key);
   }
 }
 
-LeveldbDatabase::const_iterator LeveldbDatabase::begin() const {
+template <typename K, typename V>
+typename LeveldbDatabase<K, V>::const_iterator
+    LeveldbDatabase<K, V>::begin() const {
   CHECK_NOTNULL(db_.get());
   shared_ptr<leveldb::Iterator> iter(db_->NewIterator(leveldb::ReadOptions()));
   iter->SeekToFirst();
@@ -140,18 +159,25 @@ LeveldbDatabase::const_iterator LeveldbDatabase::begin() const {
   return const_iterator(this, state);
 }
 
-LeveldbDatabase::const_iterator LeveldbDatabase::end() const {
+template <typename K, typename V>
+typename LeveldbDatabase<K, V>::const_iterator
+    LeveldbDatabase<K, V>::end() const {
   shared_ptr<DatabaseState> state;
   return const_iterator(this, state);
 }
 
-LeveldbDatabase::const_iterator LeveldbDatabase::cbegin() const {
+template <typename K, typename V>
+typename LeveldbDatabase<K, V>::const_iterator
+    LeveldbDatabase<K, V>::cbegin() const {
   return begin();
 }
 
-LeveldbDatabase::const_iterator LeveldbDatabase::cend() const { return end(); }
+template <typename K, typename V>
+typename LeveldbDatabase<K, V>::const_iterator
+    LeveldbDatabase<K, V>::cend() const { return end(); }
 
-bool LeveldbDatabase::equal(shared_ptr<DatabaseState> state1,
+template <typename K, typename V>
+bool LeveldbDatabase<K, V>::equal(shared_ptr<DatabaseState> state1,
     shared_ptr<DatabaseState> state2) const {
   shared_ptr<LeveldbState> leveldb_state1 =
       boost::dynamic_pointer_cast<LeveldbState>(state1);
@@ -165,7 +191,8 @@ bool LeveldbDatabase::equal(shared_ptr<DatabaseState> state1,
   return !leveldb_state1 && !leveldb_state2;
 }
 
-void LeveldbDatabase::increment(shared_ptr<DatabaseState>* state) const {
+template <typename K, typename V>
+void LeveldbDatabase<K, V>::increment(shared_ptr<DatabaseState>* state) const {
   shared_ptr<LeveldbState> leveldb_state =
       boost::dynamic_pointer_cast<LeveldbState>(*state);
 
@@ -182,7 +209,8 @@ void LeveldbDatabase::increment(shared_ptr<DatabaseState>* state) const {
   }
 }
 
-Database::KV& LeveldbDatabase::dereference(
+template <typename K, typename V>
+typename Database<K, V>::KV& LeveldbDatabase<K, V>::dereference(
     shared_ptr<DatabaseState> state) const {
   shared_ptr<LeveldbState> leveldb_state =
       boost::dynamic_pointer_cast<LeveldbState>(state);
@@ -195,15 +223,16 @@ Database::KV& LeveldbDatabase::dereference(
 
   CHECK(iter->Valid());
 
-  Database::key_type temp_key(key_type(iter->key().data(),
-      iter->key().data() + iter->key().size()));
+  const leveldb::Slice& key = iter->key();
+  const leveldb::Slice& value = iter->value();
+  CHECK(Base::deserialize(key.data(), key.size(),
+      &leveldb_state->kv_pair_.key));
+  CHECK(Base::deserialize(value.data(), value.size(),
+      &leveldb_state->kv_pair_.value));
 
-  Database::value_type temp_value(value_type(iter->value().data(),
-      iter->value().data() + iter->value().size()));
-
-  leveldb_state->kv_pair_.key.swap(temp_key);
-  leveldb_state->kv_pair_.value.swap(temp_value);
   return leveldb_state->kv_pair_;
 }
 
+INSTANTIATE_DATABASE(LeveldbDatabase);
+
 }  // namespace caffe
index de94391..22a08b2 100644 (file)
@@ -4,11 +4,13 @@
 #include <utility>
 #include <vector>
 
+#include "caffe/caffe.hpp"
 #include "caffe/lmdb_database.hpp"
 
 namespace caffe {
 
-bool LmdbDatabase::open(const string& filename, Mode mode) {
+template <typename K, typename V>
+bool LmdbDatabase<K, V>::open(const string& filename, Mode mode) {
   DLOG(INFO) << "LMDB: Open " << filename;
 
   CHECK(NULL == env_);
@@ -16,16 +18,16 @@ bool LmdbDatabase::open(const string& filename, Mode mode) {
   CHECK_EQ(0, dbi_);
 
   int retval;
-  if (mode != ReadOnly) {
+  if (mode != Base::ReadOnly) {
     retval = mkdir(filename.c_str(), 0744);
     switch (mode) {
-    case New:
+    case Base::New:
       if (0 != retval) {
         LOG(ERROR) << "mkdir " << filename << " failed";
         return false;
       }
       break;
-    case ReadWrite:
+    case Base::ReadWrite:
       if (-1 == retval && EEXIST != errno) {
         LOG(ERROR) << "mkdir " << filename << " failed ("
             << strerror(errno) << ")";
@@ -52,7 +54,7 @@ bool LmdbDatabase::open(const string& filename, Mode mode) {
 
   int flag1 = 0;
   int flag2 = 0;
-  if (mode == ReadOnly) {
+  if (mode == Base::ReadOnly) {
     flag1 = MDB_RDONLY | MDB_NOTLS;
     flag2 = MDB_RDONLY;
   }
@@ -78,18 +80,27 @@ bool LmdbDatabase::open(const string& filename, Mode mode) {
   return true;
 }
 
-bool LmdbDatabase::put(const key_type& key, const value_type& value) {
+template <typename K, typename V>
+bool LmdbDatabase<K, V>::put(const K& key, const V& value) {
   DLOG(INFO) << "LMDB: Put";
 
-  // MDB_val::mv_size is not const, so we need to make a local copy.
-  key_type local_key = key;
-  value_type local_value = value;
+  vector<char> serialized_key;
+  if (!Base::serialize(key, &serialized_key)) {
+    LOG(ERROR) << "failed to serialize key";
+    return false;
+  }
+
+  vector<char> serialized_value;
+  if (!Base::serialize(value, &serialized_value)) {
+    LOG(ERROR) << "failed to serialized value";
+    return false;
+  }
 
   MDB_val mdbkey, mdbdata;
-  mdbdata.mv_size = local_value.size();
-  mdbdata.mv_data = local_value.data();
-  mdbkey.mv_size = local_key.size();
-  mdbkey.mv_data = local_key.data();
+  mdbdata.mv_size = serialized_value.size();
+  mdbdata.mv_data = serialized_value.data();
+  mdbkey.mv_size = serialized_key.size();
+  mdbkey.mv_data = serialized_key.data();
 
   CHECK_NOTNULL(txn_);
   CHECK_NE(0, dbi_);
@@ -103,14 +114,19 @@ bool LmdbDatabase::put(const key_type& key, const value_type& value) {
   return true;
 }
 
-bool LmdbDatabase::get(const key_type& key, value_type* value) {
+template <typename K, typename V>
+bool LmdbDatabase<K, V>::get(const K& key, V* value) {
   DLOG(INFO) << "LMDB: Get";
 
-  key_type local_key = key;
+  vector<char> serialized_key;
+  if (!Base::serialize(key, &serialized_key)) {
+    LOG(ERROR) << "failed to serialized key";
+    return false;
+  }
 
   MDB_val mdbkey, mdbdata;
-  mdbkey.mv_data = local_key.data();
-  mdbkey.mv_size = local_key.size();
+  mdbkey.mv_data = serialized_key.data();
+  mdbkey.mv_size = serialized_key.size();
 
   int retval;
   MDB_txn* get_txn;
@@ -128,15 +144,17 @@ bool LmdbDatabase::get(const key_type& key, value_type* value) {
 
   mdb_txn_abort(get_txn);
 
-  Database::value_type temp_value(reinterpret_cast<char*>(mdbdata.mv_data),
-      reinterpret_cast<char*>(mdbdata.mv_data) + mdbdata.mv_size);
-
-  value->swap(temp_value);
+  if (!Base::deserialize(reinterpret_cast<char*>(mdbdata.mv_data),
+      mdbdata.mv_size, value)) {
+    LOG(ERROR) << "failed to deserialize value";
+    return false;
+  }
 
   return true;
 }
 
-bool LmdbDatabase::commit() {
+template <typename K, typename V>
+bool LmdbDatabase<K, V>::commit() {
   DLOG(INFO) << "LMDB: Commit";
 
   CHECK_NOTNULL(txn_);
@@ -157,7 +175,8 @@ bool LmdbDatabase::commit() {
   return true;
 }
 
-void LmdbDatabase::close() {
+template <typename K, typename V>
+void LmdbDatabase<K, V>::close() {
   DLOG(INFO) << "LMDB: Close";
 
   if (env_ && dbi_) {
@@ -169,16 +188,19 @@ void LmdbDatabase::close() {
   }
 }
 
-void LmdbDatabase::keys(vector<key_type>* keys) {
+template <typename K, typename V>
+void LmdbDatabase<K, V>::keys(vector<K>* keys) {
   DLOG(INFO) << "LMDB: Keys";
 
   keys->clear();
-  for (Database::const_iterator iter = begin(); iter != end(); ++iter) {
+  for (const_iterator iter = begin(); iter != end(); ++iter) {
     keys->push_back(iter->key);
   }
 }
 
-LmdbDatabase::const_iterator LmdbDatabase::begin() const {
+template <typename K, typename V>
+typename LmdbDatabase<K, V>::const_iterator
+    LmdbDatabase<K, V>::begin() const {
   int retval;
 
   MDB_txn* iter_txn;
@@ -204,15 +226,23 @@ LmdbDatabase::const_iterator LmdbDatabase::begin() const {
   return const_iterator(this, state);
 }
 
-LmdbDatabase::const_iterator LmdbDatabase::end() const {
+template <typename K, typename V>
+typename LmdbDatabase<K, V>::const_iterator
+    LmdbDatabase<K, V>::end() const {
   shared_ptr<DatabaseState> state;
   return const_iterator(this, state);
 }
 
-LmdbDatabase::const_iterator LmdbDatabase::cbegin() const { return begin(); }
-LmdbDatabase::const_iterator LmdbDatabase::cend() const { return end(); }
+template <typename K, typename V>
+typename LmdbDatabase<K, V>::const_iterator
+    LmdbDatabase<K, V>::cbegin() const { return begin(); }
 
-bool LmdbDatabase::equal(shared_ptr<DatabaseState> state1,
+template <typename K, typename V>
+typename LmdbDatabase<K, V>::const_iterator
+    LmdbDatabase<K, V>::cend() const { return end(); }
+
+template <typename K, typename V>
+bool LmdbDatabase<K, V>::equal(shared_ptr<DatabaseState> state1,
     shared_ptr<DatabaseState> state2) const {
   shared_ptr<LmdbState> lmdb_state1 =
       boost::dynamic_pointer_cast<LmdbState>(state1);
@@ -226,7 +256,8 @@ bool LmdbDatabase::equal(shared_ptr<DatabaseState> state1,
   return !lmdb_state1 && !lmdb_state2;
 }
 
-void LmdbDatabase::increment(shared_ptr<DatabaseState>* state) const {
+template <typename K, typename V>
+void LmdbDatabase<K, V>::increment(shared_ptr<DatabaseState>* state) const {
   shared_ptr<LmdbState> lmdb_state =
       boost::dynamic_pointer_cast<LmdbState>(*state);
 
@@ -247,7 +278,9 @@ void LmdbDatabase::increment(shared_ptr<DatabaseState>* state) const {
   }
 }
 
-Database::KV& LmdbDatabase::dereference(shared_ptr<DatabaseState> state) const {
+template <typename K, typename V>
+typename Database<K, V>::KV& LmdbDatabase<K, V>::dereference(
+    shared_ptr<DatabaseState> state) const {
   shared_ptr<LmdbState> lmdb_state =
       boost::dynamic_pointer_cast<LmdbState>(state);
 
@@ -262,18 +295,14 @@ Database::KV& LmdbDatabase::dereference(shared_ptr<DatabaseState> state) const {
   int retval = mdb_cursor_get(cursor, &mdb_key, &mdb_val, MDB_GET_CURRENT);
   CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval);
 
-  char* key_data = reinterpret_cast<char*>(mdb_key.mv_data);
-  char* value_data = reinterpret_cast<char*>(mdb_val.mv_data);
-
-  Database::key_type temp_key(key_data, key_data + mdb_key.mv_size);
-
-  Database::value_type temp_value(value_data,
-      value_data + mdb_val.mv_size);
-
-  lmdb_state->kv_pair_.key.swap(temp_key);
-  lmdb_state->kv_pair_.value.swap(temp_value);
+  CHECK(Base::deserialize(reinterpret_cast<char*>(mdb_key.mv_data),
+      mdb_key.mv_size, &lmdb_state->kv_pair_.key));
+  CHECK(Base::deserialize(reinterpret_cast<char*>(mdb_val.mv_data),
+      mdb_val.mv_size, &lmdb_state->kv_pair_.value));
 
   return lmdb_state->kv_pair_;
 }
 
+INSTANTIATE_DATABASE(LmdbDatabase);
+
 }  // namespace caffe
index 5f98206..1c3ec1f 100644 (file)
@@ -39,8 +39,9 @@ class DataLayerTest : public MultiDeviceTest<TypeParam> {
   void Fill(const bool unique_pixels, DataParameter_DB backend) {
     backend_ = backend;
     LOG(INFO) << "Using temporary database " << *filename_;
-    shared_ptr<Database> database = DatabaseFactory(backend_);
-    CHECK(database->open(*filename_, Database::New));
+    shared_ptr<Database<string, Datum> > database =
+        DatabaseFactory<string, Datum>(backend_);
+    CHECK(database->open(*filename_, Database<string, Datum>::New));
     for (int i = 0; i < 5; ++i) {
       Datum datum;
       datum.set_label(i);
@@ -54,12 +55,7 @@ class DataLayerTest : public MultiDeviceTest<TypeParam> {
       }
       stringstream ss;
       ss << i;
-      string key_str = ss.str();
-      Database::key_type key(key_str.c_str(), key_str.c_str() + key_str.size());
-      Database::value_type value(datum.ByteSize());
-      datum.SerializeWithCachedSizesToArray(
-          reinterpret_cast<unsigned char*>(value.data()));
-      CHECK(database->put(key, value));
+      CHECK(database->put(ss.str(), datum));
     }
     CHECK(database->commit());
     database->close();
index 9c56910..9a7d6de 100644 (file)
 
 namespace caffe {
 
-template <typename TypeParam>
-class DatabaseTest : public MultiDeviceTest<TypeParam> {
-  typedef typename TypeParam::Dtype Dtype;
-
- protected:
-  string DBName() {
-    string filename;
-    MakeTempDir(&filename);
-    filename += "/db";
-    return filename;
-  }
-
-  Database::key_type TestKey() {
-    const char* kKey = "hello";
-    Database::key_type key(kKey, kKey + 5);
-    return key;
-  }
-
-  Database::value_type TestValue() {
-    const char* kValue = "world";
-    Database::value_type value(kValue, kValue + 5);
-    return value;
-  }
-
-  Database::key_type TestAltKey() {
-    const char* kKey = "foo";
-    Database::key_type key(kKey, kKey + 3);
-    return key;
-  }
-
-  Database::value_type TestAltValue() {
-    const char* kValue = "bar";
-    Database::value_type value(kValue, kValue + 3);
-    return value;
-  }
-
-  template <typename T>
-  bool BufferEq(const T& buf1, const T& buf2) {
-    if (buf1.size() != buf2.size()) {
-      return false;
-    }
-    for (size_t i = 0; i < buf1.size(); ++i) {
-      if (buf1.at(i) != buf2.at(i)) {
-        return false;
-      }
-    }
+namespace DatabaseTest_internal {
 
-    return true;
-  }
+template <typename T>
+struct TestData {
+  static T TestValue();
+  static T TestAltValue();
+  static bool equals(const T& a, const T& b);
 };
 
-TYPED_TEST_CASE(DatabaseTest, TestDtypesAndDevices);
-
-TYPED_TEST(DatabaseTest, TestNewDoesntExistLevelDBPasses) {
-  shared_ptr<Database> database = DatabaseFactory("leveldb");
-  EXPECT_TRUE(database->open(this->DBName(), Database::New));
-  database->close();
+template <>
+string TestData<string>::TestValue() {
+  return "world";
 }
 
-TYPED_TEST(DatabaseTest, TestNewExistsFailsLevelDB) {
-  string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("leveldb");
-  EXPECT_TRUE(database->open(name, Database::New));
-  database->close();
-
-  EXPECT_FALSE(database->open(name, Database::New));
+template <>
+string TestData<string>::TestAltValue() {
+  return "bar";
 }
 
-TYPED_TEST(DatabaseTest, TestReadOnlyExistsLevelDBPasses) {
-  string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("leveldb");
-  EXPECT_TRUE(database->open(name, Database::New));
-  database->close();
-
-  EXPECT_TRUE(database->open(name, Database::ReadOnly));
-  database->close();
+template <>
+bool TestData<string>::equals(const string& a, const string& b) {
+  return a == b;
 }
 
-TYPED_TEST(DatabaseTest, TestReadOnlyDoesntExistFailsLevelDB) {
-  string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("leveldb");
-  EXPECT_FALSE(database->open(name, Database::ReadOnly));
-}
-
-TYPED_TEST(DatabaseTest, TestReadWriteExistsLevelDBPasses) {
-  string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("leveldb");
-  EXPECT_TRUE(database->open(name, Database::New));
-  database->close();
-
-  EXPECT_TRUE(database->open(name, Database::ReadWrite));
-  database->close();
+template <>
+vector<char> TestData<vector<char> >::TestValue() {
+  string str = "world";
+  vector<char> val(str.data(), str.data() + str.size());
+  return val;
 }
 
-TYPED_TEST(DatabaseTest, TestReadWriteDoesntExistLevelDBPasses) {
-  string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("leveldb");
-  EXPECT_TRUE(database->open(name, Database::ReadWrite));
-  database->close();
+template <>
+vector<char> TestData<vector<char> >::TestAltValue() {
+  string str = "bar";
+  vector<char> val(str.data(), str.data() + str.size());
+  return val;
 }
 
-TYPED_TEST(DatabaseTest, TestKeysLevelDB) {
-  string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("leveldb");
-  EXPECT_TRUE(database->open(name, Database::New));
-
-  Database::key_type key1 = this->TestKey();
-  Database::value_type value1 = this->TestValue();
-
-  EXPECT_TRUE(database->put(key1, value1));
-
-  Database::key_type key2 = this->TestAltKey();
-  Database::value_type value2 = this->TestAltValue();
-
-  EXPECT_TRUE(database->put(key2, value2));
-
-  EXPECT_TRUE(database->commit());
-
-  vector<Database::key_type> keys;
-  database->keys(&keys);
-
-  EXPECT_EQ(2, keys.size());
-
-  EXPECT_TRUE(this->BufferEq(keys.at(0), key1) ||
-      this->BufferEq(keys.at(0), key2));
-  EXPECT_TRUE(this->BufferEq(keys.at(1), key1) ||
-      this->BufferEq(keys.at(2), key2));
-  EXPECT_FALSE(this->BufferEq(keys.at(0), keys.at(1)));
-}
-
-TYPED_TEST(DatabaseTest, TestKeysNoCommitLevelDB) {
-  string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("leveldb");
-  EXPECT_TRUE(database->open(name, Database::New));
-
-  Database::key_type key1 = this->TestKey();
-  Database::value_type value1 = this->TestValue();
-
-  EXPECT_TRUE(database->put(key1, value1));
-
-  Database::key_type key2 = this->TestAltKey();
-  Database::value_type value2 = this->TestAltValue();
-
-  EXPECT_TRUE(database->put(key2, value2));
-
-  vector<Database::key_type> keys;
-  database->keys(&keys);
-
-  EXPECT_EQ(0, keys.size());
-}
-
-TYPED_TEST(DatabaseTest, TestIteratorsLevelDB) {
-  string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("leveldb");
-  EXPECT_TRUE(database->open(name, Database::New));
-
-  const int kNumExamples = 4;
-  for (int i = 0; i < kNumExamples; ++i) {
-    stringstream ss;
-    ss << i;
-    string key = ss.str();
-    ss << " here be data";
-    string value = ss.str();
-    Database::key_type key_buf(key.data(), key.data() + key.size());
-    Database::value_type val_buf(value.data(), value.data() + value.size());
-    EXPECT_TRUE(database->put(key_buf, val_buf));
+template <>
+bool TestData<vector<char> >::equals(const vector<char>& a,
+    const vector<char>& b) {
+  if (a.size() != b.size()) {
+    return false;
   }
-  EXPECT_TRUE(database->commit());
-
-  int count = 0;
-  for (Database::const_iterator iter = database->begin();
-      iter != database->end(); ++iter) {
-    (void)iter;
-    ++count;
+  for (size_t i = 0; i < a.size(); ++i) {
+    if (a.at(i) != b.at(i)) {
+      return false;
+    }
   }
 
-  EXPECT_EQ(kNumExamples, count);
+  return true;
 }
 
-TYPED_TEST(DatabaseTest, TestIteratorsPreIncrementLevelDB) {
-  string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("leveldb");
-  EXPECT_TRUE(database->open(name, Database::New));
-
-  Database::key_type key1 = this->TestAltKey();
-  Database::value_type value1 = this->TestAltValue();
-
-  Database::key_type key2 = this->TestKey();
-  Database::value_type value2 = this->TestValue();
-
-  EXPECT_TRUE(database->put(key1, value1));
-  EXPECT_TRUE(database->put(key2, value2));
-  EXPECT_TRUE(database->commit());
-
-  Database::const_iterator iter1 = database->begin();
-
-  EXPECT_FALSE(database->end() == iter1);
-
-  EXPECT_TRUE(this->BufferEq(iter1->key, key1));
-
-  Database::const_iterator iter2 = ++iter1;
-
-  EXPECT_FALSE(database->end() == iter1);
-  EXPECT_FALSE(database->end() == iter2);
-
-  EXPECT_TRUE(this->BufferEq(iter2->key, key2));
-
-  Database::const_iterator iter3 = ++iter2;
-
-  EXPECT_TRUE(database->end() == iter3);
-
-  database->close();
+template <>
+Datum TestData<Datum>::TestValue() {
+  Datum datum;
+  datum.set_channels(3);
+  datum.set_height(32);
+  datum.set_width(32);
+  datum.set_data(string(32 * 32 * 3 * 4, ' '));
+  datum.set_label(0);
+  return datum;
 }
 
-TYPED_TEST(DatabaseTest, TestIteratorsPostIncrementLevelDB) {
-  string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("leveldb");
-  EXPECT_TRUE(database->open(name, Database::New));
-
-  Database::key_type key1 = this->TestAltKey();
-  Database::value_type value1 = this->TestAltValue();
-
-  Database::key_type key2 = this->TestKey();
-  Database::value_type value2 = this->TestValue();
-
-  EXPECT_TRUE(database->put(key1, value1));
-  EXPECT_TRUE(database->put(key2, value2));
-  EXPECT_TRUE(database->commit());
-
-  Database::const_iterator iter1 = database->begin();
-
-  EXPECT_FALSE(database->end() == iter1);
-
-  EXPECT_TRUE(this->BufferEq(iter1->key, key1));
-
-  Database::const_iterator iter2 = iter1++;
-
-  EXPECT_FALSE(database->end() == iter1);
-  EXPECT_FALSE(database->end() == iter2);
-
-  EXPECT_TRUE(this->BufferEq(iter2->key, key1));
-  EXPECT_TRUE(this->BufferEq(iter1->key, key2));
-
-  Database::const_iterator iter3 = iter1++;
-
-  EXPECT_FALSE(database->end() == iter3);
-  EXPECT_TRUE(this->BufferEq(iter3->key, key2));
-  EXPECT_TRUE(database->end() == iter1);
-
-  database->close();
-}
-
-TYPED_TEST(DatabaseTest, TestNewPutLevelDBPasses) {
-  string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("leveldb");
-  EXPECT_TRUE(database->open(name, Database::New));
-
-  Database::key_type key = this->TestKey();
-  Database::value_type val = this->TestValue();
-
-  EXPECT_TRUE(database->put(key, val));
-
-  EXPECT_TRUE(database->commit());
-
-  database->close();
+template <>
+Datum TestData<Datum>::TestAltValue() {
+  Datum datum;
+  datum.set_channels(1);
+  datum.set_height(64);
+  datum.set_width(64);
+  datum.set_data(string(64 * 64 * 1 * 4, ' '));
+  datum.set_label(1);
+  return datum;
 }
 
-TYPED_TEST(DatabaseTest, TestNewCommitLevelDBPasses) {
-  string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("leveldb");
-  EXPECT_TRUE(database->open(name, Database::New));
+template <>
+bool TestData<Datum>::equals(const Datum& a, const Datum& b) {
+  string serialized_a;
+  a.SerializeToString(&serialized_a);
 
-  EXPECT_TRUE(database->commit());
+  string serialized_b;
+  b.SerializeToString(&serialized_b);
 
-  database->close();
+  return serialized_a == serialized_b;
 }
 
-TYPED_TEST(DatabaseTest, TestNewGetLevelDBPasses) {
-  string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("leveldb");
-  EXPECT_TRUE(database->open(name, Database::New));
-
-  Database::key_type key = this->TestKey();
-  Database::value_type val = this->TestValue();
-
-  EXPECT_TRUE(database->put(key, val));
-
-  EXPECT_TRUE(database->commit());
-
-  Database::value_type new_val;
-
-  EXPECT_TRUE(database->get(key, &new_val));
-
-  EXPECT_TRUE(this->BufferEq(val, new_val));
-
-  database->close();
-}
-
-TYPED_TEST(DatabaseTest, TestNewGetNoCommitLevelDBFails) {
-  string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("leveldb");
-  EXPECT_TRUE(database->open(name, Database::New));
-
-  Database::key_type key = this->TestKey();
-  Database::value_type val = this->TestValue();
-
-  EXPECT_TRUE(database->put(key, val));
+}  // namespace DatabaseTest_internal
 
-  Database::value_type new_val;
-
-  EXPECT_FALSE(database->get(key, &new_val));
-}
+#define UNPACK_TYPES \
+  typedef typename TypeParam::value_type value_type; \
+  const DataParameter_DB backend = TypeParam::backend;
 
+template <typename TypeParam>
+class DatabaseTest : public ::testing::Test {
+ protected:
+  typedef typename TypeParam::value_type value_type;
 
-TYPED_TEST(DatabaseTest, TestReadWritePutLevelDBPasses) {
-  string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("leveldb");
-  EXPECT_TRUE(database->open(name, Database::ReadWrite));
+  string DBName() {
+    string filename;
+    MakeTempDir(&filename);
+    filename += "/db";
+    return filename;
+  }
 
-  Database::key_type key = this->TestKey();
-  Database::value_type val = this->TestValue();
+  string TestKey() {
+    return "hello";
+  }
 
-  EXPECT_TRUE(database->put(key, val));
+  value_type TestValue() {
+    return DatabaseTest_internal::TestData<value_type>::TestValue();
+  }
 
-  EXPECT_TRUE(database->commit());
+  string TestAltKey() {
+    return "foo";
+  }
 
-  database->close();
-}
+  value_type TestAltValue() {
+    return DatabaseTest_internal::TestData<value_type>::TestAltValue();
+  }
 
-TYPED_TEST(DatabaseTest, TestReadWriteCommitLevelDBPasses) {
-  string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("leveldb");
-  EXPECT_TRUE(database->open(name, Database::ReadWrite));
+  template <typename T>
+  bool equals(const T& a, const T& b) {
+    return DatabaseTest_internal::TestData<T>::equals(a, b);
+  }
+};
 
-  EXPECT_TRUE(database->commit());
+struct StringLeveldb {
+  typedef string value_type;
+  static const DataParameter_DB backend;
+};
+const DataParameter_DB StringLeveldb::backend = DataParameter_DB_LEVELDB;
 
-  database->close();
-}
+struct StringLmdb {
+  typedef string value_type;
+  static const DataParameter_DB backend;
+};
+const DataParameter_DB StringLmdb::backend = DataParameter_DB_LEVELDB;
 
-TYPED_TEST(DatabaseTest, TestReadWriteGetLevelDBPasses) {
-  string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("leveldb");
-  EXPECT_TRUE(database->open(name, Database::New));
+struct VectorLeveldb {
+  typedef vector<char> value_type;
+  static const DataParameter_DB backend;
+};
+const DataParameter_DB VectorLeveldb::backend = DataParameter_DB_LEVELDB;
 
-  Database::key_type key = this->TestKey();
-  Database::value_type val = this->TestValue();
+struct VectorLmdb {
+  typedef vector<char> value_type;
+  static const DataParameter_DB backend;
+};
+const DataParameter_DB VectorLmdb::backend = DataParameter_DB_LEVELDB;
 
-  EXPECT_TRUE(database->put(key, val));
+struct DatumLeveldb {
+  typedef Datum value_type;
+  static const DataParameter_DB backend;
+};
+const DataParameter_DB DatumLeveldb::backend = DataParameter_DB_LEVELDB;
 
-  EXPECT_TRUE(database->commit());
+struct DatumLmdb {
+  typedef Datum value_type;
+  static const DataParameter_DB backend;
+};
+const DataParameter_DB DatumLmdb::backend = DataParameter_DB_LEVELDB;
 
-  Database::value_type new_val;
+typedef ::testing::Types<StringLeveldb, StringLmdb, VectorLeveldb, VectorLmdb,
+    DatumLeveldb, DatumLmdb> TestTypes;
 
-  EXPECT_TRUE(database->get(key, &new_val));
+TYPED_TEST_CASE(DatabaseTest, TestTypes);
 
-  EXPECT_TRUE(this->BufferEq(val, new_val));
+TYPED_TEST(DatabaseTest, TestNewDoesntExistPasses) {
+  UNPACK_TYPES;
 
+  shared_ptr<Database<string, value_type> > database =
+      DatabaseFactory<string, value_type>(backend);
+  EXPECT_TRUE(database->open(this->DBName(),
+      Database<string, value_type>::New));
   database->close();
 }
 
-TYPED_TEST(DatabaseTest, TestReadWriteGetNoCommitLevelDBFails) {
-  string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("leveldb");
-  EXPECT_TRUE(database->open(name, Database::New));
-
-  Database::key_type key = this->TestKey();
-  Database::value_type val = this->TestValue();
-
-  EXPECT_TRUE(database->put(key, val));
-
-  Database::value_type new_val;
-
-  EXPECT_FALSE(database->get(key, &new_val));
-}
+TYPED_TEST(DatabaseTest, TestNewExistsFails) {
+  UNPACK_TYPES;
 
-TYPED_TEST(DatabaseTest, TestReadOnlyPutLevelDBFails) {
   string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("leveldb");
-  EXPECT_TRUE(database->open(name, Database::New));
+  shared_ptr<Database<string, value_type> > database =
+      DatabaseFactory<string, value_type>(backend);
+  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
   database->close();
 
-  EXPECT_TRUE(database->open(name, Database::ReadOnly));
-
-  Database::key_type key = this->TestKey();
-  Database::value_type val = this->TestValue();
-
-  EXPECT_FALSE(database->put(key, val));
+  EXPECT_FALSE(database->open(name, Database<string, value_type>::New));
 }
 
-TYPED_TEST(DatabaseTest, TestReadOnlyCommitLevelDBFails) {
-  string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("leveldb");
-  EXPECT_TRUE(database->open(name, Database::New));
-  database->close();
-
-  EXPECT_TRUE(database->open(name, Database::ReadOnly));
-
-  EXPECT_FALSE(database->commit());
-}
+TYPED_TEST(DatabaseTest, TestReadOnlyExistsPasses) {
+  UNPACK_TYPES;
 
-TYPED_TEST(DatabaseTest, TestReadOnlyGetLevelDBPasses) {
   string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("leveldb");
-  EXPECT_TRUE(database->open(name, Database::New));
-
-  Database::key_type key = this->TestKey();
-  Database::value_type val = this->TestValue();
-
-  EXPECT_TRUE(database->put(key, val));
-
-  EXPECT_TRUE(database->commit());
-
+  shared_ptr<Database<string, value_type> > database =
+      DatabaseFactory<string, value_type>(backend);
+  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
   database->close();
 
-  EXPECT_TRUE(database->open(name, Database::ReadOnly));
-
-  Database::value_type new_val;
-
-  EXPECT_TRUE(database->get(key, &new_val));
-
-  EXPECT_TRUE(this->BufferEq(val, new_val));
-}
-
-TYPED_TEST(DatabaseTest, TestReadOnlyGetNoCommitLevelDBFails) {
-  string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("leveldb");
-  EXPECT_TRUE(database->open(name, Database::New));
-
-  Database::key_type key = this->TestKey();
-  Database::value_type val = this->TestValue();
-
-  EXPECT_TRUE(database->put(key, val));
-
+  EXPECT_TRUE(database->open(name, Database<string, value_type>::ReadOnly));
   database->close();
-
-  EXPECT_TRUE(database->open(name, Database::ReadOnly));
-
-  Database::value_type new_val;
-
-  EXPECT_FALSE(database->get(key, &new_val));
 }
 
-TYPED_TEST(DatabaseTest, TestNewDoesntExistLMDBPasses) {
-  shared_ptr<Database> database = DatabaseFactory("lmdb");
-  EXPECT_TRUE(database->open(this->DBName(), Database::New));
-  database->close();
-}
+TYPED_TEST(DatabaseTest, TestReadOnlyDoesntExistFails) {
+  UNPACK_TYPES;
 
-TYPED_TEST(DatabaseTest, TestNewExistsFailsLMDB) {
   string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("lmdb");
-  EXPECT_TRUE(database->open(name, Database::New));
-  database->close();
-
-  EXPECT_FALSE(database->open(name, Database::New));
+  shared_ptr<Database<string, value_type> > database =
+      DatabaseFactory<string, value_type>(backend);
+  EXPECT_FALSE(database->open(name, Database<string, value_type>::ReadOnly));
 }
 
-TYPED_TEST(DatabaseTest, TestReadOnlyExistsLMDBPasses) {
+TYPED_TEST(DatabaseTest, TestReadWriteExistsPasses) {
+  UNPACK_TYPES;
+
   string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("lmdb");
-  EXPECT_TRUE(database->open(name, Database::New));
+  shared_ptr<Database<string, value_type> > database =
+      DatabaseFactory<string, value_type>(backend);
+  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
   database->close();
 
-  EXPECT_TRUE(database->open(name, Database::ReadOnly));
+  EXPECT_TRUE(database->open(name, Database<string, value_type>::ReadWrite));
   database->close();
 }
 
-TYPED_TEST(DatabaseTest, TestReadOnlyDoesntExistFailsLMDB) {
-  string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("lmdb");
-  EXPECT_FALSE(database->open(name, Database::ReadOnly));
-}
+TYPED_TEST(DatabaseTest, TestReadWriteDoesntExistPasses) {
+  UNPACK_TYPES;
 
-TYPED_TEST(DatabaseTest, TestReadWriteExistsLMDBPasses) {
   string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("lmdb");
-  EXPECT_TRUE(database->open(name, Database::New));
-  database->close();
-
-  EXPECT_TRUE(database->open(name, Database::ReadWrite));
+  shared_ptr<Database<string, value_type> > database =
+      DatabaseFactory<string, value_type>(backend);
+  EXPECT_TRUE(database->open(name, Database<string, value_type>::ReadWrite));
   database->close();
 }
 
-TYPED_TEST(DatabaseTest, TestReadWriteDoesntExistLMDBPasses) {
-  string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("lmdb");
-  EXPECT_TRUE(database->open(name, Database::ReadWrite));
-  database->close();
-}
+TYPED_TEST(DatabaseTest, TestKeys) {
+  UNPACK_TYPES;
 
-TYPED_TEST(DatabaseTest, TestKeysLMDB) {
   string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("lmdb");
-  EXPECT_TRUE(database->open(name, Database::New));
+  shared_ptr<Database<string, value_type> > database =
+      DatabaseFactory<string, value_type>(backend);
+  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
 
-  Database::key_type key1 = this->TestKey();
-  Database::value_type value1 = this->TestValue();
+  string key1 = this->TestKey();
+  value_type value1 = this->TestValue();
 
   EXPECT_TRUE(database->put(key1, value1));
 
-  Database::key_type key2 = this->TestAltKey();
-  Database::value_type value2 = this->TestAltValue();
+  string key2 = this->TestAltKey();
+  value_type value2 = this->TestAltValue();
 
   EXPECT_TRUE(database->put(key2, value2));
 
   EXPECT_TRUE(database->commit());
 
-  vector<Database::key_type> keys;
+  vector<string> keys;
   database->keys(&keys);
 
   EXPECT_EQ(2, keys.size());
 
-  EXPECT_TRUE(this->BufferEq(keys.at(0), key1) ||
-      this->BufferEq(keys.at(0), key2));
-  EXPECT_TRUE(this->BufferEq(keys.at(1), key1) ||
-      this->BufferEq(keys.at(2), key2));
-  EXPECT_FALSE(this->BufferEq(keys.at(0), keys.at(1)));
+  EXPECT_TRUE(this->equals(keys.at(0), key1) ||
+      this->equals(keys.at(0), key2));
+  EXPECT_TRUE(this->equals(keys.at(1), key1) ||
+      this->equals(keys.at(2), key2));
+  EXPECT_FALSE(this->equals(keys.at(0), keys.at(1)));
 }
 
-TYPED_TEST(DatabaseTest, TestKeysNoCommitLMDB) {
+TYPED_TEST(DatabaseTest, TestKeysNoCommit) {
+  UNPACK_TYPES;
+
   string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("lmdb");
-  EXPECT_TRUE(database->open(name, Database::New));
+  shared_ptr<Database<string, value_type> > database =
+      DatabaseFactory<string, value_type>(backend);
+  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
 
-  Database::key_type key1 = this->TestKey();
-  Database::value_type value1 = this->TestValue();
+  string key1 = this->TestKey();
+  value_type value1 = this->TestValue();
 
   EXPECT_TRUE(database->put(key1, value1));
 
-  Database::key_type key2 = this->TestAltKey();
-  Database::value_type value2 = this->TestAltValue();
+  string key2 = this->TestAltKey();
+  value_type value2 = this->TestAltValue();
 
   EXPECT_TRUE(database->put(key2, value2));
 
-  vector<Database::key_type> keys;
+  vector<string> keys;
   database->keys(&keys);
 
   EXPECT_EQ(0, keys.size());
 }
 
+TYPED_TEST(DatabaseTest, TestIterators) {
+  UNPACK_TYPES;
 
-TYPED_TEST(DatabaseTest, TestIteratorsLMDB) {
   string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("lmdb");
-  EXPECT_TRUE(database->open(name, Database::New));
+  shared_ptr<Database<string, value_type> > database =
+      DatabaseFactory<string, value_type>(backend);
+  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
 
   const int kNumExamples = 4;
   for (int i = 0; i < kNumExamples; ++i) {
@@ -563,16 +315,14 @@ TYPED_TEST(DatabaseTest, TestIteratorsLMDB) {
     ss << i;
     string key = ss.str();
     ss << " here be data";
-    string value = ss.str();
-    Database::key_type key_buf(key.data(), key.data() + key.size());
-    Database::value_type val_buf(value.data(), value.data() + value.size());
-    EXPECT_TRUE(database->put(key_buf, val_buf));
+    value_type value = this->TestValue();
+    EXPECT_TRUE(database->put(key, value));
   }
   EXPECT_TRUE(database->commit());
 
   int count = 0;
-  for (Database::const_iterator iter = database->begin();
-      iter != database->end(); ++iter) {
+  typedef typename Database<string, value_type>::const_iterator Iter;
+  for (Iter iter = database->begin(); iter != database->end(); ++iter) {
     (void)iter;
     ++count;
   }
@@ -580,266 +330,313 @@ TYPED_TEST(DatabaseTest, TestIteratorsLMDB) {
   EXPECT_EQ(kNumExamples, count);
 }
 
-TYPED_TEST(DatabaseTest, TestIteratorsPreIncrementLMDB) {
+TYPED_TEST(DatabaseTest, TestIteratorsPreIncrement) {
+  UNPACK_TYPES;
+
   string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("lmdb");
-  EXPECT_TRUE(database->open(name, Database::New));
+  shared_ptr<Database<string, value_type> > database =
+      DatabaseFactory<string, value_type>(backend);
+  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
 
-  Database::key_type key1 = this->TestAltKey();
-  Database::value_type value1 = this->TestAltValue();
+  string key1 = this->TestAltKey();
+  value_type value1 = this->TestAltValue();
 
-  Database::key_type key2 = this->TestKey();
-  Database::value_type value2 = this->TestValue();
+  string key2 = this->TestKey();
+  value_type value2 = this->TestValue();
 
   EXPECT_TRUE(database->put(key1, value1));
   EXPECT_TRUE(database->put(key2, value2));
   EXPECT_TRUE(database->commit());
 
-  Database::const_iterator iter1 = database->begin();
+  typename Database<string, value_type>::const_iterator iter1 =
+      database->begin();
 
   EXPECT_FALSE(database->end() == iter1);
 
-  EXPECT_TRUE(this->BufferEq(iter1->key, key1));
+  EXPECT_TRUE(this->equals(iter1->key, key1));
 
-  Database::const_iterator iter2 = ++iter1;
+  typename Database<string, value_type>::const_iterator iter2 = ++iter1;
 
   EXPECT_FALSE(database->end() == iter1);
   EXPECT_FALSE(database->end() == iter2);
 
-  EXPECT_TRUE(this->BufferEq(iter2->key, key2));
+  EXPECT_TRUE(this->equals(iter2->key, key2));
 
-  Database::const_iterator iter3 = ++iter2;
+  typename Database<string, value_type>::const_iterator iter3 = ++iter2;
 
   EXPECT_TRUE(database->end() == iter3);
 
   database->close();
 }
 
-TYPED_TEST(DatabaseTest, TestIteratorsPostIncrementLMDB) {
+TYPED_TEST(DatabaseTest, TestIteratorsPostIncrement) {
+  UNPACK_TYPES;
+
   string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("lmdb");
-  EXPECT_TRUE(database->open(name, Database::New));
+  shared_ptr<Database<string, value_type> > database =
+      DatabaseFactory<string, value_type>(backend);
+  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
 
-  Database::key_type key1 = this->TestAltKey();
-  Database::value_type value1 = this->TestAltValue();
+  string key1 = this->TestAltKey();
+  value_type value1 = this->TestAltValue();
 
-  Database::key_type key2 = this->TestKey();
-  Database::value_type value2 = this->TestValue();
+  string key2 = this->TestKey();
+  value_type value2 = this->TestValue();
 
   EXPECT_TRUE(database->put(key1, value1));
   EXPECT_TRUE(database->put(key2, value2));
   EXPECT_TRUE(database->commit());
 
-  Database::const_iterator iter1 = database->begin();
+  typename Database<string, value_type>::const_iterator iter1 =
+      database->begin();
 
   EXPECT_FALSE(database->end() == iter1);
 
-  EXPECT_TRUE(this->BufferEq(iter1->key, key1));
+  EXPECT_TRUE(this->equals(iter1->key, key1));
 
-  Database::const_iterator iter2 = iter1++;
+  typename Database<string, value_type>::const_iterator iter2 = iter1++;
 
   EXPECT_FALSE(database->end() == iter1);
   EXPECT_FALSE(database->end() == iter2);
 
-  EXPECT_TRUE(this->BufferEq(iter2->key, key1));
-  EXPECT_TRUE(this->BufferEq(iter1->key, key2));
+  EXPECT_TRUE(this->equals(iter2->key, key1));
+  EXPECT_TRUE(this->equals(iter1->key, key2));
 
-  Database::const_iterator iter3 = iter1++;
+  typename Database<string, value_type>::const_iterator iter3 = iter1++;
 
   EXPECT_FALSE(database->end() == iter3);
-  EXPECT_TRUE(this->BufferEq(iter3->key, key2));
+  EXPECT_TRUE(this->equals(iter3->key, key2));
   EXPECT_TRUE(database->end() == iter1);
 
   database->close();
 }
 
-TYPED_TEST(DatabaseTest, TestNewPutLMDBPasses) {
+TYPED_TEST(DatabaseTest, TestNewPutPasses) {
+  UNPACK_TYPES;
+
   string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("lmdb");
-  EXPECT_TRUE(database->open(name, Database::New));
+  shared_ptr<Database<string, value_type> > database =
+      DatabaseFactory<string, value_type>(backend);
+  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
 
-  Database::key_type key = this->TestKey();
-  Database::value_type val = this->TestValue();
+  string key = this->TestKey();
+  value_type value = this->TestValue();
 
-  EXPECT_TRUE(database->put(key, val));
+  EXPECT_TRUE(database->put(key, value));
 
   EXPECT_TRUE(database->commit());
 
   database->close();
 }
 
-TYPED_TEST(DatabaseTest, TestNewCommitLMDBPasses) {
+TYPED_TEST(DatabaseTest, TestNewCommitPasses) {
+  UNPACK_TYPES;
+
   string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("lmdb");
-  EXPECT_TRUE(database->open(name, Database::New));
+  shared_ptr<Database<string, value_type> > database =
+      DatabaseFactory<string, value_type>(backend);
+  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
 
   EXPECT_TRUE(database->commit());
 
   database->close();
 }
 
-TYPED_TEST(DatabaseTest, TestNewGetLMDBPasses) {
+TYPED_TEST(DatabaseTest, TestNewGetPasses) {
+  UNPACK_TYPES;
+
   string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("lmdb");
-  EXPECT_TRUE(database->open(name, Database::New));
+  shared_ptr<Database<string, value_type> > database =
+      DatabaseFactory<string, value_type>(backend);
+  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
 
-  Database::key_type key = this->TestKey();
-  Database::value_type val = this->TestValue();
+  string key = this->TestKey();
+  value_type value = this->TestValue();
 
-  EXPECT_TRUE(database->put(key, val));
+  EXPECT_TRUE(database->put(key, value));
 
   EXPECT_TRUE(database->commit());
 
-  Database::value_type new_val;
+  value_type new_value;
 
-  EXPECT_TRUE(database->get(key, &new_val));
+  EXPECT_TRUE(database->get(key, &new_value));
 
-  EXPECT_TRUE(this->BufferEq(val, new_val));
+  EXPECT_TRUE(this->equals(value, new_value));
 
   database->close();
 }
 
-TYPED_TEST(DatabaseTest, TestNewGetNoCommitLMDBFails) {
+TYPED_TEST(DatabaseTest, TestNewGetNoCommitFails) {
+  UNPACK_TYPES;
+
   string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("lmdb");
-  EXPECT_TRUE(database->open(name, Database::New));
+  shared_ptr<Database<string, value_type> > database =
+      DatabaseFactory<string, value_type>(backend);
+  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
 
-  Database::key_type key = this->TestKey();
-  Database::value_type val = this->TestValue();
+  string key = this->TestKey();
+  value_type value = this->TestValue();
 
-  EXPECT_TRUE(database->put(key, val));
+  EXPECT_TRUE(database->put(key, value));
 
-  Database::value_type new_val;
+  value_type new_value;
 
-  EXPECT_FALSE(database->get(key, &new_val));
+  EXPECT_FALSE(database->get(key, &new_value));
 }
 
-TYPED_TEST(DatabaseTest, TestReadWritePutLMDBPasses) {
+
+TYPED_TEST(DatabaseTest, TestReadWritePutPasses) {
+  UNPACK_TYPES;
+
   string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("lmdb");
-  EXPECT_TRUE(database->open(name, Database::ReadWrite));
+  shared_ptr<Database<string, value_type> > database =
+      DatabaseFactory<string, value_type>(backend);
+  EXPECT_TRUE(database->open(name, Database<string, value_type>::ReadWrite));
 
-  Database::key_type key = this->TestKey();
-  Database::value_type val = this->TestValue();
+  string key = this->TestKey();
+  value_type value = this->TestValue();
 
-  EXPECT_TRUE(database->put(key, val));
+  EXPECT_TRUE(database->put(key, value));
 
   EXPECT_TRUE(database->commit());
 
   database->close();
 }
 
-TYPED_TEST(DatabaseTest, TestReadWriteCommitLMDBPasses) {
+TYPED_TEST(DatabaseTest, TestReadWriteCommitPasses) {
+  UNPACK_TYPES;
+
   string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("lmdb");
-  EXPECT_TRUE(database->open(name, Database::ReadWrite));
+  shared_ptr<Database<string, value_type> > database =
+      DatabaseFactory<string, value_type>(backend);
+  EXPECT_TRUE(database->open(name, Database<string, value_type>::ReadWrite));
 
   EXPECT_TRUE(database->commit());
 
   database->close();
 }
 
-TYPED_TEST(DatabaseTest, TestReadWriteGetLMDBPasses) {
+TYPED_TEST(DatabaseTest, TestReadWriteGetPasses) {
+  UNPACK_TYPES;
+
   string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("lmdb");
-  EXPECT_TRUE(database->open(name, Database::New));
+  shared_ptr<Database<string, value_type> > database =
+      DatabaseFactory<string, value_type>(backend);
+  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
 
-  Database::key_type key = this->TestKey();
-  Database::value_type val = this->TestValue();
+  string key = this->TestKey();
+  value_type value = this->TestValue();
 
-  EXPECT_TRUE(database->put(key, val));
+  EXPECT_TRUE(database->put(key, value));
 
   EXPECT_TRUE(database->commit());
 
-  Database::value_type new_val;
+  value_type new_value;
 
-  EXPECT_TRUE(database->get(key, &new_val));
+  EXPECT_TRUE(database->get(key, &new_value));
 
-  EXPECT_TRUE(this->BufferEq(val, new_val));
+  EXPECT_TRUE(this->equals(value, new_value));
 
   database->close();
 }
 
-TYPED_TEST(DatabaseTest, TestReadWriteGetNoCommitLMDBFails) {
+TYPED_TEST(DatabaseTest, TestReadWriteGetNoCommitFails) {
+  UNPACK_TYPES;
+
   string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("lmdb");
-  EXPECT_TRUE(database->open(name, Database::New));
+  shared_ptr<Database<string, value_type> > database =
+      DatabaseFactory<string, value_type>(backend);
+  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
 
-  Database::key_type key = this->TestKey();
-  Database::value_type val = this->TestValue();
+  string key = this->TestKey();
+  value_type value = this->TestValue();
 
-  EXPECT_TRUE(database->put(key, val));
+  EXPECT_TRUE(database->put(key, value));
 
-  Database::value_type new_val;
+  value_type new_value;
 
-  EXPECT_FALSE(database->get(key, &new_val));
+  EXPECT_FALSE(database->get(key, &new_value));
 }
 
-TYPED_TEST(DatabaseTest, TestReadOnlyPutLMDBFails) {
+TYPED_TEST(DatabaseTest, TestReadOnlyPutFails) {
+  UNPACK_TYPES;
+
   string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("lmdb");
-  EXPECT_TRUE(database->open(name, Database::New));
+  shared_ptr<Database<string, value_type> > database =
+      DatabaseFactory<string, value_type>(backend);
+  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
   database->close();
 
-  EXPECT_TRUE(database->open(name, Database::ReadOnly));
+  EXPECT_TRUE(database->open(name, Database<string, value_type>::ReadOnly));
 
-  Database::key_type key = this->TestKey();
-  Database::value_type val = this->TestValue();
+  string key = this->TestKey();
+  value_type value = this->TestValue();
 
-  EXPECT_FALSE(database->put(key, val));
+  EXPECT_FALSE(database->put(key, value));
 }
 
-TYPED_TEST(DatabaseTest, TestReadOnlyCommitLMDBFails) {
+TYPED_TEST(DatabaseTest, TestReadOnlyCommitFails) {
+  UNPACK_TYPES;
+
   string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("lmdb");
-  EXPECT_TRUE(database->open(name, Database::New));
+  shared_ptr<Database<string, value_type> > database =
+      DatabaseFactory<string, value_type>(backend);
+  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
   database->close();
 
-  EXPECT_TRUE(database->open(name, Database::ReadOnly));
+  EXPECT_TRUE(database->open(name, Database<string, value_type>::ReadOnly));
 
   EXPECT_FALSE(database->commit());
 }
 
-TYPED_TEST(DatabaseTest, TestReadOnlyGetLMDBPasses) {
+TYPED_TEST(DatabaseTest, TestReadOnlyGetPasses) {
+  UNPACK_TYPES;
+
   string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("lmdb");
-  EXPECT_TRUE(database->open(name, Database::New));
+  shared_ptr<Database<string, value_type> > database =
+      DatabaseFactory<string, value_type>(backend);
+  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
 
-  Database::key_type key = this->TestKey();
-  Database::value_type val = this->TestValue();
+  string key = this->TestKey();
+  value_type value = this->TestValue();
 
-  EXPECT_TRUE(database->put(key, val));
+  EXPECT_TRUE(database->put(key, value));
 
   EXPECT_TRUE(database->commit());
 
   database->close();
 
-  EXPECT_TRUE(database->open(name, Database::ReadOnly));
+  EXPECT_TRUE(database->open(name, Database<string, value_type>::ReadOnly));
 
-  Database::value_type new_val;
+  value_type new_value;
 
-  EXPECT_TRUE(database->get(key, &new_val));
+  EXPECT_TRUE(database->get(key, &new_value));
 
-  EXPECT_TRUE(this->BufferEq(val, new_val));
+  EXPECT_TRUE(this->equals(value, new_value));
 }
 
-TYPED_TEST(DatabaseTest, TestReadOnlyGetNoCommitLMDBFails) {
+TYPED_TEST(DatabaseTest, TestReadOnlyGetNoCommitFails) {
+  UNPACK_TYPES;
+
   string name = this->DBName();
-  shared_ptr<Database> database = DatabaseFactory("lmdb");
-  EXPECT_TRUE(database->open(name, Database::New));
+  shared_ptr<Database<string, value_type> > database =
+      DatabaseFactory<string, value_type>(backend);
+  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
 
-  Database::key_type key = this->TestKey();
-  Database::value_type val = this->TestValue();
+  string key = this->TestKey();
+  value_type value = this->TestValue();
 
-  EXPECT_TRUE(database->put(key, val));
+  EXPECT_TRUE(database->put(key, value));
 
   database->close();
 
-  EXPECT_TRUE(database->open(name, Database::ReadOnly));
+  EXPECT_TRUE(database->open(name, Database<string, value_type>::ReadOnly));
 
-  Database::value_type new_val;
+  value_type new_value;
 
-  EXPECT_FALSE(database->get(key, &new_val));
+  EXPECT_FALSE(database->get(key, &new_value));
 }
 
+#undef UNPACK_TYPES
+
 }  // namespace caffe
index d13c4a0..f1a7967 100644 (file)
@@ -26,18 +26,17 @@ int main(int argc, char** argv) {
     db_backend = std::string(argv[3]);
   }
 
-  caffe::shared_ptr<Database> database = caffe::DatabaseFactory(db_backend);
+  caffe::shared_ptr<Database<std::string, Datum> > database =
+      caffe::DatabaseFactory<std::string, Datum>(db_backend);
 
   // Open db
-  CHECK(database->open(argv[1], Database::ReadOnly));
+  CHECK(database->open(argv[1], Database<std::string, Datum>::ReadOnly));
 
-  Datum datum;
   BlobProto sum_blob;
   int count = 0;
   // load first datum
-  Database::const_iterator iter = database->begin();
-  const Database::value_type& first_blob = iter->value;
-  datum.ParseFromArray(first_blob.data(), first_blob.size());
+  Database<std::string, Datum>::const_iterator iter = database->begin();
+  const Datum& datum = iter->value;
 
   sum_blob.set_num(1);
   sum_blob.set_channels(datum.channels());
@@ -50,11 +49,10 @@ int main(int argc, char** argv) {
     sum_blob.add_data(0.);
   }
   LOG(INFO) << "Starting Iteration";
-  for (Database::const_iterator iter = database->begin();
+  for (Database<std::string, Datum>::const_iterator iter = database->begin();
       iter != database->end(); ++iter) {
     // just a dummy operation
-    const Database::value_type& blob = iter->value;
-    datum.ParseFromArray(blob.data(), blob.size());
+    const Datum& datum = iter->value;
     const std::string& data = datum.data();
     size_in_datum = std::max<int>(datum.data().size(),
         datum.float_data_size());
index 1cdca7e..2ba3e3c 100644 (file)
@@ -78,10 +78,11 @@ int main(int argc, char** argv) {
   int resize_width = std::max<int>(0, FLAGS_resize_width);
 
   // Open new db
-  shared_ptr<Database> database = DatabaseFactory(db_backend);
+  shared_ptr<Database<string, Datum> > database =
+      DatabaseFactory<string, Datum>(db_backend);
 
   // Open db
-  CHECK(database->open(db_path, Database::New));
+  CHECK(database->open(db_path, Database<string, Datum>::New));
 
   // Storing to db
   std::string root_folder(argv[1]);
@@ -110,13 +111,9 @@ int main(int argc, char** argv) {
     // sequential
     int length = snprintf(key_cstr, kMaxKeyLength, "%08d_%s", line_id,
         lines[line_id].first.c_str());
-    Database::value_type value(datum.ByteSize());
-    datum.SerializeWithCachedSizesToArray(
-        reinterpret_cast<unsigned char*>(value.data()));
-    Database::key_type keystr(key_cstr, key_cstr + length);
 
     // Put in db
-    CHECK(database->put(keystr, value));
+    CHECK(database->put(string(key_cstr, length), datum));
 
     if (++count % 1000 == 0) {
       // Commit txn
index 1340192..47565a8 100644 (file)
@@ -121,11 +121,13 @@ int feature_extraction_pipeline(int argc, char** argv) {
 
   int num_mini_batches = atoi(argv[++arg_pos]);
 
-  std::vector<shared_ptr<Database> > feature_dbs;
+  std::vector<shared_ptr<Database<std::string, Datum> > > feature_dbs;
   for (size_t i = 0; i < num_features; ++i) {
     LOG(INFO)<< "Opening database " << database_names[i];
-    shared_ptr<Database> database = DatabaseFactory(argv[++arg_pos]);
-    CHECK(database->open(database_names.at(i), Database::New));
+    shared_ptr<Database<std::string, Datum> > database =
+        DatabaseFactory<std::string, Datum>(argv[++arg_pos]);
+    CHECK(database->open(database_names.at(i),
+        Database<std::string, Datum>::New));
     feature_dbs.push_back(database);
   }
 
@@ -155,13 +157,9 @@ int feature_extraction_pipeline(int argc, char** argv) {
         for (int d = 0; d < dim_features; ++d) {
           datum.add_float_data(feature_blob_data[d]);
         }
-        Database::value_type value(datum.ByteSize());
-        datum.SerializeWithCachedSizesToArray(
-            reinterpret_cast<unsigned char*>(value.data()));
         int length = snprintf(key_str, kMaxKeyStrLength, "%d",
             image_indices[i]);
-        Database::key_type key(key_str, key_str + length);
-        CHECK(feature_dbs.at(i)->put(key, value));
+        CHECK(feature_dbs.at(i)->put(std::string(key_str, length), datum));
         ++image_indices[i];
         if (image_indices[i] % 1000 == 0) {
           CHECK(feature_dbs.at(i)->commit());