#ifndef CAFFE_DATASET_H_
#define CAFFE_DATASET_H_
+#include <google/protobuf/message.h>
+
#include <algorithm>
#include <iterator>
#include <string>
#include <vector>
#include "caffe/common.hpp"
+#include "caffe/proto/caffe.pb.h"
namespace caffe {
namespace dataset_internal {
+using google::protobuf::Message;
+
+template<bool condition>
+struct static_assertion {};
+template<>
+struct static_assertion<true> {
+ enum {
+ DEFAULT_CODER_NOT_AVAILABLE
+ };
+};
+
template <typename T>
-struct Coder {
- static bool serialize(const T& obj, string* serialized) {
+struct DefaultCoder {
+ using static_assertion<sizeof(T) == 0>::DEFAULT_CODER_NOT_AVAILABLE;
+ static bool serialize(const T& obj, string* serialized);
+ static bool serialize(const T& obj, vector<char>* serialized);
+ static bool deserialize(const string& serialized, T* obj);
+ static bool deserialize(const char* data, size_t size, T* obj);
+};
+
+template <>
+struct DefaultCoder<Message> {
+ static bool serialize(const Message& obj, string* serialized) {
return obj.SerializeToString(serialized);
}
- static bool serialize(const T& obj, vector<char>* serialized) {
+ static bool serialize(const Message& obj, vector<char>* serialized) {
serialized->resize(obj.ByteSize());
return obj.SerializeWithCachedSizesToArray(
reinterpret_cast<unsigned char*>(serialized->data()));
}
- static bool deserialize(const string& serialized, T* obj) {
+ static bool deserialize(const string& serialized, Message* obj) {
return obj->ParseFromString(serialized);
}
- static bool deserialize(const char* data, size_t size, T* obj) {
+ static bool deserialize(const char* data, size_t size, Message* obj) {
return obj->ParseFromArray(data, size);
}
};
template <>
-struct Coder<string> {
+struct DefaultCoder<caffe::Datum> : public DefaultCoder<Message> { };
+
+template <>
+struct DefaultCoder<string> {
static bool serialize(string obj, string* serialized) {
*serialized = obj;
return true;
};
template <>
-struct Coder<vector<char> > {
+struct DefaultCoder<vector<char> > {
static bool serialize(vector<char> obj, string* serialized) {
string tmp(obj.data(), obj.size());
serialized->swap(tmp);
} // namespace dataset_internal
-template <typename K, typename V>
+template <typename K, typename V,
+ typename KCoder = dataset_internal::DefaultCoder<K>,
+ typename VCoder = dataset_internal::DefaultCoder<V> >
class Dataset {
public:
enum Mode {
virtual void increment(shared_ptr<DatasetState>* state) const = 0;
virtual KV& dereference(
shared_ptr<DatasetState> state) const = 0;
-
- template <typename T>
- static bool serialize(const T& obj, string* serialized) {
- return dataset_internal::Coder<T>::serialize(obj, serialized);
- }
-
- template <typename T>
- static bool serialize(const T& obj, vector<char>* serialized) {
- return dataset_internal::Coder<T>::serialize(obj, serialized);
- }
-
- template <typename T>
- static bool deserialize(const string& serialized, T* obj) {
- return dataset_internal::Coder<T>::deserialize(serialized, obj);
- }
-
- template <typename T>
- static bool deserialize(const char* data, size_t size, T* obj) {
- return dataset_internal::Coder<T>::deserialize(data, size, obj);
- }
};
} // namespace caffe
#define INSTANTIATE_DATASET(type) \
template class type<string, string>; \
template class type<string, vector<char> >; \
- template class type<string, Datum>;
+ template class type<string, caffe::Datum>;
#endif // CAFFE_DATASET_H_
namespace caffe {
-template <typename K, typename V>
-class LeveldbDataset : public Dataset<K, V> {
+template <typename K, typename V,
+ typename KCoder = dataset_internal::DefaultCoder<K>,
+ typename VCoder = dataset_internal::DefaultCoder<V> >
+class LeveldbDataset : public Dataset<K, V, KCoder, VCoder> {
public:
- typedef Dataset<K, V> Base;
+ typedef Dataset<K, V, KCoder, VCoder> Base;
typedef typename Base::key_type key_type;
typedef typename Base::value_type value_type;
typedef typename Base::DatasetState DatasetState;
namespace caffe {
-template <typename K, typename V>
-class LmdbDataset : public Dataset<K, V> {
+template <typename K, typename V,
+ typename KCoder = dataset_internal::DefaultCoder<K>,
+ typename VCoder = dataset_internal::DefaultCoder<V> >
+class LmdbDataset : public Dataset<K, V, KCoder, VCoder> {
public:
- typedef Dataset<K, V> Base;
+ typedef Dataset<K, V, KCoder, VCoder> Base;
typedef typename Base::key_type key_type;
typedef typename Base::value_type value_type;
typedef typename Base::DatasetState DatasetState;
namespace caffe {
-template <typename K, typename V>
-bool LeveldbDataset<K, V>::open(const string& filename, Mode mode) {
+template <typename K, typename V, typename KCoder, typename VCoder>
+bool LeveldbDataset<K, V, KCoder, VCoder>::open(const string& filename,
+ Mode mode) {
DLOG(INFO) << "LevelDB: Open " << filename;
leveldb::Options options;
return true;
}
-template <typename K, typename V>
-bool LeveldbDataset<K, V>::put(const K& key, const V& value) {
+template <typename K, typename V, typename KCoder, typename VCoder>
+bool LeveldbDataset<K, V, KCoder, VCoder>::put(const K& key, const V& value) {
DLOG(INFO) << "LevelDB: Put";
if (read_only_) {
CHECK_NOTNULL(batch_.get());
string serialized_key;
- if (!Base::serialize(key, &serialized_key)) {
+ if (!KCoder::serialize(key, &serialized_key)) {
return false;
}
string serialized_value;
- if (!Base::serialize(value, &serialized_value)) {
+ if (!VCoder::serialize(value, &serialized_value)) {
return false;
}
return true;
}
-template <typename K, typename V>
-bool LeveldbDataset<K, V>::get(const K& key, V* value) {
+template <typename K, typename V, typename KCoder, typename VCoder>
+bool LeveldbDataset<K, V, KCoder, VCoder>::get(const K& key, V* value) {
DLOG(INFO) << "LevelDB: Get";
string serialized_key;
- if (!Base::serialize(key, &serialized_key)) {
+ if (!KCoder::serialize(key, &serialized_key)) {
return false;
}
return false;
}
- if (!Base::deserialize(serialized_value, value)) {
+ if (!VCoder::deserialize(serialized_value, value)) {
return false;
}
return true;
}
-template <typename K, typename V>
-bool LeveldbDataset<K, V>::commit() {
+template <typename K, typename V, typename KCoder, typename VCoder>
+bool LeveldbDataset<K, V, KCoder, VCoder>::commit() {
DLOG(INFO) << "LevelDB: Commit";
if (read_only_) {
return status.ok();
}
-template <typename K, typename V>
-void LeveldbDataset<K, V>::close() {
+template <typename K, typename V, typename KCoder, typename VCoder>
+void LeveldbDataset<K, V, KCoder, VCoder>::close() {
DLOG(INFO) << "LevelDB: Close";
batch_.reset();
db_.reset();
}
-template <typename K, typename V>
-void LeveldbDataset<K, V>::keys(vector<K>* keys) {
+template <typename K, typename V, typename KCoder, typename VCoder>
+void LeveldbDataset<K, V, KCoder, VCoder>::keys(vector<K>* keys) {
DLOG(INFO) << "LevelDB: Keys";
keys->clear();
}
}
-template <typename K, typename V>
-typename LeveldbDataset<K, V>::const_iterator
- LeveldbDataset<K, V>::begin() const {
+template <typename K, typename V, typename KCoder, typename VCoder>
+typename LeveldbDataset<K, V, KCoder, VCoder>::const_iterator
+ LeveldbDataset<K, V, KCoder, VCoder>::begin() const {
CHECK_NOTNULL(db_.get());
shared_ptr<leveldb::Iterator> iter(db_->NewIterator(leveldb::ReadOptions()));
iter->SeekToFirst();
return const_iterator(this, state);
}
-template <typename K, typename V>
-typename LeveldbDataset<K, V>::const_iterator
- LeveldbDataset<K, V>::end() const {
+template <typename K, typename V, typename KCoder, typename VCoder>
+typename LeveldbDataset<K, V, KCoder, VCoder>::const_iterator
+ LeveldbDataset<K, V, KCoder, VCoder>::end() const {
shared_ptr<DatasetState> state;
return const_iterator(this, state);
}
-template <typename K, typename V>
-typename LeveldbDataset<K, V>::const_iterator
- LeveldbDataset<K, V>::cbegin() const {
+template <typename K, typename V, typename KCoder, typename VCoder>
+typename LeveldbDataset<K, V, KCoder, VCoder>::const_iterator
+ LeveldbDataset<K, V, KCoder, VCoder>::cbegin() const {
return begin();
}
-template <typename K, typename V>
-typename LeveldbDataset<K, V>::const_iterator
- LeveldbDataset<K, V>::cend() const { return end(); }
+template <typename K, typename V, typename KCoder, typename VCoder>
+typename LeveldbDataset<K, V, KCoder, VCoder>::const_iterator
+ LeveldbDataset<K, V, KCoder, VCoder>::cend() const { return end(); }
-template <typename K, typename V>
-bool LeveldbDataset<K, V>::equal(shared_ptr<DatasetState> state1,
- shared_ptr<DatasetState> state2) const {
+template <typename K, typename V, typename KCoder, typename VCoder>
+bool LeveldbDataset<K, V, KCoder, VCoder>::equal(
+ shared_ptr<DatasetState> state1, shared_ptr<DatasetState> state2) const {
shared_ptr<LeveldbState> leveldb_state1 =
boost::dynamic_pointer_cast<LeveldbState>(state1);
return !leveldb_state1 && !leveldb_state2;
}
-template <typename K, typename V>
-void LeveldbDataset<K, V>::increment(shared_ptr<DatasetState>* state) const {
+template <typename K, typename V, typename KCoder, typename VCoder>
+void LeveldbDataset<K, V, KCoder, VCoder>::increment(
+ shared_ptr<DatasetState>* state) const {
shared_ptr<LeveldbState> leveldb_state =
boost::dynamic_pointer_cast<LeveldbState>(*state);
}
}
-template <typename K, typename V>
-typename Dataset<K, V>::KV& LeveldbDataset<K, V>::dereference(
+template <typename K, typename V, typename KCoder, typename VCoder>
+typename Dataset<K, V, KCoder, VCoder>::KV&
+ LeveldbDataset<K, V, KCoder, VCoder>::dereference(
shared_ptr<DatasetState> state) const {
shared_ptr<LeveldbState> leveldb_state =
boost::dynamic_pointer_cast<LeveldbState>(state);
const leveldb::Slice& key = iter->key();
const leveldb::Slice& value = iter->value();
- CHECK(Base::deserialize(key.data(), key.size(),
+ CHECK(KCoder::deserialize(key.data(), key.size(),
&leveldb_state->kv_pair_.key));
- CHECK(Base::deserialize(value.data(), value.size(),
+ CHECK(VCoder::deserialize(value.data(), value.size(),
&leveldb_state->kv_pair_.value));
return leveldb_state->kv_pair_;
namespace caffe {
-template <typename K, typename V>
-bool LmdbDataset<K, V>::open(const string& filename, Mode mode) {
+template <typename K, typename V, typename KCoder, typename VCoder>
+bool LmdbDataset<K, V, KCoder, VCoder>::open(const string& filename,
+ Mode mode) {
DLOG(INFO) << "LMDB: Open " << filename;
CHECK(NULL == env_);
return true;
}
-template <typename K, typename V>
-bool LmdbDataset<K, V>::put(const K& key, const V& value) {
+template <typename K, typename V, typename KCoder, typename VCoder>
+bool LmdbDataset<K, V, KCoder, VCoder>::put(const K& key, const V& value) {
DLOG(INFO) << "LMDB: Put";
vector<char> serialized_key;
- if (!Base::serialize(key, &serialized_key)) {
+ if (!KCoder::serialize(key, &serialized_key)) {
LOG(ERROR) << "failed to serialize key";
return false;
}
vector<char> serialized_value;
- if (!Base::serialize(value, &serialized_value)) {
+ if (!VCoder::serialize(value, &serialized_value)) {
LOG(ERROR) << "failed to serialized value";
return false;
}
return true;
}
-template <typename K, typename V>
-bool LmdbDataset<K, V>::get(const K& key, V* value) {
+template <typename K, typename V, typename KCoder, typename VCoder>
+bool LmdbDataset<K, V, KCoder, VCoder>::get(const K& key, V* value) {
DLOG(INFO) << "LMDB: Get";
vector<char> serialized_key;
- if (!Base::serialize(key, &serialized_key)) {
+ if (!KCoder::serialize(key, &serialized_key)) {
LOG(ERROR) << "failed to serialized key";
return false;
}
mdb_txn_abort(get_txn);
- if (!Base::deserialize(reinterpret_cast<char*>(mdbdata.mv_data),
+ if (!VCoder::deserialize(reinterpret_cast<char*>(mdbdata.mv_data),
mdbdata.mv_size, value)) {
LOG(ERROR) << "failed to deserialize value";
return false;
return true;
}
-template <typename K, typename V>
-bool LmdbDataset<K, V>::commit() {
+template <typename K, typename V, typename KCoder, typename VCoder>
+bool LmdbDataset<K, V, KCoder, VCoder>::commit() {
DLOG(INFO) << "LMDB: Commit";
CHECK_NOTNULL(txn_);
return true;
}
-template <typename K, typename V>
-void LmdbDataset<K, V>::close() {
+template <typename K, typename V, typename KCoder, typename VCoder>
+void LmdbDataset<K, V, KCoder, VCoder>::close() {
DLOG(INFO) << "LMDB: Close";
if (env_ && dbi_) {
}
}
-template <typename K, typename V>
-void LmdbDataset<K, V>::keys(vector<K>* keys) {
+template <typename K, typename V, typename KCoder, typename VCoder>
+void LmdbDataset<K, V, KCoder, VCoder>::keys(vector<K>* keys) {
DLOG(INFO) << "LMDB: Keys";
keys->clear();
}
}
-template <typename K, typename V>
-typename LmdbDataset<K, V>::const_iterator
- LmdbDataset<K, V>::begin() const {
+template <typename K, typename V, typename KCoder, typename VCoder>
+typename LmdbDataset<K, V, KCoder, VCoder>::const_iterator
+ LmdbDataset<K, V, KCoder, VCoder>::begin() const {
int retval;
MDB_txn* iter_txn;
return const_iterator(this, state);
}
-template <typename K, typename V>
-typename LmdbDataset<K, V>::const_iterator
- LmdbDataset<K, V>::end() const {
+template <typename K, typename V, typename KCoder, typename VCoder>
+typename LmdbDataset<K, V, KCoder, VCoder>::const_iterator
+ LmdbDataset<K, V, KCoder, VCoder>::end() const {
shared_ptr<DatasetState> state;
return const_iterator(this, state);
}
-template <typename K, typename V>
-typename LmdbDataset<K, V>::const_iterator
- LmdbDataset<K, V>::cbegin() const { return begin(); }
+template <typename K, typename V, typename KCoder, typename VCoder>
+typename LmdbDataset<K, V, KCoder, VCoder>::const_iterator
+ LmdbDataset<K, V, KCoder, VCoder>::cbegin() const { return begin(); }
-template <typename K, typename V>
-typename LmdbDataset<K, V>::const_iterator
- LmdbDataset<K, V>::cend() const { return end(); }
+template <typename K, typename V, typename KCoder, typename VCoder>
+typename LmdbDataset<K, V, KCoder, VCoder>::const_iterator
+ LmdbDataset<K, V, KCoder, VCoder>::cend() const { return end(); }
-template <typename K, typename V>
-bool LmdbDataset<K, V>::equal(shared_ptr<DatasetState> state1,
+template <typename K, typename V, typename KCoder, typename VCoder>
+bool LmdbDataset<K, V, KCoder, VCoder>::equal(shared_ptr<DatasetState> state1,
shared_ptr<DatasetState> state2) const {
shared_ptr<LmdbState> lmdb_state1 =
boost::dynamic_pointer_cast<LmdbState>(state1);
return !lmdb_state1 && !lmdb_state2;
}
-template <typename K, typename V>
-void LmdbDataset<K, V>::increment(shared_ptr<DatasetState>* state) const {
+template <typename K, typename V, typename KCoder, typename VCoder>
+void LmdbDataset<K, V, KCoder, VCoder>::increment(
+ shared_ptr<DatasetState>* state) const {
shared_ptr<LmdbState> lmdb_state =
boost::dynamic_pointer_cast<LmdbState>(*state);
}
}
-template <typename K, typename V>
-typename Dataset<K, V>::KV& LmdbDataset<K, V>::dereference(
+template <typename K, typename V, typename KCoder, typename VCoder>
+typename Dataset<K, V, KCoder, VCoder>::KV&
+ LmdbDataset<K, V, KCoder, VCoder>::dereference(
shared_ptr<DatasetState> state) const {
shared_ptr<LmdbState> lmdb_state =
boost::dynamic_pointer_cast<LmdbState>(state);
int retval = mdb_cursor_get(cursor, &mdb_key, &mdb_val, MDB_GET_CURRENT);
CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval);
- CHECK(Base::deserialize(reinterpret_cast<char*>(mdb_key.mv_data),
+ CHECK(KCoder::deserialize(reinterpret_cast<char*>(mdb_key.mv_data),
mdb_key.mv_size, &lmdb_state->kv_pair_.key));
- CHECK(Base::deserialize(reinterpret_cast<char*>(mdb_val.mv_data),
+ CHECK(VCoder::deserialize(reinterpret_cast<char*>(mdb_val.mv_data),
mdb_val.mv_size, &lmdb_state->kv_pair_.value));
return lmdb_state->kv_pair_;