Renamed Database interface to Dataset.
authorKevin James Matzen <kmatzen@cs.cornell.edu>
Mon, 13 Oct 2014 18:54:48 +0000 (14:54 -0400)
committerKevin James Matzen <kmatzen@cs.cornell.edu>
Tue, 14 Oct 2014 23:40:37 +0000 (19:40 -0400)
18 files changed:
examples/cifar10/convert_cifar_data.cpp
include/caffe/data_layers.hpp
include/caffe/database_factory.hpp [deleted file]
include/caffe/dataset.hpp [moved from include/caffe/database.hpp with 81% similarity]
include/caffe/dataset_factory.hpp [new file with mode: 0644]
include/caffe/leveldb_dataset.hpp [moved from include/caffe/leveldb_database.hpp with 70% similarity]
include/caffe/lmdb_dataset.hpp [moved from include/caffe/lmdb_database.hpp with 71% similarity]
src/caffe/database_factory.cpp [deleted file]
src/caffe/dataset_factory.cpp [new file with mode: 0644]
src/caffe/layers/data_layer.cpp
src/caffe/leveldb_dataset.cpp [moved from src/caffe/leveldb_database.cpp with 78% similarity]
src/caffe/lmdb_dataset.cpp [moved from src/caffe/lmdb_database.cpp with 85% similarity]
src/caffe/test/test_data_layer.cpp
src/caffe/test/test_database.cpp [deleted file]
src/caffe/test/test_dataset.cpp [new file with mode: 0644]
tools/compute_image_mean.cpp
tools/convert_imageset.cpp
tools/extract_features.cpp

index f86f393..9eecc74 100644 (file)
 #include "google/protobuf/text_format.h"
 #include "stdint.h"
 
-#include "caffe/database_factory.hpp"
+#include "caffe/dataset_factory.hpp"
 #include "caffe/proto/caffe.pb.h"
 
 using std::string;
 
-using caffe::Database;
-using caffe::DatabaseFactory;
+using caffe::Dataset;
+using caffe::DatasetFactory;
 using caffe::Datum;
 using caffe::shared_ptr;
 
@@ -38,10 +38,10 @@ void read_image(std::ifstream* file, int* label, char* buffer) {
 
 void convert_dataset(const string& input_folder, const string& output_folder,
     const string& db_type) {
-  shared_ptr<Database<string, Datum> > train_database =
-      DatabaseFactory<string, Datum>(db_type);
-  CHECK(train_database->open(output_folder + "/cifar10_train_" + db_type,
-      Database<string, Datum>::New));
+  shared_ptr<Dataset<string, Datum> > train_dataset =
+      DatasetFactory<string, Datum>(db_type);
+  CHECK(train_dataset->open(output_folder + "/cifar10_train_" + db_type,
+      Dataset<string, Datum>::New));
   // Data buffer
   int label;
   char str_buffer[kCIFARImageNBytes];
@@ -64,17 +64,17 @@ void convert_dataset(const string& input_folder, const string& output_folder,
       datum.set_data(str_buffer, kCIFARImageNBytes);
       int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d",
           fileid * kCIFARBatchSize + itemid);
-      CHECK(train_database->put(string(str_buffer, length), datum));
+      CHECK(train_dataset->put(string(str_buffer, length), datum));
     }
   }
-  CHECK(train_database->commit());
-  train_database->close();
+  CHECK(train_dataset->commit());
+  train_dataset->close();
 
   LOG(INFO) << "Writing Testing data";
-  shared_ptr<Database<string, Datum> > test_database =
-      DatabaseFactory<string, Datum>(db_type);
-  CHECK(test_database->open(output_folder + "/cifar10_test_" + db_type,
-      Database<string, Datum>::New));
+  shared_ptr<Dataset<string, Datum> > test_dataset =
+      DatasetFactory<string, Datum>(db_type);
+  CHECK(test_dataset->open(output_folder + "/cifar10_test_" + db_type,
+      Dataset<string, Datum>::New));
   // Open files
   std::ifstream data_file((input_folder + "/test_batch.bin").c_str(),
       std::ios::in | std::ios::binary);
@@ -84,10 +84,10 @@ void convert_dataset(const string& input_folder, const string& output_folder,
     datum.set_label(label);
     datum.set_data(str_buffer, kCIFARImageNBytes);
     int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d", itemid);
-    CHECK(test_database->put(string(str_buffer, length), datum));
+    CHECK(test_dataset->put(string(str_buffer, length), datum));
   }
-  CHECK(test_database->commit());
-  test_database->close();
+  CHECK(test_dataset->commit());
+  test_dataset->close();
 }
 
 int main(int argc, char** argv) {
index 13ba4e6..a2ea854 100644 (file)
@@ -11,7 +11,7 @@
 #include "caffe/blob.hpp"
 #include "caffe/common.hpp"
 #include "caffe/data_transformer.hpp"
-#include "caffe/database.hpp"
+#include "caffe/dataset.hpp"
 #include "caffe/filler.hpp"
 #include "caffe/internal_thread.hpp"
 #include "caffe/layer.hpp"
@@ -100,8 +100,8 @@ class DataLayer : public BasePrefetchingDataLayer<Dtype> {
  protected:
   virtual void InternalThreadEntry();
 
-  shared_ptr<Database<string, Datum> > database_;
-  Database<string, Datum>::const_iterator iter_;
+  shared_ptr<Dataset<string, Datum> > dataset_;
+  Dataset<string, Datum>::const_iterator iter_;
 };
 
 /**
diff --git a/include/caffe/database_factory.hpp b/include/caffe/database_factory.hpp
deleted file mode 100644 (file)
index 30f191e..0000000
+++ /dev/null
@@ -1,20 +0,0 @@
-#ifndef CAFFE_DATABASE_FACTORY_H_
-#define CAFFE_DATABASE_FACTORY_H_
-
-#include <string>
-
-#include "caffe/common.hpp"
-#include "caffe/database.hpp"
-#include "caffe/proto/caffe.pb.h"
-
-namespace caffe {
-
-template <typename K, typename V>
-shared_ptr<Database<K, V> > DatabaseFactory(const DataParameter_DB& type);
-
-template <typename K, typename V>
-shared_ptr<Database<K, V> > DatabaseFactory(const string& type);
-
-}  // namespace caffe
-
-#endif  // CAFFE_DATABASE_FACTORY_H_
similarity index 81%
rename from include/caffe/database.hpp
rename to include/caffe/dataset.hpp
index 711ebdd..1f07aa9 100644 (file)
@@ -1,5 +1,5 @@
-#ifndef CAFFE_DATABASE_H_
-#define CAFFE_DATABASE_H_
+#ifndef CAFFE_DATASET_H_
+#define CAFFE_DATASET_H_
 
 #include <algorithm>
 #include <iterator>
@@ -11,7 +11,7 @@
 
 namespace caffe {
 
-namespace database_internal {
+namespace dataset_internal {
 
 template <typename T>
 struct Coder {
@@ -85,10 +85,10 @@ struct Coder<vector<char> > {
   }
 };
 
-}  // namespace database_internal
+}  // namespace dataset_internal
 
 template <typename K, typename V>
-class Database {
+class Dataset {
  public:
   enum Mode {
     New,
@@ -112,8 +112,8 @@ class Database {
 
   virtual void keys(vector<K>* keys) = 0;
 
-  Database() { }
-  virtual ~Database() { }
+  Dataset() { }
+  virtual ~Dataset() { }
 
   class iterator;
   typedef iterator const_iterator;
@@ -124,7 +124,7 @@ class Database {
   virtual const_iterator cend() const = 0;
 
  protected:
-  class DatabaseState;
+  class DatasetState;
 
  public:
   class iterator : public std::iterator<std::forward_iterator_tag, KV> {
@@ -137,7 +137,7 @@ class Database {
     iterator()
         : parent_(NULL) { }
 
-    iterator(const Database* parent, shared_ptr<DatabaseState> state)
+    iterator(const Dataset* parent, shared_ptr<DatasetState> state)
         : parent_(parent),
           state_(state) { }
     ~iterator() { }
@@ -145,7 +145,7 @@ class Database {
     iterator(const iterator& other)
         : parent_(other.parent_),
           state_(other.state_ ? other.state_->clone()
-              : shared_ptr<DatabaseState>()) { }
+              : shared_ptr<DatasetState>()) { }
 
     iterator& operator=(iterator copy) {
       copy.swap(*this);
@@ -184,49 +184,49 @@ class Database {
     }
 
    protected:
-    const Database* parent_;
-    shared_ptr<DatabaseState> state_;
+    const Dataset* parent_;
+    shared_ptr<DatasetState> state_;
   };
 
  protected:
-  class DatabaseState {
+  class DatasetState {
    public:
-    virtual ~DatabaseState() { }
-    virtual shared_ptr<DatabaseState> clone() = 0;
+    virtual ~DatasetState() { }
+    virtual shared_ptr<DatasetState> clone() = 0;
   };
 
-  virtual bool equal(shared_ptr<DatabaseState> state1,
-      shared_ptr<DatabaseState> state2) const = 0;
-  virtual void increment(shared_ptr<DatabaseState>* state) const = 0;
+  virtual bool equal(shared_ptr<DatasetState> state1,
+      shared_ptr<DatasetState> state2) const = 0;
+  virtual void increment(shared_ptr<DatasetState>* state) const = 0;
   virtual KV& dereference(
-      shared_ptr<DatabaseState> state) const = 0;
+      shared_ptr<DatasetState> state) const = 0;
 
   template <typename T>
   static bool serialize(const T& obj, string* serialized) {
-    return database_internal::Coder<T>::serialize(obj, serialized);
+    return dataset_internal::Coder<T>::serialize(obj, serialized);
   }
 
   template <typename T>
   static bool serialize(const T& obj, vector<char>* serialized) {
-    return database_internal::Coder<T>::serialize(obj, serialized);
+    return dataset_internal::Coder<T>::serialize(obj, serialized);
   }
 
   template <typename T>
   static bool deserialize(const string& serialized, T* obj) {
-    return database_internal::Coder<T>::deserialize(serialized, obj);
+    return dataset_internal::Coder<T>::deserialize(serialized, obj);
   }
 
   template <typename T>
   static bool deserialize(const char* data, size_t size, T* obj) {
-    return database_internal::Coder<T>::deserialize(data, size, obj);
+    return dataset_internal::Coder<T>::deserialize(data, size, obj);
   }
 };
 
 }  // namespace caffe
 
-#define INSTANTIATE_DATABASE(type) \
+#define INSTANTIATE_DATASET(type) \
   template class type<string, string>; \
   template class type<string, vector<char> >; \
   template class type<string, Datum>;
 
-#endif  // CAFFE_DATABASE_H_
+#endif  // CAFFE_DATASET_H_
diff --git a/include/caffe/dataset_factory.hpp b/include/caffe/dataset_factory.hpp
new file mode 100644 (file)
index 0000000..57db49b
--- /dev/null
@@ -0,0 +1,20 @@
+#ifndef CAFFE_DATASET_FACTORY_H_
+#define CAFFE_DATASET_FACTORY_H_
+
+#include <string>
+
+#include "caffe/common.hpp"
+#include "caffe/dataset.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+namespace caffe {
+
+template <typename K, typename V>
+shared_ptr<Dataset<K, V> > DatasetFactory(const DataParameter_DB& type);
+
+template <typename K, typename V>
+shared_ptr<Dataset<K, V> > DatasetFactory(const string& type);
+
+}  // namespace caffe
+
+#endif  // CAFFE_DATASET_FACTORY_H_
similarity index 70%
rename from include/caffe/leveldb_database.hpp
rename to include/caffe/leveldb_dataset.hpp
index cd966c4..eb354d9 100644 (file)
@@ -1,5 +1,5 @@
-#ifndef CAFFE_LEVELDB_DATABASE_H_
-#define CAFFE_LEVELDB_DATABASE_H_
+#ifndef CAFFE_LEVELDB_DATASET_H_
+#define CAFFE_LEVELDB_DATASET_H_
 
 #include <leveldb/db.h>
 #include <leveldb/write_batch.h>
@@ -9,17 +9,17 @@
 #include <vector>
 
 #include "caffe/common.hpp"
-#include "caffe/database.hpp"
+#include "caffe/dataset.hpp"
 
 namespace caffe {
 
 template <typename K, typename V>
-class LeveldbDatabase : public Database<K, V> {
+class LeveldbDataset : public Dataset<K, V> {
  public:
-  typedef Database<K, V> Base;
+  typedef Dataset<K, V> Base;
   typedef typename Base::key_type key_type;
   typedef typename Base::value_type value_type;
-  typedef typename Base::DatabaseState DatabaseState;
+  typedef typename Base::DatasetState DatasetState;
   typedef typename Base::Mode Mode;
   typedef typename Base::const_iterator const_iterator;
   typedef typename Base::KV KV;
@@ -38,11 +38,11 @@ class LeveldbDatabase : public Database<K, V> {
   const_iterator cend() const;
 
  protected:
-  class LeveldbState : public DatabaseState {
+  class LeveldbState : public DatasetState {
    public:
     explicit LeveldbState(shared_ptr<leveldb::DB> db,
         shared_ptr<leveldb::Iterator> iter)
-        : DatabaseState(),
+        : DatasetState(),
           db_(db),
           iter_(iter) { }
 
@@ -54,7 +54,7 @@ class LeveldbDatabase : public Database<K, V> {
       db_.reset();
     }
 
-    shared_ptr<DatabaseState> clone() {
+    shared_ptr<DatasetState> clone() {
       shared_ptr<leveldb::Iterator> new_iter;
 
       if (iter_.get()) {
@@ -64,7 +64,7 @@ class LeveldbDatabase : public Database<K, V> {
         CHECK(new_iter->Valid());
       }
 
-      return shared_ptr<DatabaseState>(new LeveldbState(db_, new_iter));
+      return shared_ptr<DatasetState>(new LeveldbState(db_, new_iter));
     }
 
     shared_ptr<leveldb::DB> db_;
@@ -72,10 +72,10 @@ class LeveldbDatabase : public Database<K, V> {
     KV kv_pair_;
   };
 
-  bool equal(shared_ptr<DatabaseState> state1,
-      shared_ptr<DatabaseState> state2) const;
-  void increment(shared_ptr<DatabaseState>* state) const;
-  KV& dereference(shared_ptr<DatabaseState> state) const;
+  bool equal(shared_ptr<DatasetState> state1,
+      shared_ptr<DatasetState> state2) const;
+  void increment(shared_ptr<DatasetState>* state) const;
+  KV& dereference(shared_ptr<DatasetState> state) const;
 
   shared_ptr<leveldb::DB> db_;
   shared_ptr<leveldb::WriteBatch> batch_;
@@ -84,4 +84,4 @@ class LeveldbDatabase : public Database<K, V> {
 
 }  // namespace caffe
 
-#endif  // CAFFE_LEVELDB_DATABASE_H_
+#endif  // CAFFE_LEVELDB_DATASET_H_
similarity index 71%
rename from include/caffe/lmdb_database.hpp
rename to include/caffe/lmdb_dataset.hpp
index b5a02ca..57da5de 100644 (file)
@@ -1,5 +1,5 @@
-#ifndef CAFFE_LMDB_DATABASE_H_
-#define CAFFE_LMDB_DATABASE_H_
+#ifndef CAFFE_LMDB_DATASET_H_
+#define CAFFE_LMDB_DATASET_H_
 
 #include <string>
 #include <utility>
@@ -8,22 +8,22 @@
 #include "lmdb.h"
 
 #include "caffe/common.hpp"
-#include "caffe/database.hpp"
+#include "caffe/dataset.hpp"
 
 namespace caffe {
 
 template <typename K, typename V>
-class LmdbDatabase : public Database<K, V> {
+class LmdbDataset : public Dataset<K, V> {
  public:
-  typedef Database<K, V> Base;
+  typedef Dataset<K, V> Base;
   typedef typename Base::key_type key_type;
   typedef typename Base::value_type value_type;
-  typedef typename Base::DatabaseState DatabaseState;
+  typedef typename Base::DatasetState DatasetState;
   typedef typename Base::Mode Mode;
   typedef typename Base::const_iterator const_iterator;
   typedef typename Base::KV KV;
 
-  LmdbDatabase()
+  LmdbDataset()
       : env_(NULL),
         dbi_(0),
         txn_(NULL) { }
@@ -42,15 +42,15 @@ class LmdbDatabase : public Database<K, V> {
   const_iterator cend() const;
 
  protected:
-  class LmdbState : public DatabaseState {
+  class LmdbState : public DatasetState {
    public:
     explicit LmdbState(MDB_cursor* cursor, MDB_txn* txn, const MDB_dbi* dbi)
-        : DatabaseState(),
+        : DatasetState(),
           cursor_(cursor),
           txn_(txn),
           dbi_(dbi) { }
 
-    shared_ptr<DatabaseState> clone() {
+    shared_ptr<DatasetState> clone() {
       MDB_cursor* new_cursor;
 
       if (cursor_) {
@@ -67,7 +67,7 @@ class LmdbDatabase : public Database<K, V> {
         new_cursor = cursor_;
       }
 
-      return shared_ptr<DatabaseState>(new LmdbState(new_cursor, txn_, dbi_));
+      return shared_ptr<DatasetState>(new LmdbState(new_cursor, txn_, dbi_));
     }
 
     MDB_cursor* cursor_;
@@ -76,10 +76,10 @@ class LmdbDatabase : public Database<K, V> {
     KV kv_pair_;
   };
 
-  bool equal(shared_ptr<DatabaseState> state1,
-      shared_ptr<DatabaseState> state2) const;
-  void increment(shared_ptr<DatabaseState>* state) const;
-  KV& dereference(shared_ptr<DatabaseState> state) const;
+  bool equal(shared_ptr<DatasetState> state1,
+      shared_ptr<DatasetState> state2) const;
+  void increment(shared_ptr<DatasetState>* state) const;
+  KV& dereference(shared_ptr<DatasetState> state) const;
 
   MDB_env* env_;
   MDB_dbi dbi_;
@@ -88,4 +88,4 @@ class LmdbDatabase : public Database<K, V> {
 
 }  // namespace caffe
 
-#endif  // CAFFE_LMDB_DATABASE_H_
+#endif  // CAFFE_LMDB_DATASET_H_
diff --git a/src/caffe/database_factory.cpp b/src/caffe/database_factory.cpp
deleted file mode 100644 (file)
index 4ccd429..0000000
+++ /dev/null
@@ -1,50 +0,0 @@
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "caffe/database_factory.hpp"
-#include "caffe/leveldb_database.hpp"
-#include "caffe/lmdb_database.hpp"
-
-namespace caffe {
-
-template <typename K, typename V>
-shared_ptr<Database<K, V> > DatabaseFactory(const DataParameter_DB& type) {
-  switch (type) {
-  case DataParameter_DB_LEVELDB:
-    return shared_ptr<Database<K, V> >(new LeveldbDatabase<K, V>());
-  case DataParameter_DB_LMDB:
-    return shared_ptr<Database<K, V> >(new LmdbDatabase<K, V>());
-  default:
-    LOG(FATAL) << "Unknown database type " << type;
-    return shared_ptr<Database<K, V> >();
-  }
-}
-
-template <typename K, typename V>
-shared_ptr<Database<K, V> > DatabaseFactory(const string& type) {
-  if ("leveldb" == type) {
-    return DatabaseFactory<K, V>(DataParameter_DB_LEVELDB);
-  } else if ("lmdb" == type) {
-    return DatabaseFactory<K, V>(DataParameter_DB_LMDB);
-  } else {
-    LOG(FATAL) << "Unknown database type " << type;
-    return shared_ptr<Database<K, V> >();
-  }
-}
-
-#define REGISTER_DATABASE(key_type, value_type) \
-  template shared_ptr<Database<key_type, value_type> > \
-      DatabaseFactory(const string& type); \
-  template shared_ptr<Database<key_type, value_type> > \
-      DatabaseFactory(const DataParameter_DB& type); \
-
-REGISTER_DATABASE(string, string);
-REGISTER_DATABASE(string, vector<char>);
-REGISTER_DATABASE(string, Datum);
-
-#undef REGISTER_DATABASE
-
-}  // namespace caffe
-
-
diff --git a/src/caffe/dataset_factory.cpp b/src/caffe/dataset_factory.cpp
new file mode 100644 (file)
index 0000000..3313de3
--- /dev/null
@@ -0,0 +1,50 @@
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "caffe/dataset_factory.hpp"
+#include "caffe/leveldb_dataset.hpp"
+#include "caffe/lmdb_dataset.hpp"
+
+namespace caffe {
+
+template <typename K, typename V>
+shared_ptr<Dataset<K, V> > DatasetFactory(const DataParameter_DB& type) {
+  switch (type) {
+  case DataParameter_DB_LEVELDB:
+    return shared_ptr<Dataset<K, V> >(new LeveldbDataset<K, V>());
+  case DataParameter_DB_LMDB:
+    return shared_ptr<Dataset<K, V> >(new LmdbDataset<K, V>());
+  default:
+    LOG(FATAL) << "Unknown dataset type " << type;
+    return shared_ptr<Dataset<K, V> >();
+  }
+}
+
+template <typename K, typename V>
+shared_ptr<Dataset<K, V> > DatasetFactory(const string& type) {
+  if ("leveldb" == type) {
+    return DatasetFactory<K, V>(DataParameter_DB_LEVELDB);
+  } else if ("lmdb" == type) {
+    return DatasetFactory<K, V>(DataParameter_DB_LMDB);
+  } else {
+    LOG(FATAL) << "Unknown dataset type " << type;
+    return shared_ptr<Dataset<K, V> >();
+  }
+}
+
+#define REGISTER_DATASET(key_type, value_type) \
+  template shared_ptr<Dataset<key_type, value_type> > \
+      DatasetFactory(const string& type); \
+  template shared_ptr<Dataset<key_type, value_type> > \
+      DatasetFactory(const DataParameter_DB& type); \
+
+REGISTER_DATASET(string, string);
+REGISTER_DATASET(string, vector<char>);
+REGISTER_DATASET(string, Datum);
+
+#undef REGISTER_DATASET
+
+}  // namespace caffe
+
+
index 6296335..fcf9bb2 100644 (file)
@@ -5,7 +5,7 @@
 
 #include "caffe/common.hpp"
 #include "caffe/data_layers.hpp"
-#include "caffe/database_factory.hpp"
+#include "caffe/dataset_factory.hpp"
 #include "caffe/layer.hpp"
 #include "caffe/proto/caffe.pb.h"
 #include "caffe/util/io.hpp"
@@ -17,20 +17,20 @@ namespace caffe {
 template <typename Dtype>
 DataLayer<Dtype>::~DataLayer<Dtype>() {
   this->JoinPrefetchThread();
-  // clean up the database resources
-  database_->close();
+  // clean up the dataset resources
+  dataset_->close();
 }
 
 template <typename Dtype>
 void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top) {
   // Initialize DB
-  database_ = DatabaseFactory<string, Datum>(
+  dataset_ = DatasetFactory<string, Datum>(
       this->layer_param_.data_param().backend());
   const string& source = this->layer_param_.data_param().source();
-  LOG(INFO) << "Opening database " << source;
-  CHECK(database_->open(source, Database<string, Datum>::ReadOnly));
-  iter_ = database_->begin();
+  LOG(INFO) << "Opening dataset " << source;
+  CHECK(dataset_->open(source, Dataset<string, Datum>::ReadOnly));
+  iter_ = dataset_->begin();
 
   // Check if we would need to randomly skip a few data points
   if (this->layer_param_.data_param().rand_skip()) {
@@ -38,13 +38,13 @@ void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
                         this->layer_param_.data_param().rand_skip();
     LOG(INFO) << "Skipping first " << skip << " data points.";
     while (skip-- > 0) {
-      if (++iter_ == database_->end()) {
-        iter_ = database_->begin();
+      if (++iter_ == dataset_->end()) {
+        iter_ = dataset_->begin();
       }
     }
   }
   // Read a data point, and use it to initialize the top blob.
-  CHECK(iter_ != database_->end());
+  CHECK(iter_ != dataset_->end());
   const Datum& datum = iter_->value;
 
   // image
@@ -89,7 +89,7 @@ void DataLayer<Dtype>::InternalThreadEntry() {
   const int batch_size = this->layer_param_.data_param().batch_size();
 
   for (int item_id = 0; item_id < batch_size; ++item_id) {
-    CHECK(iter_ != database_->end());
+    CHECK(iter_ != dataset_->end());
     const Datum& datum = iter_->value;
 
     // Apply data transformations (mirror, scale, crop...)
@@ -102,8 +102,8 @@ void DataLayer<Dtype>::InternalThreadEntry() {
 
     // go to the next iter
     ++iter_;
-    if (iter_ == database_->end()) {
-      iter_ = database_->begin();
+    if (iter_ == dataset_->end()) {
+      iter_ = dataset_->begin();
     }
   }
 }
similarity index 78%
rename from src/caffe/leveldb_database.cpp
rename to src/caffe/leveldb_dataset.cpp
index 267ddf0..41c71e4 100644 (file)
@@ -3,12 +3,12 @@
 #include <vector>
 
 #include "caffe/caffe.hpp"
-#include "caffe/leveldb_database.hpp"
+#include "caffe/leveldb_dataset.hpp"
 
 namespace caffe {
 
 template <typename K, typename V>
-bool LeveldbDatabase<K, V>::open(const string& filename, Mode mode) {
+bool LeveldbDataset<K, V>::open(const string& filename, Mode mode) {
   DLOG(INFO) << "LevelDB: Open " << filename;
 
   leveldb::Options options;
@@ -55,11 +55,11 @@ bool LeveldbDatabase<K, V>::open(const string& filename, Mode mode) {
 }
 
 template <typename K, typename V>
-bool LeveldbDatabase<K, V>::put(const K& key, const V& value) {
+bool LeveldbDataset<K, V>::put(const K& key, const V& value) {
   DLOG(INFO) << "LevelDB: Put";
 
   if (read_only_) {
-    LOG(ERROR) << "put can not be used on a database in ReadOnly mode";
+    LOG(ERROR) << "put can not be used on a dataset in ReadOnly mode";
     return false;
   }
 
@@ -81,7 +81,7 @@ bool LeveldbDatabase<K, V>::put(const K& key, const V& value) {
 }
 
 template <typename K, typename V>
-bool LeveldbDatabase<K, V>::get(const K& key, V* value) {
+bool LeveldbDataset<K, V>::get(const K& key, V* value) {
   DLOG(INFO) << "LevelDB: Get";
 
   string serialized_key;
@@ -106,11 +106,11 @@ bool LeveldbDatabase<K, V>::get(const K& key, V* value) {
 }
 
 template <typename K, typename V>
-bool LeveldbDatabase<K, V>::commit() {
+bool LeveldbDataset<K, V>::commit() {
   DLOG(INFO) << "LevelDB: Commit";
 
   if (read_only_) {
-    LOG(ERROR) << "commit can not be used on a database in ReadOnly mode";
+    LOG(ERROR) << "commit can not be used on a dataset in ReadOnly mode";
     return false;
   }
 
@@ -125,7 +125,7 @@ bool LeveldbDatabase<K, V>::commit() {
 }
 
 template <typename K, typename V>
-void LeveldbDatabase<K, V>::close() {
+void LeveldbDataset<K, V>::close() {
   DLOG(INFO) << "LevelDB: Close";
 
   batch_.reset();
@@ -133,7 +133,7 @@ void LeveldbDatabase<K, V>::close() {
 }
 
 template <typename K, typename V>
-void LeveldbDatabase<K, V>::keys(vector<K>* keys) {
+void LeveldbDataset<K, V>::keys(vector<K>* keys) {
   DLOG(INFO) << "LevelDB: Keys";
 
   keys->clear();
@@ -143,8 +143,8 @@ void LeveldbDatabase<K, V>::keys(vector<K>* keys) {
 }
 
 template <typename K, typename V>
-typename LeveldbDatabase<K, V>::const_iterator
-    LeveldbDatabase<K, V>::begin() const {
+typename LeveldbDataset<K, V>::const_iterator
+    LeveldbDataset<K, V>::begin() const {
   CHECK_NOTNULL(db_.get());
   shared_ptr<leveldb::Iterator> iter(db_->NewIterator(leveldb::ReadOptions()));
   iter->SeekToFirst();
@@ -152,7 +152,7 @@ typename LeveldbDatabase<K, V>::const_iterator
     iter.reset();
   }
 
-  shared_ptr<DatabaseState> state;
+  shared_ptr<DatasetState> state;
   if (iter) {
     state.reset(new LeveldbState(db_, iter));
   }
@@ -160,25 +160,25 @@ typename LeveldbDatabase<K, V>::const_iterator
 }
 
 template <typename K, typename V>
-typename LeveldbDatabase<K, V>::const_iterator
-    LeveldbDatabase<K, V>::end() const {
-  shared_ptr<DatabaseState> state;
+typename LeveldbDataset<K, V>::const_iterator
+    LeveldbDataset<K, V>::end() const {
+  shared_ptr<DatasetState> state;
   return const_iterator(this, state);
 }
 
 template <typename K, typename V>
-typename LeveldbDatabase<K, V>::const_iterator
-    LeveldbDatabase<K, V>::cbegin() const {
+typename LeveldbDataset<K, V>::const_iterator
+    LeveldbDataset<K, V>::cbegin() const {
   return begin();
 }
 
 template <typename K, typename V>
-typename LeveldbDatabase<K, V>::const_iterator
-    LeveldbDatabase<K, V>::cend() const { return end(); }
+typename LeveldbDataset<K, V>::const_iterator
+    LeveldbDataset<K, V>::cend() const { return end(); }
 
 template <typename K, typename V>
-bool LeveldbDatabase<K, V>::equal(shared_ptr<DatabaseState> state1,
-    shared_ptr<DatabaseState> state2) const {
+bool LeveldbDataset<K, V>::equal(shared_ptr<DatasetState> state1,
+    shared_ptr<DatasetState> state2) const {
   shared_ptr<LeveldbState> leveldb_state1 =
       boost::dynamic_pointer_cast<LeveldbState>(state1);
 
@@ -192,7 +192,7 @@ bool LeveldbDatabase<K, V>::equal(shared_ptr<DatabaseState> state1,
 }
 
 template <typename K, typename V>
-void LeveldbDatabase<K, V>::increment(shared_ptr<DatabaseState>* state) const {
+void LeveldbDataset<K, V>::increment(shared_ptr<DatasetState>* state) const {
   shared_ptr<LeveldbState> leveldb_state =
       boost::dynamic_pointer_cast<LeveldbState>(*state);
 
@@ -210,8 +210,8 @@ void LeveldbDatabase<K, V>::increment(shared_ptr<DatabaseState>* state) const {
 }
 
 template <typename K, typename V>
-typename Database<K, V>::KV& LeveldbDatabase<K, V>::dereference(
-    shared_ptr<DatabaseState> state) const {
+typename Dataset<K, V>::KV& LeveldbDataset<K, V>::dereference(
+    shared_ptr<DatasetState> state) const {
   shared_ptr<LeveldbState> leveldb_state =
       boost::dynamic_pointer_cast<LeveldbState>(state);
 
@@ -233,6 +233,6 @@ typename Database<K, V>::KV& LeveldbDatabase<K, V>::dereference(
   return leveldb_state->kv_pair_;
 }
 
-INSTANTIATE_DATABASE(LeveldbDatabase);
+INSTANTIATE_DATASET(LeveldbDataset);
 
 }  // namespace caffe
similarity index 85%
rename from src/caffe/lmdb_database.cpp
rename to src/caffe/lmdb_dataset.cpp
index 22a08b2..d2b6fcd 100644 (file)
@@ -5,12 +5,12 @@
 #include <vector>
 
 #include "caffe/caffe.hpp"
-#include "caffe/lmdb_database.hpp"
+#include "caffe/lmdb_dataset.hpp"
 
 namespace caffe {
 
 template <typename K, typename V>
-bool LmdbDatabase<K, V>::open(const string& filename, Mode mode) {
+bool LmdbDataset<K, V>::open(const string& filename, Mode mode) {
   DLOG(INFO) << "LMDB: Open " << filename;
 
   CHECK(NULL == env_);
@@ -81,7 +81,7 @@ bool LmdbDatabase<K, V>::open(const string& filename, Mode mode) {
 }
 
 template <typename K, typename V>
-bool LmdbDatabase<K, V>::put(const K& key, const V& value) {
+bool LmdbDataset<K, V>::put(const K& key, const V& value) {
   DLOG(INFO) << "LMDB: Put";
 
   vector<char> serialized_key;
@@ -115,7 +115,7 @@ bool LmdbDatabase<K, V>::put(const K& key, const V& value) {
 }
 
 template <typename K, typename V>
-bool LmdbDatabase<K, V>::get(const K& key, V* value) {
+bool LmdbDataset<K, V>::get(const K& key, V* value) {
   DLOG(INFO) << "LMDB: Get";
 
   vector<char> serialized_key;
@@ -154,7 +154,7 @@ bool LmdbDatabase<K, V>::get(const K& key, V* value) {
 }
 
 template <typename K, typename V>
-bool LmdbDatabase<K, V>::commit() {
+bool LmdbDataset<K, V>::commit() {
   DLOG(INFO) << "LMDB: Commit";
 
   CHECK_NOTNULL(txn_);
@@ -176,7 +176,7 @@ bool LmdbDatabase<K, V>::commit() {
 }
 
 template <typename K, typename V>
-void LmdbDatabase<K, V>::close() {
+void LmdbDataset<K, V>::close() {
   DLOG(INFO) << "LMDB: Close";
 
   if (env_ && dbi_) {
@@ -189,7 +189,7 @@ void LmdbDatabase<K, V>::close() {
 }
 
 template <typename K, typename V>
-void LmdbDatabase<K, V>::keys(vector<K>* keys) {
+void LmdbDataset<K, V>::keys(vector<K>* keys) {
   DLOG(INFO) << "LMDB: Keys";
 
   keys->clear();
@@ -199,8 +199,8 @@ void LmdbDatabase<K, V>::keys(vector<K>* keys) {
 }
 
 template <typename K, typename V>
-typename LmdbDatabase<K, V>::const_iterator
-    LmdbDatabase<K, V>::begin() const {
+typename LmdbDataset<K, V>::const_iterator
+    LmdbDataset<K, V>::begin() const {
   int retval;
 
   MDB_txn* iter_txn;
@@ -219,7 +219,7 @@ typename LmdbDatabase<K, V>::const_iterator
   CHECK(MDB_SUCCESS == retval || MDB_NOTFOUND == retval)
       << mdb_strerror(retval);
 
-  shared_ptr<DatabaseState> state;
+  shared_ptr<DatasetState> state;
   if (MDB_SUCCESS == retval) {
     state.reset(new LmdbState(cursor, iter_txn, &dbi_));
   }
@@ -227,23 +227,23 @@ typename LmdbDatabase<K, V>::const_iterator
 }
 
 template <typename K, typename V>
-typename LmdbDatabase<K, V>::const_iterator
-    LmdbDatabase<K, V>::end() const {
-  shared_ptr<DatabaseState> state;
+typename LmdbDataset<K, V>::const_iterator
+    LmdbDataset<K, V>::end() const {
+  shared_ptr<DatasetState> state;
   return const_iterator(this, state);
 }
 
 template <typename K, typename V>
-typename LmdbDatabase<K, V>::const_iterator
-    LmdbDatabase<K, V>::cbegin() const { return begin(); }
+typename LmdbDataset<K, V>::const_iterator
+    LmdbDataset<K, V>::cbegin() const { return begin(); }
 
 template <typename K, typename V>
-typename LmdbDatabase<K, V>::const_iterator
-    LmdbDatabase<K, V>::cend() const { return end(); }
+typename LmdbDataset<K, V>::const_iterator
+    LmdbDataset<K, V>::cend() const { return end(); }
 
 template <typename K, typename V>
-bool LmdbDatabase<K, V>::equal(shared_ptr<DatabaseState> state1,
-    shared_ptr<DatabaseState> state2) const {
+bool LmdbDataset<K, V>::equal(shared_ptr<DatasetState> state1,
+    shared_ptr<DatasetState> state2) const {
   shared_ptr<LmdbState> lmdb_state1 =
       boost::dynamic_pointer_cast<LmdbState>(state1);
 
@@ -257,7 +257,7 @@ bool LmdbDatabase<K, V>::equal(shared_ptr<DatabaseState> state1,
 }
 
 template <typename K, typename V>
-void LmdbDatabase<K, V>::increment(shared_ptr<DatabaseState>* state) const {
+void LmdbDataset<K, V>::increment(shared_ptr<DatasetState>* state) const {
   shared_ptr<LmdbState> lmdb_state =
       boost::dynamic_pointer_cast<LmdbState>(*state);
 
@@ -279,8 +279,8 @@ void LmdbDatabase<K, V>::increment(shared_ptr<DatabaseState>* state) const {
 }
 
 template <typename K, typename V>
-typename Database<K, V>::KV& LmdbDatabase<K, V>::dereference(
-    shared_ptr<DatabaseState> state) const {
+typename Dataset<K, V>::KV& LmdbDataset<K, V>::dereference(
+    shared_ptr<DatasetState> state) const {
   shared_ptr<LmdbState> lmdb_state =
       boost::dynamic_pointer_cast<LmdbState>(state);
 
@@ -303,6 +303,6 @@ typename Database<K, V>::KV& LmdbDatabase<K, V>::dereference(
   return lmdb_state->kv_pair_;
 }
 
-INSTANTIATE_DATABASE(LmdbDatabase);
+INSTANTIATE_DATASET(LmdbDataset);
 
 }  // namespace caffe
index 1c3ec1f..32f5d41 100644 (file)
@@ -5,7 +5,7 @@
 
 #include "caffe/blob.hpp"
 #include "caffe/common.hpp"
-#include "caffe/database_factory.hpp"
+#include "caffe/dataset_factory.hpp"
 #include "caffe/filler.hpp"
 #include "caffe/proto/caffe.pb.h"
 #include "caffe/util/io.hpp"
@@ -38,10 +38,10 @@ class DataLayerTest : public MultiDeviceTest<TypeParam> {
   // an image are the same.
   void Fill(const bool unique_pixels, DataParameter_DB backend) {
     backend_ = backend;
-    LOG(INFO) << "Using temporary database " << *filename_;
-    shared_ptr<Database<string, Datum> > database =
-        DatabaseFactory<string, Datum>(backend_);
-    CHECK(database->open(*filename_, Database<string, Datum>::New));
+    LOG(INFO) << "Using temporary dataset " << *filename_;
+    shared_ptr<Dataset<string, Datum> > dataset =
+        DatasetFactory<string, Datum>(backend_);
+    CHECK(dataset->open(*filename_, Dataset<string, Datum>::New));
     for (int i = 0; i < 5; ++i) {
       Datum datum;
       datum.set_label(i);
@@ -55,10 +55,10 @@ class DataLayerTest : public MultiDeviceTest<TypeParam> {
       }
       stringstream ss;
       ss << i;
-      CHECK(database->put(ss.str(), datum));
+      CHECK(dataset->put(ss.str(), datum));
     }
-    CHECK(database->commit());
-    database->close();
+    CHECK(dataset->commit());
+    dataset->close();
   }
 
   void TestRead() {
@@ -183,7 +183,7 @@ class DataLayerTest : public MultiDeviceTest<TypeParam> {
         }
         crop_sequence.push_back(iter_crop_sequence);
       }
-    }  // destroy 1st data layer and unlock the database
+    }  // destroy 1st data layer and unlock the dataset
 
     // Get crop sequence after reseeding Caffe with 1701.
     // Check that the sequence is the same as the original.
@@ -238,7 +238,7 @@ class DataLayerTest : public MultiDeviceTest<TypeParam> {
         }
         crop_sequence.push_back(iter_crop_sequence);
       }
-    }  // destroy 1st data layer and unlock the database
+    }  // destroy 1st data layer and unlock the dataset
 
     // Get crop sequence continuing from previous Caffe RNG state; reseed
     // srand with 1701. Check that the sequence differs from the original.
diff --git a/src/caffe/test/test_database.cpp b/src/caffe/test/test_database.cpp
deleted file mode 100644 (file)
index 9a7d6de..0000000
+++ /dev/null
@@ -1,642 +0,0 @@
-#include <string>
-#include <vector>
-
-#include "caffe/util/io.hpp"
-
-#include "gtest/gtest.h"
-
-#include "caffe/database_factory.hpp"
-
-#include "caffe/test/test_caffe_main.hpp"
-
-namespace caffe {
-
-namespace DatabaseTest_internal {
-
-template <typename T>
-struct TestData {
-  static T TestValue();
-  static T TestAltValue();
-  static bool equals(const T& a, const T& b);
-};
-
-template <>
-string TestData<string>::TestValue() {
-  return "world";
-}
-
-template <>
-string TestData<string>::TestAltValue() {
-  return "bar";
-}
-
-template <>
-bool TestData<string>::equals(const string& a, const string& b) {
-  return a == b;
-}
-
-template <>
-vector<char> TestData<vector<char> >::TestValue() {
-  string str = "world";
-  vector<char> val(str.data(), str.data() + str.size());
-  return val;
-}
-
-template <>
-vector<char> TestData<vector<char> >::TestAltValue() {
-  string str = "bar";
-  vector<char> val(str.data(), str.data() + str.size());
-  return val;
-}
-
-template <>
-bool TestData<vector<char> >::equals(const vector<char>& a,
-    const vector<char>& b) {
-  if (a.size() != b.size()) {
-    return false;
-  }
-  for (size_t i = 0; i < a.size(); ++i) {
-    if (a.at(i) != b.at(i)) {
-      return false;
-    }
-  }
-
-  return true;
-}
-
-template <>
-Datum TestData<Datum>::TestValue() {
-  Datum datum;
-  datum.set_channels(3);
-  datum.set_height(32);
-  datum.set_width(32);
-  datum.set_data(string(32 * 32 * 3 * 4, ' '));
-  datum.set_label(0);
-  return datum;
-}
-
-template <>
-Datum TestData<Datum>::TestAltValue() {
-  Datum datum;
-  datum.set_channels(1);
-  datum.set_height(64);
-  datum.set_width(64);
-  datum.set_data(string(64 * 64 * 1 * 4, ' '));
-  datum.set_label(1);
-  return datum;
-}
-
-template <>
-bool TestData<Datum>::equals(const Datum& a, const Datum& b) {
-  string serialized_a;
-  a.SerializeToString(&serialized_a);
-
-  string serialized_b;
-  b.SerializeToString(&serialized_b);
-
-  return serialized_a == serialized_b;
-}
-
-}  // namespace DatabaseTest_internal
-
-#define UNPACK_TYPES \
-  typedef typename TypeParam::value_type value_type; \
-  const DataParameter_DB backend = TypeParam::backend;
-
-template <typename TypeParam>
-class DatabaseTest : public ::testing::Test {
- protected:
-  typedef typename TypeParam::value_type value_type;
-
-  string DBName() {
-    string filename;
-    MakeTempDir(&filename);
-    filename += "/db";
-    return filename;
-  }
-
-  string TestKey() {
-    return "hello";
-  }
-
-  value_type TestValue() {
-    return DatabaseTest_internal::TestData<value_type>::TestValue();
-  }
-
-  string TestAltKey() {
-    return "foo";
-  }
-
-  value_type TestAltValue() {
-    return DatabaseTest_internal::TestData<value_type>::TestAltValue();
-  }
-
-  template <typename T>
-  bool equals(const T& a, const T& b) {
-    return DatabaseTest_internal::TestData<T>::equals(a, b);
-  }
-};
-
-struct StringLeveldb {
-  typedef string value_type;
-  static const DataParameter_DB backend;
-};
-const DataParameter_DB StringLeveldb::backend = DataParameter_DB_LEVELDB;
-
-struct StringLmdb {
-  typedef string value_type;
-  static const DataParameter_DB backend;
-};
-const DataParameter_DB StringLmdb::backend = DataParameter_DB_LEVELDB;
-
-struct VectorLeveldb {
-  typedef vector<char> value_type;
-  static const DataParameter_DB backend;
-};
-const DataParameter_DB VectorLeveldb::backend = DataParameter_DB_LEVELDB;
-
-struct VectorLmdb {
-  typedef vector<char> value_type;
-  static const DataParameter_DB backend;
-};
-const DataParameter_DB VectorLmdb::backend = DataParameter_DB_LEVELDB;
-
-struct DatumLeveldb {
-  typedef Datum value_type;
-  static const DataParameter_DB backend;
-};
-const DataParameter_DB DatumLeveldb::backend = DataParameter_DB_LEVELDB;
-
-struct DatumLmdb {
-  typedef Datum value_type;
-  static const DataParameter_DB backend;
-};
-const DataParameter_DB DatumLmdb::backend = DataParameter_DB_LEVELDB;
-
-typedef ::testing::Types<StringLeveldb, StringLmdb, VectorLeveldb, VectorLmdb,
-    DatumLeveldb, DatumLmdb> TestTypes;
-
-TYPED_TEST_CASE(DatabaseTest, TestTypes);
-
-TYPED_TEST(DatabaseTest, TestNewDoesntExistPasses) {
-  UNPACK_TYPES;
-
-  shared_ptr<Database<string, value_type> > database =
-      DatabaseFactory<string, value_type>(backend);
-  EXPECT_TRUE(database->open(this->DBName(),
-      Database<string, value_type>::New));
-  database->close();
-}
-
-TYPED_TEST(DatabaseTest, TestNewExistsFails) {
-  UNPACK_TYPES;
-
-  string name = this->DBName();
-  shared_ptr<Database<string, value_type> > database =
-      DatabaseFactory<string, value_type>(backend);
-  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
-  database->close();
-
-  EXPECT_FALSE(database->open(name, Database<string, value_type>::New));
-}
-
-TYPED_TEST(DatabaseTest, TestReadOnlyExistsPasses) {
-  UNPACK_TYPES;
-
-  string name = this->DBName();
-  shared_ptr<Database<string, value_type> > database =
-      DatabaseFactory<string, value_type>(backend);
-  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
-  database->close();
-
-  EXPECT_TRUE(database->open(name, Database<string, value_type>::ReadOnly));
-  database->close();
-}
-
-TYPED_TEST(DatabaseTest, TestReadOnlyDoesntExistFails) {
-  UNPACK_TYPES;
-
-  string name = this->DBName();
-  shared_ptr<Database<string, value_type> > database =
-      DatabaseFactory<string, value_type>(backend);
-  EXPECT_FALSE(database->open(name, Database<string, value_type>::ReadOnly));
-}
-
-TYPED_TEST(DatabaseTest, TestReadWriteExistsPasses) {
-  UNPACK_TYPES;
-
-  string name = this->DBName();
-  shared_ptr<Database<string, value_type> > database =
-      DatabaseFactory<string, value_type>(backend);
-  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
-  database->close();
-
-  EXPECT_TRUE(database->open(name, Database<string, value_type>::ReadWrite));
-  database->close();
-}
-
-TYPED_TEST(DatabaseTest, TestReadWriteDoesntExistPasses) {
-  UNPACK_TYPES;
-
-  string name = this->DBName();
-  shared_ptr<Database<string, value_type> > database =
-      DatabaseFactory<string, value_type>(backend);
-  EXPECT_TRUE(database->open(name, Database<string, value_type>::ReadWrite));
-  database->close();
-}
-
-TYPED_TEST(DatabaseTest, TestKeys) {
-  UNPACK_TYPES;
-
-  string name = this->DBName();
-  shared_ptr<Database<string, value_type> > database =
-      DatabaseFactory<string, value_type>(backend);
-  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
-
-  string key1 = this->TestKey();
-  value_type value1 = this->TestValue();
-
-  EXPECT_TRUE(database->put(key1, value1));
-
-  string key2 = this->TestAltKey();
-  value_type value2 = this->TestAltValue();
-
-  EXPECT_TRUE(database->put(key2, value2));
-
-  EXPECT_TRUE(database->commit());
-
-  vector<string> keys;
-  database->keys(&keys);
-
-  EXPECT_EQ(2, keys.size());
-
-  EXPECT_TRUE(this->equals(keys.at(0), key1) ||
-      this->equals(keys.at(0), key2));
-  EXPECT_TRUE(this->equals(keys.at(1), key1) ||
-      this->equals(keys.at(2), key2));
-  EXPECT_FALSE(this->equals(keys.at(0), keys.at(1)));
-}
-
-TYPED_TEST(DatabaseTest, TestKeysNoCommit) {
-  UNPACK_TYPES;
-
-  string name = this->DBName();
-  shared_ptr<Database<string, value_type> > database =
-      DatabaseFactory<string, value_type>(backend);
-  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
-
-  string key1 = this->TestKey();
-  value_type value1 = this->TestValue();
-
-  EXPECT_TRUE(database->put(key1, value1));
-
-  string key2 = this->TestAltKey();
-  value_type value2 = this->TestAltValue();
-
-  EXPECT_TRUE(database->put(key2, value2));
-
-  vector<string> keys;
-  database->keys(&keys);
-
-  EXPECT_EQ(0, keys.size());
-}
-
-TYPED_TEST(DatabaseTest, TestIterators) {
-  UNPACK_TYPES;
-
-  string name = this->DBName();
-  shared_ptr<Database<string, value_type> > database =
-      DatabaseFactory<string, value_type>(backend);
-  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
-
-  const int kNumExamples = 4;
-  for (int i = 0; i < kNumExamples; ++i) {
-    stringstream ss;
-    ss << i;
-    string key = ss.str();
-    ss << " here be data";
-    value_type value = this->TestValue();
-    EXPECT_TRUE(database->put(key, value));
-  }
-  EXPECT_TRUE(database->commit());
-
-  int count = 0;
-  typedef typename Database<string, value_type>::const_iterator Iter;
-  for (Iter iter = database->begin(); iter != database->end(); ++iter) {
-    (void)iter;
-    ++count;
-  }
-
-  EXPECT_EQ(kNumExamples, count);
-}
-
-TYPED_TEST(DatabaseTest, TestIteratorsPreIncrement) {
-  UNPACK_TYPES;
-
-  string name = this->DBName();
-  shared_ptr<Database<string, value_type> > database =
-      DatabaseFactory<string, value_type>(backend);
-  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
-
-  string key1 = this->TestAltKey();
-  value_type value1 = this->TestAltValue();
-
-  string key2 = this->TestKey();
-  value_type value2 = this->TestValue();
-
-  EXPECT_TRUE(database->put(key1, value1));
-  EXPECT_TRUE(database->put(key2, value2));
-  EXPECT_TRUE(database->commit());
-
-  typename Database<string, value_type>::const_iterator iter1 =
-      database->begin();
-
-  EXPECT_FALSE(database->end() == iter1);
-
-  EXPECT_TRUE(this->equals(iter1->key, key1));
-
-  typename Database<string, value_type>::const_iterator iter2 = ++iter1;
-
-  EXPECT_FALSE(database->end() == iter1);
-  EXPECT_FALSE(database->end() == iter2);
-
-  EXPECT_TRUE(this->equals(iter2->key, key2));
-
-  typename Database<string, value_type>::const_iterator iter3 = ++iter2;
-
-  EXPECT_TRUE(database->end() == iter3);
-
-  database->close();
-}
-
-TYPED_TEST(DatabaseTest, TestIteratorsPostIncrement) {
-  UNPACK_TYPES;
-
-  string name = this->DBName();
-  shared_ptr<Database<string, value_type> > database =
-      DatabaseFactory<string, value_type>(backend);
-  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
-
-  string key1 = this->TestAltKey();
-  value_type value1 = this->TestAltValue();
-
-  string key2 = this->TestKey();
-  value_type value2 = this->TestValue();
-
-  EXPECT_TRUE(database->put(key1, value1));
-  EXPECT_TRUE(database->put(key2, value2));
-  EXPECT_TRUE(database->commit());
-
-  typename Database<string, value_type>::const_iterator iter1 =
-      database->begin();
-
-  EXPECT_FALSE(database->end() == iter1);
-
-  EXPECT_TRUE(this->equals(iter1->key, key1));
-
-  typename Database<string, value_type>::const_iterator iter2 = iter1++;
-
-  EXPECT_FALSE(database->end() == iter1);
-  EXPECT_FALSE(database->end() == iter2);
-
-  EXPECT_TRUE(this->equals(iter2->key, key1));
-  EXPECT_TRUE(this->equals(iter1->key, key2));
-
-  typename Database<string, value_type>::const_iterator iter3 = iter1++;
-
-  EXPECT_FALSE(database->end() == iter3);
-  EXPECT_TRUE(this->equals(iter3->key, key2));
-  EXPECT_TRUE(database->end() == iter1);
-
-  database->close();
-}
-
-TYPED_TEST(DatabaseTest, TestNewPutPasses) {
-  UNPACK_TYPES;
-
-  string name = this->DBName();
-  shared_ptr<Database<string, value_type> > database =
-      DatabaseFactory<string, value_type>(backend);
-  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
-
-  string key = this->TestKey();
-  value_type value = this->TestValue();
-
-  EXPECT_TRUE(database->put(key, value));
-
-  EXPECT_TRUE(database->commit());
-
-  database->close();
-}
-
-TYPED_TEST(DatabaseTest, TestNewCommitPasses) {
-  UNPACK_TYPES;
-
-  string name = this->DBName();
-  shared_ptr<Database<string, value_type> > database =
-      DatabaseFactory<string, value_type>(backend);
-  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
-
-  EXPECT_TRUE(database->commit());
-
-  database->close();
-}
-
-TYPED_TEST(DatabaseTest, TestNewGetPasses) {
-  UNPACK_TYPES;
-
-  string name = this->DBName();
-  shared_ptr<Database<string, value_type> > database =
-      DatabaseFactory<string, value_type>(backend);
-  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
-
-  string key = this->TestKey();
-  value_type value = this->TestValue();
-
-  EXPECT_TRUE(database->put(key, value));
-
-  EXPECT_TRUE(database->commit());
-
-  value_type new_value;
-
-  EXPECT_TRUE(database->get(key, &new_value));
-
-  EXPECT_TRUE(this->equals(value, new_value));
-
-  database->close();
-}
-
-TYPED_TEST(DatabaseTest, TestNewGetNoCommitFails) {
-  UNPACK_TYPES;
-
-  string name = this->DBName();
-  shared_ptr<Database<string, value_type> > database =
-      DatabaseFactory<string, value_type>(backend);
-  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
-
-  string key = this->TestKey();
-  value_type value = this->TestValue();
-
-  EXPECT_TRUE(database->put(key, value));
-
-  value_type new_value;
-
-  EXPECT_FALSE(database->get(key, &new_value));
-}
-
-
-TYPED_TEST(DatabaseTest, TestReadWritePutPasses) {
-  UNPACK_TYPES;
-
-  string name = this->DBName();
-  shared_ptr<Database<string, value_type> > database =
-      DatabaseFactory<string, value_type>(backend);
-  EXPECT_TRUE(database->open(name, Database<string, value_type>::ReadWrite));
-
-  string key = this->TestKey();
-  value_type value = this->TestValue();
-
-  EXPECT_TRUE(database->put(key, value));
-
-  EXPECT_TRUE(database->commit());
-
-  database->close();
-}
-
-TYPED_TEST(DatabaseTest, TestReadWriteCommitPasses) {
-  UNPACK_TYPES;
-
-  string name = this->DBName();
-  shared_ptr<Database<string, value_type> > database =
-      DatabaseFactory<string, value_type>(backend);
-  EXPECT_TRUE(database->open(name, Database<string, value_type>::ReadWrite));
-
-  EXPECT_TRUE(database->commit());
-
-  database->close();
-}
-
-TYPED_TEST(DatabaseTest, TestReadWriteGetPasses) {
-  UNPACK_TYPES;
-
-  string name = this->DBName();
-  shared_ptr<Database<string, value_type> > database =
-      DatabaseFactory<string, value_type>(backend);
-  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
-
-  string key = this->TestKey();
-  value_type value = this->TestValue();
-
-  EXPECT_TRUE(database->put(key, value));
-
-  EXPECT_TRUE(database->commit());
-
-  value_type new_value;
-
-  EXPECT_TRUE(database->get(key, &new_value));
-
-  EXPECT_TRUE(this->equals(value, new_value));
-
-  database->close();
-}
-
-TYPED_TEST(DatabaseTest, TestReadWriteGetNoCommitFails) {
-  UNPACK_TYPES;
-
-  string name = this->DBName();
-  shared_ptr<Database<string, value_type> > database =
-      DatabaseFactory<string, value_type>(backend);
-  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
-
-  string key = this->TestKey();
-  value_type value = this->TestValue();
-
-  EXPECT_TRUE(database->put(key, value));
-
-  value_type new_value;
-
-  EXPECT_FALSE(database->get(key, &new_value));
-}
-
-TYPED_TEST(DatabaseTest, TestReadOnlyPutFails) {
-  UNPACK_TYPES;
-
-  string name = this->DBName();
-  shared_ptr<Database<string, value_type> > database =
-      DatabaseFactory<string, value_type>(backend);
-  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
-  database->close();
-
-  EXPECT_TRUE(database->open(name, Database<string, value_type>::ReadOnly));
-
-  string key = this->TestKey();
-  value_type value = this->TestValue();
-
-  EXPECT_FALSE(database->put(key, value));
-}
-
-TYPED_TEST(DatabaseTest, TestReadOnlyCommitFails) {
-  UNPACK_TYPES;
-
-  string name = this->DBName();
-  shared_ptr<Database<string, value_type> > database =
-      DatabaseFactory<string, value_type>(backend);
-  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
-  database->close();
-
-  EXPECT_TRUE(database->open(name, Database<string, value_type>::ReadOnly));
-
-  EXPECT_FALSE(database->commit());
-}
-
-TYPED_TEST(DatabaseTest, TestReadOnlyGetPasses) {
-  UNPACK_TYPES;
-
-  string name = this->DBName();
-  shared_ptr<Database<string, value_type> > database =
-      DatabaseFactory<string, value_type>(backend);
-  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
-
-  string key = this->TestKey();
-  value_type value = this->TestValue();
-
-  EXPECT_TRUE(database->put(key, value));
-
-  EXPECT_TRUE(database->commit());
-
-  database->close();
-
-  EXPECT_TRUE(database->open(name, Database<string, value_type>::ReadOnly));
-
-  value_type new_value;
-
-  EXPECT_TRUE(database->get(key, &new_value));
-
-  EXPECT_TRUE(this->equals(value, new_value));
-}
-
-TYPED_TEST(DatabaseTest, TestReadOnlyGetNoCommitFails) {
-  UNPACK_TYPES;
-
-  string name = this->DBName();
-  shared_ptr<Database<string, value_type> > database =
-      DatabaseFactory<string, value_type>(backend);
-  EXPECT_TRUE(database->open(name, Database<string, value_type>::New));
-
-  string key = this->TestKey();
-  value_type value = this->TestValue();
-
-  EXPECT_TRUE(database->put(key, value));
-
-  database->close();
-
-  EXPECT_TRUE(database->open(name, Database<string, value_type>::ReadOnly));
-
-  value_type new_value;
-
-  EXPECT_FALSE(database->get(key, &new_value));
-}
-
-#undef UNPACK_TYPES
-
-}  // namespace caffe
diff --git a/src/caffe/test/test_dataset.cpp b/src/caffe/test/test_dataset.cpp
new file mode 100644 (file)
index 0000000..bb6cf4c
--- /dev/null
@@ -0,0 +1,642 @@
+#include <string>
+#include <vector>
+
+#include "caffe/util/io.hpp"
+
+#include "gtest/gtest.h"
+
+#include "caffe/dataset_factory.hpp"
+
+#include "caffe/test/test_caffe_main.hpp"
+
+namespace caffe {
+
+namespace DatasetTest_internal {
+
+template <typename T>
+struct TestData {
+  static T TestValue();
+  static T TestAltValue();
+  static bool equals(const T& a, const T& b);
+};
+
+template <>
+string TestData<string>::TestValue() {
+  return "world";
+}
+
+template <>
+string TestData<string>::TestAltValue() {
+  return "bar";
+}
+
+template <>
+bool TestData<string>::equals(const string& a, const string& b) {
+  return a == b;
+}
+
+template <>
+vector<char> TestData<vector<char> >::TestValue() {
+  string str = "world";
+  vector<char> val(str.data(), str.data() + str.size());
+  return val;
+}
+
+template <>
+vector<char> TestData<vector<char> >::TestAltValue() {
+  string str = "bar";
+  vector<char> val(str.data(), str.data() + str.size());
+  return val;
+}
+
+template <>
+bool TestData<vector<char> >::equals(const vector<char>& a,
+    const vector<char>& b) {
+  if (a.size() != b.size()) {
+    return false;
+  }
+  for (size_t i = 0; i < a.size(); ++i) {
+    if (a.at(i) != b.at(i)) {
+      return false;
+    }
+  }
+
+  return true;
+}
+
+template <>
+Datum TestData<Datum>::TestValue() {
+  Datum datum;
+  datum.set_channels(3);
+  datum.set_height(32);
+  datum.set_width(32);
+  datum.set_data(string(32 * 32 * 3 * 4, ' '));
+  datum.set_label(0);
+  return datum;
+}
+
+template <>
+Datum TestData<Datum>::TestAltValue() {
+  Datum datum;
+  datum.set_channels(1);
+  datum.set_height(64);
+  datum.set_width(64);
+  datum.set_data(string(64 * 64 * 1 * 4, ' '));
+  datum.set_label(1);
+  return datum;
+}
+
+template <>
+bool TestData<Datum>::equals(const Datum& a, const Datum& b) {
+  string serialized_a;
+  a.SerializeToString(&serialized_a);
+
+  string serialized_b;
+  b.SerializeToString(&serialized_b);
+
+  return serialized_a == serialized_b;
+}
+
+}  // namespace DatasetTest_internal
+
+#define UNPACK_TYPES \
+  typedef typename TypeParam::value_type value_type; \
+  const DataParameter_DB backend = TypeParam::backend;
+
+template <typename TypeParam>
+class DatasetTest : public ::testing::Test {
+ protected:
+  typedef typename TypeParam::value_type value_type;
+
+  string DBName() {
+    string filename;
+    MakeTempDir(&filename);
+    filename += "/db";
+    return filename;
+  }
+
+  string TestKey() {
+    return "hello";
+  }
+
+  value_type TestValue() {
+    return DatasetTest_internal::TestData<value_type>::TestValue();
+  }
+
+  string TestAltKey() {
+    return "foo";
+  }
+
+  value_type TestAltValue() {
+    return DatasetTest_internal::TestData<value_type>::TestAltValue();
+  }
+
+  template <typename T>
+  bool equals(const T& a, const T& b) {
+    return DatasetTest_internal::TestData<T>::equals(a, b);
+  }
+};
+
+struct StringLeveldb {
+  typedef string value_type;
+  static const DataParameter_DB backend;
+};
+const DataParameter_DB StringLeveldb::backend = DataParameter_DB_LEVELDB;
+
+struct StringLmdb {
+  typedef string value_type;
+  static const DataParameter_DB backend;
+};
+const DataParameter_DB StringLmdb::backend = DataParameter_DB_LEVELDB;
+
+struct VectorLeveldb {
+  typedef vector<char> value_type;
+  static const DataParameter_DB backend;
+};
+const DataParameter_DB VectorLeveldb::backend = DataParameter_DB_LEVELDB;
+
+struct VectorLmdb {
+  typedef vector<char> value_type;
+  static const DataParameter_DB backend;
+};
+const DataParameter_DB VectorLmdb::backend = DataParameter_DB_LEVELDB;
+
+struct DatumLeveldb {
+  typedef Datum value_type;
+  static const DataParameter_DB backend;
+};
+const DataParameter_DB DatumLeveldb::backend = DataParameter_DB_LEVELDB;
+
+struct DatumLmdb {
+  typedef Datum value_type;
+  static const DataParameter_DB backend;
+};
+const DataParameter_DB DatumLmdb::backend = DataParameter_DB_LEVELDB;
+
+typedef ::testing::Types<StringLeveldb, StringLmdb, VectorLeveldb, VectorLmdb,
+    DatumLeveldb, DatumLmdb> TestTypes;
+
+TYPED_TEST_CASE(DatasetTest, TestTypes);
+
+TYPED_TEST(DatasetTest, TestNewDoesntExistPasses) {
+  UNPACK_TYPES;
+
+  shared_ptr<Dataset<string, value_type> > dataset =
+      DatasetFactory<string, value_type>(backend);
+  EXPECT_TRUE(dataset->open(this->DBName(),
+      Dataset<string, value_type>::New));
+  dataset->close();
+}
+
+TYPED_TEST(DatasetTest, TestNewExistsFails) {
+  UNPACK_TYPES;
+
+  string name = this->DBName();
+  shared_ptr<Dataset<string, value_type> > dataset =
+      DatasetFactory<string, value_type>(backend);
+  EXPECT_TRUE(dataset->open(name, Dataset<string, value_type>::New));
+  dataset->close();
+
+  EXPECT_FALSE(dataset->open(name, Dataset<string, value_type>::New));
+}
+
+TYPED_TEST(DatasetTest, TestReadOnlyExistsPasses) {
+  UNPACK_TYPES;
+
+  string name = this->DBName();
+  shared_ptr<Dataset<string, value_type> > dataset =
+      DatasetFactory<string, value_type>(backend);
+  EXPECT_TRUE(dataset->open(name, Dataset<string, value_type>::New));
+  dataset->close();
+
+  EXPECT_TRUE(dataset->open(name, Dataset<string, value_type>::ReadOnly));
+  dataset->close();
+}
+
+TYPED_TEST(DatasetTest, TestReadOnlyDoesntExistFails) {
+  UNPACK_TYPES;
+
+  string name = this->DBName();
+  shared_ptr<Dataset<string, value_type> > dataset =
+      DatasetFactory<string, value_type>(backend);
+  EXPECT_FALSE(dataset->open(name, Dataset<string, value_type>::ReadOnly));
+}
+
+TYPED_TEST(DatasetTest, TestReadWriteExistsPasses) {
+  UNPACK_TYPES;
+
+  string name = this->DBName();
+  shared_ptr<Dataset<string, value_type> > dataset =
+      DatasetFactory<string, value_type>(backend);
+  EXPECT_TRUE(dataset->open(name, Dataset<string, value_type>::New));
+  dataset->close();
+
+  EXPECT_TRUE(dataset->open(name, Dataset<string, value_type>::ReadWrite));
+  dataset->close();
+}
+
+TYPED_TEST(DatasetTest, TestReadWriteDoesntExistPasses) {
+  UNPACK_TYPES;
+
+  string name = this->DBName();
+  shared_ptr<Dataset<string, value_type> > dataset =
+      DatasetFactory<string, value_type>(backend);
+  EXPECT_TRUE(dataset->open(name, Dataset<string, value_type>::ReadWrite));
+  dataset->close();
+}
+
+TYPED_TEST(DatasetTest, TestKeys) {
+  UNPACK_TYPES;
+
+  string name = this->DBName();
+  shared_ptr<Dataset<string, value_type> > dataset =
+      DatasetFactory<string, value_type>(backend);
+  EXPECT_TRUE(dataset->open(name, Dataset<string, value_type>::New));
+
+  string key1 = this->TestKey();
+  value_type value1 = this->TestValue();
+
+  EXPECT_TRUE(dataset->put(key1, value1));
+
+  string key2 = this->TestAltKey();
+  value_type value2 = this->TestAltValue();
+
+  EXPECT_TRUE(dataset->put(key2, value2));
+
+  EXPECT_TRUE(dataset->commit());
+
+  vector<string> keys;
+  dataset->keys(&keys);
+
+  EXPECT_EQ(2, keys.size());
+
+  EXPECT_TRUE(this->equals(keys.at(0), key1) ||
+      this->equals(keys.at(0), key2));
+  EXPECT_TRUE(this->equals(keys.at(1), key1) ||
+      this->equals(keys.at(2), key2));
+  EXPECT_FALSE(this->equals(keys.at(0), keys.at(1)));
+}
+
+TYPED_TEST(DatasetTest, TestKeysNoCommit) {
+  UNPACK_TYPES;
+
+  string name = this->DBName();
+  shared_ptr<Dataset<string, value_type> > dataset =
+      DatasetFactory<string, value_type>(backend);
+  EXPECT_TRUE(dataset->open(name, Dataset<string, value_type>::New));
+
+  string key1 = this->TestKey();
+  value_type value1 = this->TestValue();
+
+  EXPECT_TRUE(dataset->put(key1, value1));
+
+  string key2 = this->TestAltKey();
+  value_type value2 = this->TestAltValue();
+
+  EXPECT_TRUE(dataset->put(key2, value2));
+
+  vector<string> keys;
+  dataset->keys(&keys);
+
+  EXPECT_EQ(0, keys.size());
+}
+
+TYPED_TEST(DatasetTest, TestIterators) {
+  UNPACK_TYPES;
+
+  string name = this->DBName();
+  shared_ptr<Dataset<string, value_type> > dataset =
+      DatasetFactory<string, value_type>(backend);
+  EXPECT_TRUE(dataset->open(name, Dataset<string, value_type>::New));
+
+  const int kNumExamples = 4;
+  for (int i = 0; i < kNumExamples; ++i) {
+    stringstream ss;
+    ss << i;
+    string key = ss.str();
+    ss << " here be data";
+    value_type value = this->TestValue();
+    EXPECT_TRUE(dataset->put(key, value));
+  }
+  EXPECT_TRUE(dataset->commit());
+
+  int count = 0;
+  typedef typename Dataset<string, value_type>::const_iterator Iter;
+  for (Iter iter = dataset->begin(); iter != dataset->end(); ++iter) {
+    (void)iter;
+    ++count;
+  }
+
+  EXPECT_EQ(kNumExamples, count);
+}
+
+TYPED_TEST(DatasetTest, TestIteratorsPreIncrement) {
+  UNPACK_TYPES;
+
+  string name = this->DBName();
+  shared_ptr<Dataset<string, value_type> > dataset =
+      DatasetFactory<string, value_type>(backend);
+  EXPECT_TRUE(dataset->open(name, Dataset<string, value_type>::New));
+
+  string key1 = this->TestAltKey();
+  value_type value1 = this->TestAltValue();
+
+  string key2 = this->TestKey();
+  value_type value2 = this->TestValue();
+
+  EXPECT_TRUE(dataset->put(key1, value1));
+  EXPECT_TRUE(dataset->put(key2, value2));
+  EXPECT_TRUE(dataset->commit());
+
+  typename Dataset<string, value_type>::const_iterator iter1 =
+      dataset->begin();
+
+  EXPECT_FALSE(dataset->end() == iter1);
+
+  EXPECT_TRUE(this->equals(iter1->key, key1));
+
+  typename Dataset<string, value_type>::const_iterator iter2 = ++iter1;
+
+  EXPECT_FALSE(dataset->end() == iter1);
+  EXPECT_FALSE(dataset->end() == iter2);
+
+  EXPECT_TRUE(this->equals(iter2->key, key2));
+
+  typename Dataset<string, value_type>::const_iterator iter3 = ++iter2;
+
+  EXPECT_TRUE(dataset->end() == iter3);
+
+  dataset->close();
+}
+
+TYPED_TEST(DatasetTest, TestIteratorsPostIncrement) {
+  UNPACK_TYPES;
+
+  string name = this->DBName();
+  shared_ptr<Dataset<string, value_type> > dataset =
+      DatasetFactory<string, value_type>(backend);
+  EXPECT_TRUE(dataset->open(name, Dataset<string, value_type>::New));
+
+  string key1 = this->TestAltKey();
+  value_type value1 = this->TestAltValue();
+
+  string key2 = this->TestKey();
+  value_type value2 = this->TestValue();
+
+  EXPECT_TRUE(dataset->put(key1, value1));
+  EXPECT_TRUE(dataset->put(key2, value2));
+  EXPECT_TRUE(dataset->commit());
+
+  typename Dataset<string, value_type>::const_iterator iter1 =
+      dataset->begin();
+
+  EXPECT_FALSE(dataset->end() == iter1);
+
+  EXPECT_TRUE(this->equals(iter1->key, key1));
+
+  typename Dataset<string, value_type>::const_iterator iter2 = iter1++;
+
+  EXPECT_FALSE(dataset->end() == iter1);
+  EXPECT_FALSE(dataset->end() == iter2);
+
+  EXPECT_TRUE(this->equals(iter2->key, key1));
+  EXPECT_TRUE(this->equals(iter1->key, key2));
+
+  typename Dataset<string, value_type>::const_iterator iter3 = iter1++;
+
+  EXPECT_FALSE(dataset->end() == iter3);
+  EXPECT_TRUE(this->equals(iter3->key, key2));
+  EXPECT_TRUE(dataset->end() == iter1);
+
+  dataset->close();
+}
+
+TYPED_TEST(DatasetTest, TestNewPutPasses) {
+  UNPACK_TYPES;
+
+  string name = this->DBName();
+  shared_ptr<Dataset<string, value_type> > dataset =
+      DatasetFactory<string, value_type>(backend);
+  EXPECT_TRUE(dataset->open(name, Dataset<string, value_type>::New));
+
+  string key = this->TestKey();
+  value_type value = this->TestValue();
+
+  EXPECT_TRUE(dataset->put(key, value));
+
+  EXPECT_TRUE(dataset->commit());
+
+  dataset->close();
+}
+
+TYPED_TEST(DatasetTest, TestNewCommitPasses) {
+  UNPACK_TYPES;
+
+  string name = this->DBName();
+  shared_ptr<Dataset<string, value_type> > dataset =
+      DatasetFactory<string, value_type>(backend);
+  EXPECT_TRUE(dataset->open(name, Dataset<string, value_type>::New));
+
+  EXPECT_TRUE(dataset->commit());
+
+  dataset->close();
+}
+
+TYPED_TEST(DatasetTest, TestNewGetPasses) {
+  UNPACK_TYPES;
+
+  string name = this->DBName();
+  shared_ptr<Dataset<string, value_type> > dataset =
+      DatasetFactory<string, value_type>(backend);
+  EXPECT_TRUE(dataset->open(name, Dataset<string, value_type>::New));
+
+  string key = this->TestKey();
+  value_type value = this->TestValue();
+
+  EXPECT_TRUE(dataset->put(key, value));
+
+  EXPECT_TRUE(dataset->commit());
+
+  value_type new_value;
+
+  EXPECT_TRUE(dataset->get(key, &new_value));
+
+  EXPECT_TRUE(this->equals(value, new_value));
+
+  dataset->close();
+}
+
+TYPED_TEST(DatasetTest, TestNewGetNoCommitFails) {
+  UNPACK_TYPES;
+
+  string name = this->DBName();
+  shared_ptr<Dataset<string, value_type> > dataset =
+      DatasetFactory<string, value_type>(backend);
+  EXPECT_TRUE(dataset->open(name, Dataset<string, value_type>::New));
+
+  string key = this->TestKey();
+  value_type value = this->TestValue();
+
+  EXPECT_TRUE(dataset->put(key, value));
+
+  value_type new_value;
+
+  EXPECT_FALSE(dataset->get(key, &new_value));
+}
+
+
+TYPED_TEST(DatasetTest, TestReadWritePutPasses) {
+  UNPACK_TYPES;
+
+  string name = this->DBName();
+  shared_ptr<Dataset<string, value_type> > dataset =
+      DatasetFactory<string, value_type>(backend);
+  EXPECT_TRUE(dataset->open(name, Dataset<string, value_type>::ReadWrite));
+
+  string key = this->TestKey();
+  value_type value = this->TestValue();
+
+  EXPECT_TRUE(dataset->put(key, value));
+
+  EXPECT_TRUE(dataset->commit());
+
+  dataset->close();
+}
+
+TYPED_TEST(DatasetTest, TestReadWriteCommitPasses) {
+  UNPACK_TYPES;
+
+  string name = this->DBName();
+  shared_ptr<Dataset<string, value_type> > dataset =
+      DatasetFactory<string, value_type>(backend);
+  EXPECT_TRUE(dataset->open(name, Dataset<string, value_type>::ReadWrite));
+
+  EXPECT_TRUE(dataset->commit());
+
+  dataset->close();
+}
+
+TYPED_TEST(DatasetTest, TestReadWriteGetPasses) {
+  UNPACK_TYPES;
+
+  string name = this->DBName();
+  shared_ptr<Dataset<string, value_type> > dataset =
+      DatasetFactory<string, value_type>(backend);
+  EXPECT_TRUE(dataset->open(name, Dataset<string, value_type>::New));
+
+  string key = this->TestKey();
+  value_type value = this->TestValue();
+
+  EXPECT_TRUE(dataset->put(key, value));
+
+  EXPECT_TRUE(dataset->commit());
+
+  value_type new_value;
+
+  EXPECT_TRUE(dataset->get(key, &new_value));
+
+  EXPECT_TRUE(this->equals(value, new_value));
+
+  dataset->close();
+}
+
+TYPED_TEST(DatasetTest, TestReadWriteGetNoCommitFails) {
+  UNPACK_TYPES;
+
+  string name = this->DBName();
+  shared_ptr<Dataset<string, value_type> > dataset =
+      DatasetFactory<string, value_type>(backend);
+  EXPECT_TRUE(dataset->open(name, Dataset<string, value_type>::New));
+
+  string key = this->TestKey();
+  value_type value = this->TestValue();
+
+  EXPECT_TRUE(dataset->put(key, value));
+
+  value_type new_value;
+
+  EXPECT_FALSE(dataset->get(key, &new_value));
+}
+
+TYPED_TEST(DatasetTest, TestReadOnlyPutFails) {
+  UNPACK_TYPES;
+
+  string name = this->DBName();
+  shared_ptr<Dataset<string, value_type> > dataset =
+      DatasetFactory<string, value_type>(backend);
+  EXPECT_TRUE(dataset->open(name, Dataset<string, value_type>::New));
+  dataset->close();
+
+  EXPECT_TRUE(dataset->open(name, Dataset<string, value_type>::ReadOnly));
+
+  string key = this->TestKey();
+  value_type value = this->TestValue();
+
+  EXPECT_FALSE(dataset->put(key, value));
+}
+
+TYPED_TEST(DatasetTest, TestReadOnlyCommitFails) {
+  UNPACK_TYPES;
+
+  string name = this->DBName();
+  shared_ptr<Dataset<string, value_type> > dataset =
+      DatasetFactory<string, value_type>(backend);
+  EXPECT_TRUE(dataset->open(name, Dataset<string, value_type>::New));
+  dataset->close();
+
+  EXPECT_TRUE(dataset->open(name, Dataset<string, value_type>::ReadOnly));
+
+  EXPECT_FALSE(dataset->commit());
+}
+
+TYPED_TEST(DatasetTest, TestReadOnlyGetPasses) {
+  UNPACK_TYPES;
+
+  string name = this->DBName();
+  shared_ptr<Dataset<string, value_type> > dataset =
+      DatasetFactory<string, value_type>(backend);
+  EXPECT_TRUE(dataset->open(name, Dataset<string, value_type>::New));
+
+  string key = this->TestKey();
+  value_type value = this->TestValue();
+
+  EXPECT_TRUE(dataset->put(key, value));
+
+  EXPECT_TRUE(dataset->commit());
+
+  dataset->close();
+
+  EXPECT_TRUE(dataset->open(name, Dataset<string, value_type>::ReadOnly));
+
+  value_type new_value;
+
+  EXPECT_TRUE(dataset->get(key, &new_value));
+
+  EXPECT_TRUE(this->equals(value, new_value));
+}
+
+TYPED_TEST(DatasetTest, TestReadOnlyGetNoCommitFails) {
+  UNPACK_TYPES;
+
+  string name = this->DBName();
+  shared_ptr<Dataset<string, value_type> > dataset =
+      DatasetFactory<string, value_type>(backend);
+  EXPECT_TRUE(dataset->open(name, Dataset<string, value_type>::New));
+
+  string key = this->TestKey();
+  value_type value = this->TestValue();
+
+  EXPECT_TRUE(dataset->put(key, value));
+
+  dataset->close();
+
+  EXPECT_TRUE(dataset->open(name, Dataset<string, value_type>::ReadOnly));
+
+  value_type new_value;
+
+  EXPECT_FALSE(dataset->get(key, &new_value));
+}
+
+#undef UNPACK_TYPES
+
+}  // namespace caffe
index f1a7967..a720f16 100644 (file)
@@ -4,11 +4,11 @@
 #include <algorithm>
 #include <string>
 
-#include "caffe/database_factory.hpp"
+#include "caffe/dataset_factory.hpp"
 #include "caffe/proto/caffe.pb.h"
 #include "caffe/util/io.hpp"
 
-using caffe::Database;
+using caffe::Dataset;
 using caffe::Datum;
 using caffe::BlobProto;
 using std::max;
@@ -26,16 +26,16 @@ int main(int argc, char** argv) {
     db_backend = std::string(argv[3]);
   }
 
-  caffe::shared_ptr<Database<std::string, Datum> > database =
-      caffe::DatabaseFactory<std::string, Datum>(db_backend);
+  caffe::shared_ptr<Dataset<std::string, Datum> > dataset =
+      caffe::DatasetFactory<std::string, Datum>(db_backend);
 
   // Open db
-  CHECK(database->open(argv[1], Database<std::string, Datum>::ReadOnly));
+  CHECK(dataset->open(argv[1], Dataset<std::string, Datum>::ReadOnly));
 
   BlobProto sum_blob;
   int count = 0;
   // load first datum
-  Database<std::string, Datum>::const_iterator iter = database->begin();
+  Dataset<std::string, Datum>::const_iterator iter = dataset->begin();
   const Datum& datum = iter->value;
 
   sum_blob.set_num(1);
@@ -49,8 +49,8 @@ int main(int argc, char** argv) {
     sum_blob.add_data(0.);
   }
   LOG(INFO) << "Starting Iteration";
-  for (Database<std::string, Datum>::const_iterator iter = database->begin();
-      iter != database->end(); ++iter) {
+  for (Dataset<std::string, Datum>::const_iterator iter = dataset->begin();
+      iter != dataset->end(); ++iter) {
     // just a dummy operation
     const Datum& datum = iter->value;
     const std::string& data = datum.data();
@@ -87,6 +87,6 @@ int main(int argc, char** argv) {
   WriteProtoToBinaryFile(sum_blob, argv[2]);
 
   // Clean up
-  database->close();
+  dataset->close();
   return 0;
 }
index 2ba3e3c..37efa5c 100644 (file)
@@ -17,7 +17,7 @@
 #include <utility>
 #include <vector>
 
-#include "caffe/database_factory.hpp"
+#include "caffe/dataset_factory.hpp"
 #include "caffe/proto/caffe.pb.h"
 #include "caffe/util/io.hpp"
 #include "caffe/util/rng.hpp"
@@ -78,11 +78,11 @@ int main(int argc, char** argv) {
   int resize_width = std::max<int>(0, FLAGS_resize_width);
 
   // Open new db
-  shared_ptr<Database<string, Datum> > database =
-      DatabaseFactory<string, Datum>(db_backend);
+  shared_ptr<Dataset<string, Datum> > dataset =
+      DatasetFactory<string, Datum>(db_backend);
 
   // Open db
-  CHECK(database->open(db_path, Database<string, Datum>::New));
+  CHECK(dataset->open(db_path, Dataset<string, Datum>::New));
 
   // Storing to db
   std::string root_folder(argv[1]);
@@ -113,19 +113,19 @@ int main(int argc, char** argv) {
         lines[line_id].first.c_str());
 
     // Put in db
-    CHECK(database->put(string(key_cstr, length), datum));
+    CHECK(dataset->put(string(key_cstr, length), datum));
 
     if (++count % 1000 == 0) {
       // Commit txn
-      CHECK(database->commit());
+      CHECK(dataset->commit());
       LOG(ERROR) << "Processed " << count << " files.";
     }
   }
   // write the last batch
   if (count % 1000 != 0) {
-    CHECK(database->commit());
+    CHECK(dataset->commit());
     LOG(ERROR) << "Processed " << count << " files.";
   }
-  database->close();
+  dataset->close();
   return 0;
 }
index 47565a8..ddbce10 100644 (file)
@@ -7,7 +7,7 @@
 
 #include "caffe/blob.hpp"
 #include "caffe/common.hpp"
-#include "caffe/database_factory.hpp"
+#include "caffe/dataset_factory.hpp"
 #include "caffe/net.hpp"
 #include "caffe/proto/caffe.pb.h"
 #include "caffe/util/io.hpp"
@@ -16,8 +16,8 @@
 using boost::shared_ptr;
 using caffe::Blob;
 using caffe::Caffe;
-using caffe::Database;
-using caffe::DatabaseFactory;
+using caffe::Dataset;
+using caffe::DatasetFactory;
 using caffe::Datum;
 using caffe::Net;
 
@@ -39,12 +39,12 @@ int feature_extraction_pipeline(int argc, char** argv) {
     " extract features of the input data produced by the net.\n"
     "Usage: extract_features  pretrained_net_param"
     "  feature_extraction_proto_file  extract_feature_blob_name1[,name2,...]"
-    "  save_feature_database_name1[,name2,...]  num_mini_batches  db_type"
+    "  save_feature_dataset_name1[,name2,...]  num_mini_batches  db_type"
     "  [CPU/GPU] [DEVICE_ID=0]\n"
     "Note: you can extract multiple features in one pass by specifying"
-    " multiple feature blob names and database names seperated by ','."
+    " multiple feature blob names and dataset names seperated by ','."
     " The names cannot contain white space characters and the number of blobs"
-    " and databases must be equal.";
+    " and datasets must be equal.";
     return 1;
   }
   int arg_pos = num_required_args;
@@ -105,12 +105,12 @@ int feature_extraction_pipeline(int argc, char** argv) {
   std::vector<std::string> blob_names;
   boost::split(blob_names, extract_feature_blob_names, boost::is_any_of(","));
 
-  std::string save_feature_database_names(argv[++arg_pos]);
-  std::vector<std::string> database_names;
-  boost::split(database_names, save_feature_database_names,
+  std::string save_feature_dataset_names(argv[++arg_pos]);
+  std::vector<std::string> dataset_names;
+  boost::split(dataset_names, save_feature_dataset_names,
                boost::is_any_of(","));
-  CHECK_EQ(blob_names.size(), database_names.size()) <<
-      " the number of blob names and database names must be equal";
+  CHECK_EQ(blob_names.size(), dataset_names.size()) <<
+      " the number of blob names and dataset names must be equal";
   size_t num_features = blob_names.size();
 
   for (size_t i = 0; i < num_features; i++) {
@@ -121,14 +121,13 @@ int feature_extraction_pipeline(int argc, char** argv) {
 
   int num_mini_batches = atoi(argv[++arg_pos]);
 
-  std::vector<shared_ptr<Database<std::string, Datum> > > feature_dbs;
+  std::vector<shared_ptr<Dataset<std::string, Datum> > > feature_dbs;
   for (size_t i = 0; i < num_features; ++i) {
-    LOG(INFO)<< "Opening database " << database_names[i];
-    shared_ptr<Database<std::string, Datum> > database =
-        DatabaseFactory<std::string, Datum>(argv[++arg_pos]);
-    CHECK(database->open(database_names.at(i),
-        Database<std::string, Datum>::New));
-    feature_dbs.push_back(database);
+    LOG(INFO)<< "Opening dataset " << dataset_names[i];
+    shared_ptr<Dataset<std::string, Datum> > dataset =
+        DatasetFactory<std::string, Datum>(argv[++arg_pos]);
+    CHECK(dataset->open(dataset_names.at(i), Dataset<std::string, Datum>::New));
+    feature_dbs.push_back(dataset);
   }
 
   LOG(ERROR)<< "Extacting Features";