From: Kevin James Matzen Date: Mon, 13 Oct 2014 18:54:48 +0000 (-0400) Subject: Renamed Database interface to Dataset. X-Git-Tag: submit/tizen/20180823.020014~572^2~109^2~2 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=d275f77332e1e4a1457a520ee3deaf4d126ddfdf;p=platform%2Fupstream%2Fcaffeonacl.git Renamed Database interface to Dataset. --- diff --git a/examples/cifar10/convert_cifar_data.cpp b/examples/cifar10/convert_cifar_data.cpp index f86f393..9eecc74 100644 --- a/examples/cifar10/convert_cifar_data.cpp +++ b/examples/cifar10/convert_cifar_data.cpp @@ -13,13 +13,13 @@ #include "google/protobuf/text_format.h" #include "stdint.h" -#include "caffe/database_factory.hpp" +#include "caffe/dataset_factory.hpp" #include "caffe/proto/caffe.pb.h" using std::string; -using caffe::Database; -using caffe::DatabaseFactory; +using caffe::Dataset; +using caffe::DatasetFactory; using caffe::Datum; using caffe::shared_ptr; @@ -38,10 +38,10 @@ void read_image(std::ifstream* file, int* label, char* buffer) { void convert_dataset(const string& input_folder, const string& output_folder, const string& db_type) { - shared_ptr > train_database = - DatabaseFactory(db_type); - CHECK(train_database->open(output_folder + "/cifar10_train_" + db_type, - Database::New)); + shared_ptr > train_dataset = + DatasetFactory(db_type); + CHECK(train_dataset->open(output_folder + "/cifar10_train_" + db_type, + Dataset::New)); // Data buffer int label; char str_buffer[kCIFARImageNBytes]; @@ -64,17 +64,17 @@ void convert_dataset(const string& input_folder, const string& output_folder, datum.set_data(str_buffer, kCIFARImageNBytes); int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d", fileid * kCIFARBatchSize + itemid); - CHECK(train_database->put(string(str_buffer, length), datum)); + CHECK(train_dataset->put(string(str_buffer, length), datum)); } } - CHECK(train_database->commit()); - train_database->close(); + CHECK(train_dataset->commit()); + train_dataset->close(); LOG(INFO) << "Writing Testing data"; - shared_ptr > test_database = - DatabaseFactory(db_type); - CHECK(test_database->open(output_folder + "/cifar10_test_" + db_type, - Database::New)); + shared_ptr > test_dataset = + DatasetFactory(db_type); + CHECK(test_dataset->open(output_folder + "/cifar10_test_" + db_type, + Dataset::New)); // Open files std::ifstream data_file((input_folder + "/test_batch.bin").c_str(), std::ios::in | std::ios::binary); @@ -84,10 +84,10 @@ void convert_dataset(const string& input_folder, const string& output_folder, datum.set_label(label); datum.set_data(str_buffer, kCIFARImageNBytes); int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d", itemid); - CHECK(test_database->put(string(str_buffer, length), datum)); + CHECK(test_dataset->put(string(str_buffer, length), datum)); } - CHECK(test_database->commit()); - test_database->close(); + CHECK(test_dataset->commit()); + test_dataset->close(); } int main(int argc, char** argv) { diff --git a/include/caffe/data_layers.hpp b/include/caffe/data_layers.hpp index 13ba4e6..a2ea854 100644 --- a/include/caffe/data_layers.hpp +++ b/include/caffe/data_layers.hpp @@ -11,7 +11,7 @@ #include "caffe/blob.hpp" #include "caffe/common.hpp" #include "caffe/data_transformer.hpp" -#include "caffe/database.hpp" +#include "caffe/dataset.hpp" #include "caffe/filler.hpp" #include "caffe/internal_thread.hpp" #include "caffe/layer.hpp" @@ -100,8 +100,8 @@ class DataLayer : public BasePrefetchingDataLayer { protected: virtual void InternalThreadEntry(); - shared_ptr > database_; - Database::const_iterator iter_; + shared_ptr > dataset_; + Dataset::const_iterator iter_; }; /** diff --git a/include/caffe/database_factory.hpp b/include/caffe/database_factory.hpp deleted file mode 100644 index 30f191e..0000000 --- a/include/caffe/database_factory.hpp +++ /dev/null @@ -1,20 +0,0 @@ -#ifndef CAFFE_DATABASE_FACTORY_H_ -#define CAFFE_DATABASE_FACTORY_H_ - -#include - -#include "caffe/common.hpp" -#include "caffe/database.hpp" -#include "caffe/proto/caffe.pb.h" - -namespace caffe { - -template -shared_ptr > DatabaseFactory(const DataParameter_DB& type); - -template -shared_ptr > DatabaseFactory(const string& type); - -} // namespace caffe - -#endif // CAFFE_DATABASE_FACTORY_H_ diff --git a/include/caffe/database.hpp b/include/caffe/dataset.hpp similarity index 81% rename from include/caffe/database.hpp rename to include/caffe/dataset.hpp index 711ebdd..1f07aa9 100644 --- a/include/caffe/database.hpp +++ b/include/caffe/dataset.hpp @@ -1,5 +1,5 @@ -#ifndef CAFFE_DATABASE_H_ -#define CAFFE_DATABASE_H_ +#ifndef CAFFE_DATASET_H_ +#define CAFFE_DATASET_H_ #include #include @@ -11,7 +11,7 @@ namespace caffe { -namespace database_internal { +namespace dataset_internal { template struct Coder { @@ -85,10 +85,10 @@ struct Coder > { } }; -} // namespace database_internal +} // namespace dataset_internal template -class Database { +class Dataset { public: enum Mode { New, @@ -112,8 +112,8 @@ class Database { virtual void keys(vector* keys) = 0; - Database() { } - virtual ~Database() { } + Dataset() { } + virtual ~Dataset() { } class iterator; typedef iterator const_iterator; @@ -124,7 +124,7 @@ class Database { virtual const_iterator cend() const = 0; protected: - class DatabaseState; + class DatasetState; public: class iterator : public std::iterator { @@ -137,7 +137,7 @@ class Database { iterator() : parent_(NULL) { } - iterator(const Database* parent, shared_ptr state) + iterator(const Dataset* parent, shared_ptr state) : parent_(parent), state_(state) { } ~iterator() { } @@ -145,7 +145,7 @@ class Database { iterator(const iterator& other) : parent_(other.parent_), state_(other.state_ ? other.state_->clone() - : shared_ptr()) { } + : shared_ptr()) { } iterator& operator=(iterator copy) { copy.swap(*this); @@ -184,49 +184,49 @@ class Database { } protected: - const Database* parent_; - shared_ptr state_; + const Dataset* parent_; + shared_ptr state_; }; protected: - class DatabaseState { + class DatasetState { public: - virtual ~DatabaseState() { } - virtual shared_ptr clone() = 0; + virtual ~DatasetState() { } + virtual shared_ptr clone() = 0; }; - virtual bool equal(shared_ptr state1, - shared_ptr state2) const = 0; - virtual void increment(shared_ptr* state) const = 0; + virtual bool equal(shared_ptr state1, + shared_ptr state2) const = 0; + virtual void increment(shared_ptr* state) const = 0; virtual KV& dereference( - shared_ptr state) const = 0; + shared_ptr state) const = 0; template static bool serialize(const T& obj, string* serialized) { - return database_internal::Coder::serialize(obj, serialized); + return dataset_internal::Coder::serialize(obj, serialized); } template static bool serialize(const T& obj, vector* serialized) { - return database_internal::Coder::serialize(obj, serialized); + return dataset_internal::Coder::serialize(obj, serialized); } template static bool deserialize(const string& serialized, T* obj) { - return database_internal::Coder::deserialize(serialized, obj); + return dataset_internal::Coder::deserialize(serialized, obj); } template static bool deserialize(const char* data, size_t size, T* obj) { - return database_internal::Coder::deserialize(data, size, obj); + return dataset_internal::Coder::deserialize(data, size, obj); } }; } // namespace caffe -#define INSTANTIATE_DATABASE(type) \ +#define INSTANTIATE_DATASET(type) \ template class type; \ template class type >; \ template class type; -#endif // CAFFE_DATABASE_H_ +#endif // CAFFE_DATASET_H_ diff --git a/include/caffe/dataset_factory.hpp b/include/caffe/dataset_factory.hpp new file mode 100644 index 0000000..57db49b --- /dev/null +++ b/include/caffe/dataset_factory.hpp @@ -0,0 +1,20 @@ +#ifndef CAFFE_DATASET_FACTORY_H_ +#define CAFFE_DATASET_FACTORY_H_ + +#include + +#include "caffe/common.hpp" +#include "caffe/dataset.hpp" +#include "caffe/proto/caffe.pb.h" + +namespace caffe { + +template +shared_ptr > DatasetFactory(const DataParameter_DB& type); + +template +shared_ptr > DatasetFactory(const string& type); + +} // namespace caffe + +#endif // CAFFE_DATASET_FACTORY_H_ diff --git a/include/caffe/leveldb_database.hpp b/include/caffe/leveldb_dataset.hpp similarity index 70% rename from include/caffe/leveldb_database.hpp rename to include/caffe/leveldb_dataset.hpp index cd966c4..eb354d9 100644 --- a/include/caffe/leveldb_database.hpp +++ b/include/caffe/leveldb_dataset.hpp @@ -1,5 +1,5 @@ -#ifndef CAFFE_LEVELDB_DATABASE_H_ -#define CAFFE_LEVELDB_DATABASE_H_ +#ifndef CAFFE_LEVELDB_DATASET_H_ +#define CAFFE_LEVELDB_DATASET_H_ #include #include @@ -9,17 +9,17 @@ #include #include "caffe/common.hpp" -#include "caffe/database.hpp" +#include "caffe/dataset.hpp" namespace caffe { template -class LeveldbDatabase : public Database { +class LeveldbDataset : public Dataset { public: - typedef Database Base; + typedef Dataset Base; typedef typename Base::key_type key_type; typedef typename Base::value_type value_type; - typedef typename Base::DatabaseState DatabaseState; + typedef typename Base::DatasetState DatasetState; typedef typename Base::Mode Mode; typedef typename Base::const_iterator const_iterator; typedef typename Base::KV KV; @@ -38,11 +38,11 @@ class LeveldbDatabase : public Database { const_iterator cend() const; protected: - class LeveldbState : public DatabaseState { + class LeveldbState : public DatasetState { public: explicit LeveldbState(shared_ptr db, shared_ptr iter) - : DatabaseState(), + : DatasetState(), db_(db), iter_(iter) { } @@ -54,7 +54,7 @@ class LeveldbDatabase : public Database { db_.reset(); } - shared_ptr clone() { + shared_ptr clone() { shared_ptr new_iter; if (iter_.get()) { @@ -64,7 +64,7 @@ class LeveldbDatabase : public Database { CHECK(new_iter->Valid()); } - return shared_ptr(new LeveldbState(db_, new_iter)); + return shared_ptr(new LeveldbState(db_, new_iter)); } shared_ptr db_; @@ -72,10 +72,10 @@ class LeveldbDatabase : public Database { KV kv_pair_; }; - bool equal(shared_ptr state1, - shared_ptr state2) const; - void increment(shared_ptr* state) const; - KV& dereference(shared_ptr state) const; + bool equal(shared_ptr state1, + shared_ptr state2) const; + void increment(shared_ptr* state) const; + KV& dereference(shared_ptr state) const; shared_ptr db_; shared_ptr batch_; @@ -84,4 +84,4 @@ class LeveldbDatabase : public Database { } // namespace caffe -#endif // CAFFE_LEVELDB_DATABASE_H_ +#endif // CAFFE_LEVELDB_DATASET_H_ diff --git a/include/caffe/lmdb_database.hpp b/include/caffe/lmdb_dataset.hpp similarity index 71% rename from include/caffe/lmdb_database.hpp rename to include/caffe/lmdb_dataset.hpp index b5a02ca..57da5de 100644 --- a/include/caffe/lmdb_database.hpp +++ b/include/caffe/lmdb_dataset.hpp @@ -1,5 +1,5 @@ -#ifndef CAFFE_LMDB_DATABASE_H_ -#define CAFFE_LMDB_DATABASE_H_ +#ifndef CAFFE_LMDB_DATASET_H_ +#define CAFFE_LMDB_DATASET_H_ #include #include @@ -8,22 +8,22 @@ #include "lmdb.h" #include "caffe/common.hpp" -#include "caffe/database.hpp" +#include "caffe/dataset.hpp" namespace caffe { template -class LmdbDatabase : public Database { +class LmdbDataset : public Dataset { public: - typedef Database Base; + typedef Dataset Base; typedef typename Base::key_type key_type; typedef typename Base::value_type value_type; - typedef typename Base::DatabaseState DatabaseState; + typedef typename Base::DatasetState DatasetState; typedef typename Base::Mode Mode; typedef typename Base::const_iterator const_iterator; typedef typename Base::KV KV; - LmdbDatabase() + LmdbDataset() : env_(NULL), dbi_(0), txn_(NULL) { } @@ -42,15 +42,15 @@ class LmdbDatabase : public Database { const_iterator cend() const; protected: - class LmdbState : public DatabaseState { + class LmdbState : public DatasetState { public: explicit LmdbState(MDB_cursor* cursor, MDB_txn* txn, const MDB_dbi* dbi) - : DatabaseState(), + : DatasetState(), cursor_(cursor), txn_(txn), dbi_(dbi) { } - shared_ptr clone() { + shared_ptr clone() { MDB_cursor* new_cursor; if (cursor_) { @@ -67,7 +67,7 @@ class LmdbDatabase : public Database { new_cursor = cursor_; } - return shared_ptr(new LmdbState(new_cursor, txn_, dbi_)); + return shared_ptr(new LmdbState(new_cursor, txn_, dbi_)); } MDB_cursor* cursor_; @@ -76,10 +76,10 @@ class LmdbDatabase : public Database { KV kv_pair_; }; - bool equal(shared_ptr state1, - shared_ptr state2) const; - void increment(shared_ptr* state) const; - KV& dereference(shared_ptr state) const; + bool equal(shared_ptr state1, + shared_ptr state2) const; + void increment(shared_ptr* state) const; + KV& dereference(shared_ptr state) const; MDB_env* env_; MDB_dbi dbi_; @@ -88,4 +88,4 @@ class LmdbDatabase : public Database { } // namespace caffe -#endif // CAFFE_LMDB_DATABASE_H_ +#endif // CAFFE_LMDB_DATASET_H_ diff --git a/src/caffe/database_factory.cpp b/src/caffe/database_factory.cpp deleted file mode 100644 index 4ccd429..0000000 --- a/src/caffe/database_factory.cpp +++ /dev/null @@ -1,50 +0,0 @@ -#include -#include -#include - -#include "caffe/database_factory.hpp" -#include "caffe/leveldb_database.hpp" -#include "caffe/lmdb_database.hpp" - -namespace caffe { - -template -shared_ptr > DatabaseFactory(const DataParameter_DB& type) { - switch (type) { - case DataParameter_DB_LEVELDB: - return shared_ptr >(new LeveldbDatabase()); - case DataParameter_DB_LMDB: - return shared_ptr >(new LmdbDatabase()); - default: - LOG(FATAL) << "Unknown database type " << type; - return shared_ptr >(); - } -} - -template -shared_ptr > DatabaseFactory(const string& type) { - if ("leveldb" == type) { - return DatabaseFactory(DataParameter_DB_LEVELDB); - } else if ("lmdb" == type) { - return DatabaseFactory(DataParameter_DB_LMDB); - } else { - LOG(FATAL) << "Unknown database type " << type; - return shared_ptr >(); - } -} - -#define REGISTER_DATABASE(key_type, value_type) \ - template shared_ptr > \ - DatabaseFactory(const string& type); \ - template shared_ptr > \ - DatabaseFactory(const DataParameter_DB& type); \ - -REGISTER_DATABASE(string, string); -REGISTER_DATABASE(string, vector); -REGISTER_DATABASE(string, Datum); - -#undef REGISTER_DATABASE - -} // namespace caffe - - diff --git a/src/caffe/dataset_factory.cpp b/src/caffe/dataset_factory.cpp new file mode 100644 index 0000000..3313de3 --- /dev/null +++ b/src/caffe/dataset_factory.cpp @@ -0,0 +1,50 @@ +#include +#include +#include + +#include "caffe/dataset_factory.hpp" +#include "caffe/leveldb_dataset.hpp" +#include "caffe/lmdb_dataset.hpp" + +namespace caffe { + +template +shared_ptr > DatasetFactory(const DataParameter_DB& type) { + switch (type) { + case DataParameter_DB_LEVELDB: + return shared_ptr >(new LeveldbDataset()); + case DataParameter_DB_LMDB: + return shared_ptr >(new LmdbDataset()); + default: + LOG(FATAL) << "Unknown dataset type " << type; + return shared_ptr >(); + } +} + +template +shared_ptr > DatasetFactory(const string& type) { + if ("leveldb" == type) { + return DatasetFactory(DataParameter_DB_LEVELDB); + } else if ("lmdb" == type) { + return DatasetFactory(DataParameter_DB_LMDB); + } else { + LOG(FATAL) << "Unknown dataset type " << type; + return shared_ptr >(); + } +} + +#define REGISTER_DATASET(key_type, value_type) \ + template shared_ptr > \ + DatasetFactory(const string& type); \ + template shared_ptr > \ + DatasetFactory(const DataParameter_DB& type); \ + +REGISTER_DATASET(string, string); +REGISTER_DATASET(string, vector); +REGISTER_DATASET(string, Datum); + +#undef REGISTER_DATASET + +} // namespace caffe + + diff --git a/src/caffe/layers/data_layer.cpp b/src/caffe/layers/data_layer.cpp index 6296335..fcf9bb2 100644 --- a/src/caffe/layers/data_layer.cpp +++ b/src/caffe/layers/data_layer.cpp @@ -5,7 +5,7 @@ #include "caffe/common.hpp" #include "caffe/data_layers.hpp" -#include "caffe/database_factory.hpp" +#include "caffe/dataset_factory.hpp" #include "caffe/layer.hpp" #include "caffe/proto/caffe.pb.h" #include "caffe/util/io.hpp" @@ -17,20 +17,20 @@ namespace caffe { template DataLayer::~DataLayer() { this->JoinPrefetchThread(); - // clean up the database resources - database_->close(); + // clean up the dataset resources + dataset_->close(); } template void DataLayer::DataLayerSetUp(const vector*>& bottom, const vector*>& top) { // Initialize DB - database_ = DatabaseFactory( + dataset_ = DatasetFactory( this->layer_param_.data_param().backend()); const string& source = this->layer_param_.data_param().source(); - LOG(INFO) << "Opening database " << source; - CHECK(database_->open(source, Database::ReadOnly)); - iter_ = database_->begin(); + LOG(INFO) << "Opening dataset " << source; + CHECK(dataset_->open(source, Dataset::ReadOnly)); + iter_ = dataset_->begin(); // Check if we would need to randomly skip a few data points if (this->layer_param_.data_param().rand_skip()) { @@ -38,13 +38,13 @@ void DataLayer::DataLayerSetUp(const vector*>& bottom, this->layer_param_.data_param().rand_skip(); LOG(INFO) << "Skipping first " << skip << " data points."; while (skip-- > 0) { - if (++iter_ == database_->end()) { - iter_ = database_->begin(); + if (++iter_ == dataset_->end()) { + iter_ = dataset_->begin(); } } } // Read a data point, and use it to initialize the top blob. - CHECK(iter_ != database_->end()); + CHECK(iter_ != dataset_->end()); const Datum& datum = iter_->value; // image @@ -89,7 +89,7 @@ void DataLayer::InternalThreadEntry() { const int batch_size = this->layer_param_.data_param().batch_size(); for (int item_id = 0; item_id < batch_size; ++item_id) { - CHECK(iter_ != database_->end()); + CHECK(iter_ != dataset_->end()); const Datum& datum = iter_->value; // Apply data transformations (mirror, scale, crop...) @@ -102,8 +102,8 @@ void DataLayer::InternalThreadEntry() { // go to the next iter ++iter_; - if (iter_ == database_->end()) { - iter_ = database_->begin(); + if (iter_ == dataset_->end()) { + iter_ = dataset_->begin(); } } } diff --git a/src/caffe/leveldb_database.cpp b/src/caffe/leveldb_dataset.cpp similarity index 78% rename from src/caffe/leveldb_database.cpp rename to src/caffe/leveldb_dataset.cpp index 267ddf0..41c71e4 100644 --- a/src/caffe/leveldb_database.cpp +++ b/src/caffe/leveldb_dataset.cpp @@ -3,12 +3,12 @@ #include #include "caffe/caffe.hpp" -#include "caffe/leveldb_database.hpp" +#include "caffe/leveldb_dataset.hpp" namespace caffe { template -bool LeveldbDatabase::open(const string& filename, Mode mode) { +bool LeveldbDataset::open(const string& filename, Mode mode) { DLOG(INFO) << "LevelDB: Open " << filename; leveldb::Options options; @@ -55,11 +55,11 @@ bool LeveldbDatabase::open(const string& filename, Mode mode) { } template -bool LeveldbDatabase::put(const K& key, const V& value) { +bool LeveldbDataset::put(const K& key, const V& value) { DLOG(INFO) << "LevelDB: Put"; if (read_only_) { - LOG(ERROR) << "put can not be used on a database in ReadOnly mode"; + LOG(ERROR) << "put can not be used on a dataset in ReadOnly mode"; return false; } @@ -81,7 +81,7 @@ bool LeveldbDatabase::put(const K& key, const V& value) { } template -bool LeveldbDatabase::get(const K& key, V* value) { +bool LeveldbDataset::get(const K& key, V* value) { DLOG(INFO) << "LevelDB: Get"; string serialized_key; @@ -106,11 +106,11 @@ bool LeveldbDatabase::get(const K& key, V* value) { } template -bool LeveldbDatabase::commit() { +bool LeveldbDataset::commit() { DLOG(INFO) << "LevelDB: Commit"; if (read_only_) { - LOG(ERROR) << "commit can not be used on a database in ReadOnly mode"; + LOG(ERROR) << "commit can not be used on a dataset in ReadOnly mode"; return false; } @@ -125,7 +125,7 @@ bool LeveldbDatabase::commit() { } template -void LeveldbDatabase::close() { +void LeveldbDataset::close() { DLOG(INFO) << "LevelDB: Close"; batch_.reset(); @@ -133,7 +133,7 @@ void LeveldbDatabase::close() { } template -void LeveldbDatabase::keys(vector* keys) { +void LeveldbDataset::keys(vector* keys) { DLOG(INFO) << "LevelDB: Keys"; keys->clear(); @@ -143,8 +143,8 @@ void LeveldbDatabase::keys(vector* keys) { } template -typename LeveldbDatabase::const_iterator - LeveldbDatabase::begin() const { +typename LeveldbDataset::const_iterator + LeveldbDataset::begin() const { CHECK_NOTNULL(db_.get()); shared_ptr iter(db_->NewIterator(leveldb::ReadOptions())); iter->SeekToFirst(); @@ -152,7 +152,7 @@ typename LeveldbDatabase::const_iterator iter.reset(); } - shared_ptr state; + shared_ptr state; if (iter) { state.reset(new LeveldbState(db_, iter)); } @@ -160,25 +160,25 @@ typename LeveldbDatabase::const_iterator } template -typename LeveldbDatabase::const_iterator - LeveldbDatabase::end() const { - shared_ptr state; +typename LeveldbDataset::const_iterator + LeveldbDataset::end() const { + shared_ptr state; return const_iterator(this, state); } template -typename LeveldbDatabase::const_iterator - LeveldbDatabase::cbegin() const { +typename LeveldbDataset::const_iterator + LeveldbDataset::cbegin() const { return begin(); } template -typename LeveldbDatabase::const_iterator - LeveldbDatabase::cend() const { return end(); } +typename LeveldbDataset::const_iterator + LeveldbDataset::cend() const { return end(); } template -bool LeveldbDatabase::equal(shared_ptr state1, - shared_ptr state2) const { +bool LeveldbDataset::equal(shared_ptr state1, + shared_ptr state2) const { shared_ptr leveldb_state1 = boost::dynamic_pointer_cast(state1); @@ -192,7 +192,7 @@ bool LeveldbDatabase::equal(shared_ptr state1, } template -void LeveldbDatabase::increment(shared_ptr* state) const { +void LeveldbDataset::increment(shared_ptr* state) const { shared_ptr leveldb_state = boost::dynamic_pointer_cast(*state); @@ -210,8 +210,8 @@ void LeveldbDatabase::increment(shared_ptr* state) const { } template -typename Database::KV& LeveldbDatabase::dereference( - shared_ptr state) const { +typename Dataset::KV& LeveldbDataset::dereference( + shared_ptr state) const { shared_ptr leveldb_state = boost::dynamic_pointer_cast(state); @@ -233,6 +233,6 @@ typename Database::KV& LeveldbDatabase::dereference( return leveldb_state->kv_pair_; } -INSTANTIATE_DATABASE(LeveldbDatabase); +INSTANTIATE_DATASET(LeveldbDataset); } // namespace caffe diff --git a/src/caffe/lmdb_database.cpp b/src/caffe/lmdb_dataset.cpp similarity index 85% rename from src/caffe/lmdb_database.cpp rename to src/caffe/lmdb_dataset.cpp index 22a08b2..d2b6fcd 100644 --- a/src/caffe/lmdb_database.cpp +++ b/src/caffe/lmdb_dataset.cpp @@ -5,12 +5,12 @@ #include #include "caffe/caffe.hpp" -#include "caffe/lmdb_database.hpp" +#include "caffe/lmdb_dataset.hpp" namespace caffe { template -bool LmdbDatabase::open(const string& filename, Mode mode) { +bool LmdbDataset::open(const string& filename, Mode mode) { DLOG(INFO) << "LMDB: Open " << filename; CHECK(NULL == env_); @@ -81,7 +81,7 @@ bool LmdbDatabase::open(const string& filename, Mode mode) { } template -bool LmdbDatabase::put(const K& key, const V& value) { +bool LmdbDataset::put(const K& key, const V& value) { DLOG(INFO) << "LMDB: Put"; vector serialized_key; @@ -115,7 +115,7 @@ bool LmdbDatabase::put(const K& key, const V& value) { } template -bool LmdbDatabase::get(const K& key, V* value) { +bool LmdbDataset::get(const K& key, V* value) { DLOG(INFO) << "LMDB: Get"; vector serialized_key; @@ -154,7 +154,7 @@ bool LmdbDatabase::get(const K& key, V* value) { } template -bool LmdbDatabase::commit() { +bool LmdbDataset::commit() { DLOG(INFO) << "LMDB: Commit"; CHECK_NOTNULL(txn_); @@ -176,7 +176,7 @@ bool LmdbDatabase::commit() { } template -void LmdbDatabase::close() { +void LmdbDataset::close() { DLOG(INFO) << "LMDB: Close"; if (env_ && dbi_) { @@ -189,7 +189,7 @@ void LmdbDatabase::close() { } template -void LmdbDatabase::keys(vector* keys) { +void LmdbDataset::keys(vector* keys) { DLOG(INFO) << "LMDB: Keys"; keys->clear(); @@ -199,8 +199,8 @@ void LmdbDatabase::keys(vector* keys) { } template -typename LmdbDatabase::const_iterator - LmdbDatabase::begin() const { +typename LmdbDataset::const_iterator + LmdbDataset::begin() const { int retval; MDB_txn* iter_txn; @@ -219,7 +219,7 @@ typename LmdbDatabase::const_iterator CHECK(MDB_SUCCESS == retval || MDB_NOTFOUND == retval) << mdb_strerror(retval); - shared_ptr state; + shared_ptr state; if (MDB_SUCCESS == retval) { state.reset(new LmdbState(cursor, iter_txn, &dbi_)); } @@ -227,23 +227,23 @@ typename LmdbDatabase::const_iterator } template -typename LmdbDatabase::const_iterator - LmdbDatabase::end() const { - shared_ptr state; +typename LmdbDataset::const_iterator + LmdbDataset::end() const { + shared_ptr state; return const_iterator(this, state); } template -typename LmdbDatabase::const_iterator - LmdbDatabase::cbegin() const { return begin(); } +typename LmdbDataset::const_iterator + LmdbDataset::cbegin() const { return begin(); } template -typename LmdbDatabase::const_iterator - LmdbDatabase::cend() const { return end(); } +typename LmdbDataset::const_iterator + LmdbDataset::cend() const { return end(); } template -bool LmdbDatabase::equal(shared_ptr state1, - shared_ptr state2) const { +bool LmdbDataset::equal(shared_ptr state1, + shared_ptr state2) const { shared_ptr lmdb_state1 = boost::dynamic_pointer_cast(state1); @@ -257,7 +257,7 @@ bool LmdbDatabase::equal(shared_ptr state1, } template -void LmdbDatabase::increment(shared_ptr* state) const { +void LmdbDataset::increment(shared_ptr* state) const { shared_ptr lmdb_state = boost::dynamic_pointer_cast(*state); @@ -279,8 +279,8 @@ void LmdbDatabase::increment(shared_ptr* state) const { } template -typename Database::KV& LmdbDatabase::dereference( - shared_ptr state) const { +typename Dataset::KV& LmdbDataset::dereference( + shared_ptr state) const { shared_ptr lmdb_state = boost::dynamic_pointer_cast(state); @@ -303,6 +303,6 @@ typename Database::KV& LmdbDatabase::dereference( return lmdb_state->kv_pair_; } -INSTANTIATE_DATABASE(LmdbDatabase); +INSTANTIATE_DATASET(LmdbDataset); } // namespace caffe diff --git a/src/caffe/test/test_data_layer.cpp b/src/caffe/test/test_data_layer.cpp index 1c3ec1f..32f5d41 100644 --- a/src/caffe/test/test_data_layer.cpp +++ b/src/caffe/test/test_data_layer.cpp @@ -5,7 +5,7 @@ #include "caffe/blob.hpp" #include "caffe/common.hpp" -#include "caffe/database_factory.hpp" +#include "caffe/dataset_factory.hpp" #include "caffe/filler.hpp" #include "caffe/proto/caffe.pb.h" #include "caffe/util/io.hpp" @@ -38,10 +38,10 @@ class DataLayerTest : public MultiDeviceTest { // an image are the same. void Fill(const bool unique_pixels, DataParameter_DB backend) { backend_ = backend; - LOG(INFO) << "Using temporary database " << *filename_; - shared_ptr > database = - DatabaseFactory(backend_); - CHECK(database->open(*filename_, Database::New)); + LOG(INFO) << "Using temporary dataset " << *filename_; + shared_ptr > dataset = + DatasetFactory(backend_); + CHECK(dataset->open(*filename_, Dataset::New)); for (int i = 0; i < 5; ++i) { Datum datum; datum.set_label(i); @@ -55,10 +55,10 @@ class DataLayerTest : public MultiDeviceTest { } stringstream ss; ss << i; - CHECK(database->put(ss.str(), datum)); + CHECK(dataset->put(ss.str(), datum)); } - CHECK(database->commit()); - database->close(); + CHECK(dataset->commit()); + dataset->close(); } void TestRead() { @@ -183,7 +183,7 @@ class DataLayerTest : public MultiDeviceTest { } crop_sequence.push_back(iter_crop_sequence); } - } // destroy 1st data layer and unlock the database + } // destroy 1st data layer and unlock the dataset // Get crop sequence after reseeding Caffe with 1701. // Check that the sequence is the same as the original. @@ -238,7 +238,7 @@ class DataLayerTest : public MultiDeviceTest { } crop_sequence.push_back(iter_crop_sequence); } - } // destroy 1st data layer and unlock the database + } // destroy 1st data layer and unlock the dataset // Get crop sequence continuing from previous Caffe RNG state; reseed // srand with 1701. Check that the sequence differs from the original. diff --git a/src/caffe/test/test_database.cpp b/src/caffe/test/test_database.cpp deleted file mode 100644 index 9a7d6de..0000000 --- a/src/caffe/test/test_database.cpp +++ /dev/null @@ -1,642 +0,0 @@ -#include -#include - -#include "caffe/util/io.hpp" - -#include "gtest/gtest.h" - -#include "caffe/database_factory.hpp" - -#include "caffe/test/test_caffe_main.hpp" - -namespace caffe { - -namespace DatabaseTest_internal { - -template -struct TestData { - static T TestValue(); - static T TestAltValue(); - static bool equals(const T& a, const T& b); -}; - -template <> -string TestData::TestValue() { - return "world"; -} - -template <> -string TestData::TestAltValue() { - return "bar"; -} - -template <> -bool TestData::equals(const string& a, const string& b) { - return a == b; -} - -template <> -vector TestData >::TestValue() { - string str = "world"; - vector val(str.data(), str.data() + str.size()); - return val; -} - -template <> -vector TestData >::TestAltValue() { - string str = "bar"; - vector val(str.data(), str.data() + str.size()); - return val; -} - -template <> -bool TestData >::equals(const vector& a, - const vector& b) { - if (a.size() != b.size()) { - return false; - } - for (size_t i = 0; i < a.size(); ++i) { - if (a.at(i) != b.at(i)) { - return false; - } - } - - return true; -} - -template <> -Datum TestData::TestValue() { - Datum datum; - datum.set_channels(3); - datum.set_height(32); - datum.set_width(32); - datum.set_data(string(32 * 32 * 3 * 4, ' ')); - datum.set_label(0); - return datum; -} - -template <> -Datum TestData::TestAltValue() { - Datum datum; - datum.set_channels(1); - datum.set_height(64); - datum.set_width(64); - datum.set_data(string(64 * 64 * 1 * 4, ' ')); - datum.set_label(1); - return datum; -} - -template <> -bool TestData::equals(const Datum& a, const Datum& b) { - string serialized_a; - a.SerializeToString(&serialized_a); - - string serialized_b; - b.SerializeToString(&serialized_b); - - return serialized_a == serialized_b; -} - -} // namespace DatabaseTest_internal - -#define UNPACK_TYPES \ - typedef typename TypeParam::value_type value_type; \ - const DataParameter_DB backend = TypeParam::backend; - -template -class DatabaseTest : public ::testing::Test { - protected: - typedef typename TypeParam::value_type value_type; - - string DBName() { - string filename; - MakeTempDir(&filename); - filename += "/db"; - return filename; - } - - string TestKey() { - return "hello"; - } - - value_type TestValue() { - return DatabaseTest_internal::TestData::TestValue(); - } - - string TestAltKey() { - return "foo"; - } - - value_type TestAltValue() { - return DatabaseTest_internal::TestData::TestAltValue(); - } - - template - bool equals(const T& a, const T& b) { - return DatabaseTest_internal::TestData::equals(a, b); - } -}; - -struct StringLeveldb { - typedef string value_type; - static const DataParameter_DB backend; -}; -const DataParameter_DB StringLeveldb::backend = DataParameter_DB_LEVELDB; - -struct StringLmdb { - typedef string value_type; - static const DataParameter_DB backend; -}; -const DataParameter_DB StringLmdb::backend = DataParameter_DB_LEVELDB; - -struct VectorLeveldb { - typedef vector value_type; - static const DataParameter_DB backend; -}; -const DataParameter_DB VectorLeveldb::backend = DataParameter_DB_LEVELDB; - -struct VectorLmdb { - typedef vector value_type; - static const DataParameter_DB backend; -}; -const DataParameter_DB VectorLmdb::backend = DataParameter_DB_LEVELDB; - -struct DatumLeveldb { - typedef Datum value_type; - static const DataParameter_DB backend; -}; -const DataParameter_DB DatumLeveldb::backend = DataParameter_DB_LEVELDB; - -struct DatumLmdb { - typedef Datum value_type; - static const DataParameter_DB backend; -}; -const DataParameter_DB DatumLmdb::backend = DataParameter_DB_LEVELDB; - -typedef ::testing::Types TestTypes; - -TYPED_TEST_CASE(DatabaseTest, TestTypes); - -TYPED_TEST(DatabaseTest, TestNewDoesntExistPasses) { - UNPACK_TYPES; - - shared_ptr > database = - DatabaseFactory(backend); - EXPECT_TRUE(database->open(this->DBName(), - Database::New)); - database->close(); -} - -TYPED_TEST(DatabaseTest, TestNewExistsFails) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > database = - DatabaseFactory(backend); - EXPECT_TRUE(database->open(name, Database::New)); - database->close(); - - EXPECT_FALSE(database->open(name, Database::New)); -} - -TYPED_TEST(DatabaseTest, TestReadOnlyExistsPasses) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > database = - DatabaseFactory(backend); - EXPECT_TRUE(database->open(name, Database::New)); - database->close(); - - EXPECT_TRUE(database->open(name, Database::ReadOnly)); - database->close(); -} - -TYPED_TEST(DatabaseTest, TestReadOnlyDoesntExistFails) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > database = - DatabaseFactory(backend); - EXPECT_FALSE(database->open(name, Database::ReadOnly)); -} - -TYPED_TEST(DatabaseTest, TestReadWriteExistsPasses) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > database = - DatabaseFactory(backend); - EXPECT_TRUE(database->open(name, Database::New)); - database->close(); - - EXPECT_TRUE(database->open(name, Database::ReadWrite)); - database->close(); -} - -TYPED_TEST(DatabaseTest, TestReadWriteDoesntExistPasses) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > database = - DatabaseFactory(backend); - EXPECT_TRUE(database->open(name, Database::ReadWrite)); - database->close(); -} - -TYPED_TEST(DatabaseTest, TestKeys) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > database = - DatabaseFactory(backend); - EXPECT_TRUE(database->open(name, Database::New)); - - string key1 = this->TestKey(); - value_type value1 = this->TestValue(); - - EXPECT_TRUE(database->put(key1, value1)); - - string key2 = this->TestAltKey(); - value_type value2 = this->TestAltValue(); - - EXPECT_TRUE(database->put(key2, value2)); - - EXPECT_TRUE(database->commit()); - - vector keys; - database->keys(&keys); - - EXPECT_EQ(2, keys.size()); - - EXPECT_TRUE(this->equals(keys.at(0), key1) || - this->equals(keys.at(0), key2)); - EXPECT_TRUE(this->equals(keys.at(1), key1) || - this->equals(keys.at(2), key2)); - EXPECT_FALSE(this->equals(keys.at(0), keys.at(1))); -} - -TYPED_TEST(DatabaseTest, TestKeysNoCommit) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > database = - DatabaseFactory(backend); - EXPECT_TRUE(database->open(name, Database::New)); - - string key1 = this->TestKey(); - value_type value1 = this->TestValue(); - - EXPECT_TRUE(database->put(key1, value1)); - - string key2 = this->TestAltKey(); - value_type value2 = this->TestAltValue(); - - EXPECT_TRUE(database->put(key2, value2)); - - vector keys; - database->keys(&keys); - - EXPECT_EQ(0, keys.size()); -} - -TYPED_TEST(DatabaseTest, TestIterators) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > database = - DatabaseFactory(backend); - EXPECT_TRUE(database->open(name, Database::New)); - - const int kNumExamples = 4; - for (int i = 0; i < kNumExamples; ++i) { - stringstream ss; - ss << i; - string key = ss.str(); - ss << " here be data"; - value_type value = this->TestValue(); - EXPECT_TRUE(database->put(key, value)); - } - EXPECT_TRUE(database->commit()); - - int count = 0; - typedef typename Database::const_iterator Iter; - for (Iter iter = database->begin(); iter != database->end(); ++iter) { - (void)iter; - ++count; - } - - EXPECT_EQ(kNumExamples, count); -} - -TYPED_TEST(DatabaseTest, TestIteratorsPreIncrement) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > database = - DatabaseFactory(backend); - EXPECT_TRUE(database->open(name, Database::New)); - - string key1 = this->TestAltKey(); - value_type value1 = this->TestAltValue(); - - string key2 = this->TestKey(); - value_type value2 = this->TestValue(); - - EXPECT_TRUE(database->put(key1, value1)); - EXPECT_TRUE(database->put(key2, value2)); - EXPECT_TRUE(database->commit()); - - typename Database::const_iterator iter1 = - database->begin(); - - EXPECT_FALSE(database->end() == iter1); - - EXPECT_TRUE(this->equals(iter1->key, key1)); - - typename Database::const_iterator iter2 = ++iter1; - - EXPECT_FALSE(database->end() == iter1); - EXPECT_FALSE(database->end() == iter2); - - EXPECT_TRUE(this->equals(iter2->key, key2)); - - typename Database::const_iterator iter3 = ++iter2; - - EXPECT_TRUE(database->end() == iter3); - - database->close(); -} - -TYPED_TEST(DatabaseTest, TestIteratorsPostIncrement) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > database = - DatabaseFactory(backend); - EXPECT_TRUE(database->open(name, Database::New)); - - string key1 = this->TestAltKey(); - value_type value1 = this->TestAltValue(); - - string key2 = this->TestKey(); - value_type value2 = this->TestValue(); - - EXPECT_TRUE(database->put(key1, value1)); - EXPECT_TRUE(database->put(key2, value2)); - EXPECT_TRUE(database->commit()); - - typename Database::const_iterator iter1 = - database->begin(); - - EXPECT_FALSE(database->end() == iter1); - - EXPECT_TRUE(this->equals(iter1->key, key1)); - - typename Database::const_iterator iter2 = iter1++; - - EXPECT_FALSE(database->end() == iter1); - EXPECT_FALSE(database->end() == iter2); - - EXPECT_TRUE(this->equals(iter2->key, key1)); - EXPECT_TRUE(this->equals(iter1->key, key2)); - - typename Database::const_iterator iter3 = iter1++; - - EXPECT_FALSE(database->end() == iter3); - EXPECT_TRUE(this->equals(iter3->key, key2)); - EXPECT_TRUE(database->end() == iter1); - - database->close(); -} - -TYPED_TEST(DatabaseTest, TestNewPutPasses) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > database = - DatabaseFactory(backend); - EXPECT_TRUE(database->open(name, Database::New)); - - string key = this->TestKey(); - value_type value = this->TestValue(); - - EXPECT_TRUE(database->put(key, value)); - - EXPECT_TRUE(database->commit()); - - database->close(); -} - -TYPED_TEST(DatabaseTest, TestNewCommitPasses) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > database = - DatabaseFactory(backend); - EXPECT_TRUE(database->open(name, Database::New)); - - EXPECT_TRUE(database->commit()); - - database->close(); -} - -TYPED_TEST(DatabaseTest, TestNewGetPasses) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > database = - DatabaseFactory(backend); - EXPECT_TRUE(database->open(name, Database::New)); - - string key = this->TestKey(); - value_type value = this->TestValue(); - - EXPECT_TRUE(database->put(key, value)); - - EXPECT_TRUE(database->commit()); - - value_type new_value; - - EXPECT_TRUE(database->get(key, &new_value)); - - EXPECT_TRUE(this->equals(value, new_value)); - - database->close(); -} - -TYPED_TEST(DatabaseTest, TestNewGetNoCommitFails) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > database = - DatabaseFactory(backend); - EXPECT_TRUE(database->open(name, Database::New)); - - string key = this->TestKey(); - value_type value = this->TestValue(); - - EXPECT_TRUE(database->put(key, value)); - - value_type new_value; - - EXPECT_FALSE(database->get(key, &new_value)); -} - - -TYPED_TEST(DatabaseTest, TestReadWritePutPasses) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > database = - DatabaseFactory(backend); - EXPECT_TRUE(database->open(name, Database::ReadWrite)); - - string key = this->TestKey(); - value_type value = this->TestValue(); - - EXPECT_TRUE(database->put(key, value)); - - EXPECT_TRUE(database->commit()); - - database->close(); -} - -TYPED_TEST(DatabaseTest, TestReadWriteCommitPasses) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > database = - DatabaseFactory(backend); - EXPECT_TRUE(database->open(name, Database::ReadWrite)); - - EXPECT_TRUE(database->commit()); - - database->close(); -} - -TYPED_TEST(DatabaseTest, TestReadWriteGetPasses) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > database = - DatabaseFactory(backend); - EXPECT_TRUE(database->open(name, Database::New)); - - string key = this->TestKey(); - value_type value = this->TestValue(); - - EXPECT_TRUE(database->put(key, value)); - - EXPECT_TRUE(database->commit()); - - value_type new_value; - - EXPECT_TRUE(database->get(key, &new_value)); - - EXPECT_TRUE(this->equals(value, new_value)); - - database->close(); -} - -TYPED_TEST(DatabaseTest, TestReadWriteGetNoCommitFails) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > database = - DatabaseFactory(backend); - EXPECT_TRUE(database->open(name, Database::New)); - - string key = this->TestKey(); - value_type value = this->TestValue(); - - EXPECT_TRUE(database->put(key, value)); - - value_type new_value; - - EXPECT_FALSE(database->get(key, &new_value)); -} - -TYPED_TEST(DatabaseTest, TestReadOnlyPutFails) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > database = - DatabaseFactory(backend); - EXPECT_TRUE(database->open(name, Database::New)); - database->close(); - - EXPECT_TRUE(database->open(name, Database::ReadOnly)); - - string key = this->TestKey(); - value_type value = this->TestValue(); - - EXPECT_FALSE(database->put(key, value)); -} - -TYPED_TEST(DatabaseTest, TestReadOnlyCommitFails) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > database = - DatabaseFactory(backend); - EXPECT_TRUE(database->open(name, Database::New)); - database->close(); - - EXPECT_TRUE(database->open(name, Database::ReadOnly)); - - EXPECT_FALSE(database->commit()); -} - -TYPED_TEST(DatabaseTest, TestReadOnlyGetPasses) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > database = - DatabaseFactory(backend); - EXPECT_TRUE(database->open(name, Database::New)); - - string key = this->TestKey(); - value_type value = this->TestValue(); - - EXPECT_TRUE(database->put(key, value)); - - EXPECT_TRUE(database->commit()); - - database->close(); - - EXPECT_TRUE(database->open(name, Database::ReadOnly)); - - value_type new_value; - - EXPECT_TRUE(database->get(key, &new_value)); - - EXPECT_TRUE(this->equals(value, new_value)); -} - -TYPED_TEST(DatabaseTest, TestReadOnlyGetNoCommitFails) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > database = - DatabaseFactory(backend); - EXPECT_TRUE(database->open(name, Database::New)); - - string key = this->TestKey(); - value_type value = this->TestValue(); - - EXPECT_TRUE(database->put(key, value)); - - database->close(); - - EXPECT_TRUE(database->open(name, Database::ReadOnly)); - - value_type new_value; - - EXPECT_FALSE(database->get(key, &new_value)); -} - -#undef UNPACK_TYPES - -} // namespace caffe diff --git a/src/caffe/test/test_dataset.cpp b/src/caffe/test/test_dataset.cpp new file mode 100644 index 0000000..bb6cf4c --- /dev/null +++ b/src/caffe/test/test_dataset.cpp @@ -0,0 +1,642 @@ +#include +#include + +#include "caffe/util/io.hpp" + +#include "gtest/gtest.h" + +#include "caffe/dataset_factory.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +namespace DatasetTest_internal { + +template +struct TestData { + static T TestValue(); + static T TestAltValue(); + static bool equals(const T& a, const T& b); +}; + +template <> +string TestData::TestValue() { + return "world"; +} + +template <> +string TestData::TestAltValue() { + return "bar"; +} + +template <> +bool TestData::equals(const string& a, const string& b) { + return a == b; +} + +template <> +vector TestData >::TestValue() { + string str = "world"; + vector val(str.data(), str.data() + str.size()); + return val; +} + +template <> +vector TestData >::TestAltValue() { + string str = "bar"; + vector val(str.data(), str.data() + str.size()); + return val; +} + +template <> +bool TestData >::equals(const vector& a, + const vector& b) { + if (a.size() != b.size()) { + return false; + } + for (size_t i = 0; i < a.size(); ++i) { + if (a.at(i) != b.at(i)) { + return false; + } + } + + return true; +} + +template <> +Datum TestData::TestValue() { + Datum datum; + datum.set_channels(3); + datum.set_height(32); + datum.set_width(32); + datum.set_data(string(32 * 32 * 3 * 4, ' ')); + datum.set_label(0); + return datum; +} + +template <> +Datum TestData::TestAltValue() { + Datum datum; + datum.set_channels(1); + datum.set_height(64); + datum.set_width(64); + datum.set_data(string(64 * 64 * 1 * 4, ' ')); + datum.set_label(1); + return datum; +} + +template <> +bool TestData::equals(const Datum& a, const Datum& b) { + string serialized_a; + a.SerializeToString(&serialized_a); + + string serialized_b; + b.SerializeToString(&serialized_b); + + return serialized_a == serialized_b; +} + +} // namespace DatasetTest_internal + +#define UNPACK_TYPES \ + typedef typename TypeParam::value_type value_type; \ + const DataParameter_DB backend = TypeParam::backend; + +template +class DatasetTest : public ::testing::Test { + protected: + typedef typename TypeParam::value_type value_type; + + string DBName() { + string filename; + MakeTempDir(&filename); + filename += "/db"; + return filename; + } + + string TestKey() { + return "hello"; + } + + value_type TestValue() { + return DatasetTest_internal::TestData::TestValue(); + } + + string TestAltKey() { + return "foo"; + } + + value_type TestAltValue() { + return DatasetTest_internal::TestData::TestAltValue(); + } + + template + bool equals(const T& a, const T& b) { + return DatasetTest_internal::TestData::equals(a, b); + } +}; + +struct StringLeveldb { + typedef string value_type; + static const DataParameter_DB backend; +}; +const DataParameter_DB StringLeveldb::backend = DataParameter_DB_LEVELDB; + +struct StringLmdb { + typedef string value_type; + static const DataParameter_DB backend; +}; +const DataParameter_DB StringLmdb::backend = DataParameter_DB_LEVELDB; + +struct VectorLeveldb { + typedef vector value_type; + static const DataParameter_DB backend; +}; +const DataParameter_DB VectorLeveldb::backend = DataParameter_DB_LEVELDB; + +struct VectorLmdb { + typedef vector value_type; + static const DataParameter_DB backend; +}; +const DataParameter_DB VectorLmdb::backend = DataParameter_DB_LEVELDB; + +struct DatumLeveldb { + typedef Datum value_type; + static const DataParameter_DB backend; +}; +const DataParameter_DB DatumLeveldb::backend = DataParameter_DB_LEVELDB; + +struct DatumLmdb { + typedef Datum value_type; + static const DataParameter_DB backend; +}; +const DataParameter_DB DatumLmdb::backend = DataParameter_DB_LEVELDB; + +typedef ::testing::Types TestTypes; + +TYPED_TEST_CASE(DatasetTest, TestTypes); + +TYPED_TEST(DatasetTest, TestNewDoesntExistPasses) { + UNPACK_TYPES; + + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(this->DBName(), + Dataset::New)); + dataset->close(); +} + +TYPED_TEST(DatasetTest, TestNewExistsFails) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + dataset->close(); + + EXPECT_FALSE(dataset->open(name, Dataset::New)); +} + +TYPED_TEST(DatasetTest, TestReadOnlyExistsPasses) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + dataset->close(); + + EXPECT_TRUE(dataset->open(name, Dataset::ReadOnly)); + dataset->close(); +} + +TYPED_TEST(DatasetTest, TestReadOnlyDoesntExistFails) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_FALSE(dataset->open(name, Dataset::ReadOnly)); +} + +TYPED_TEST(DatasetTest, TestReadWriteExistsPasses) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + dataset->close(); + + EXPECT_TRUE(dataset->open(name, Dataset::ReadWrite)); + dataset->close(); +} + +TYPED_TEST(DatasetTest, TestReadWriteDoesntExistPasses) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::ReadWrite)); + dataset->close(); +} + +TYPED_TEST(DatasetTest, TestKeys) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + string key1 = this->TestKey(); + value_type value1 = this->TestValue(); + + EXPECT_TRUE(dataset->put(key1, value1)); + + string key2 = this->TestAltKey(); + value_type value2 = this->TestAltValue(); + + EXPECT_TRUE(dataset->put(key2, value2)); + + EXPECT_TRUE(dataset->commit()); + + vector keys; + dataset->keys(&keys); + + EXPECT_EQ(2, keys.size()); + + EXPECT_TRUE(this->equals(keys.at(0), key1) || + this->equals(keys.at(0), key2)); + EXPECT_TRUE(this->equals(keys.at(1), key1) || + this->equals(keys.at(2), key2)); + EXPECT_FALSE(this->equals(keys.at(0), keys.at(1))); +} + +TYPED_TEST(DatasetTest, TestKeysNoCommit) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + string key1 = this->TestKey(); + value_type value1 = this->TestValue(); + + EXPECT_TRUE(dataset->put(key1, value1)); + + string key2 = this->TestAltKey(); + value_type value2 = this->TestAltValue(); + + EXPECT_TRUE(dataset->put(key2, value2)); + + vector keys; + dataset->keys(&keys); + + EXPECT_EQ(0, keys.size()); +} + +TYPED_TEST(DatasetTest, TestIterators) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + const int kNumExamples = 4; + for (int i = 0; i < kNumExamples; ++i) { + stringstream ss; + ss << i; + string key = ss.str(); + ss << " here be data"; + value_type value = this->TestValue(); + EXPECT_TRUE(dataset->put(key, value)); + } + EXPECT_TRUE(dataset->commit()); + + int count = 0; + typedef typename Dataset::const_iterator Iter; + for (Iter iter = dataset->begin(); iter != dataset->end(); ++iter) { + (void)iter; + ++count; + } + + EXPECT_EQ(kNumExamples, count); +} + +TYPED_TEST(DatasetTest, TestIteratorsPreIncrement) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + string key1 = this->TestAltKey(); + value_type value1 = this->TestAltValue(); + + string key2 = this->TestKey(); + value_type value2 = this->TestValue(); + + EXPECT_TRUE(dataset->put(key1, value1)); + EXPECT_TRUE(dataset->put(key2, value2)); + EXPECT_TRUE(dataset->commit()); + + typename Dataset::const_iterator iter1 = + dataset->begin(); + + EXPECT_FALSE(dataset->end() == iter1); + + EXPECT_TRUE(this->equals(iter1->key, key1)); + + typename Dataset::const_iterator iter2 = ++iter1; + + EXPECT_FALSE(dataset->end() == iter1); + EXPECT_FALSE(dataset->end() == iter2); + + EXPECT_TRUE(this->equals(iter2->key, key2)); + + typename Dataset::const_iterator iter3 = ++iter2; + + EXPECT_TRUE(dataset->end() == iter3); + + dataset->close(); +} + +TYPED_TEST(DatasetTest, TestIteratorsPostIncrement) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + string key1 = this->TestAltKey(); + value_type value1 = this->TestAltValue(); + + string key2 = this->TestKey(); + value_type value2 = this->TestValue(); + + EXPECT_TRUE(dataset->put(key1, value1)); + EXPECT_TRUE(dataset->put(key2, value2)); + EXPECT_TRUE(dataset->commit()); + + typename Dataset::const_iterator iter1 = + dataset->begin(); + + EXPECT_FALSE(dataset->end() == iter1); + + EXPECT_TRUE(this->equals(iter1->key, key1)); + + typename Dataset::const_iterator iter2 = iter1++; + + EXPECT_FALSE(dataset->end() == iter1); + EXPECT_FALSE(dataset->end() == iter2); + + EXPECT_TRUE(this->equals(iter2->key, key1)); + EXPECT_TRUE(this->equals(iter1->key, key2)); + + typename Dataset::const_iterator iter3 = iter1++; + + EXPECT_FALSE(dataset->end() == iter3); + EXPECT_TRUE(this->equals(iter3->key, key2)); + EXPECT_TRUE(dataset->end() == iter1); + + dataset->close(); +} + +TYPED_TEST(DatasetTest, TestNewPutPasses) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + string key = this->TestKey(); + value_type value = this->TestValue(); + + EXPECT_TRUE(dataset->put(key, value)); + + EXPECT_TRUE(dataset->commit()); + + dataset->close(); +} + +TYPED_TEST(DatasetTest, TestNewCommitPasses) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + EXPECT_TRUE(dataset->commit()); + + dataset->close(); +} + +TYPED_TEST(DatasetTest, TestNewGetPasses) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + string key = this->TestKey(); + value_type value = this->TestValue(); + + EXPECT_TRUE(dataset->put(key, value)); + + EXPECT_TRUE(dataset->commit()); + + value_type new_value; + + EXPECT_TRUE(dataset->get(key, &new_value)); + + EXPECT_TRUE(this->equals(value, new_value)); + + dataset->close(); +} + +TYPED_TEST(DatasetTest, TestNewGetNoCommitFails) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + string key = this->TestKey(); + value_type value = this->TestValue(); + + EXPECT_TRUE(dataset->put(key, value)); + + value_type new_value; + + EXPECT_FALSE(dataset->get(key, &new_value)); +} + + +TYPED_TEST(DatasetTest, TestReadWritePutPasses) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::ReadWrite)); + + string key = this->TestKey(); + value_type value = this->TestValue(); + + EXPECT_TRUE(dataset->put(key, value)); + + EXPECT_TRUE(dataset->commit()); + + dataset->close(); +} + +TYPED_TEST(DatasetTest, TestReadWriteCommitPasses) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::ReadWrite)); + + EXPECT_TRUE(dataset->commit()); + + dataset->close(); +} + +TYPED_TEST(DatasetTest, TestReadWriteGetPasses) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + string key = this->TestKey(); + value_type value = this->TestValue(); + + EXPECT_TRUE(dataset->put(key, value)); + + EXPECT_TRUE(dataset->commit()); + + value_type new_value; + + EXPECT_TRUE(dataset->get(key, &new_value)); + + EXPECT_TRUE(this->equals(value, new_value)); + + dataset->close(); +} + +TYPED_TEST(DatasetTest, TestReadWriteGetNoCommitFails) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + string key = this->TestKey(); + value_type value = this->TestValue(); + + EXPECT_TRUE(dataset->put(key, value)); + + value_type new_value; + + EXPECT_FALSE(dataset->get(key, &new_value)); +} + +TYPED_TEST(DatasetTest, TestReadOnlyPutFails) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + dataset->close(); + + EXPECT_TRUE(dataset->open(name, Dataset::ReadOnly)); + + string key = this->TestKey(); + value_type value = this->TestValue(); + + EXPECT_FALSE(dataset->put(key, value)); +} + +TYPED_TEST(DatasetTest, TestReadOnlyCommitFails) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + dataset->close(); + + EXPECT_TRUE(dataset->open(name, Dataset::ReadOnly)); + + EXPECT_FALSE(dataset->commit()); +} + +TYPED_TEST(DatasetTest, TestReadOnlyGetPasses) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + string key = this->TestKey(); + value_type value = this->TestValue(); + + EXPECT_TRUE(dataset->put(key, value)); + + EXPECT_TRUE(dataset->commit()); + + dataset->close(); + + EXPECT_TRUE(dataset->open(name, Dataset::ReadOnly)); + + value_type new_value; + + EXPECT_TRUE(dataset->get(key, &new_value)); + + EXPECT_TRUE(this->equals(value, new_value)); +} + +TYPED_TEST(DatasetTest, TestReadOnlyGetNoCommitFails) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + string key = this->TestKey(); + value_type value = this->TestValue(); + + EXPECT_TRUE(dataset->put(key, value)); + + dataset->close(); + + EXPECT_TRUE(dataset->open(name, Dataset::ReadOnly)); + + value_type new_value; + + EXPECT_FALSE(dataset->get(key, &new_value)); +} + +#undef UNPACK_TYPES + +} // namespace caffe diff --git a/tools/compute_image_mean.cpp b/tools/compute_image_mean.cpp index f1a7967..a720f16 100644 --- a/tools/compute_image_mean.cpp +++ b/tools/compute_image_mean.cpp @@ -4,11 +4,11 @@ #include #include -#include "caffe/database_factory.hpp" +#include "caffe/dataset_factory.hpp" #include "caffe/proto/caffe.pb.h" #include "caffe/util/io.hpp" -using caffe::Database; +using caffe::Dataset; using caffe::Datum; using caffe::BlobProto; using std::max; @@ -26,16 +26,16 @@ int main(int argc, char** argv) { db_backend = std::string(argv[3]); } - caffe::shared_ptr > database = - caffe::DatabaseFactory(db_backend); + caffe::shared_ptr > dataset = + caffe::DatasetFactory(db_backend); // Open db - CHECK(database->open(argv[1], Database::ReadOnly)); + CHECK(dataset->open(argv[1], Dataset::ReadOnly)); BlobProto sum_blob; int count = 0; // load first datum - Database::const_iterator iter = database->begin(); + Dataset::const_iterator iter = dataset->begin(); const Datum& datum = iter->value; sum_blob.set_num(1); @@ -49,8 +49,8 @@ int main(int argc, char** argv) { sum_blob.add_data(0.); } LOG(INFO) << "Starting Iteration"; - for (Database::const_iterator iter = database->begin(); - iter != database->end(); ++iter) { + for (Dataset::const_iterator iter = dataset->begin(); + iter != dataset->end(); ++iter) { // just a dummy operation const Datum& datum = iter->value; const std::string& data = datum.data(); @@ -87,6 +87,6 @@ int main(int argc, char** argv) { WriteProtoToBinaryFile(sum_blob, argv[2]); // Clean up - database->close(); + dataset->close(); return 0; } diff --git a/tools/convert_imageset.cpp b/tools/convert_imageset.cpp index 2ba3e3c..37efa5c 100644 --- a/tools/convert_imageset.cpp +++ b/tools/convert_imageset.cpp @@ -17,7 +17,7 @@ #include #include -#include "caffe/database_factory.hpp" +#include "caffe/dataset_factory.hpp" #include "caffe/proto/caffe.pb.h" #include "caffe/util/io.hpp" #include "caffe/util/rng.hpp" @@ -78,11 +78,11 @@ int main(int argc, char** argv) { int resize_width = std::max(0, FLAGS_resize_width); // Open new db - shared_ptr > database = - DatabaseFactory(db_backend); + shared_ptr > dataset = + DatasetFactory(db_backend); // Open db - CHECK(database->open(db_path, Database::New)); + CHECK(dataset->open(db_path, Dataset::New)); // Storing to db std::string root_folder(argv[1]); @@ -113,19 +113,19 @@ int main(int argc, char** argv) { lines[line_id].first.c_str()); // Put in db - CHECK(database->put(string(key_cstr, length), datum)); + CHECK(dataset->put(string(key_cstr, length), datum)); if (++count % 1000 == 0) { // Commit txn - CHECK(database->commit()); + CHECK(dataset->commit()); LOG(ERROR) << "Processed " << count << " files."; } } // write the last batch if (count % 1000 != 0) { - CHECK(database->commit()); + CHECK(dataset->commit()); LOG(ERROR) << "Processed " << count << " files."; } - database->close(); + dataset->close(); return 0; } diff --git a/tools/extract_features.cpp b/tools/extract_features.cpp index 47565a8..ddbce10 100644 --- a/tools/extract_features.cpp +++ b/tools/extract_features.cpp @@ -7,7 +7,7 @@ #include "caffe/blob.hpp" #include "caffe/common.hpp" -#include "caffe/database_factory.hpp" +#include "caffe/dataset_factory.hpp" #include "caffe/net.hpp" #include "caffe/proto/caffe.pb.h" #include "caffe/util/io.hpp" @@ -16,8 +16,8 @@ using boost::shared_ptr; using caffe::Blob; using caffe::Caffe; -using caffe::Database; -using caffe::DatabaseFactory; +using caffe::Dataset; +using caffe::DatasetFactory; using caffe::Datum; using caffe::Net; @@ -39,12 +39,12 @@ int feature_extraction_pipeline(int argc, char** argv) { " extract features of the input data produced by the net.\n" "Usage: extract_features pretrained_net_param" " feature_extraction_proto_file extract_feature_blob_name1[,name2,...]" - " save_feature_database_name1[,name2,...] num_mini_batches db_type" + " save_feature_dataset_name1[,name2,...] num_mini_batches db_type" " [CPU/GPU] [DEVICE_ID=0]\n" "Note: you can extract multiple features in one pass by specifying" - " multiple feature blob names and database names seperated by ','." + " multiple feature blob names and dataset names seperated by ','." " The names cannot contain white space characters and the number of blobs" - " and databases must be equal."; + " and datasets must be equal."; return 1; } int arg_pos = num_required_args; @@ -105,12 +105,12 @@ int feature_extraction_pipeline(int argc, char** argv) { std::vector blob_names; boost::split(blob_names, extract_feature_blob_names, boost::is_any_of(",")); - std::string save_feature_database_names(argv[++arg_pos]); - std::vector database_names; - boost::split(database_names, save_feature_database_names, + std::string save_feature_dataset_names(argv[++arg_pos]); + std::vector dataset_names; + boost::split(dataset_names, save_feature_dataset_names, boost::is_any_of(",")); - CHECK_EQ(blob_names.size(), database_names.size()) << - " the number of blob names and database names must be equal"; + CHECK_EQ(blob_names.size(), dataset_names.size()) << + " the number of blob names and dataset names must be equal"; size_t num_features = blob_names.size(); for (size_t i = 0; i < num_features; i++) { @@ -121,14 +121,13 @@ int feature_extraction_pipeline(int argc, char** argv) { int num_mini_batches = atoi(argv[++arg_pos]); - std::vector > > feature_dbs; + std::vector > > feature_dbs; for (size_t i = 0; i < num_features; ++i) { - LOG(INFO)<< "Opening database " << database_names[i]; - shared_ptr > database = - DatabaseFactory(argv[++arg_pos]); - CHECK(database->open(database_names.at(i), - Database::New)); - feature_dbs.push_back(database); + LOG(INFO)<< "Opening dataset " << dataset_names[i]; + shared_ptr > dataset = + DatasetFactory(argv[++arg_pos]); + CHECK(dataset->open(dataset_names.at(i), Dataset::New)); + feature_dbs.push_back(dataset); } LOG(ERROR)<< "Extacting Features";