Reworked the Coder interface such that a Dataset now has both user-definable KCoder...
authorKevin James Matzen <kmatzen@cs.cornell.edu>
Tue, 14 Oct 2014 18:00:29 +0000 (14:00 -0400)
committerKevin James Matzen <kmatzen@cs.cornell.edu>
Tue, 14 Oct 2014 23:40:37 +0000 (19:40 -0400)
include/caffe/dataset.hpp
include/caffe/leveldb_dataset.hpp
include/caffe/lmdb_dataset.hpp
src/caffe/leveldb_dataset.cpp
src/caffe/lmdb_dataset.cpp

index 1f07aa9..efe3ffd 100644 (file)
@@ -1,6 +1,8 @@
 #ifndef CAFFE_DATASET_H_
 #define CAFFE_DATASET_H_
 
+#include <google/protobuf/message.h>
+
 #include <algorithm>
 #include <iterator>
 #include <string>
 #include <vector>
 
 #include "caffe/common.hpp"
+#include "caffe/proto/caffe.pb.h"
 
 namespace caffe {
 
 namespace dataset_internal {
 
+using google::protobuf::Message;
+
+template<bool condition>
+struct static_assertion {};
+template<>
+struct static_assertion<true> {
+  enum {
+    DEFAULT_CODER_NOT_AVAILABLE
+  };
+};
+
 template <typename T>
-struct Coder {
-  static bool serialize(const T& obj, string* serialized) {
+struct DefaultCoder {
+  using static_assertion<sizeof(T) == 0>::DEFAULT_CODER_NOT_AVAILABLE;
+  static bool serialize(const T& obj, string* serialized);
+  static bool serialize(const T& obj, vector<char>* serialized);
+  static bool deserialize(const string& serialized, T* obj);
+  static bool deserialize(const char* data, size_t size, T* obj);
+};
+
+template <>
+struct DefaultCoder<Message> {
+  static bool serialize(const Message& obj, string* serialized) {
     return obj.SerializeToString(serialized);
   }
 
-  static bool serialize(const T& obj, vector<char>* serialized) {
+  static bool serialize(const Message& obj, vector<char>* serialized) {
     serialized->resize(obj.ByteSize());
     return obj.SerializeWithCachedSizesToArray(
         reinterpret_cast<unsigned char*>(serialized->data()));
   }
 
-  static bool deserialize(const string& serialized, T* obj) {
+  static bool deserialize(const string& serialized, Message* obj) {
     return obj->ParseFromString(serialized);
   }
 
-  static bool deserialize(const char* data, size_t size, T* obj) {
+  static bool deserialize(const char* data, size_t size, Message* obj) {
     return obj->ParseFromArray(data, size);
   }
 };
 
 template <>
-struct Coder<string> {
+struct DefaultCoder<caffe::Datum> : public DefaultCoder<Message> { };
+
+template <>
+struct DefaultCoder<string> {
   static bool serialize(string obj, string* serialized) {
     *serialized = obj;
     return true;
@@ -60,7 +86,7 @@ struct Coder<string> {
 };
 
 template <>
-struct Coder<vector<char> > {
+struct DefaultCoder<vector<char> > {
   static bool serialize(vector<char> obj, string* serialized) {
     string tmp(obj.data(), obj.size());
     serialized->swap(tmp);
@@ -87,7 +113,9 @@ struct Coder<vector<char> > {
 
 }  // namespace dataset_internal
 
-template <typename K, typename V>
+template <typename K, typename V,
+          typename KCoder = dataset_internal::DefaultCoder<K>,
+          typename VCoder = dataset_internal::DefaultCoder<V> >
 class Dataset {
  public:
   enum Mode {
@@ -200,26 +228,6 @@ class Dataset {
   virtual void increment(shared_ptr<DatasetState>* state) const = 0;
   virtual KV& dereference(
       shared_ptr<DatasetState> state) const = 0;
-
-  template <typename T>
-  static bool serialize(const T& obj, string* serialized) {
-    return dataset_internal::Coder<T>::serialize(obj, serialized);
-  }
-
-  template <typename T>
-  static bool serialize(const T& obj, vector<char>* serialized) {
-    return dataset_internal::Coder<T>::serialize(obj, serialized);
-  }
-
-  template <typename T>
-  static bool deserialize(const string& serialized, T* obj) {
-    return dataset_internal::Coder<T>::deserialize(serialized, obj);
-  }
-
-  template <typename T>
-  static bool deserialize(const char* data, size_t size, T* obj) {
-    return dataset_internal::Coder<T>::deserialize(data, size, obj);
-  }
 };
 
 }  // namespace caffe
@@ -227,6 +235,6 @@ class Dataset {
 #define INSTANTIATE_DATASET(type) \
   template class type<string, string>; \
   template class type<string, vector<char> >; \
-  template class type<string, Datum>;
+  template class type<string, caffe::Datum>;
 
 #endif  // CAFFE_DATASET_H_
index eb354d9..90f92d9 100644 (file)
 
 namespace caffe {
 
-template <typename K, typename V>
-class LeveldbDataset : public Dataset<K, V> {
+template <typename K, typename V,
+          typename KCoder = dataset_internal::DefaultCoder<K>,
+          typename VCoder = dataset_internal::DefaultCoder<V> >
+class LeveldbDataset : public Dataset<K, V, KCoder, VCoder> {
  public:
-  typedef Dataset<K, V> Base;
+  typedef Dataset<K, V, KCoder, VCoder> Base;
   typedef typename Base::key_type key_type;
   typedef typename Base::value_type value_type;
   typedef typename Base::DatasetState DatasetState;
index 57da5de..8817df7 100644 (file)
 
 namespace caffe {
 
-template <typename K, typename V>
-class LmdbDataset : public Dataset<K, V> {
+template <typename K, typename V,
+          typename KCoder = dataset_internal::DefaultCoder<K>,
+          typename VCoder = dataset_internal::DefaultCoder<V> >
+class LmdbDataset : public Dataset<K, V, KCoder, VCoder> {
  public:
-  typedef Dataset<K, V> Base;
+  typedef Dataset<K, V, KCoder, VCoder> Base;
   typedef typename Base::key_type key_type;
   typedef typename Base::value_type value_type;
   typedef typename Base::DatasetState DatasetState;
index 41c71e4..af95604 100644 (file)
@@ -7,8 +7,9 @@
 
 namespace caffe {
 
-template <typename K, typename V>
-bool LeveldbDataset<K, V>::open(const string& filename, Mode mode) {
+template <typename K, typename V, typename KCoder, typename VCoder>
+bool LeveldbDataset<K, V, KCoder, VCoder>::open(const string& filename,
+    Mode mode) {
   DLOG(INFO) << "LevelDB: Open " << filename;
 
   leveldb::Options options;
@@ -54,8 +55,8 @@ bool LeveldbDataset<K, V>::open(const string& filename, Mode mode) {
   return true;
 }
 
-template <typename K, typename V>
-bool LeveldbDataset<K, V>::put(const K& key, const V& value) {
+template <typename K, typename V, typename KCoder, typename VCoder>
+bool LeveldbDataset<K, V, KCoder, VCoder>::put(const K& key, const V& value) {
   DLOG(INFO) << "LevelDB: Put";
 
   if (read_only_) {
@@ -66,12 +67,12 @@ bool LeveldbDataset<K, V>::put(const K& key, const V& value) {
   CHECK_NOTNULL(batch_.get());
 
   string serialized_key;
-  if (!Base::serialize(key, &serialized_key)) {
+  if (!KCoder::serialize(key, &serialized_key)) {
     return false;
   }
 
   string serialized_value;
-  if (!Base::serialize(value, &serialized_value)) {
+  if (!VCoder::serialize(value, &serialized_value)) {
     return false;
   }
 
@@ -80,12 +81,12 @@ bool LeveldbDataset<K, V>::put(const K& key, const V& value) {
   return true;
 }
 
-template <typename K, typename V>
-bool LeveldbDataset<K, V>::get(const K& key, V* value) {
+template <typename K, typename V, typename KCoder, typename VCoder>
+bool LeveldbDataset<K, V, KCoder, VCoder>::get(const K& key, V* value) {
   DLOG(INFO) << "LevelDB: Get";
 
   string serialized_key;
-  if (!Base::serialize(key, &serialized_key)) {
+  if (!KCoder::serialize(key, &serialized_key)) {
     return false;
   }
 
@@ -98,15 +99,15 @@ bool LeveldbDataset<K, V>::get(const K& key, V* value) {
     return false;
   }
 
-  if (!Base::deserialize(serialized_value, value)) {
+  if (!VCoder::deserialize(serialized_value, value)) {
     return false;
   }
 
   return true;
 }
 
-template <typename K, typename V>
-bool LeveldbDataset<K, V>::commit() {
+template <typename K, typename V, typename KCoder, typename VCoder>
+bool LeveldbDataset<K, V, KCoder, VCoder>::commit() {
   DLOG(INFO) << "LevelDB: Commit";
 
   if (read_only_) {
@@ -124,16 +125,16 @@ bool LeveldbDataset<K, V>::commit() {
   return status.ok();
 }
 
-template <typename K, typename V>
-void LeveldbDataset<K, V>::close() {
+template <typename K, typename V, typename KCoder, typename VCoder>
+void LeveldbDataset<K, V, KCoder, VCoder>::close() {
   DLOG(INFO) << "LevelDB: Close";
 
   batch_.reset();
   db_.reset();
 }
 
-template <typename K, typename V>
-void LeveldbDataset<K, V>::keys(vector<K>* keys) {
+template <typename K, typename V, typename KCoder, typename VCoder>
+void LeveldbDataset<K, V, KCoder, VCoder>::keys(vector<K>* keys) {
   DLOG(INFO) << "LevelDB: Keys";
 
   keys->clear();
@@ -142,9 +143,9 @@ void LeveldbDataset<K, V>::keys(vector<K>* keys) {
   }
 }
 
-template <typename K, typename V>
-typename LeveldbDataset<K, V>::const_iterator
-    LeveldbDataset<K, V>::begin() const {
+template <typename K, typename V, typename KCoder, typename VCoder>
+typename LeveldbDataset<K, V, KCoder, VCoder>::const_iterator
+    LeveldbDataset<K, V, KCoder, VCoder>::begin() const {
   CHECK_NOTNULL(db_.get());
   shared_ptr<leveldb::Iterator> iter(db_->NewIterator(leveldb::ReadOptions()));
   iter->SeekToFirst();
@@ -159,26 +160,26 @@ typename LeveldbDataset<K, V>::const_iterator
   return const_iterator(this, state);
 }
 
-template <typename K, typename V>
-typename LeveldbDataset<K, V>::const_iterator
-    LeveldbDataset<K, V>::end() const {
+template <typename K, typename V, typename KCoder, typename VCoder>
+typename LeveldbDataset<K, V, KCoder, VCoder>::const_iterator
+    LeveldbDataset<K, V, KCoder, VCoder>::end() const {
   shared_ptr<DatasetState> state;
   return const_iterator(this, state);
 }
 
-template <typename K, typename V>
-typename LeveldbDataset<K, V>::const_iterator
-    LeveldbDataset<K, V>::cbegin() const {
+template <typename K, typename V, typename KCoder, typename VCoder>
+typename LeveldbDataset<K, V, KCoder, VCoder>::const_iterator
+    LeveldbDataset<K, V, KCoder, VCoder>::cbegin() const {
   return begin();
 }
 
-template <typename K, typename V>
-typename LeveldbDataset<K, V>::const_iterator
-    LeveldbDataset<K, V>::cend() const { return end(); }
+template <typename K, typename V, typename KCoder, typename VCoder>
+typename LeveldbDataset<K, V, KCoder, VCoder>::const_iterator
+    LeveldbDataset<K, V, KCoder, VCoder>::cend() const { return end(); }
 
-template <typename K, typename V>
-bool LeveldbDataset<K, V>::equal(shared_ptr<DatasetState> state1,
-    shared_ptr<DatasetState> state2) const {
+template <typename K, typename V, typename KCoder, typename VCoder>
+bool LeveldbDataset<K, V, KCoder, VCoder>::equal(
+    shared_ptr<DatasetState> state1, shared_ptr<DatasetState> state2) const {
   shared_ptr<LeveldbState> leveldb_state1 =
       boost::dynamic_pointer_cast<LeveldbState>(state1);
 
@@ -191,8 +192,9 @@ bool LeveldbDataset<K, V>::equal(shared_ptr<DatasetState> state1,
   return !leveldb_state1 && !leveldb_state2;
 }
 
-template <typename K, typename V>
-void LeveldbDataset<K, V>::increment(shared_ptr<DatasetState>* state) const {
+template <typename K, typename V, typename KCoder, typename VCoder>
+void LeveldbDataset<K, V, KCoder, VCoder>::increment(
+    shared_ptr<DatasetState>* state) const {
   shared_ptr<LeveldbState> leveldb_state =
       boost::dynamic_pointer_cast<LeveldbState>(*state);
 
@@ -209,8 +211,9 @@ void LeveldbDataset<K, V>::increment(shared_ptr<DatasetState>* state) const {
   }
 }
 
-template <typename K, typename V>
-typename Dataset<K, V>::KV& LeveldbDataset<K, V>::dereference(
+template <typename K, typename V, typename KCoder, typename VCoder>
+typename Dataset<K, V, KCoder, VCoder>::KV&
+    LeveldbDataset<K, V, KCoder, VCoder>::dereference(
     shared_ptr<DatasetState> state) const {
   shared_ptr<LeveldbState> leveldb_state =
       boost::dynamic_pointer_cast<LeveldbState>(state);
@@ -225,9 +228,9 @@ typename Dataset<K, V>::KV& LeveldbDataset<K, V>::dereference(
 
   const leveldb::Slice& key = iter->key();
   const leveldb::Slice& value = iter->value();
-  CHECK(Base::deserialize(key.data(), key.size(),
+  CHECK(KCoder::deserialize(key.data(), key.size(),
       &leveldb_state->kv_pair_.key));
-  CHECK(Base::deserialize(value.data(), value.size(),
+  CHECK(VCoder::deserialize(value.data(), value.size(),
       &leveldb_state->kv_pair_.value));
 
   return leveldb_state->kv_pair_;
index d2b6fcd..ca96843 100644 (file)
@@ -9,8 +9,9 @@
 
 namespace caffe {
 
-template <typename K, typename V>
-bool LmdbDataset<K, V>::open(const string& filename, Mode mode) {
+template <typename K, typename V, typename KCoder, typename VCoder>
+bool LmdbDataset<K, V, KCoder, VCoder>::open(const string& filename,
+    Mode mode) {
   DLOG(INFO) << "LMDB: Open " << filename;
 
   CHECK(NULL == env_);
@@ -80,18 +81,18 @@ bool LmdbDataset<K, V>::open(const string& filename, Mode mode) {
   return true;
 }
 
-template <typename K, typename V>
-bool LmdbDataset<K, V>::put(const K& key, const V& value) {
+template <typename K, typename V, typename KCoder, typename VCoder>
+bool LmdbDataset<K, V, KCoder, VCoder>::put(const K& key, const V& value) {
   DLOG(INFO) << "LMDB: Put";
 
   vector<char> serialized_key;
-  if (!Base::serialize(key, &serialized_key)) {
+  if (!KCoder::serialize(key, &serialized_key)) {
     LOG(ERROR) << "failed to serialize key";
     return false;
   }
 
   vector<char> serialized_value;
-  if (!Base::serialize(value, &serialized_value)) {
+  if (!VCoder::serialize(value, &serialized_value)) {
     LOG(ERROR) << "failed to serialized value";
     return false;
   }
@@ -114,12 +115,12 @@ bool LmdbDataset<K, V>::put(const K& key, const V& value) {
   return true;
 }
 
-template <typename K, typename V>
-bool LmdbDataset<K, V>::get(const K& key, V* value) {
+template <typename K, typename V, typename KCoder, typename VCoder>
+bool LmdbDataset<K, V, KCoder, VCoder>::get(const K& key, V* value) {
   DLOG(INFO) << "LMDB: Get";
 
   vector<char> serialized_key;
-  if (!Base::serialize(key, &serialized_key)) {
+  if (!KCoder::serialize(key, &serialized_key)) {
     LOG(ERROR) << "failed to serialized key";
     return false;
   }
@@ -144,7 +145,7 @@ bool LmdbDataset<K, V>::get(const K& key, V* value) {
 
   mdb_txn_abort(get_txn);
 
-  if (!Base::deserialize(reinterpret_cast<char*>(mdbdata.mv_data),
+  if (!VCoder::deserialize(reinterpret_cast<char*>(mdbdata.mv_data),
       mdbdata.mv_size, value)) {
     LOG(ERROR) << "failed to deserialize value";
     return false;
@@ -153,8 +154,8 @@ bool LmdbDataset<K, V>::get(const K& key, V* value) {
   return true;
 }
 
-template <typename K, typename V>
-bool LmdbDataset<K, V>::commit() {
+template <typename K, typename V, typename KCoder, typename VCoder>
+bool LmdbDataset<K, V, KCoder, VCoder>::commit() {
   DLOG(INFO) << "LMDB: Commit";
 
   CHECK_NOTNULL(txn_);
@@ -175,8 +176,8 @@ bool LmdbDataset<K, V>::commit() {
   return true;
 }
 
-template <typename K, typename V>
-void LmdbDataset<K, V>::close() {
+template <typename K, typename V, typename KCoder, typename VCoder>
+void LmdbDataset<K, V, KCoder, VCoder>::close() {
   DLOG(INFO) << "LMDB: Close";
 
   if (env_ && dbi_) {
@@ -188,8 +189,8 @@ void LmdbDataset<K, V>::close() {
   }
 }
 
-template <typename K, typename V>
-void LmdbDataset<K, V>::keys(vector<K>* keys) {
+template <typename K, typename V, typename KCoder, typename VCoder>
+void LmdbDataset<K, V, KCoder, VCoder>::keys(vector<K>* keys) {
   DLOG(INFO) << "LMDB: Keys";
 
   keys->clear();
@@ -198,9 +199,9 @@ void LmdbDataset<K, V>::keys(vector<K>* keys) {
   }
 }
 
-template <typename K, typename V>
-typename LmdbDataset<K, V>::const_iterator
-    LmdbDataset<K, V>::begin() const {
+template <typename K, typename V, typename KCoder, typename VCoder>
+typename LmdbDataset<K, V, KCoder, VCoder>::const_iterator
+    LmdbDataset<K, V, KCoder, VCoder>::begin() const {
   int retval;
 
   MDB_txn* iter_txn;
@@ -226,23 +227,23 @@ typename LmdbDataset<K, V>::const_iterator
   return const_iterator(this, state);
 }
 
-template <typename K, typename V>
-typename LmdbDataset<K, V>::const_iterator
-    LmdbDataset<K, V>::end() const {
+template <typename K, typename V, typename KCoder, typename VCoder>
+typename LmdbDataset<K, V, KCoder, VCoder>::const_iterator
+    LmdbDataset<K, V, KCoder, VCoder>::end() const {
   shared_ptr<DatasetState> state;
   return const_iterator(this, state);
 }
 
-template <typename K, typename V>
-typename LmdbDataset<K, V>::const_iterator
-    LmdbDataset<K, V>::cbegin() const { return begin(); }
+template <typename K, typename V, typename KCoder, typename VCoder>
+typename LmdbDataset<K, V, KCoder, VCoder>::const_iterator
+    LmdbDataset<K, V, KCoder, VCoder>::cbegin() const { return begin(); }
 
-template <typename K, typename V>
-typename LmdbDataset<K, V>::const_iterator
-    LmdbDataset<K, V>::cend() const { return end(); }
+template <typename K, typename V, typename KCoder, typename VCoder>
+typename LmdbDataset<K, V, KCoder, VCoder>::const_iterator
+    LmdbDataset<K, V, KCoder, VCoder>::cend() const { return end(); }
 
-template <typename K, typename V>
-bool LmdbDataset<K, V>::equal(shared_ptr<DatasetState> state1,
+template <typename K, typename V, typename KCoder, typename VCoder>
+bool LmdbDataset<K, V, KCoder, VCoder>::equal(shared_ptr<DatasetState> state1,
     shared_ptr<DatasetState> state2) const {
   shared_ptr<LmdbState> lmdb_state1 =
       boost::dynamic_pointer_cast<LmdbState>(state1);
@@ -256,8 +257,9 @@ bool LmdbDataset<K, V>::equal(shared_ptr<DatasetState> state1,
   return !lmdb_state1 && !lmdb_state2;
 }
 
-template <typename K, typename V>
-void LmdbDataset<K, V>::increment(shared_ptr<DatasetState>* state) const {
+template <typename K, typename V, typename KCoder, typename VCoder>
+void LmdbDataset<K, V, KCoder, VCoder>::increment(
+    shared_ptr<DatasetState>* state) const {
   shared_ptr<LmdbState> lmdb_state =
       boost::dynamic_pointer_cast<LmdbState>(*state);
 
@@ -278,8 +280,9 @@ void LmdbDataset<K, V>::increment(shared_ptr<DatasetState>* state) const {
   }
 }
 
-template <typename K, typename V>
-typename Dataset<K, V>::KV& LmdbDataset<K, V>::dereference(
+template <typename K, typename V, typename KCoder, typename VCoder>
+typename Dataset<K, V, KCoder, VCoder>::KV&
+    LmdbDataset<K, V, KCoder, VCoder>::dereference(
     shared_ptr<DatasetState> state) const {
   shared_ptr<LmdbState> lmdb_state =
       boost::dynamic_pointer_cast<LmdbState>(state);
@@ -295,9 +298,9 @@ typename Dataset<K, V>::KV& LmdbDataset<K, V>::dereference(
   int retval = mdb_cursor_get(cursor, &mdb_key, &mdb_val, MDB_GET_CURRENT);
   CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval);
 
-  CHECK(Base::deserialize(reinterpret_cast<char*>(mdb_key.mv_data),
+  CHECK(KCoder::deserialize(reinterpret_cast<char*>(mdb_key.mv_data),
       mdb_key.mv_size, &lmdb_state->kv_pair_.key));
-  CHECK(Base::deserialize(reinterpret_cast<char*>(mdb_val.mv_data),
+  CHECK(VCoder::deserialize(reinterpret_cast<char*>(mdb_val.mv_data),
       mdb_val.mv_size, &lmdb_state->kv_pair_.value));
 
   return lmdb_state->kv_pair_;